1from typing import List, Dict, Any, Optional, Callable, Tuple 2from collections import namedtuple, deque 3from easydict import EasyDict 4from functools import reduce 5import copy 6import numpy as np 7import torch 8 9from ding.utils import build_logger, EasyTimer, deep_merge_dicts, lists_to_dicts, dicts_to_lists, \ 10 SERIAL_EVALUATOR_REGISTRY 11from ding.envs import BaseEnvManager 12from ding.torch_utils import to_tensor, to_ndarray, tensor_to_list, to_item 13from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor 14 15 16@SERIAL_EVALUATOR_REGISTRY.register('battle_interaction') 17class BattleInteractionSerialEvaluator(ISerialEvaluator): 18 """ 19 Overview: 20 Multiple player battle evaluator class. 21 Interfaces: 22 __init__, reset, reset_policy, reset_env, close, should_eval, eval 23 Property: 24 env, policy 25 """ 26 27 @classmethod 28 def default_config(cls: type) -> EasyDict: 29 """ 30 Overview: 31 Get evaluator's default config. We merge evaluator's default config with other default configs\ 32 and user's config to get the final config. 33 Return: 34 cfg: (:obj:`EasyDict`): evaluator's default config 35 """ 36 cfg = EasyDict(copy.deepcopy(cls.config)) 37 cfg.cfg_type = cls.__name__ + 'Dict' 38 return cfg 39 40 config = dict( 41 # Evaluate every "eval_freq" training iterations. 42 eval_freq=50, 43 ) 44 45 def __init__( 46 self, 47 cfg: dict, 48 env: BaseEnvManager = None, 49 policy: List[namedtuple] = None, 50 tb_logger: 'SummaryWriter' = None, # noqa 51 exp_name: Optional[str] = 'default_experiment', 52 instance_name: Optional[str] = 'evaluator', 53 ) -> None: 54 """ 55 Overview: 56 Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components, 57 e.g. logger helper, timer. 58 Policy is not initialized here, but set afterwards through policy setter. 59 Arguments: 60 - cfg (:obj:`EasyDict`) 61 """ 62 self._cfg = cfg 63 self._exp_name = exp_name 64 self._instance_name = instance_name 65 if tb_logger is not None: 66 self._logger, _ = build_logger( 67 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False 68 ) 69 self._tb_logger = tb_logger 70 else: 71 self._logger, self._tb_logger = build_logger( 72 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name 73 ) 74 self.reset(policy, env) 75 76 self._timer = EasyTimer() 77 self._default_n_episode = cfg.n_episode 78 self._stop_value = cfg.stop_value 79 80 def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: 81 """ 82 Overview: 83 Reset evaluator's environment. In some case, we need evaluator use the same policy in different \ 84 environments. We can use reset_env to reset the environment. 85 If _env is None, reset the old environment. 86 If _env is not None, replace the old environment in the evaluator with the \ 87 new passed in environment and launch. 88 Arguments: 89 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 90 env_manager(BaseEnvManager) 91 """ 92 if _env is not None: 93 self._env = _env 94 self._env.launch() 95 self._env_num = self._env.env_num 96 else: 97 self._env.reset() 98 99 def reset_policy(self, _policy: Optional[List[namedtuple]] = None) -> None: 100 """ 101 Overview: 102 Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\ 103 different policy. We can use reset_policy to reset the policy. 104 If _policy is None, reset the old policy. 105 If _policy is not None, replace the old policy in the evaluator with the new passed in policy. 106 Arguments: 107 - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of eval_mode policy 108 """ 109 assert hasattr(self, '_env'), "please set env first" 110 if _policy is not None: 111 assert len(_policy) > 1, "battle evaluator needs more than 1 policy, but found {}".format(len(_policy)) 112 self._policy = _policy 113 self._policy_num = len(self._policy) 114 for p in self._policy: 115 p.reset() 116 117 def reset(self, _policy: Optional[List[namedtuple]] = None, _env: Optional[BaseEnvManager] = None) -> None: 118 """ 119 Overview: 120 Reset evaluator's policy and environment. Use new policy and environment to collect data. 121 If _env is None, reset the old environment. 122 If _env is not None, replace the old environment in the evaluator with the new passed in \ 123 environment and launch. 124 If _policy is None, reset the old policy. 125 If _policy is not None, replace the old policy in the evaluator with the new passed in policy. 126 Arguments: 127 - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of eval_mode policy 128 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 129 env_manager(BaseEnvManager) 130 """ 131 if _env is not None: 132 self.reset_env(_env) 133 if _policy is not None: 134 self.reset_policy(_policy) 135 self._max_episode_return = float("-inf") 136 self._last_eval_iter = 0 137 self._end_flag = False 138 139 def close(self) -> None: 140 """ 141 Overview: 142 Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\ 143 and close the tb_logger. 144 """ 145 if self._end_flag: 146 return 147 self._end_flag = True 148 self._env.close() 149 self._tb_logger.flush() 150 self._tb_logger.close() 151 152 def __del__(self): 153 """ 154 Overview: 155 Execute the close command and close the evaluator. __del__ is automatically called \ 156 to destroy the evaluator instance when the evaluator finishes its work 157 """ 158 self.close() 159 160 def should_eval(self, train_iter: int) -> bool: 161 """ 162 Overview: 163 Determine whether you need to start the evaluation mode, if the number of training has reached\ 164 the maximum number of times to start the evaluator, return True 165 """ 166 if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0: 167 return False 168 self._last_eval_iter = train_iter 169 return True 170 171 def eval( 172 self, 173 save_ckpt_fn: Callable = None, 174 train_iter: int = -1, 175 envstep: int = -1, 176 n_episode: Optional[int] = None 177 ) -> Tuple[bool, List[dict]]: 178 ''' 179 Overview: 180 Evaluate policy and store the best policy based on whether it reaches the highest historical reward. 181 Arguments: 182 - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. 183 - train_iter (:obj:`int`): Current training iteration. 184 - envstep (:obj:`int`): Current env interaction step. 185 - n_episode (:obj:`int`): Number of evaluation episodes. 186 Returns: 187 - stop_flag (:obj:`bool`): Whether this training program can be ended. 188 - return_info (:obj:`list`): Environment information of each finished episode. 189 ''' 190 if n_episode is None: 191 n_episode = self._default_n_episode 192 assert n_episode is not None, "please indicate eval n_episode" 193 envstep_count = 0 194 info = {} 195 # TODO replace return_info with episode_info (validated by the league demo case) 196 return_info = [[] for _ in range(self._policy_num)] 197 eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) 198 self._env.reset() 199 for p in self._policy: 200 p.reset() 201 202 with self._timer: 203 while not eval_monitor.is_finished(): 204 obs = self._env.ready_obs 205 ready_env_id = obs.keys() 206 obs = to_tensor(obs, dtype=torch.float32) 207 obs = dicts_to_lists(obs) 208 policy_output = [p.forward(obs[i]) for i, p in enumerate(self._policy)] 209 actions = {} 210 for env_id in ready_env_id: 211 actions[env_id] = [] 212 for output in policy_output: 213 actions[env_id].append(output[env_id]['action']) 214 actions = to_ndarray(actions) 215 timesteps = self._env.step(actions) 216 timesteps = to_tensor(timesteps, dtype=torch.float32) 217 for env_id, t in timesteps.items(): 218 if t.done: 219 # Env reset is done by env_manager automatically. 220 for p in self._policy: 221 p.reset([env_id]) 222 # policy0 is regarded as main policy default 223 reward = t.info[0]['eval_episode_return'] 224 if 'episode_info' in t.info[0]: 225 eval_monitor.update_info(env_id, t.info[0]['episode_info']) 226 eval_monitor.update_reward(env_id, reward) 227 for policy_id in range(self._policy_num): 228 return_info[policy_id].append(t.info[policy_id]) 229 self._logger.info( 230 "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( 231 env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() 232 ) 233 ) 234 envstep_count += 1 235 duration = self._timer.value 236 episode_return = eval_monitor.get_episode_return() 237 info = { 238 'train_iter': train_iter, 239 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), 240 'episode_count': n_episode, 241 'envstep_count': envstep_count, 242 'avg_envstep_per_episode': envstep_count / n_episode, 243 'evaluate_time': duration, 244 'avg_envstep_per_sec': envstep_count / duration, 245 'avg_time_per_episode': n_episode / duration, 246 'reward_mean': np.mean(episode_return), 247 'reward_std': np.std(episode_return), 248 'reward_max': np.max(episode_return), 249 'reward_min': np.min(episode_return), 250 # 'each_reward': episode_return, 251 } 252 episode_info = eval_monitor.get_episode_info() 253 if episode_info is not None: 254 info.update(episode_info) 255 self._logger.info(self._logger.get_tabulate_vars_hor(info)) 256 # self._logger.info(self._logger.get_tabulate_vars(info)) 257 for k, v in info.items(): 258 if k in ['train_iter', 'ckpt_name', 'each_reward']: 259 continue 260 if not np.isscalar(v): 261 continue 262 self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) 263 self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) 264 episode_return = np.mean(episode_return) 265 if episode_return > self._max_episode_return: 266 if save_ckpt_fn: 267 save_ckpt_fn('ckpt_best.pth.tar') 268 self._max_episode_return = episode_return 269 stop_flag = episode_return >= self._stop_value and train_iter > 0 270 if stop_flag: 271 self._logger.info( 272 "[DI-engine serial pipeline] " + 273 "Current episode_return: {} is greater than stop_value: {}".format(episode_return, self._stop_value) + 274 ", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." 275 ) 276 return_info = to_item(return_info) 277 return stop_flag, return_info