1from typing import Optional, Any, List 2from collections import namedtuple 3from easydict import EasyDict 4import copy 5import numpy as np 6import torch 7 8from ding.envs import BaseEnvManager 9from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \ 10 allreduce_data 11from ding.torch_utils import to_tensor, to_ndarray 12from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions 13 14 15@SERIAL_COLLECTOR_REGISTRY.register('sample') 16class SampleSerialCollector(ISerialCollector): 17 """ 18 Overview: 19 Sample collector(n_sample), a sample is one training sample for updating model, 20 it is usually like <s, a, s', r, d>(one transition) 21 while is a trajectory with many transitions, which is often used in RNN-model. 22 Interfaces: 23 __init__, reset, reset_env, reset_policy, collect, close 24 Property: 25 envstep 26 """ 27 28 config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100) 29 30 def __init__( 31 self, 32 cfg: EasyDict, 33 env: BaseEnvManager = None, 34 policy: namedtuple = None, 35 tb_logger: 'SummaryWriter' = None, # noqa 36 exp_name: Optional[str] = 'default_experiment', 37 instance_name: Optional[str] = 'collector' 38 ) -> None: 39 """ 40 Overview: 41 Initialization method. 42 Arguments: 43 - cfg (:obj:`EasyDict`): Config dict 44 - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) 45 - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy 46 - tb_logger (:obj:`SummaryWriter`): tensorboard handle 47 """ 48 self._exp_name = exp_name 49 self._instance_name = instance_name 50 self._collect_print_freq = cfg.collect_print_freq 51 self._deepcopy_obs = cfg.deepcopy_obs # whether to deepcopy each data 52 self._transform_obs = cfg.transform_obs 53 self._cfg = cfg 54 self._timer = EasyTimer() 55 self._end_flag = False 56 self._rank = get_rank() 57 self._world_size = get_world_size() 58 59 if self._rank == 0: 60 if tb_logger is not None: 61 self._logger, _ = build_logger( 62 path='./{}/log/{}'.format(self._exp_name, self._instance_name), 63 name=self._instance_name, 64 need_tb=False 65 ) 66 self._tb_logger = tb_logger 67 else: 68 self._logger, self._tb_logger = build_logger( 69 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name 70 ) 71 else: 72 self._logger, _ = build_logger( 73 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False 74 ) 75 self._tb_logger = None 76 77 self.reset(policy, env) 78 79 def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: 80 """ 81 Overview: 82 Reset the environment. 83 If _env is None, reset the old environment. 84 If _env is not None, replace the old environment in the collector with the new passed \ 85 in environment and launch. 86 Arguments: 87 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 88 env_manager(BaseEnvManager) 89 """ 90 if _env is not None: 91 self._env = _env 92 self._env.launch() 93 self._env_num = self._env.env_num 94 else: 95 self._env.reset() 96 97 def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: 98 """ 99 Overview: 100 Reset the policy. 101 If _policy is None, reset the old policy. 102 If _policy is not None, replace the old policy in the collector with the new passed in policy. 103 Arguments: 104 - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy 105 """ 106 assert hasattr(self, '_env'), "please set env first" 107 if _policy is not None: 108 self._policy = _policy 109 self._policy_cfg = self._policy.get_attribute('cfg') 110 self._default_n_sample = _policy.get_attribute('n_sample') 111 self._traj_len_inf = self._policy_cfg.traj_len_inf 112 self._unroll_len = _policy.get_attribute('unroll_len') 113 self._on_policy = _policy.get_attribute('on_policy') 114 if self._default_n_sample is not None and not self._traj_len_inf: 115 self._traj_len = max( 116 self._unroll_len, 117 self._default_n_sample // self._env_num + int(self._default_n_sample % self._env_num != 0) 118 ) 119 self._logger.debug( 120 'Set default n_sample mode(n_sample({}), env_num({}), traj_len({}))'.format( 121 self._default_n_sample, self._env_num, self._traj_len 122 ) 123 ) 124 else: 125 self._traj_len = INF 126 self._policy.reset() 127 128 def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: 129 """ 130 Overview: 131 Reset the environment and policy. 132 If _env is None, reset the old environment. 133 If _env is not None, replace the old environment in the collector with the new passed \ 134 in environment and launch. 135 If _policy is None, reset the old policy. 136 If _policy is not None, replace the old policy in the collector with the new passed in policy. 137 Arguments: 138 - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy 139 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 140 env_manager(BaseEnvManager) 141 """ 142 if _env is not None: 143 self.reset_env(_env) 144 if _policy is not None: 145 self.reset_policy(_policy) 146 147 if self._policy_cfg.type == 'dreamer_command': 148 self._states = None 149 self._resets = np.array([False for i in range(self._env_num)]) 150 self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs) 151 self._policy_output_pool = CachePool('policy_output', self._env_num) 152 # _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions 153 maxlen = self._traj_len if self._traj_len != INF else None 154 self._traj_buffer = { 155 env_id: TrajBuffer(maxlen=maxlen, deepcopy=self._deepcopy_obs) 156 for env_id in range(self._env_num) 157 } 158 self._env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(self._env_num)} 159 160 self._episode_info = [] 161 self._total_envstep_count = 0 162 self._total_episode_count = 0 163 self._total_train_sample_count = 0 164 self._total_duration = 0 165 self._last_train_iter = 0 166 self._end_flag = False 167 168 def _reset_stat(self, env_id: int) -> None: 169 """ 170 Overview: 171 Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ 172 and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ 173 to get more messages. 174 Arguments: 175 - env_id (:obj:`int`): the id where we need to reset the collector's state 176 """ 177 self._traj_buffer[env_id].clear() 178 self._obs_pool.reset(env_id) 179 self._policy_output_pool.reset(env_id) 180 self._env_info[env_id] = {'time': 0., 'step': 0, 'train_sample': 0} 181 182 @property 183 def envstep(self) -> int: 184 """ 185 Overview: 186 Print the total envstep count. 187 Return: 188 - envstep (:obj:`int`): The total envstep count. 189 """ 190 return self._total_envstep_count 191 192 @envstep.setter 193 def envstep(self, value: int) -> None: 194 """ 195 Overview: 196 Set the total envstep count. 197 Arguments: 198 - value (:obj:`int`): The total envstep count. 199 """ 200 self._total_envstep_count = value 201 202 def close(self) -> None: 203 """ 204 Overview: 205 Close the collector. If end_flag is False, close the environment, flush the tb_logger\ 206 and close the tb_logger. 207 """ 208 if self._end_flag: 209 return 210 self._end_flag = True 211 self._env.close() 212 if self._tb_logger: 213 self._tb_logger.flush() 214 self._tb_logger.close() 215 216 def __del__(self) -> None: 217 """ 218 Overview: 219 Execute the close command and close the collector. __del__ is automatically called to \ 220 destroy the collector instance when the collector finishes its work 221 """ 222 self.close() 223 224 def collect( 225 self, 226 n_sample: Optional[int] = None, 227 train_iter: int = 0, 228 drop_extra: bool = True, 229 random_collect: bool = False, 230 record_random_collect: bool = True, 231 policy_kwargs: Optional[dict] = None, 232 level_seeds: Optional[List] = None, 233 ) -> List[Any]: 234 """ 235 Overview: 236 Collect `n_sample` data with policy_kwargs, which is already trained `train_iter` iterations. 237 Arguments: 238 - n_sample (:obj:`int`): The number of collecting data sample. 239 - train_iter (:obj:`int`): The number of training iteration when calling collect method. 240 - drop_extra (:obj:`bool`): Whether to drop extra return_data more than `n_sample`. 241 - record_random_collect (:obj:`bool`) :Whether to output logs of random collect. 242 - policy_kwargs (:obj:`dict`): The keyword args for policy forward. 243 - level_seeds (:obj:`dict`): Used in PLR, represents the seed of the environment that \ 244 generate the data 245 Returns: 246 - return_data (:obj:`List`): A list containing training samples. 247 """ 248 if n_sample is None: 249 if self._default_n_sample is None: 250 raise RuntimeError("Please specify collect n_sample") 251 else: 252 n_sample = self._default_n_sample 253 if n_sample % self._env_num != 0: 254 one_time_warning( 255 "Please make sure env_num is divisible by n_sample: {}/{}, ".format(n_sample, self._env_num) + 256 "which may cause convergence problems in a few algorithms" 257 ) 258 if policy_kwargs is None: 259 policy_kwargs = {} 260 collected_sample = 0 261 collected_step = 0 262 collected_episode = 0 263 return_data = [] 264 265 while collected_sample < n_sample: 266 with self._timer: 267 # Get current env obs. 268 obs = self._env.ready_obs 269 # Policy forward. 270 self._obs_pool.update(obs) 271 if self._transform_obs: 272 obs = to_tensor(obs, dtype=torch.float32) 273 if self._policy_cfg.type == 'dreamer_command' and not random_collect: 274 policy_output = self._policy.forward(obs, **policy_kwargs, reset=self._resets, state=self._states) 275 #self._states = {env_id: output['state'] for env_id, output in policy_output.items()} 276 self._states = [output['state'] for output in policy_output.values()] 277 else: 278 policy_output = self._policy.forward(obs, **policy_kwargs) 279 self._policy_output_pool.update(policy_output) 280 # Interact with env. 281 actions = {env_id: output['action'] for env_id, output in policy_output.items()} 282 actions = to_ndarray(actions) 283 timesteps = self._env.step(actions) 284 285 # TODO(nyz) this duration may be inaccurate in async env 286 interaction_duration = self._timer.value / len(timesteps) 287 288 # TODO(nyz) vectorize this for loop 289 for env_id, timestep in timesteps.items(): 290 with self._timer: 291 if timestep.info.get('abnormal', False): 292 # If there is an abnormal timestep, reset all the related variables(including this env). 293 # suppose there is no reset param, just reset this env 294 self._env.reset({env_id: None}) 295 self._policy.reset([env_id]) 296 self._reset_stat(env_id) 297 self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) 298 continue 299 if self._policy_cfg.type == 'dreamer_command' and not random_collect: 300 self._resets[env_id] = timestep.done 301 if self._policy_cfg.type == 'ngu_command': # for NGU policy 302 transition = self._policy.process_transition( 303 self._obs_pool[env_id], self._policy_output_pool[env_id], timestep, env_id 304 ) 305 else: 306 transition = self._policy.process_transition( 307 self._obs_pool[env_id], self._policy_output_pool[env_id], timestep 308 ) 309 if level_seeds is not None: 310 transition['seed'] = level_seeds[env_id] 311 # ``train_iter`` passed in from ``serial_entry``, indicates current collecting model's iteration. 312 transition['collect_iter'] = train_iter 313 self._traj_buffer[env_id].append(transition) 314 self._env_info[env_id]['step'] += 1 315 collected_step += 1 316 # prepare data 317 if timestep.done or len(self._traj_buffer[env_id]) == self._traj_len: 318 # If policy is r2d2: 319 # 1. For each collect_env, we want to collect data of length self._traj_len=INF 320 # unless the episode enters the 'done' state. 321 # 2. The length of a train (sequence) sample in r2d2 is <burnin + learn_unroll_length> 322 # (please refer to r2d2.py) and in each collect phase, 323 # we collect a total of <n_sample> (sequence) samples. 324 # 3. When timestep is done and we only collected very few transitions in self._traj_buffer, 325 # by going through self._policy.get_train_sample, it will be padded automatically to get the 326 # sequence sample of length <burnin + learn_unroll_len> (please refer to r2d2.py). 327 328 # Episode is done or traj_buffer(maxlen=traj_len) is full. 329 # indicate whether to shallow copy next obs, i.e., overlap of s_t and s_t+1 330 transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs) 331 train_sample = self._policy.get_train_sample(transitions) 332 return_data.extend(train_sample) 333 self._env_info[env_id]['train_sample'] += len(train_sample) 334 collected_sample += len(train_sample) 335 self._traj_buffer[env_id].clear() 336 337 self._env_info[env_id]['time'] += self._timer.value + interaction_duration 338 339 # If env is done, record episode info and reset 340 if timestep.done: 341 collected_episode += 1 342 reward = timestep.info['eval_episode_return'] 343 info = { 344 'reward': reward, 345 'time': self._env_info[env_id]['time'], 346 'step': self._env_info[env_id]['step'], 347 'train_sample': self._env_info[env_id]['train_sample'], 348 } 349 self._episode_info.append(info) 350 # Env reset is done by env_manager automatically 351 self._policy.reset([env_id]) 352 self._reset_stat(env_id) 353 354 collected_duration = sum([d['time'] for d in self._episode_info]) 355 # reduce data when enables DDP 356 if self._world_size > 1: 357 collected_sample = allreduce_data(collected_sample, 'sum') 358 collected_step = allreduce_data(collected_step, 'sum') 359 collected_episode = allreduce_data(collected_episode, 'sum') 360 collected_duration = allreduce_data(collected_duration, 'sum') 361 self._total_envstep_count += collected_step 362 self._total_episode_count += collected_episode 363 self._total_duration += collected_duration 364 self._total_train_sample_count += collected_sample 365 # log 366 if record_random_collect: # default is true, but when random collect, record_random_collect is False 367 self._output_log(train_iter) 368 else: 369 self._episode_info.clear() 370 # on-policy reset 371 if self._on_policy: 372 for env_id in range(self._env_num): 373 self._reset_stat(env_id) 374 375 if drop_extra: 376 return return_data[:n_sample] 377 else: 378 return return_data 379 380 def _output_log(self, train_iter: int) -> None: 381 """ 382 Overview: 383 Print the output log information. You can refer to the docs of `Best Practice` to understand \ 384 the training generated logs and tensorboards. 385 Arguments: 386 - train_iter (:obj:`int`): the number of training iteration. 387 """ 388 if self._rank != 0: 389 return 390 if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: 391 self._last_train_iter = train_iter 392 episode_count = len(self._episode_info) 393 envstep_count = sum([d['step'] for d in self._episode_info]) 394 train_sample_count = sum([d['train_sample'] for d in self._episode_info]) 395 duration = sum([d['time'] for d in self._episode_info]) 396 episode_return = [d['reward'] for d in self._episode_info] 397 info = { 398 'episode_count': episode_count, 399 'envstep_count': envstep_count, 400 'train_sample_count': train_sample_count, 401 'avg_envstep_per_episode': envstep_count / episode_count, 402 'avg_sample_per_episode': train_sample_count / episode_count, 403 'avg_envstep_per_sec': envstep_count / duration, 404 'avg_train_sample_per_sec': train_sample_count / duration, 405 'avg_episode_per_sec': episode_count / duration, 406 'reward_mean': np.mean(episode_return), 407 'reward_std': np.std(episode_return), 408 'reward_max': np.max(episode_return), 409 'reward_min': np.min(episode_return), 410 'total_envstep_count': self._total_envstep_count, 411 'total_train_sample_count': self._total_train_sample_count, 412 'total_episode_count': self._total_episode_count, 413 # 'each_reward': episode_return, 414 } 415 self._episode_info.clear() 416 self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) 417 for k, v in info.items(): 418 if k in ['each_reward']: 419 continue 420 self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) 421 if k in ['total_envstep_count']: 422 continue 423 self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)