ding.worker.collector.zergling_parallel_collector¶
ding.worker.collector.zergling_parallel_collector
¶
ZerglingParallelCollector
¶
Bases: BaseParallelCollector
Feature
- one policy, many envs
- async envs(step + reset)
- batch network eval
- different episode length env
- periodic policy update
- metadata + stepdata
Full Source Code
../ding/worker/collector/zergling_parallel_collector.py
1from typing import Dict, Any, List 2import time 3import uuid 4from collections import namedtuple 5from threading import Thread 6from functools import partial 7 8import numpy as np 9import torch 10from easydict import EasyDict 11 12from ding.policy import create_policy, Policy 13from ding.envs import get_vec_env_setting, create_env_manager, BaseEnvManager 14from ding.utils import get_data_compressor, pretty_print, PARALLEL_COLLECTOR_REGISTRY 15from .base_parallel_collector import BaseParallelCollector 16from .base_serial_collector import CachePool, TrajBuffer 17 18INF = float("inf") 19 20 21@PARALLEL_COLLECTOR_REGISTRY.register('zergling') 22class ZerglingParallelCollector(BaseParallelCollector): 23 """ 24 Feature: 25 - one policy, many envs 26 - async envs(step + reset) 27 - batch network eval 28 - different episode length env 29 - periodic policy update 30 - metadata + stepdata 31 """ 32 config = dict( 33 print_freq=5, 34 compressor='lz4', 35 update_policy_second=3, 36 # The following keys is set by the commander 37 # env 38 # policy 39 # collect_setting 40 # eval_flag 41 # policy_update_path 42 ) 43 44 # override 45 def __init__(self, cfg: dict) -> None: 46 super().__init__(cfg) 47 self._update_policy_thread = Thread( 48 target=self._update_policy_periodically, args=(), name='update_policy', daemon=True 49 ) 50 self._start_time = time.time() 51 self._compressor = get_data_compressor(self._cfg.compressor) 52 53 # create env 54 self._env_cfg = self._cfg.env 55 env_manager = self._setup_env_manager(self._env_cfg) 56 self.env_manager = env_manager 57 58 # create policy 59 if self._eval_flag: 60 policy = create_policy(self._cfg.policy, enable_field=['eval']).eval_mode 61 else: 62 policy = create_policy(self._cfg.policy, enable_field=['collect']).collect_mode 63 self.policy = policy 64 65 self._episode_result = [[] for k in range(self._env_num)] 66 self._obs_pool = CachePool('obs', self._env_num) 67 self._policy_output_pool = CachePool('policy_output', self._env_num) 68 self._traj_buffer = {env_id: TrajBuffer(self._traj_len) for env_id in range(self._env_num)} 69 self._total_step = 0 70 self._total_sample = 0 71 self._total_episode = 0 72 73 @property 74 def policy(self) -> Policy: 75 return self._policy 76 77 # override 78 @policy.setter 79 def policy(self, _policy: Policy) -> None: 80 self._policy = _policy 81 self._policy_cfg = self._policy.get_attribute('cfg') 82 self._n_sample = _policy.get_attribute('n_sample') 83 self._n_episode = _policy.get_attribute('n_episode') 84 assert not all( 85 [t is None for t in [self._n_sample, self._n_episode]] 86 ), "n_episode/n_sample in policy cfg can't be not None at the same time" 87 # TODO(nyz) the same definition of traj_len in serial and parallel 88 if self._n_episode is not None: 89 self._traj_len = INF 90 elif self._n_sample is not None: 91 self._traj_len = self._n_sample 92 93 @property 94 def env_manager(self, _env_manager) -> None: 95 self._env_manager = _env_manager 96 97 # override 98 @env_manager.setter 99 def env_manager(self, _env_manager: BaseEnvManager) -> None: 100 self._env_manager = _env_manager 101 self._env_manager.launch() 102 self._env_num = self._env_manager.env_num 103 self._predefined_episode_count = self._env_num * self._env_manager._episode_num 104 105 def _setup_env_manager(self, cfg: EasyDict) -> BaseEnvManager: 106 env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg) 107 if self._eval_flag: 108 env_cfg = evaluator_env_cfg 109 else: 110 env_cfg = collector_env_cfg 111 env_manager = create_env_manager(cfg.manager, [partial(env_fn, cfg=c) for c in env_cfg]) 112 return env_manager 113 114 def _start_thread(self) -> None: 115 # evaluator doesn't need to update policy periodically, only updating policy when starts 116 if not self._eval_flag: 117 self._update_policy_thread.start() 118 119 def _join_thread(self) -> None: 120 if not self._eval_flag: 121 self._update_policy_thread.join() 122 del self._update_policy_thread 123 124 # override 125 def close(self) -> None: 126 if self._end_flag: 127 return 128 self._end_flag = True 129 time.sleep(1) 130 if hasattr(self, '_env_manager'): 131 self._env_manager.close() 132 self._join_thread() 133 134 # override 135 def _policy_inference(self, obs: Dict[int, Any]) -> Dict[int, Any]: 136 self._obs_pool.update(obs) 137 if self._eval_flag: 138 policy_output = self._policy.forward(obs) 139 else: 140 policy_output = self._policy.forward(obs, **self._cfg.collect_setting) 141 self._policy_output_pool.update(policy_output) 142 actions = {env_id: output['action'] for env_id, output in policy_output.items()} 143 return actions 144 145 # override 146 def _env_step(self, actions: Dict[int, Any]) -> Dict[int, Any]: 147 return self._env_manager.step(actions) 148 149 # override 150 def _process_timestep(self, timestep: Dict[int, namedtuple]) -> None: 151 send_data_time = [] 152 for env_id, t in timestep.items(): 153 if t.info.get('abnormal', False): 154 # if there is a abnormal timestep, reset all the related variable, also this env has been reset 155 self._traj_buffer[env_id].clear() 156 self._obs_pool.reset(env_id) 157 self._policy_output_pool.reset(env_id) 158 self._policy.reset([env_id]) 159 continue 160 self._total_step += 1 161 if t.done: # must be executed before send_metadata 162 self._total_episode += 1 163 if not self._eval_flag: 164 transition = self._policy.process_transition( 165 self._obs_pool[env_id], self._policy_output_pool[env_id], t 166 ) 167 self._traj_buffer[env_id].append(transition) 168 if (not self._eval_flag) and (t.done or len(self._traj_buffer[env_id]) == self._traj_len): 169 train_sample = self._policy.get_train_sample(self._traj_buffer[env_id]) 170 for s in train_sample: 171 s = self._compressor(s) 172 self._total_sample += 1 173 with self._timer: 174 metadata = self._get_metadata(s, env_id) 175 object_ref = self.send_stepdata(metadata['data_id'], s) 176 if object_ref: 177 metadata['object_ref'] = object_ref 178 self.send_metadata(metadata) 179 send_data_time.append(self._timer.value) 180 self._traj_buffer[env_id].clear() 181 if t.done: 182 # env reset is done by env_manager automatically 183 self._obs_pool.reset(env_id) 184 self._policy_output_pool.reset(env_id) 185 self._policy.reset([env_id]) 186 reward = t.info['eval_episode_return'] 187 if isinstance(reward, torch.Tensor): 188 reward = reward.item() 189 self._episode_result[env_id].append(reward) 190 self.debug( 191 "env {} finish episode, final reward: {}, collected episode {}".format( 192 env_id, reward, len(self._episode_result[env_id]) 193 ) 194 ) 195 self.debug( 196 "send {} train sample with average time: {:.6f}".format( 197 len(send_data_time), 198 sum(send_data_time) / (1e-6 + len(send_data_time)) 199 ) 200 ) 201 dones = [t.done for t in timestep.values()] 202 if any(dones): 203 collector_info = self._get_collector_info() 204 self.send_metadata(collector_info) 205 206 # override 207 def get_finish_info(self) -> dict: 208 duration = max(time.time() - self._start_time, 1e-8) 209 episode_result = sum(self._episode_result, []) 210 finish_info = { 211 'eval_flag': self._eval_flag, 212 'env_num': self._env_num, 213 'duration': duration, 214 'train_iter': self._policy_iter, 215 'collector_done': self._env_manager.done, 216 'predefined_episode_count': self._predefined_episode_count, 217 'real_episode_count': self._total_episode, 218 'step_count': self._total_step, 219 'sample_count': self._total_sample, 220 'avg_time_per_episode': duration / max(1, self._total_episode), 221 'avg_time_per_step': duration / self._total_step, 222 'avg_time_per_train_sample': duration / max(1, self._total_sample), 223 'avg_step_per_episode': self._total_step / max(1, self._total_episode), 224 'avg_sample_per_episode': self._total_sample / max(1, self._total_episode), 225 'reward_mean': np.mean(episode_result) if len(episode_result) > 0 else 0, 226 'reward_std': np.std(episode_result) if len(episode_result) > 0 else 0, 227 'reward_raw': episode_result, 228 'finish_time': time.time() 229 } 230 if not self._eval_flag: 231 finish_info['collect_setting'] = self._cfg.collect_setting 232 self._logger.info('\nFINISH INFO\n{}'.format(pretty_print(finish_info, direct_print=False))) 233 return finish_info 234 235 # override 236 def _update_policy(self) -> None: 237 path = self._cfg.policy_update_path 238 while True: 239 try: 240 policy_update_info = self.get_policy_update_info(path) 241 break 242 except Exception as e: 243 self.error('Policy update error: {}'.format(e)) 244 time.sleep(1) 245 if policy_update_info is None: 246 return 247 248 self._policy_iter = policy_update_info.pop('iter') 249 self._policy.load_state_dict(policy_update_info) 250 self.debug('update policy with {}(iter{}) in {}'.format(path, self._policy_iter, time.time())) 251 252 # ******************************** thread ************************************** 253 254 def _update_policy_periodically(self) -> None: 255 last = time.time() 256 while not self._end_flag: 257 cur = time.time() 258 interval = cur - last 259 if interval < self._cfg.update_policy_second: 260 time.sleep(self._cfg.update_policy_second * 0.1) 261 continue 262 else: 263 self._update_policy() 264 last = time.time() 265 time.sleep(0.1) 266 267 def _get_metadata(self, stepdata: List, env_id: int) -> dict: 268 data_id = "env_{}_{}".format(env_id, str(uuid.uuid1())) 269 metadata = { 270 'eval_flag': self._eval_flag, 271 'data_id': data_id, 272 'env_id': env_id, 273 'policy_iter': self._policy_iter, 274 'unroll_len': len(stepdata), 275 'compressor': self._cfg.compressor, 276 'get_data_time': time.time(), 277 # TODO(nyz) the relationship between traj priority and step priority 278 'priority': 1.0, 279 'cur_episode': self._total_episode, 280 'cur_sample': self._total_sample, 281 'cur_step': self._total_step, 282 } 283 return metadata 284 285 def _get_collector_info(self) -> dict: 286 return { 287 'eval_flag': self._eval_flag, 288 'get_info_time': time.time(), 289 'collector_done': self._env_manager.done, 290 'cur_episode': self._total_episode, 291 'cur_sample': self._total_sample, 292 'cur_step': self._total_step, 293 } 294 295 def __repr__(self) -> str: 296 return "ZerglingParallelCollector"