Skip to content

ding.worker.collector.metric_serial_evaluator

ding.worker.collector.metric_serial_evaluator

IMetric

Bases: ABC

gt(metric1, metric2) abstractmethod

Overview

Whether metric1 is greater than metric2 (>=)

.. note:: If metric2 is None, return True

MetricSerialEvaluator

Bases: ISerialEvaluator

Overview

Metric serial evaluator class, policy is evaluated by objective metric(env).

Interfaces: init, reset, reset_policy, reset_env, close, should_eval, eval Property: env, policy

__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='evaluator')

Overview

Init method. Load config and use self._cfg setting to build common serial evaluator components, e.g. logger helper, timer.

Arguments: - cfg (:obj:EasyDict): Configuration EasyDict.

reset_env(_env=None)

Overview

Reset evaluator's environment. In some case, we need evaluator use the same policy in different environments. We can use reset_env to reset the environment. If _env is not None, replace the old environment in the evaluator with the new one

Arguments: - env (:obj:Optional[Tuple[DataLoader, IMetric]]): Instance of the DataLoader and Metric

reset_policy(_policy=None)

Overview

Reset evaluator's policy. In some case, we need evaluator work in this same environment but use different policy. We can use reset_policy to reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy.

Arguments: - policy (:obj:Optional[namedtuple]): the api namedtuple of eval_mode policy

reset(_policy=None, _env=None)

Overview

Reset evaluator's policy and environment. Use new policy and environment to collect data. If _env is not None, replace the old environment in the evaluator with the new one If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy.

Arguments: - policy (:obj:Optional[namedtuple]): the api namedtuple of eval_mode policy - env (:obj:Optional[Tuple[DataLoader, IMetric]]): Instance of the DataLoader and Metric

close()

Overview

Close the evaluator. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.

__del__()

Overview

Execute the close command and close the evaluator. del is automatically called to destroy the evaluator instance when the evaluator finishes its work

should_eval(train_iter)

Overview

Determine whether you need to start the evaluation mode, if the number of training has reached the maximum number of times to start the evaluator, return True

eval(save_ckpt_fn=None, train_iter=-1, envstep=-1)

Overview

Evaluate policy and store the best policy based on whether it reaches the highest historical reward.

Arguments: - save_ckpt_fn (:obj:Callable): Saving ckpt function, which will be triggered by getting the best reward. - train_iter (:obj:int): Current training iteration. - envstep (:obj:int): Current env interaction step. Returns: - stop_flag (:obj:bool): Whether this training program can be ended. - eval_metric (:obj:float): Current evaluation metric result.

Full Source Code

../ding/worker/collector/metric_serial_evaluator.py

1from typing import Optional, Callable, Tuple, Any, List 2from abc import ABC, abstractmethod 3from collections import namedtuple 4import numpy as np 5import torch 6from torch.utils.data import DataLoader 7 8from ding.torch_utils import to_tensor, to_ndarray 9from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY, allreduce 10from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor 11 12 13class IMetric(ABC): 14 15 @abstractmethod 16 def eval(self, inputs: Any, label: Any) -> dict: 17 raise NotImplementedError 18 19 @abstractmethod 20 def reduce_mean(self, inputs: List[Any]) -> Any: 21 raise NotImplementedError 22 23 @abstractmethod 24 def gt(self, metric1: Any, metric2: Any) -> bool: 25 """ 26 Overview: 27 Whether metric1 is greater than metric2 (>=) 28 29 .. note:: 30 If metric2 is None, return True 31 """ 32 raise NotImplementedError 33 34 35@SERIAL_EVALUATOR_REGISTRY.register('metric') 36class MetricSerialEvaluator(ISerialEvaluator): 37 """ 38 Overview: 39 Metric serial evaluator class, policy is evaluated by objective metric(env). 40 Interfaces: 41 __init__, reset, reset_policy, reset_env, close, should_eval, eval 42 Property: 43 env, policy 44 """ 45 46 config = dict( 47 # Evaluate every "eval_freq" training iterations. 48 eval_freq=50, 49 ) 50 51 def __init__( 52 self, 53 cfg: dict, 54 env: Tuple[DataLoader, IMetric] = None, 55 policy: namedtuple = None, 56 tb_logger: 'SummaryWriter' = None, # noqa 57 exp_name: Optional[str] = 'default_experiment', 58 instance_name: Optional[str] = 'evaluator', 59 ) -> None: 60 """ 61 Overview: 62 Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components, 63 e.g. logger helper, timer. 64 Arguments: 65 - cfg (:obj:`EasyDict`): Configuration EasyDict. 66 """ 67 self._cfg = cfg 68 self._exp_name = exp_name 69 self._instance_name = instance_name 70 if tb_logger is not None: 71 self._logger, _ = build_logger( 72 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False 73 ) 74 self._tb_logger = tb_logger 75 else: 76 self._logger, self._tb_logger = build_logger( 77 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name 78 ) 79 self.reset(policy, env) 80 81 self._timer = EasyTimer() 82 self._stop_value = cfg.stop_value 83 84 def reset_env(self, _env: Optional[Tuple[DataLoader, IMetric]] = None) -> None: 85 """ 86 Overview: 87 Reset evaluator's environment. In some case, we need evaluator use the same policy in different \ 88 environments. We can use reset_env to reset the environment. 89 If _env is not None, replace the old environment in the evaluator with the new one 90 Arguments: 91 - env (:obj:`Optional[Tuple[DataLoader, IMetric]]`): Instance of the DataLoader and Metric 92 """ 93 if _env is not None: 94 self._dataloader, self._metric = _env 95 96 def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: 97 """ 98 Overview: 99 Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\ 100 different policy. We can use reset_policy to reset the policy. 101 If _policy is None, reset the old policy. 102 If _policy is not None, replace the old policy in the evaluator with the new passed in policy. 103 Arguments: 104 - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy 105 """ 106 if _policy is not None: 107 self._policy = _policy 108 self._policy.reset() 109 110 def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[Tuple[DataLoader, IMetric]] = None) -> None: 111 """ 112 Overview: 113 Reset evaluator's policy and environment. Use new policy and environment to collect data. 114 If _env is not None, replace the old environment in the evaluator with the new one 115 If _policy is None, reset the old policy. 116 If _policy is not None, replace the old policy in the evaluator with the new passed in policy. 117 Arguments: 118 - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy 119 - env (:obj:`Optional[Tuple[DataLoader, IMetric]]`): Instance of the DataLoader and Metric 120 """ 121 if _env is not None: 122 self.reset_env(_env) 123 if _policy is not None: 124 self.reset_policy(_policy) 125 self._max_avg_eval_result = None 126 self._last_eval_iter = -1 127 self._end_flag = False 128 129 def close(self) -> None: 130 """ 131 Overview: 132 Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\ 133 and close the tb_logger. 134 """ 135 if self._end_flag: 136 return 137 self._end_flag = True 138 self._tb_logger.flush() 139 self._tb_logger.close() 140 141 def __del__(self): 142 """ 143 Overview: 144 Execute the close command and close the evaluator. __del__ is automatically called \ 145 to destroy the evaluator instance when the evaluator finishes its work 146 """ 147 self.close() 148 149 def should_eval(self, train_iter: int) -> bool: 150 """ 151 Overview: 152 Determine whether you need to start the evaluation mode, if the number of training has reached\ 153 the maximum number of times to start the evaluator, return True 154 """ 155 if train_iter == self._last_eval_iter: 156 return False 157 if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0: 158 return False 159 self._last_eval_iter = train_iter 160 return True 161 162 def eval( 163 self, 164 save_ckpt_fn: Callable = None, 165 train_iter: int = -1, 166 envstep: int = -1, 167 ) -> Tuple[bool, Any]: 168 ''' 169 Overview: 170 Evaluate policy and store the best policy based on whether it reaches the highest historical reward. 171 Arguments: 172 - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. 173 - train_iter (:obj:`int`): Current training iteration. 174 - envstep (:obj:`int`): Current env interaction step. 175 Returns: 176 - stop_flag (:obj:`bool`): Whether this training program can be ended. 177 - eval_metric (:obj:`float`): Current evaluation metric result. 178 ''' 179 self._policy.reset() 180 eval_results = [] 181 182 with self._timer: 183 self._logger.info("Evaluation begin...") 184 for batch_idx, batch_data in enumerate(self._dataloader): 185 inputs, label = to_tensor(batch_data) 186 policy_output = self._policy.forward(inputs) 187 eval_results.append(self._metric.eval(policy_output, label)) 188 avg_eval_result = self._metric.reduce_mean(eval_results) 189 if self._cfg.multi_gpu: 190 device = self._policy.get_attribute('device') 191 for k in avg_eval_result.keys(): 192 value_tensor = torch.FloatTensor([avg_eval_result[k]]).to(device) 193 allreduce(value_tensor) 194 avg_eval_result[k] = value_tensor.item() 195 196 duration = self._timer.value 197 info = { 198 'train_iter': train_iter, 199 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), 200 'data_length': len(self._dataloader), 201 'evaluate_time': duration, 202 'avg_time_per_data': duration / len(self._dataloader), 203 } 204 info.update(avg_eval_result) 205 self._logger.info(self._logger.get_tabulate_vars_hor(info)) 206 # self._logger.info(self._logger.get_tabulate_vars(info)) 207 for k, v in info.items(): 208 if k in ['train_iter', 'ckpt_name']: 209 continue 210 if not np.isscalar(v): 211 continue 212 self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) 213 self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) 214 if self._metric.gt(avg_eval_result, self._max_avg_eval_result): 215 if save_ckpt_fn: 216 save_ckpt_fn('ckpt_best.pth.tar') 217 self._max_avg_eval_result = avg_eval_result 218 stop_flag = self._metric.gt(avg_eval_result, self._stop_value) and train_iter > 0 219 if stop_flag: 220 self._logger.info( 221 "[DI-engine serial pipeline] " + 222 "Current episode_return: {} is greater than stop_value: {}".format(avg_eval_result, self._stop_value) + 223 ", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." 224 ) 225 return stop_flag, avg_eval_result