Skip to content

ding.worker.collector.episode_serial_collector

ding.worker.collector.episode_serial_collector

EpisodeSerialCollector

Bases: ISerialCollector

Overview

Episode collector(n_episode)

Interfaces: init, reset, reset_env, reset_policy, collect, close Property: envstep

envstep property writable

Overview

Print the total envstep count.

Return: - envstep (:obj:int): The total envstep count.

__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='collector')

Overview

Initialization method.

Arguments: - cfg (:obj:EasyDict): Config dict - env (:obj:BaseEnvManager): the subclass of vectorized env_manager(BaseEnvManager) - policy (:obj:namedtuple): the api namedtuple of collect_mode policy - tb_logger (:obj:SummaryWriter): tensorboard handle

reset_env(_env=None)

Overview

Reset the environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch.

Arguments: - env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)

reset_policy(_policy=None)

Overview

Reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.

Arguments: - policy (:obj:Optional[namedtuple]): the api namedtuple of collect_mode policy

reset(_policy=None, _env=None)

Overview

Reset the environment and policy. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.

Arguments: - policy (:obj:Optional[namedtuple]): the api namedtuple of collect_mode policy - env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)

close()

Overview

Close the collector. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.

__del__()

Overview

Execute the close command and close the collector. del is automatically called to destroy the collector instance when the collector finishes its work

collect(n_episode=None, train_iter=0, policy_kwargs=None)

Overview

Collect n_episode data with policy_kwargs, which is already trained train_iter iterations

Arguments: - n_episode (:obj:int): the number of collecting data episode - train_iter (:obj:int): the number of training iteration - policy_kwargs (:obj:dict): the keyword args for policy forward Returns: - return_data (:obj:List): A list containing collected episodes if not get_train_sample, otherwise, return train_samples split by unroll_len.

Full Source Code

../ding/worker/collector/episode_serial_collector.py

1from typing import Optional, Any, List 2from collections import namedtuple 3from easydict import EasyDict 4import numpy as np 5import torch 6 7from ding.envs import BaseEnvManager 8from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY 9from ding.torch_utils import to_tensor, to_ndarray 10from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions 11 12 13@SERIAL_COLLECTOR_REGISTRY.register('episode') 14class EpisodeSerialCollector(ISerialCollector): 15 """ 16 Overview: 17 Episode collector(n_episode) 18 Interfaces: 19 __init__, reset, reset_env, reset_policy, collect, close 20 Property: 21 envstep 22 """ 23 24 config = dict( 25 deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False, reward_shaping=False 26 ) 27 28 def __init__( 29 self, 30 cfg: EasyDict, 31 env: BaseEnvManager = None, 32 policy: namedtuple = None, 33 tb_logger: 'SummaryWriter' = None, # noqa 34 exp_name: Optional[str] = 'default_experiment', 35 instance_name: Optional[str] = 'collector' 36 ) -> None: 37 """ 38 Overview: 39 Initialization method. 40 Arguments: 41 - cfg (:obj:`EasyDict`): Config dict 42 - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) 43 - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy 44 - tb_logger (:obj:`SummaryWriter`): tensorboard handle 45 """ 46 self._exp_name = exp_name 47 self._instance_name = instance_name 48 self._collect_print_freq = cfg.collect_print_freq 49 self._deepcopy_obs = cfg.deepcopy_obs 50 self._transform_obs = cfg.transform_obs 51 self._cfg = cfg 52 self._timer = EasyTimer() 53 self._end_flag = False 54 55 if tb_logger is not None: 56 self._logger, _ = build_logger( 57 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False 58 ) 59 self._tb_logger = tb_logger 60 else: 61 self._logger, self._tb_logger = build_logger( 62 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name 63 ) 64 self.reset(policy, env) 65 66 def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: 67 """ 68 Overview: 69 Reset the environment. 70 If _env is None, reset the old environment. 71 If _env is not None, replace the old environment in the collector with the new passed \ 72 in environment and launch. 73 Arguments: 74 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 75 env_manager(BaseEnvManager) 76 """ 77 if _env is not None: 78 self._env = _env 79 self._env.launch() 80 self._env_num = self._env.env_num 81 else: 82 self._env.reset() 83 84 def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: 85 """ 86 Overview: 87 Reset the policy. 88 If _policy is None, reset the old policy. 89 If _policy is not None, replace the old policy in the collector with the new passed in policy. 90 Arguments: 91 - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy 92 """ 93 assert hasattr(self, '_env'), "please set env first" 94 if _policy is not None: 95 self._policy = _policy 96 self._policy_cfg = self._policy.get_attribute('cfg') 97 self._default_n_episode = _policy.get_attribute('n_episode') 98 self._unroll_len = _policy.get_attribute('unroll_len') 99 self._on_policy = _policy.get_attribute('on_policy') 100 self._traj_len = INF 101 self._logger.debug( 102 'Set default n_episode mode(n_episode({}), env_num({}), traj_len({}))'.format( 103 self._default_n_episode, self._env_num, self._traj_len 104 ) 105 ) 106 self._policy.reset() 107 108 def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: 109 """ 110 Overview: 111 Reset the environment and policy. 112 If _env is None, reset the old environment. 113 If _env is not None, replace the old environment in the collector with the new passed \ 114 in environment and launch. 115 If _policy is None, reset the old policy. 116 If _policy is not None, replace the old policy in the collector with the new passed in policy. 117 Arguments: 118 - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy 119 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 120 env_manager(BaseEnvManager) 121 """ 122 if _env is not None: 123 self.reset_env(_env) 124 if _policy is not None: 125 self.reset_policy(_policy) 126 127 self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs) 128 self._policy_output_pool = CachePool('policy_output', self._env_num) 129 # _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions 130 self._traj_buffer = {env_id: TrajBuffer(maxlen=self._traj_len) for env_id in range(self._env_num)} 131 self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} 132 133 self._episode_info = [] 134 self._total_envstep_count = 0 135 self._total_episode_count = 0 136 self._total_duration = 0 137 self._last_train_iter = 0 138 self._end_flag = False 139 140 def _reset_stat(self, env_id: int) -> None: 141 """ 142 Overview: 143 Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ 144 and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ 145 to get more messages. 146 Arguments: 147 - env_id (:obj:`int`): the id where we need to reset the collector's state 148 """ 149 self._traj_buffer[env_id].clear() 150 self._obs_pool.reset(env_id) 151 self._policy_output_pool.reset(env_id) 152 self._env_info[env_id] = {'time': 0., 'step': 0} 153 154 @property 155 def envstep(self) -> int: 156 """ 157 Overview: 158 Print the total envstep count. 159 Return: 160 - envstep (:obj:`int`): The total envstep count. 161 """ 162 return self._total_envstep_count 163 164 @envstep.setter 165 def envstep(self, value: int) -> None: 166 """ 167 Overview: 168 Set the total envstep count. 169 Arguments: 170 - value (:obj:`int`): The total envstep count. 171 """ 172 self._total_envstep_count = value 173 174 def close(self) -> None: 175 """ 176 Overview: 177 Close the collector. If end_flag is False, close the environment, flush the tb_logger\ 178 and close the tb_logger. 179 """ 180 if self._end_flag: 181 return 182 self._end_flag = True 183 self._env.close() 184 self._tb_logger.flush() 185 self._tb_logger.close() 186 187 def __del__(self) -> None: 188 """ 189 Overview: 190 Execute the close command and close the collector. __del__ is automatically called to \ 191 destroy the collector instance when the collector finishes its work 192 """ 193 self.close() 194 195 def collect(self, 196 n_episode: Optional[int] = None, 197 train_iter: int = 0, 198 policy_kwargs: Optional[dict] = None) -> List[Any]: 199 """ 200 Overview: 201 Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations 202 Arguments: 203 - n_episode (:obj:`int`): the number of collecting data episode 204 - train_iter (:obj:`int`): the number of training iteration 205 - policy_kwargs (:obj:`dict`): the keyword args for policy forward 206 Returns: 207 - return_data (:obj:`List`): A list containing collected episodes if not get_train_sample, otherwise, \ 208 return train_samples split by unroll_len. 209 """ 210 if n_episode is None: 211 if self._default_n_episode is None: 212 raise RuntimeError("Please specify collect n_episode") 213 else: 214 n_episode = self._default_n_episode 215 assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) 216 if policy_kwargs is None: 217 policy_kwargs = {} 218 collected_episode = 0 219 return_data = [] 220 ready_env_id = set() 221 remain_episode = n_episode 222 223 while True: 224 with self._timer: 225 # Get current env obs. 226 obs = self._env.ready_obs 227 new_available_env_id = set(obs.keys()).difference(ready_env_id) 228 ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) 229 remain_episode -= min(len(new_available_env_id), remain_episode) 230 obs = {env_id: obs[env_id] for env_id in ready_env_id} 231 # Policy forward. 232 self._obs_pool.update(obs) 233 if self._transform_obs: 234 obs = to_tensor(obs, dtype=torch.float32) 235 policy_output = self._policy.forward(obs, **policy_kwargs) 236 self._policy_output_pool.update(policy_output) 237 # Interact with env. 238 actions = {env_id: output['action'] for env_id, output in policy_output.items()} 239 actions = to_ndarray(actions) 240 timesteps = self._env.step(actions) 241 242 # TODO(nyz) this duration may be inaccurate in async env 243 interaction_duration = self._timer.value / len(timesteps) 244 245 # TODO(nyz) vectorize this for loop 246 for env_id, timestep in timesteps.items(): 247 with self._timer: 248 if timestep.info.get('abnormal', False): 249 # If there is an abnormal timestep, reset all the related variables(including this env). 250 # suppose there is no reset param, just reset this env 251 self._env.reset({env_id: None}) 252 self._policy.reset([env_id]) 253 self._reset_stat(env_id) 254 self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) 255 continue 256 transition = self._policy.process_transition( 257 self._obs_pool[env_id], self._policy_output_pool[env_id], timestep 258 ) 259 # ``train_iter`` passed in from ``serial_entry``, indicates current collecting model's iteration. 260 transition['collect_iter'] = train_iter 261 self._traj_buffer[env_id].append(transition) 262 self._env_info[env_id]['step'] += 1 263 self._total_envstep_count += 1 264 # prepare data 265 if timestep.done: 266 transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs) 267 if self._cfg.reward_shaping: 268 self._env.reward_shaping(env_id, transitions) 269 if self._cfg.get_train_sample: 270 train_sample = self._policy.get_train_sample(transitions) 271 return_data.extend(train_sample) 272 else: 273 return_data.append(transitions) 274 self._traj_buffer[env_id].clear() 275 276 self._env_info[env_id]['time'] += self._timer.value + interaction_duration 277 278 # If env is done, record episode info and reset 279 if timestep.done: 280 self._total_episode_count += 1 281 reward = timestep.info['eval_episode_return'] 282 info = { 283 'reward': reward, 284 'time': self._env_info[env_id]['time'], 285 'step': self._env_info[env_id]['step'], 286 } 287 collected_episode += 1 288 self._episode_info.append(info) 289 self._policy.reset([env_id]) 290 self._reset_stat(env_id) 291 ready_env_id.remove(env_id) 292 if collected_episode >= n_episode: 293 break 294 # log 295 self._output_log(train_iter) 296 return return_data 297 298 def _output_log(self, train_iter: int) -> None: 299 """ 300 Overview: 301 Print the output log information. You can refer to Docs/Best Practice/How to understand\ 302 training generated folders/Serial mode/log/collector for more details. 303 Arguments: 304 - train_iter (:obj:`int`): the number of training iteration. 305 """ 306 if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: 307 self._last_train_iter = train_iter 308 episode_count = len(self._episode_info) 309 envstep_count = sum([d['step'] for d in self._episode_info]) 310 duration = sum([d['time'] for d in self._episode_info]) 311 episode_return = [d['reward'] for d in self._episode_info] 312 self._total_duration += duration 313 info = { 314 'episode_count': episode_count, 315 'envstep_count': envstep_count, 316 'avg_envstep_per_episode': envstep_count / episode_count, 317 'avg_envstep_per_sec': envstep_count / duration, 318 'avg_episode_per_sec': episode_count / duration, 319 'collect_time': duration, 320 'reward_mean': np.mean(episode_return), 321 'reward_std': np.std(episode_return), 322 'reward_max': np.max(episode_return), 323 'reward_min': np.min(episode_return), 324 'total_envstep_count': self._total_envstep_count, 325 'total_episode_count': self._total_episode_count, 326 'total_duration': self._total_duration, 327 # 'each_reward': episode_return, 328 } 329 self._episode_info.clear() 330 self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) 331 for k, v in info.items(): 332 if k in ['each_reward']: 333 continue 334 self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) 335 if k in ['total_envstep_count']: 336 continue 337 self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)