Skip to content

ding.worker.collector.battle_sample_serial_collector

ding.worker.collector.battle_sample_serial_collector

BattleSampleSerialCollector

Bases: ISerialCollector

Overview

Sample collector(n_sample) with multiple(n VS n) policy battle

Interfaces: init, reset, reset_env, reset_policy, collect, close Property: envstep

envstep property writable

Overview

Print the total envstep count.

Return: - envstep (:obj:int): The total envstep count.

__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='collector')

Overview

Initialization method.

Arguments: - cfg (:obj:EasyDict): Config dict - env (:obj:BaseEnvManager): the subclass of vectorized env_manager(BaseEnvManager) - policy (:obj:List[namedtuple]): the api namedtuple of collect_mode policy - tb_logger (:obj:SummaryWriter): tensorboard handle

reset_env(_env=None)

Overview

Reset the environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch.

Arguments: - env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)

reset_policy(_policy=None)

Overview

Reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.

Arguments: - policy (:obj:Optional[List[namedtuple]]): the api namedtuple of collect_mode policy

reset(_policy=None, _env=None)

Overview

Reset the environment and policy. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.

Arguments: - policy (:obj:Optional[List[namedtuple]]): the api namedtuple of collect_mode policy - env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)

close()

Overview

Close the collector. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.

__del__()

Overview

Execute the close command and close the collector. del is automatically called to destroy the collector instance when the collector finishes its work

collect(n_sample=None, train_iter=0, drop_extra=True, policy_kwargs=None)

Overview

Collect n_sample data with policy_kwargs, which is already trained train_iter iterations.

Arguments: - n_sample (:obj:int): The number of collecting data sample. - train_iter (:obj:int): The number of training iteration when calling collect method. - drop_extra (:obj:bool): Whether to drop extra return_data more than n_sample. - policy_kwargs (:obj:dict): The keyword args for policy forward. Returns: - return_data (:obj:List): A list containing training samples.

Full Source Code

../ding/worker/collector/battle_sample_serial_collector.py

1from typing import Optional, Any, List, Tuple 2from collections import namedtuple 3from easydict import EasyDict 4import numpy as np 5import torch 6 7from ding.envs import BaseEnvManager 8from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, dicts_to_lists, one_time_warning 9from ding.torch_utils import to_tensor, to_ndarray 10from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions 11 12 13@SERIAL_COLLECTOR_REGISTRY.register('sample_1v1') 14class BattleSampleSerialCollector(ISerialCollector): 15 """ 16 Overview: 17 Sample collector(n_sample) with multiple(n VS n) policy battle 18 Interfaces: 19 __init__, reset, reset_env, reset_policy, collect, close 20 Property: 21 envstep 22 """ 23 24 config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100) 25 26 def __init__( 27 self, 28 cfg: EasyDict, 29 env: BaseEnvManager = None, 30 policy: List[namedtuple] = None, 31 tb_logger: 'SummaryWriter' = None, # noqa 32 exp_name: Optional[str] = 'default_experiment', 33 instance_name: Optional[str] = 'collector' 34 ) -> None: 35 """ 36 Overview: 37 Initialization method. 38 Arguments: 39 - cfg (:obj:`EasyDict`): Config dict 40 - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) 41 - policy (:obj:`List[namedtuple]`): the api namedtuple of collect_mode policy 42 - tb_logger (:obj:`SummaryWriter`): tensorboard handle 43 """ 44 self._exp_name = exp_name 45 self._instance_name = instance_name 46 self._collect_print_freq = cfg.collect_print_freq 47 self._deepcopy_obs = cfg.deepcopy_obs 48 self._transform_obs = cfg.transform_obs 49 self._cfg = cfg 50 self._timer = EasyTimer() 51 self._end_flag = False 52 53 if tb_logger is not None: 54 self._logger, _ = build_logger( 55 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False 56 ) 57 self._tb_logger = tb_logger 58 else: 59 self._logger, self._tb_logger = build_logger( 60 path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name 61 ) 62 self._traj_len = float("inf") 63 self.reset(policy, env) 64 65 def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: 66 """ 67 Overview: 68 Reset the environment. 69 If _env is None, reset the old environment. 70 If _env is not None, replace the old environment in the collector with the new passed \ 71 in environment and launch. 72 Arguments: 73 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 74 env_manager(BaseEnvManager) 75 """ 76 if _env is not None: 77 self._env = _env 78 self._env.launch() 79 self._env_num = self._env.env_num 80 else: 81 self._env.reset() 82 83 def reset_policy(self, _policy: Optional[List[namedtuple]] = None) -> None: 84 """ 85 Overview: 86 Reset the policy. 87 If _policy is None, reset the old policy. 88 If _policy is not None, replace the old policy in the collector with the new passed in policy. 89 Arguments: 90 - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy 91 """ 92 assert hasattr(self, '_env'), "please set env first" 93 if _policy is not None: 94 assert len(_policy) > 1, "battle sample collector needs more than 1 policy, but found {}".format( 95 len(_policy) 96 ) 97 self._policy = _policy 98 self._policy_num = len(self._policy) 99 self._default_n_sample = _policy[0].get_attribute('cfg').collect.get('n_sample', None) 100 self._unroll_len = _policy[0].get_attribute('unroll_len') 101 self._on_policy = _policy[0].get_attribute('cfg').on_policy 102 self._policy_collect_data = [ 103 getattr(self._policy[i], 'collect_data', True) for i in range(self._policy_num) 104 ] 105 if self._default_n_sample is not None: 106 self._traj_len = max( 107 self._unroll_len, 108 self._default_n_sample // self._env_num + int(self._default_n_sample % self._env_num != 0) 109 ) 110 self._logger.debug( 111 'Set default n_sample mode(n_sample({}), env_num({}), traj_len({}))'.format( 112 self._default_n_sample, self._env_num, self._traj_len 113 ) 114 ) 115 else: 116 self._traj_len = INF 117 for p in self._policy: 118 p.reset() 119 120 def reset(self, _policy: Optional[List[namedtuple]] = None, _env: Optional[BaseEnvManager] = None) -> None: 121 """ 122 Overview: 123 Reset the environment and policy. 124 If _env is None, reset the old environment. 125 If _env is not None, replace the old environment in the collector with the new passed \ 126 in environment and launch. 127 If _policy is None, reset the old policy. 128 If _policy is not None, replace the old policy in the collector with the new passed in policy. 129 Arguments: 130 - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy 131 - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ 132 env_manager(BaseEnvManager) 133 """ 134 if _env is not None: 135 self.reset_env(_env) 136 if _policy is not None: 137 self.reset_policy(_policy) 138 139 self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs) 140 self._policy_output_pool = CachePool('policy_output', self._env_num) 141 # _traj_buffer is {env_id: {policy_id: TrajBuffer}}, is used to store traj_len pieces of transitions 142 self._traj_buffer = { 143 env_id: {policy_id: TrajBuffer(maxlen=self._traj_len) 144 for policy_id in range(self._policy_num)} 145 for env_id in range(self._env_num) 146 } 147 self._env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(self._env_num)} 148 149 self._episode_info = [] 150 self._total_envstep_count = 0 151 self._total_episode_count = 0 152 self._total_train_sample_count = 0 153 self._total_duration = 0 154 self._last_train_iter = 0 155 self._end_flag = False 156 157 def _reset_stat(self, env_id: int) -> None: 158 """ 159 Overview: 160 Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ 161 and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ 162 to get more messages. 163 Arguments: 164 - env_id (:obj:`int`): the id where we need to reset the collector's state 165 """ 166 for i in range(2): 167 self._traj_buffer[env_id][i].clear() 168 self._obs_pool.reset(env_id) 169 self._policy_output_pool.reset(env_id) 170 self._env_info[env_id] = {'time': 0., 'step': 0, 'train_sample': 0} 171 172 @property 173 def envstep(self) -> int: 174 """ 175 Overview: 176 Print the total envstep count. 177 Return: 178 - envstep (:obj:`int`): The total envstep count. 179 """ 180 return self._total_envstep_count 181 182 @envstep.setter 183 def envstep(self, value: int) -> None: 184 """ 185 Overview: 186 Set the total envstep count. 187 Arguments: 188 - value (:obj:`int`): The total envstep count. 189 """ 190 self._total_envstep_count = value 191 192 def close(self) -> None: 193 """ 194 Overview: 195 Close the collector. If end_flag is False, close the environment, flush the tb_logger\ 196 and close the tb_logger. 197 """ 198 if self._end_flag: 199 return 200 self._end_flag = True 201 self._env.close() 202 self._tb_logger.flush() 203 self._tb_logger.close() 204 205 def __del__(self) -> None: 206 """ 207 Overview: 208 Execute the close command and close the collector. __del__ is automatically called to \ 209 destroy the collector instance when the collector finishes its work 210 """ 211 self.close() 212 213 def collect( 214 self, 215 n_sample: Optional[int] = None, 216 train_iter: int = 0, 217 drop_extra: bool = True, 218 policy_kwargs: Optional[dict] = None 219 ) -> Tuple[List[Any], List[Any]]: 220 """ 221 Overview: 222 Collect `n_sample` data with policy_kwargs, which is already trained `train_iter` iterations. 223 Arguments: 224 - n_sample (:obj:`int`): The number of collecting data sample. 225 - train_iter (:obj:`int`): The number of training iteration when calling collect method. 226 - drop_extra (:obj:`bool`): Whether to drop extra return_data more than `n_sample`. 227 - policy_kwargs (:obj:`dict`): The keyword args for policy forward. 228 Returns: 229 - return_data (:obj:`List`): A list containing training samples. 230 """ 231 if n_sample is None: 232 if self._default_n_sample is None: 233 raise RuntimeError("Please specify collect n_sample") 234 else: 235 n_sample = self._default_n_sample 236 if n_sample % self._env_num != 0: 237 one_time_warning( 238 "Please make sure env_num is divisible by n_sample: {}/{}, ".format(n_sample, self._env_num) + 239 "which may cause convergence problems in a few algorithms" 240 ) 241 if policy_kwargs is None: 242 policy_kwargs = {} 243 collected_sample = [0 for _ in range(self._policy_num)] 244 return_data = [[] for _ in range(self._policy_num)] 245 return_info = [[] for _ in range(self._policy_num)] 246 247 while any([c < n_sample for i, c in enumerate(collected_sample) if self._policy_collect_data[i]]): 248 with self._timer: 249 # Get current env obs. 250 obs = self._env.ready_obs 251 # Policy forward. 252 self._obs_pool.update(obs) 253 if self._transform_obs: 254 obs = to_tensor(obs, dtype=torch.float32) 255 obs = dicts_to_lists(obs) 256 policy_output = [p.forward(obs[i], **policy_kwargs) for i, p in enumerate(self._policy)] 257 self._policy_output_pool.update(policy_output) 258 # Interact with env. 259 actions = {} 260 for policy_output_item in policy_output: 261 for env_id, output in policy_output_item.items(): 262 if env_id not in actions: 263 actions[env_id] = [] 264 actions[env_id].append(output['action']) 265 actions = to_ndarray(actions) 266 timesteps = self._env.step(actions) 267 268 # TODO(nyz) this duration may be inaccurate in async env 269 interaction_duration = self._timer.value / len(timesteps) 270 271 # TODO(nyz) vectorize this for loop 272 for env_id, timestep in timesteps.items(): 273 self._env_info[env_id]['step'] += 1 274 self._total_envstep_count += 1 275 with self._timer: 276 for policy_id, policy in enumerate(self._policy): 277 if not self._policy_collect_data[policy_id]: 278 continue 279 policy_timestep_data = [d[policy_id] if not isinstance(d, bool) else d for d in timestep] 280 policy_timestep = type(timestep)(*policy_timestep_data) 281 transition = self._policy[policy_id].process_transition( 282 self._obs_pool[env_id][policy_id], self._policy_output_pool[env_id][policy_id], 283 policy_timestep 284 ) 285 transition['collect_iter'] = train_iter 286 self._traj_buffer[env_id][policy_id].append(transition) 287 # prepare data 288 if timestep.done or len(self._traj_buffer[env_id][policy_id]) == self._traj_len: 289 transitions = to_tensor_transitions( 290 self._traj_buffer[env_id][policy_id], not self._deepcopy_obs 291 ) 292 train_sample = self._policy[policy_id].get_train_sample(transitions) 293 return_data[policy_id].extend(train_sample) 294 self._total_train_sample_count += len(train_sample) 295 self._env_info[env_id]['train_sample'] += len(train_sample) 296 collected_sample[policy_id] += len(train_sample) 297 self._traj_buffer[env_id][policy_id].clear() 298 299 self._env_info[env_id]['time'] += self._timer.value + interaction_duration 300 301 # If env is done, record episode info and reset 302 if timestep.done: 303 self._total_episode_count += 1 304 info = { 305 'time': self._env_info[env_id]['time'], 306 'step': self._env_info[env_id]['step'], 307 'train_sample': self._env_info[env_id]['train_sample'], 308 } 309 for i in range(self._policy_num): 310 info['reward{}'.format(i)] = timestep.info[i]['eval_episode_return'] 311 self._episode_info.append(info) 312 for i, p in enumerate(self._policy): 313 p.reset([env_id]) 314 self._reset_stat(env_id) 315 for policy_id in range(2): 316 return_info[policy_id].append(timestep.info[policy_id]) 317 # log 318 self._output_log(train_iter) 319 return_data = [r[:n_sample] for r in return_data] 320 if drop_extra: 321 return_data = return_data[:n_sample] 322 return return_data, return_info 323 324 def _output_log(self, train_iter: int) -> None: 325 """ 326 Overview: 327 Print the output log information. You can refer to Docs/Best Practice/How to understand\ 328 training generated folders/Serial mode/log/collector for more details. 329 Arguments: 330 - train_iter (:obj:`int`): the number of training iteration. 331 """ 332 if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: 333 self._last_train_iter = train_iter 334 episode_count = len(self._episode_info) 335 envstep_count = sum([d['step'] for d in self._episode_info]) 336 duration = sum([d['time'] for d in self._episode_info]) 337 episode_return = [] 338 for i in range(self._policy_num): 339 episode_return_item = [d['reward{}'.format(i)] for d in self._episode_info] 340 episode_return.append(episode_return_item) 341 self._total_duration += duration 342 info = { 343 'episode_count': episode_count, 344 'envstep_count': envstep_count, 345 'avg_envstep_per_episode': envstep_count / episode_count, 346 'avg_envstep_per_sec': envstep_count / duration, 347 'avg_episode_per_sec': episode_count / duration, 348 'collect_time': duration, 349 'total_envstep_count': self._total_envstep_count, 350 'total_episode_count': self._total_episode_count, 351 'total_duration': self._total_duration, 352 } 353 for k, fn in {'mean': np.mean, 'std': np.std, 'max': np.max, 'min': np.min}.items(): 354 for i in range(self._policy_num): 355 # such as reward0_mean 356 info['reward{}_{}'.format(i, k)] = fn(episode_return[i]) 357 self._episode_info.clear() 358 self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) 359 for k, v in info.items(): 360 self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) 361 if k in ['total_envstep_count']: 362 continue 363 self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)