1from typing import Any, Union, Tuple 2from abc import ABC, abstractmethod 3import sys 4from ditk import logging 5import copy 6from collections import namedtuple 7from functools import partial 8from easydict import EasyDict 9import torch 10 11from ding.policy import Policy 12from ding.envs import BaseEnvManager 13from ding.utils.autolog import LoggedValue, LoggedModel, TickTime 14from ding.utils import build_logger, EasyTimer, get_task_uid, import_module, pretty_print, PARALLEL_COLLECTOR_REGISTRY 15from ding.torch_utils import build_log_buffer, to_tensor, to_ndarray 16 17 18class BaseParallelCollector(ABC): 19 """ 20 Overview: 21 Abstract baseclass for collector. 22 Interfaces: 23 __init__, info, error, debug, get_finish_info, start, close, _setup_timer, _setup_logger, _iter_after_hook, 24 _policy_inference, _env_step, _process_timestep, _finish_task, _update_policy, _start_thread, _join_thread 25 Property: 26 policy 27 """ 28 29 @classmethod 30 def default_config(cls: type) -> EasyDict: 31 cfg = EasyDict(copy.deepcopy(cls.config)) 32 cfg.cfg_type = cls.__name__ + 'Dict' 33 return cfg 34 35 def __init__(self, cfg: EasyDict) -> None: 36 """ 37 Overview: 38 Initialization method. 39 Arguments: 40 - cfg (:obj:`EasyDict`): Config dict 41 """ 42 self._cfg = cfg 43 self._eval_flag = cfg.eval_flag 44 self._prefix = 'EVALUATOR' if self._eval_flag else 'COLLECTOR' 45 self._collector_uid = get_task_uid() 46 self._logger, self._monitor, self._log_buffer = self._setup_logger() 47 self._end_flag = False 48 self._setup_timer() 49 self._iter_count = 0 50 self.info("\nCFG INFO:\n{}".format(pretty_print(cfg, direct_print=False))) 51 52 def info(self, s: str) -> None: 53 self._logger.info("[{}({})]: {}".format(self._prefix, self._collector_uid, s)) 54 55 def debug(self, s: str) -> None: 56 self._logger.debug("[{}({})]: {}".format(self._prefix, self._collector_uid, s)) 57 58 def error(self, s: str) -> None: 59 self._logger.error("[{}({})]: {}".format(self._prefix, self._collector_uid, s)) 60 61 def _setup_timer(self) -> None: 62 """ 63 Overview: 64 Setup TimeWrapper for base_collector. TimeWrapper is a decent timer wrapper that can be used easily. 65 You can refer to ``ding/utils/time_helper.py``. 66 67 Note: 68 - _policy_inference (:obj:`Callable`): The wrapper to acquire a policy's time. 69 - _env_step (:obj:`Callable`): The wrapper to acquire a environment's time. 70 """ 71 self._timer = EasyTimer() 72 73 def policy_wrapper(fn): 74 75 def wrapper(*args, **kwargs): 76 with self._timer: 77 ret = fn(*args, **kwargs) 78 self._log_buffer['policy_time'] = self._timer.value 79 return ret 80 81 return wrapper 82 83 def env_wrapper(fn): 84 85 def wrapper(*args, **kwargs): 86 with self._timer: 87 ret = fn(*args, **kwargs) 88 size = sys.getsizeof(ret) / (1024 * 1024) # MB 89 self._log_buffer['env_time'] = self._timer.value 90 self._log_buffer['timestep_size'] = size 91 self._log_buffer['norm_env_time'] = self._timer.value / size 92 return ret 93 94 return wrapper 95 96 self._policy_inference = policy_wrapper(self._policy_inference) 97 self._env_step = env_wrapper(self._env_step) 98 99 def _setup_logger(self) -> Tuple[logging.Logger, 'TickMonitor', 'LogDict']: # noqa 100 """ 101 Overview: 102 Setup logger for base_collector. Logger includes logger, monitor and log buffer dict. 103 Returns: 104 - logger (:obj:`logging.Logger`): logger that displays terminal output 105 - monitor (:obj:`TickMonitor`): monitor that is related info of one interation with env 106 - log_buffer (:obj:`LogDict`): log buffer dict 107 """ 108 path = './{}/log/{}'.format(self._cfg.exp_name, self._prefix.lower()) 109 name = '{}'.format(self._collector_uid) 110 logger, _ = build_logger(path, name, need_tb=False) 111 monitor = TickMonitor(TickTime(), expire=self._cfg.print_freq * 2) 112 log_buffer = build_log_buffer() 113 return logger, monitor, log_buffer 114 115 def start(self) -> None: 116 self._end_flag = False 117 self._update_policy() 118 self._start_thread() 119 while not self._end_flag: 120 obs = self._env_manager.ready_obs 121 obs = to_tensor(obs, dtype=torch.float32) 122 action = self._policy_inference(obs) 123 action = to_ndarray(action) 124 timestep = self._env_step(action) 125 timestep = to_tensor(timestep, dtype=torch.float32) 126 self._process_timestep(timestep) 127 self._iter_after_hook() 128 if self._env_manager.done: 129 break 130 131 def close(self) -> None: 132 if self._end_flag: 133 return 134 self._end_flag = True 135 self._join_thread() 136 137 def _iter_after_hook(self): 138 # log_buffer -> tick_monitor -> monitor.step 139 for k, v in self._log_buffer.items(): 140 setattr(self._monitor, k, v) 141 self._monitor.time.step() 142 # Print info 143 if self._iter_count % self._cfg.print_freq == 0: 144 self.debug('{}TimeStep{}{}'.format('=' * 35, self._iter_count, '=' * 35)) 145 # tick_monitor -> var_dict 146 var_dict = {} 147 for k in self._log_buffer: 148 for attr in self._monitor.get_property_attribute(k): 149 k_attr = k + '_' + attr 150 var_dict[k_attr] = getattr(self._monitor, attr)[k]() 151 self._logger.debug(self._logger.get_tabulate_vars_hor(var_dict)) 152 self._log_buffer.clear() 153 self._iter_count += 1 154 155 @abstractmethod 156 def get_finish_info(self) -> dict: 157 raise NotImplementedError 158 159 @abstractmethod 160 def __repr__(self) -> str: 161 raise NotImplementedError 162 163 @abstractmethod 164 def _policy_inference(self, obs: Any) -> Any: 165 raise NotImplementedError 166 167 @abstractmethod 168 def _env_step(self, action: Any) -> Any: 169 raise NotImplementedError 170 171 @abstractmethod 172 def _process_timestep(self, timestep: namedtuple) -> None: 173 raise NotImplementedError 174 175 @abstractmethod 176 def _update_policy(self) -> None: 177 raise NotImplementedError 178 179 def _start_thread(self) -> None: 180 pass 181 182 def _join_thread(self) -> None: 183 pass 184 185 @property 186 def policy(self) -> Policy: 187 return self._policy 188 189 @policy.setter 190 def policy(self, _policy: Policy) -> None: 191 self._policy = _policy 192 193 @property 194 def env_manager(self) -> BaseEnvManager: 195 return self._env_manager 196 197 @env_manager.setter 198 def env_manager(self, _env_manager: BaseEnvManager) -> None: 199 self._env_manager = _env_manager 200 201 202def create_parallel_collector(cfg: EasyDict) -> BaseParallelCollector: 203 import_module(cfg.get('import_names', [])) 204 return PARALLEL_COLLECTOR_REGISTRY.build(cfg.type, cfg=cfg) 205 206 207def get_parallel_collector_cls(cfg: EasyDict) -> type: 208 import_module(cfg.get('import_names', [])) 209 return PARALLEL_COLLECTOR_REGISTRY.get(cfg.type) 210 211 212class TickMonitor(LoggedModel): 213 """ 214 Overview: 215 TickMonitor is to monitor related info of one interation with env. 216 Info include: policy_time, env_time, norm_env_time, timestep_size... 217 These info variables would first be recorded in ``log_buffer``, then in ``self._iter_after_hook`` will vars in 218 in this monitor be updated by``log_buffer``, then printed to text logger and tensorboard logger. 219 Interface: 220 __init__, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__ 221 Property: 222 time, expire 223 """ 224 policy_time = LoggedValue(float) 225 env_time = LoggedValue(float) 226 timestep_size = LoggedValue(float) 227 norm_env_time = LoggedValue(float) 228 229 def __init__(self, time_: 'BaseTime', expire: Union[int, float]): # noqa 230 LoggedModel.__init__(self, time_, expire) 231 self.__register() 232 233 def __register(self): 234 235 def __avg_func(prop_name: str) -> float: 236 records = self.range_values[prop_name]() 237 _list = [_value for (_begin_time, _end_time), _value in records] 238 return sum(_list) / len(_list) if len(_list) != 0 else 0 239 240 self.register_attribute_value('avg', 'policy_time', partial(__avg_func, prop_name='policy_time')) 241 self.register_attribute_value('avg', 'env_time', partial(__avg_func, prop_name='env_time')) 242 self.register_attribute_value('avg', 'timestep_size', partial(__avg_func, prop_name='timestep_size')) 243 self.register_attribute_value('avg', 'norm_env_time', partial(__avg_func, prop_name='norm_env_time'))