1from typing import Optional, Any, List, Tuple 2from collections import namedtuple, deque 3from easydict import EasyDict 4import numpy as np 5import torch 6 7from ding.envs import BaseEnvManager 8from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, dicts_to_lists 9from ding.torch_utils import to_tensor, to_ndarray 10from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions 11 12 13@SERIAL_COLLECTOR_REGISTRY.register('episode_1v1') 14class BattleEpisodeSerialCollector(ISerialCollector): 15 """ 16 Overview: 17 Episode collector(n_episode) with two policy battle 18 Interfaces: 19 __init__, reset, reset_env, reset_policy, collect, close 20 Property: 21 envstep 22 """ 23 24 config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False) 25 26 def __init__( 27 self, 28 cfg: EasyDict, 29 env: BaseEnvManager = None, 30 policy: List[namedtuple] = None, 31 tb_logger: 'SummaryWriter' = None, # noqa 32 exp_name: Optional[str] = 'default_experiment', 33 instance_name: Optional[str] = 'collector' 34 ) -> None: 35 """ 36 Overview: 37 Initialization method. 38 Arguments: 39 - cfg (:obj:`EasyDict`): Config dict 40 - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) 41 - policy (:obj:`List[namedtuple]`): the api namedtuple of collect_mode policy 42 - tb_logger (:obj:`SummaryWriter`): tensorboard handle 43 """ 44 self._exp_name = exp_name 45 self._instance_name = instance_name 46 self._collect_print_freq = cfg.collect_print_freq 47 self._deepcopy_obs = cfg.deepcopy_obs 48 self._transform_obs = cfg.transform_obs 49 self._cfg = cfg 50 self._timer = EasyTimer() 51 self._end_flag = False 52 53 if tb_logger is not None: 54 self._logger, _ = build_logger( 55 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False 56 ) 57 self._tb_logger = tb_logger 58 else: 59 self._logger, self._tb_logger = build_logger( 60 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name 61 ) 62 self._traj_len = float("inf") 63 self.reset(policy, env) 64 65 def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: 66 """ 67 Overview: 68 Reset the environment. 69 If _env is None, reset the old environment. 70 If _env is not None, replace the old environment in the collector with the new passed \ 71 in environment and launch. 72 Arguments: 73 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 74 env_manager(BaseEnvManager) 75 """ 76 if _env is not None: 77 self._env = _env 78 self._env.launch() 79 self._env_num = self._env.env_num 80 else: 81 self._env.reset() 82 83 def reset_policy(self, _policy: Optional[List[namedtuple]] = None) -> None: 84 """ 85 Overview: 86 Reset the policy. 87 If _policy is None, reset the old policy. 88 If _policy is not None, replace the old policy in the collector with the new passed in policy. 89 Arguments: 90 - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy 91 """ 92 assert hasattr(self, '_env'), "please set env first" 93 if _policy is not None: 94 assert len(_policy) == 2, "1v1 episode collector needs 2 policy, but found {}".format(len(_policy)) 95 self._policy = _policy 96 self._default_n_episode = _policy[0].get_attribute('cfg').collect.get('n_episode', None) 97 self._unroll_len = _policy[0].get_attribute('unroll_len') 98 self._on_policy = _policy[0].get_attribute('cfg').on_policy 99 self._traj_len = INF 100 self._logger.debug( 101 'Set default n_episode mode(n_episode({}), env_num({}), traj_len({}))'.format( 102 self._default_n_episode, self._env_num, self._traj_len 103 ) 104 ) 105 for p in self._policy: 106 p.reset() 107 108 def reset(self, _policy: Optional[List[namedtuple]] = None, _env: Optional[BaseEnvManager] = None) -> None: 109 """ 110 Overview: 111 Reset the environment and policy. 112 If _env is None, reset the old environment. 113 If _env is not None, replace the old environment in the collector with the new passed \ 114 in environment and launch. 115 If _policy is None, reset the old policy. 116 If _policy is not None, replace the old policy in the collector with the new passed in policy. 117 Arguments: 118 - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy 119 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 120 env_manager(BaseEnvManager) 121 """ 122 if _env is not None: 123 self.reset_env(_env) 124 if _policy is not None: 125 self.reset_policy(_policy) 126 127 self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs) 128 self._policy_output_pool = CachePool('policy_output', self._env_num) 129 # _traj_buffer is {env_id: {policy_id: TrajBuffer}}, is used to store traj_len pieces of transitions 130 self._traj_buffer = { 131 env_id: {policy_id: TrajBuffer(maxlen=self._traj_len) 132 for policy_id in range(2)} 133 for env_id in range(self._env_num) 134 } 135 self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} 136 137 self._episode_info = [] 138 self._total_envstep_count = 0 139 self._total_episode_count = 0 140 self._total_duration = 0 141 self._last_train_iter = 0 142 self._end_flag = False 143 144 def _reset_stat(self, env_id: int) -> None: 145 """ 146 Overview: 147 Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ 148 and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ 149 to get more messages. 150 Arguments: 151 - env_id (:obj:`int`): the id where we need to reset the collector's state 152 """ 153 for i in range(2): 154 self._traj_buffer[env_id][i].clear() 155 self._obs_pool.reset(env_id) 156 self._policy_output_pool.reset(env_id) 157 self._env_info[env_id] = {'time': 0., 'step': 0} 158 159 @property 160 def envstep(self) -> int: 161 """ 162 Overview: 163 Print the total envstep count. 164 Return: 165 - envstep (:obj:`int`): The total envstep count. 166 """ 167 return self._total_envstep_count 168 169 @envstep.setter 170 def envstep(self, value: int) -> None: 171 """ 172 Overview: 173 Set the total envstep count. 174 Arguments: 175 - value (:obj:`int`): The total envstep count. 176 """ 177 self._total_envstep_count = value 178 179 def close(self) -> None: 180 """ 181 Overview: 182 Close the collector. If end_flag is False, close the environment, flush the tb_logger\ 183 and close the tb_logger. 184 """ 185 if self._end_flag: 186 return 187 self._end_flag = True 188 self._env.close() 189 self._tb_logger.flush() 190 self._tb_logger.close() 191 192 def __del__(self) -> None: 193 """ 194 Overview: 195 Execute the close command and close the collector. __del__ is automatically called to \ 196 destroy the collector instance when the collector finishes its work 197 """ 198 self.close() 199 200 def collect(self, 201 n_episode: Optional[int] = None, 202 train_iter: int = 0, 203 policy_kwargs: Optional[dict] = None) -> Tuple[List[Any], List[Any]]: 204 """ 205 Overview: 206 Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations 207 Arguments: 208 - n_episode (:obj:`int`): the number of collecting data episode 209 - train_iter (:obj:`int`): the number of training iteration 210 - policy_kwargs (:obj:`dict`): the keyword args for policy forward 211 Returns: 212 - return_data (:obj:`Tuple[List, List]`): A tuple with training sample(data) and episode info, \ 213 the former is a list containing collected episodes if not get_train_sample, \ 214 otherwise, return train_samples split by unroll_len. 215 """ 216 if n_episode is None: 217 if self._default_n_episode is None: 218 raise RuntimeError("Please specify collect n_episode") 219 else: 220 n_episode = self._default_n_episode 221 assert n_episode >= self._env_num, "Please make sure n_episode >= env_num" 222 if policy_kwargs is None: 223 policy_kwargs = {} 224 collected_episode = 0 225 return_data = [[] for _ in range(2)] 226 return_info = [[] for _ in range(2)] 227 ready_env_id = set() 228 remain_episode = n_episode 229 230 while True: 231 with self._timer: 232 # Get current env obs. 233 obs = self._env.ready_obs 234 new_available_env_id = set(obs.keys()).difference(ready_env_id) 235 ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) 236 remain_episode -= min(len(new_available_env_id), remain_episode) 237 obs = {env_id: obs[env_id] for env_id in ready_env_id} 238 # Policy forward. 239 self._obs_pool.update(obs) 240 if self._transform_obs: 241 obs = to_tensor(obs, dtype=torch.float32) 242 obs = dicts_to_lists(obs) 243 policy_output = [p.forward(obs[i], **policy_kwargs) for i, p in enumerate(self._policy)] 244 self._policy_output_pool.update(policy_output) 245 # Interact with env. 246 actions = {} 247 for env_id in ready_env_id: 248 actions[env_id] = [] 249 for output in policy_output: 250 actions[env_id].append(output[env_id]['action']) 251 actions = to_ndarray(actions) 252 timesteps = self._env.step(actions) 253 254 # TODO(nyz) this duration may be inaccurate in async env 255 interaction_duration = self._timer.value / len(timesteps) 256 257 # TODO(nyz) vectorize this for loop 258 for env_id, timestep in timesteps.items(): 259 self._env_info[env_id]['step'] += 1 260 self._total_envstep_count += 1 261 with self._timer: 262 for policy_id, policy in enumerate(self._policy): 263 policy_timestep_data = [d[policy_id] if not isinstance(d, bool) else d for d in timestep] 264 policy_timestep = type(timestep)(*policy_timestep_data) 265 transition = self._policy[policy_id].process_transition( 266 self._obs_pool[env_id][policy_id], self._policy_output_pool[env_id][policy_id], 267 policy_timestep 268 ) 269 transition['collect_iter'] = train_iter 270 self._traj_buffer[env_id][policy_id].append(transition) 271 # prepare data 272 if timestep.done: 273 transitions = to_tensor_transitions( 274 self._traj_buffer[env_id][policy_id], not self._deepcopy_obs 275 ) 276 if self._cfg.get_train_sample: 277 train_sample = self._policy[policy_id].get_train_sample(transitions) 278 return_data[policy_id].extend(train_sample) 279 else: 280 return_data[policy_id].append(transitions) 281 self._traj_buffer[env_id][policy_id].clear() 282 283 self._env_info[env_id]['time'] += self._timer.value + interaction_duration 284 285 # If env is done, record episode info and reset 286 if timestep.done: 287 self._total_episode_count += 1 288 info = { 289 'reward0': timestep.info[0]['eval_episode_return'], 290 'reward1': timestep.info[1]['eval_episode_return'], 291 'time': self._env_info[env_id]['time'], 292 'step': self._env_info[env_id]['step'], 293 } 294 collected_episode += 1 295 self._episode_info.append(info) 296 for i, p in enumerate(self._policy): 297 p.reset([env_id]) 298 self._reset_stat(env_id) 299 ready_env_id.remove(env_id) 300 for policy_id in range(2): 301 return_info[policy_id].append(timestep.info[policy_id]) 302 if collected_episode >= n_episode: 303 break 304 # log 305 self._output_log(train_iter) 306 return return_data, return_info 307 308 def _output_log(self, train_iter: int) -> None: 309 """ 310 Overview: 311 Print the output log information. You can refer to Docs/Best Practice/How to understand\ 312 training generated folders/Serial mode/log/collector for more details. 313 Arguments: 314 - train_iter (:obj:`int`): the number of training iteration. 315 """ 316 if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: 317 self._last_train_iter = train_iter 318 episode_count = len(self._episode_info) 319 envstep_count = sum([d['step'] for d in self._episode_info]) 320 duration = sum([d['time'] for d in self._episode_info]) 321 episode_return0 = [d['reward0'] for d in self._episode_info] 322 episode_return1 = [d['reward1'] for d in self._episode_info] 323 self._total_duration += duration 324 info = { 325 'episode_count': episode_count, 326 'envstep_count': envstep_count, 327 'avg_envstep_per_episode': envstep_count / episode_count, 328 'avg_envstep_per_sec': envstep_count / duration, 329 'avg_episode_per_sec': episode_count / duration, 330 'collect_time': duration, 331 'reward0_mean': np.mean(episode_return0), 332 'reward0_std': np.std(episode_return0), 333 'reward0_max': np.max(episode_return0), 334 'reward0_min': np.min(episode_return0), 335 'reward1_mean': np.mean(episode_return1), 336 'reward1_std': np.std(episode_return1), 337 'reward1_max': np.max(episode_return1), 338 'reward1_min': np.min(episode_return1), 339 'total_envstep_count': self._total_envstep_count, 340 'total_episode_count': self._total_episode_count, 341 'total_duration': self._total_duration, 342 } 343 self._episode_info.clear() 344 self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) 345 for k, v in info.items(): 346 self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) 347 if k in ['total_envstep_count']: 348 continue 349 self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)