1from typing import Optional, Callable, Tuple, Dict, List 2from collections import namedtuple 3import numpy as np 4import torch 5 6from ding.envs import BaseEnvManager 7from ding.torch_utils import to_tensor, to_ndarray, to_item 8from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY 9from ding.utils import get_world_size, get_rank, broadcast_object_list 10from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor 11 12 13@SERIAL_EVALUATOR_REGISTRY.register('interaction') 14class InteractionSerialEvaluator(ISerialEvaluator): 15 """ 16 Overview: 17 Interaction serial evaluator class, policy interacts with env. 18 Interfaces: 19 __init__, reset, reset_policy, reset_env, close, should_eval, eval 20 Property: 21 env, policy 22 """ 23 24 config = dict( 25 # (int) Evaluate every "eval_freq" training iterations. 26 eval_freq=1000, 27 render=dict( 28 # Tensorboard video render is disabled by default. 29 render_freq=-1, 30 mode='train_iter', 31 ), 32 # (str) File path for visualize environment information. 33 figure_path=None, 34 ) 35 36 def __init__( 37 self, 38 cfg: dict, 39 env: BaseEnvManager = None, 40 policy: namedtuple = None, 41 tb_logger: 'SummaryWriter' = None, # noqa 42 exp_name: Optional[str] = 'default_experiment', 43 instance_name: Optional[str] = 'evaluator', 44 ) -> None: 45 """ 46 Overview: 47 Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components, \ 48 e.g. logger helper, timer. 49 Arguments: 50 - cfg (:obj:`EasyDict`): Configuration EasyDict. 51 """ 52 self._cfg = cfg 53 self._exp_name = exp_name 54 self._instance_name = instance_name 55 56 # Logger (Monitor will be initialized in policy setter) 57 # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output. 58 if get_rank() == 0: 59 if tb_logger is not None: 60 self._logger, _ = build_logger( 61 './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False 62 ) 63 self._tb_logger = tb_logger 64 else: 65 self._logger, self._tb_logger = build_logger( 66 './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name 67 ) 68 else: 69 self._logger, self._tb_logger = None, None # for close elegantly 70 self.reset(policy, env) 71 72 self._timer = EasyTimer() 73 self._default_n_episode = cfg.n_episode 74 self._stop_value = cfg.stop_value 75 # only one freq 76 self._render = cfg.render 77 assert self._render.mode in ('envstep', 'train_iter'), 'mode should be envstep or train_iter' 78 79 def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: 80 """ 81 Overview: 82 Reset evaluator's environment. In some case, we need evaluator use the same policy in different \ 83 environments. We can use reset_env to reset the environment. 84 If _env is None, reset the old environment. 85 If _env is not None, replace the old environment in the evaluator with the \ 86 new passed in environment and launch. 87 Arguments: 88 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 89 env_manager(BaseEnvManager) 90 """ 91 if _env is not None: 92 self._env = _env 93 self._env.launch() 94 self._env_num = self._env.env_num 95 else: 96 self._env.reset() 97 98 def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: 99 """ 100 Overview: 101 Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\ 102 different policy. We can use reset_policy to reset the policy. 103 If _policy is None, reset the old policy. 104 If _policy is not None, replace the old policy in the evaluator with the new passed in policy. 105 Arguments: 106 - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy 107 """ 108 assert hasattr(self, '_env'), "please set env first" 109 if _policy is not None: 110 self._policy = _policy 111 self._policy_cfg = self._policy.get_attribute('cfg') 112 self._policy.reset() 113 114 def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: 115 """ 116 Overview: 117 Reset evaluator's policy and environment. Use new policy and environment to collect data. 118 If _env is None, reset the old environment. 119 If _env is not None, replace the old environment in the evaluator with the new passed in \ 120 environment and launch. 121 If _policy is None, reset the old policy. 122 If _policy is not None, replace the old policy in the evaluator with the new passed in policy. 123 Arguments: 124 - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy 125 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 126 env_manager(BaseEnvManager) 127 """ 128 if _env is not None: 129 self.reset_env(_env) 130 if _policy is not None: 131 self.reset_policy(_policy) 132 if self._policy_cfg.type == 'dreamer_command': 133 self._states = None 134 self._resets = np.array([False for i in range(self._env_num)]) 135 self._max_episode_return = float("-inf") 136 self._last_eval_iter = -1 137 self._end_flag = False 138 self._last_render_iter = -1 139 140 def close(self) -> None: 141 """ 142 Overview: 143 Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\ 144 and close the tb_logger. 145 """ 146 if self._end_flag: 147 return 148 self._end_flag = True 149 self._env.close() 150 if self._tb_logger: 151 self._tb_logger.flush() 152 self._tb_logger.close() 153 154 def __del__(self): 155 """ 156 Overview: 157 Execute the close command and close the evaluator. __del__ is automatically called \ 158 to destroy the evaluator instance when the evaluator finishes its work 159 """ 160 self.close() 161 162 def should_eval(self, train_iter: int) -> bool: 163 """ 164 Overview: 165 Determine whether you need to start the evaluation mode, if the number of training has reached\ 166 the maximum number of times to start the evaluator, return True 167 """ 168 if train_iter == self._last_eval_iter: 169 return False 170 if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0: 171 return False 172 self._last_eval_iter = train_iter 173 return True 174 175 def _should_render(self, envstep, train_iter): 176 if self._render.render_freq == -1: 177 return False 178 iter = envstep if self._render.mode == 'envstep' else train_iter 179 if (iter - self._last_render_iter) < self._render.render_freq: 180 return False 181 self._last_render_iter = iter 182 return True 183 184 def eval( 185 self, 186 save_ckpt_fn: Callable = None, 187 train_iter: int = -1, 188 envstep: int = -1, 189 n_episode: Optional[int] = None, 190 force_render: bool = False, 191 policy_kwargs: Optional[Dict] = {}, 192 ) -> Tuple[bool, Dict[str, List]]: 193 ''' 194 Overview: 195 Evaluate policy and store the best policy based on whether it reaches the highest historical reward. 196 Arguments: 197 - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. 198 - train_iter (:obj:`int`): Current training iteration. 199 - envstep (:obj:`int`): Current env interaction step. 200 - n_episode (:obj:`int`): Number of evaluation episodes. 201 Returns: 202 - stop_flag (:obj:`bool`): Whether this training program can be ended. 203 - episode_info (:obj:`Dict[str, List]`): Current evaluation episode information. 204 ''' 205 # evaluator only work on rank0 206 stop_flag = False 207 episode_info = None # Initialize to ensure it's defined in all ranks 208 209 if get_rank() == 0: 210 if n_episode is None: 211 n_episode = self._default_n_episode 212 assert n_episode is not None, "please indicate eval n_episode" 213 envstep_count = 0 214 info = {} 215 eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) 216 self._env.reset() 217 self._policy.reset() 218 219 # force_render overwrite frequency constraint 220 render = force_render or self._should_render(envstep, train_iter) 221 222 with self._timer: 223 while not eval_monitor.is_finished(): 224 obs = self._env.ready_obs 225 obs = to_tensor(obs, dtype=torch.float32) 226 227 # update videos 228 if render: 229 eval_monitor.update_video(self._env.ready_imgs) 230 231 if self._policy_cfg.type == 'dreamer_command': 232 policy_output = self._policy.forward( 233 obs, **policy_kwargs, reset=self._resets, state=self._states 234 ) 235 #self._states = {env_id: output['state'] for env_id, output in policy_output.items()} 236 self._states = [output['state'] for output in policy_output.values()] 237 else: 238 policy_output = self._policy.forward(obs, **policy_kwargs) 239 actions = {i: a['action'] for i, a in policy_output.items()} 240 actions = to_ndarray(actions) 241 timesteps = self._env.step(actions) 242 timesteps = to_tensor(timesteps, dtype=torch.float32) 243 for env_id, t in timesteps.items(): 244 if t.info.get('abnormal', False): 245 # If there is an abnormal timestep, reset all the related variables(including this env). 246 self._policy.reset([env_id]) 247 continue 248 if self._policy_cfg.type == 'dreamer_command': 249 self._resets[env_id] = t.done 250 if t.done: 251 # Env reset is done by env_manager automatically. 252 if 'figure_path' in self._cfg and self._cfg.figure_path is not None: 253 self._env.enable_save_figure(env_id, self._cfg.figure_path) 254 self._policy.reset([env_id]) 255 reward = t.info['eval_episode_return'] 256 saved_info = {'eval_episode_return': t.info['eval_episode_return']} 257 if 'episode_info' in t.info: 258 saved_info.update(t.info['episode_info']) 259 eval_monitor.update_info(env_id, saved_info) 260 eval_monitor.update_reward(env_id, reward) 261 self._logger.info( 262 "[EVALUATOR]env {} finish episode, final reward: {:.4f}, current episode: {}".format( 263 env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() 264 ) 265 ) 266 envstep_count += 1 267 duration = self._timer.value 268 episode_return = eval_monitor.get_episode_return() 269 info = { 270 'train_iter': train_iter, 271 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), 272 'episode_count': n_episode, 273 'envstep_count': envstep_count, 274 'avg_envstep_per_episode': envstep_count / n_episode, 275 'evaluate_time': duration, 276 'avg_envstep_per_sec': envstep_count / duration, 277 'avg_time_per_episode': n_episode / duration, 278 'reward_mean': np.mean(episode_return), 279 'reward_std': np.std(episode_return), 280 'reward_max': np.max(episode_return), 281 'reward_min': np.min(episode_return), 282 # 'each_reward': episode_return, 283 } 284 episode_info = eval_monitor.get_episode_info() 285 if episode_info is not None: 286 info.update(episode_info) 287 self._logger.info(self._logger.get_tabulate_vars_hor(info)) 288 # self._logger.info(self._logger.get_tabulate_vars(info)) 289 for k, v in info.items(): 290 if k in ['train_iter', 'ckpt_name', 'each_reward']: 291 continue 292 if not np.isscalar(v): 293 continue 294 self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) 295 self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) 296 297 if render: 298 video_title = '{}_{}/'.format(self._instance_name, self._render.mode) 299 videos = eval_monitor.get_video() 300 render_iter = envstep if self._render.mode == 'envstep' else train_iter 301 from ding.utils import fps 302 self._tb_logger.add_video(video_title, videos, render_iter, fps(self._env)) 303 304 episode_return = np.mean(episode_return) 305 if episode_return > self._max_episode_return: 306 if save_ckpt_fn: 307 save_ckpt_fn('ckpt_best.pth.tar') 308 self._max_episode_return = episode_return 309 stop_flag = episode_return >= self._stop_value and train_iter > 0 310 if stop_flag: 311 self._logger.info( 312 "[DI-engine serial pipeline] " + "Current episode_return: {:.4f} is greater than stop_value: {}". 313 format(episode_return, self._stop_value) + ", so your RL agent is converged, you can refer to " + 314 "'log/evaluator/evaluator_logger.txt' for details." 315 ) 316 317 if get_world_size() > 1: 318 objects = [stop_flag, episode_info] 319 broadcast_object_list(objects, src=0) 320 stop_flag, episode_info = objects 321 322 # Ensure episode_info is converted to the correct format 323 episode_info = to_item(episode_info) if episode_info is not None else {} 324 325 return stop_flag, episode_info