Skip to content

ding.worker.collector.marine_parallel_collector

ding.worker.collector.marine_parallel_collector

MarineParallelCollector

Bases: BaseParallelCollector

Feature
  • one policy or two policies, many envs
  • async envs(step + reset)
  • batch network eval
  • different episode length env
  • periodic policy update
  • metadata + stepdata

Full Source Code

../ding/worker/collector/marine_parallel_collector.py

1from typing import Dict, Any, List 2import copy 3import time 4import uuid 5from collections import namedtuple 6from threading import Thread 7from functools import partial 8import numpy as np 9import torch 10from easydict import EasyDict 11 12from ding.policy import create_policy, Policy 13from ding.envs import get_vec_env_setting, create_env_manager 14from ding.utils import get_data_compressor, pretty_print, PARALLEL_COLLECTOR_REGISTRY 15from ding.envs import BaseEnvTimestep, BaseEnvManager 16from .base_parallel_collector import BaseParallelCollector 17from .base_serial_collector import CachePool, TrajBuffer 18 19INF = float("inf") 20 21 22@PARALLEL_COLLECTOR_REGISTRY.register('marine') 23class MarineParallelCollector(BaseParallelCollector): 24 """ 25 Feature: 26 - one policy or two policies, many envs 27 - async envs(step + reset) 28 - batch network eval 29 - different episode length env 30 - periodic policy update 31 - metadata + stepdata 32 """ 33 config = dict( 34 print_freq=5, 35 compressor='lz4', 36 update_policy_second=3, 37 # The following keys is set by the commander 38 # env 39 # policy 40 # collect_setting 41 # eval_flag 42 # policy_update_path 43 ) 44 45 # override 46 def __init__(self, cfg: dict) -> None: 47 super().__init__(cfg) 48 self._update_policy_thread = Thread( 49 target=self._update_policy_periodically, args=(), name='update_policy', daemon=True 50 ) 51 self._start_time = time.time() 52 self._compressor = get_data_compressor(self._cfg.compressor) 53 54 # create env 55 self._env_cfg = self._cfg.env 56 env_manager = self._setup_env_manager(self._env_cfg) 57 self.env_manager = env_manager 58 59 # create policy 60 if self._eval_flag: 61 assert len(self._cfg.policy) == 1 62 policy = [create_policy(self._cfg.policy[0], enable_field=['eval']).eval_mode] 63 self.policy = policy 64 self._policy_is_active = [None] 65 self._policy_iter = [None] 66 self._traj_buffer_length = self._traj_len if self._traj_len != INF else None 67 self._traj_buffer = {env_id: [TrajBuffer(self._traj_len)] for env_id in range(self._env_num)} 68 else: 69 assert len(self._cfg.policy) == 2 70 policy = [create_policy(self._cfg.policy[i], enable_field=['collect']).collect_mode for i in range(2)] 71 self.policy = policy 72 self._policy_is_active = [None for _ in range(2)] 73 self._policy_iter = [None for _ in range(2)] 74 self._traj_buffer_length = self._traj_len if self._traj_len != INF else None 75 self._traj_buffer = { 76 env_id: [TrajBuffer(self._traj_buffer_length) for _ in range(len(policy))] 77 for env_id in range(self._env_num) 78 } 79 # self._first_update_policy = True 80 81 self._episode_result = [[] for k in range(self._env_num)] 82 self._obs_pool = CachePool('obs', self._env_num) 83 self._policy_output_pool = CachePool('policy_output', self._env_num) 84 self._total_step = 0 85 self._total_sample = 0 86 self._total_episode = 0 87 88 @property 89 def policy(self) -> List[Policy]: 90 return self._policy 91 92 # override 93 @policy.setter 94 def policy(self, _policy: List[Policy]) -> None: 95 self._policy = _policy 96 self._n_episode = _policy[0].get_attribute('cfg').collect.get('n_episode', None) 97 self._n_sample = _policy[0].get_attribute('cfg').collect.get('n_sample', None) 98 assert any( 99 [t is None for t in [self._n_sample, self._n_episode]] 100 ), "n_episode/n_sample in policy cfg can't be not None at the same time" 101 # TODO(nyz) the same definition of traj_len in serial and parallel 102 if self._n_episode is not None: 103 self._traj_len = INF 104 elif self._n_sample is not None: 105 self._traj_len = self._n_sample 106 107 @property 108 def env_manager(self, _env_manager) -> None: 109 self._env_manager = _env_manager 110 111 # override 112 @env_manager.setter 113 def env_manager(self, _env_manager: BaseEnvManager) -> None: 114 self._env_manager = _env_manager 115 self._env_manager.launch() 116 self._env_num = self._env_manager.env_num 117 self._predefined_episode_count = self._env_num * self._env_manager._episode_num 118 119 def _setup_env_manager(self, cfg: EasyDict) -> BaseEnvManager: 120 env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg) 121 if self._eval_flag: 122 env_cfg = evaluator_env_cfg 123 else: 124 env_cfg = collector_env_cfg 125 env_manager = create_env_manager(cfg.manager, [partial(env_fn, cfg=c) for c in env_cfg]) 126 return env_manager 127 128 def _start_thread(self) -> None: 129 # evaluator doesn't need to update policy periodically, only updating policy when starts 130 if not self._eval_flag: 131 self._update_policy_thread.start() 132 133 def _join_thread(self) -> None: 134 if not self._eval_flag: 135 self._update_policy_thread.join() 136 del self._update_policy_thread 137 138 # override 139 def close(self) -> None: 140 if self._end_flag: 141 return 142 self._end_flag = True 143 time.sleep(1) 144 if hasattr(self, '_env_manager'): 145 self._env_manager.close() 146 self._join_thread() 147 148 # override 149 def _policy_inference(self, obs: Dict[int, Any]) -> Dict[int, Any]: 150 env_ids = list(obs.keys()) 151 if len(self._policy) > 1: 152 assert not self._eval_flag 153 obs = [{id: obs[id][i] for id in env_ids} for i in range(len(self._policy))] 154 else: 155 assert self._eval_flag 156 obs = [obs] 157 self._obs_pool.update(obs) 158 policy_outputs = [] 159 for i in range(len(self._policy)): 160 if self._eval_flag: 161 policy_output = self._policy[i].forward(obs[i]) 162 else: 163 policy_output = self._policy[i].forward(obs[i], **self._cfg.collect_setting) 164 policy_outputs.append(policy_output) 165 self._policy_output_pool.update(policy_outputs) 166 actions = {} 167 for env_id in env_ids: 168 action = [policy_outputs[i][env_id]['action'] for i in range(len(self._policy))] 169 action = torch.stack(action).squeeze() 170 actions[env_id] = action 171 return actions 172 173 # override 174 def _env_step(self, actions: Dict[int, Any]) -> Dict[int, Any]: 175 return self._env_manager.step(actions) 176 177 # override 178 def _process_timestep(self, timestep: Dict[int, namedtuple]) -> None: 179 for env_id, t in timestep.items(): 180 if t.info.get('abnormal', False): 181 # If there is an abnormal timestep, reset all the related variables, also this env has been reset 182 for c in self._traj_buffer[env_id]: 183 c.clear() 184 self._obs_pool.reset(env_id) 185 self._policy_output_pool.reset(env_id) 186 for p in self._policy: 187 p.reset([env_id]) 188 continue 189 self._total_step += 1 190 t = [BaseEnvTimestep(t.obs[i], t.reward[i], t.done, t.info) for i in range(len(self._policy))] 191 if t[0].done: 192 self._total_episode += 1 193 if not self._eval_flag: 194 for i in range(len(self._policy)): 195 if self._policy_is_active[i]: 196 # Only active policy will store transition into replay buffer. 197 transition = self._policy[i].process_transition( 198 self._obs_pool[env_id][i], self._policy_output_pool[env_id][i], t[i] 199 ) 200 self._traj_buffer[env_id][i].append(transition) 201 full_indices = [] 202 for i in range(len(self._traj_buffer[env_id])): 203 if len(self._traj_buffer[env_id][i]) == self._traj_len: 204 full_indices.append(i) 205 if t[0].done or len(full_indices) > 0: 206 for i in full_indices: 207 train_sample = self._policy[i].get_train_sample(self._traj_buffer[env_id][i]) 208 for s in train_sample: 209 s = self._compressor(s) 210 self._total_sample += 1 211 metadata = self._get_metadata(s, env_id) 212 self.send_stepdata(metadata['data_id'], s) 213 self.send_metadata(metadata) 214 self._traj_buffer[env_id][i].clear() 215 if t[0].done: 216 # env reset is done by env_manager automatically 217 self._obs_pool.reset(env_id) 218 self._policy_output_pool.reset(env_id) 219 for p in self._policy: 220 p.reset([env_id]) 221 reward = t[0].info['eval_episode_return'] 222 # Only left player's reward will be recorded. 223 left_reward = reward[0] 224 if isinstance(left_reward, torch.Tensor): 225 left_reward = left_reward.item() 226 self._episode_result[env_id].append(left_reward) 227 self.debug( 228 "Env {} finish episode, final reward: {}, collected episode: {}.".format( 229 env_id, reward, len(self._episode_result[env_id]) 230 ) 231 ) 232 self._total_step += 1 233 dones = [t.done for t in timestep.values()] 234 if any(dones): 235 collector_info = self._get_collector_info() 236 self.send_metadata(collector_info) 237 238 # override 239 def get_finish_info(self) -> dict: 240 duration = max(time.time() - self._start_time, 1e-8) 241 game_result = copy.deepcopy(self._episode_result) 242 for i, env_result in enumerate(game_result): 243 for j, rew in enumerate(env_result): 244 if rew < 0: 245 game_result[i][j] = "losses" 246 elif rew == 0: 247 game_result[i][j] = "draws" 248 else: 249 game_result[i][j] = "wins" 250 251 finish_info = { 252 # 'finished_task': True, # flag 253 'eval_flag': self._eval_flag, 254 # 'episode_num': self._episode_num, 255 'env_num': self._env_num, 256 'duration': duration, 257 'collector_done': self._env_manager.done, 258 'predefined_episode_count': self._predefined_episode_count, 259 'real_episode_count': self._total_episode, 260 'step_count': self._total_step, 261 'sample_count': self._total_sample, 262 'avg_time_per_episode': duration / max(1, self._total_episode), 263 'avg_time_per_step': duration / self._total_step, 264 'avg_time_per_train_sample': duration / max(1, self._total_sample), 265 'avg_step_per_episode': self._total_step / max(1, self._total_episode), 266 'avg_sample_per_episode': self._total_sample / max(1, self._total_episode), 267 'reward_mean': np.mean(self._episode_result), 268 'reward_std': np.std(self._episode_result), 269 'reward_raw': self._episode_result, 270 'finish_time': time.time(), 271 'game_result': game_result, 272 } 273 if not self._eval_flag: 274 finish_info['collect_setting'] = self._cfg.collect_setting 275 self._logger.info('\nFINISH INFO\n{}'.format(pretty_print(finish_info, direct_print=False))) 276 return finish_info 277 278 # override 279 def _update_policy(self) -> None: 280 path = self._cfg.policy_update_path 281 self._policy_is_active = self._cfg.policy_update_flag 282 for i in range(len(path)): 283 # if not self._first_update_policy and not self._policy_is_active[i]: 284 if not self._policy_is_active[i]: 285 # For the first time, all policies should be updated(i.e. initialized); 286 # For other times, only active player's policies should be updated. 287 continue 288 while True: 289 try: 290 policy_update_info = self.get_policy_update_info(path[i]) 291 break 292 except Exception as e: 293 self.error('Policy {} update error: {}'.format(i + 1, e)) 294 time.sleep(1) 295 if policy_update_info is None: 296 continue 297 self._policy_iter[i] = policy_update_info.pop('iter') 298 self._policy[i].load_state_dict(policy_update_info) 299 self.debug('Update policy {} with {}(iter{}) in {}'.format(i + 1, path, self._policy_iter, time.time())) 300 # self._first_update_policy = False 301 302 # ******************************** thread ************************************** 303 304 def _update_policy_periodically(self) -> None: 305 last = time.time() 306 while not self._end_flag: 307 cur = time.time() 308 interval = cur - last 309 if interval < self._cfg.update_policy_second: 310 time.sleep(self._cfg.update_policy_second * 0.1) 311 continue 312 else: 313 self._update_policy() 314 last = time.time() 315 time.sleep(0.1) 316 317 def _get_metadata(self, stepdata: List, env_id: int) -> dict: 318 data_id = "env_{}_{}".format(env_id, str(uuid.uuid1())) 319 metadata = { 320 'eval_flag': self._eval_flag, 321 'data_id': data_id, 322 'env_id': env_id, 323 'policy_iter': self._policy_iter, 324 'unroll_len': len(stepdata), 325 'compressor': self._cfg.compressor, 326 'get_data_time': time.time(), 327 # TODO(nyz) the relationship between traj priority and step priority 328 'priority': 1.0, 329 'cur_episode': self._total_episode, 330 'cur_sample': self._total_sample, 331 'cur_step': self._total_step, 332 } 333 return metadata 334 335 def _get_collector_info(self) -> dict: 336 return { 337 'eval_flag': self._eval_flag, 338 'get_info_time': time.time(), 339 'collector_done': self._env_manager.done, 340 'cur_episode': self._total_episode, 341 'cur_sample': self._total_sample, 342 'cur_step': self._total_step, 343 } 344 345 def __repr__(self) -> str: 346 return "MarineParallelCollector"