1import numbers 2import os 3from abc import ABC, abstractmethod 4from typing import Any, Dict, List 5import torch 6from easydict import EasyDict 7 8import ding 9from ding.utils import allreduce, read_file, save_file 10 11 12class Hook(ABC): 13 """ 14 Overview: 15 Abstract class for hooks. 16 Interfaces: 17 __init__, __call__ 18 Property: 19 name, priority 20 """ 21 22 def __init__(self, name: str, priority: float, **kwargs) -> None: 23 """ 24 Overview: 25 Init method for hooks. Set name and priority. 26 Arguments: 27 - name (:obj:`str`): The name of hook 28 - priority (:obj:`float`): The priority used in ``call_hook``'s calling sequence. \ 29 Lower value means higher priority. 30 """ 31 self._name = name 32 assert priority >= 0, "invalid priority value: {}".format(priority) 33 self._priority = priority 34 35 @property 36 def name(self) -> str: 37 return self._name 38 39 @property 40 def priority(self) -> float: 41 return self._priority 42 43 @abstractmethod 44 def __call__(self, engine: Any) -> Any: 45 """ 46 Overview: 47 Should be overwritten by subclass. 48 Arguments: 49 - engine (:obj:`Any`): For LearnerHook, it should be ``BaseLearner`` or its subclass. 50 """ 51 raise NotImplementedError 52 53 54class LearnerHook(Hook): 55 """ 56 Overview: 57 Abstract class for hooks used in Learner. 58 Interfaces: 59 __init__ 60 Property: 61 name, priority, position 62 63 .. note:: 64 65 Subclass should implement ``self.__call__``. 66 """ 67 positions = ['before_run', 'after_run', 'before_iter', 'after_iter'] 68 69 def __init__(self, *args, position: str, **kwargs) -> None: 70 """ 71 Overview: 72 Init LearnerHook. 73 Arguments: 74 - position (:obj:`str`): The position to call hook in learner. \ 75 Must be in ['before_run', 'after_run', 'before_iter', 'after_iter']. 76 """ 77 super().__init__(*args, **kwargs) 78 assert position in self.positions 79 self._position = position 80 81 @property 82 def position(self) -> str: 83 return self._position 84 85 86class LoadCkptHook(LearnerHook): 87 """ 88 Overview: 89 Hook to load checkpoint 90 Interfaces: 91 __init__, __call__ 92 Property: 93 name, priority, position 94 """ 95 96 def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None: 97 """ 98 Overview: 99 Init LoadCkptHook. 100 Arguments: 101 - ext_args (:obj:`EasyDict`): Extended arguments. Use ``ext_args.freq`` to set ``load_ckpt_freq``. 102 """ 103 super().__init__(*args, **kwargs) 104 self._load_path = ext_args['load_path'] 105 106 def __call__(self, engine: 'BaseLearner') -> None: # noqa 107 """ 108 Overview: 109 Load checkpoint to learner. Checkpoint info includes policy state_dict and iter num. 110 Arguments: 111 - engine (:obj:`BaseLearner`): The BaseLearner to load checkpoint to. 112 """ 113 path = self._load_path 114 if path == '': # not load 115 return 116 state_dict = read_file(path) 117 if 'last_iter' in state_dict: 118 last_iter = state_dict.pop('last_iter') 119 engine.last_iter.update(last_iter) 120 if 'last_step' in state_dict: 121 last_step = state_dict.pop('last_step') 122 engine._collector_envstep = last_step 123 engine.policy.load_state_dict(state_dict) 124 engine.info('{} load ckpt in {}'.format(engine.instance_name, path)) 125 126 127class SaveCkptHook(LearnerHook): 128 """ 129 Overview: 130 Hook to save checkpoint 131 Interfaces: 132 __init__, __call__ 133 Property: 134 name, priority, position 135 """ 136 137 def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None: 138 """ 139 Overview: 140 init SaveCkptHook 141 Arguments: 142 - ext_args (:obj:`EasyDict`): extended_args, use ext_args.freq to set save_ckpt_freq 143 """ 144 super().__init__(*args, **kwargs) 145 if ext_args == {}: 146 self._freq = 1 147 else: 148 self._freq = ext_args.freq 149 150 def __call__(self, engine: 'BaseLearner') -> None: # noqa 151 """ 152 Overview: 153 Save checkpoint in corresponding path. 154 Checkpoint info includes policy state_dict and iter num. 155 Arguments: 156 - engine (:obj:`BaseLearner`): the BaseLearner which needs to save checkpoint 157 """ 158 if engine.rank == 0 and engine.last_iter.val % self._freq == 0: 159 if engine.instance_name == 'learner': 160 dirname = './{}/ckpt'.format(engine.exp_name) 161 else: 162 dirname = './{}/ckpt_{}'.format(engine.exp_name, engine.instance_name) 163 if not os.path.exists(dirname): 164 try: 165 os.makedirs(dirname) 166 except FileExistsError: 167 pass 168 ckpt_name = engine.ckpt_name if engine.ckpt_name else 'iteration_{}.pth.tar'.format(engine.last_iter.val) 169 path = os.path.join(dirname, ckpt_name) 170 state_dict = engine.policy.state_dict() 171 state_dict.update({'last_iter': engine.last_iter.val}) 172 state_dict.update({'last_step': engine.collector_envstep}) 173 save_file(path, state_dict) 174 engine.info('{} save ckpt in {}'.format(engine.instance_name, path)) 175 176 177class LogShowHook(LearnerHook): 178 """ 179 Overview: 180 Hook to show log 181 Interfaces: 182 __init__, __call__ 183 Property: 184 name, priority, position 185 """ 186 187 def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None: 188 """ 189 Overview: 190 Init LogShowHook. 191 Arguments: 192 - ext_args (:obj:`EasyDict`): Extended arguments, use ext_args.freq to set frequency and \ 193 ext_args.only_monitor_rank0 to control if only rank 0 should monitor, default is True. 194 """ 195 super().__init__(*args, **kwargs) 196 if ext_args == {}: 197 self._freq = 1 198 else: 199 self._freq = ext_args.freq 200 self._only_monitor_rank0 = None 201 202 def __call__(self, engine: 'BaseLearner') -> None: # noqa 203 """ 204 Overview: 205 Show log, update record and tb_logger if rank is 0 and at interval iterations, 206 clear the log buffer for all learners regardless of rank. 207 Arguments: 208 - engine (:obj:`BaseLearner`): The BaseLearner. 209 """ 210 self._only_monitor_rank0 = engine.only_monitor_rank0 211 # Only show log for rank 0 learner if _only_monitor_rank0 is True 212 if engine.rank != 0 and self._only_monitor_rank0: 213 for k in engine.log_buffer: 214 engine.log_buffer[k].clear() 215 return 216 217 # For 'scalar' type variables: log_buffer -> tick_monitor -> monitor_time.step 218 for k, v in engine.log_buffer['scalar'].items(): 219 setattr(engine.monitor, k, v) 220 engine.monitor.time.step() 221 222 iters = engine.last_iter.val 223 if iters % self._freq == 0: 224 engine.info("=== Training Iteration {} Result ===".format(iters)) 225 # For 'scalar' type variables: tick_monitor -> var_dict -> text_logger & tb_logger 226 var_dict = {} 227 log_vars = engine.policy.monitor_vars() 228 attr = 'avg' 229 for k in log_vars: 230 k_attr = k + '_' + attr 231 var_dict[k_attr] = getattr(engine.monitor, attr)[k]() 232 engine.logger.info(engine.logger.get_tabulate_vars_hor(var_dict)) 233 for k, v in var_dict.items(): 234 engine.tb_logger.add_scalar('{}_iter/'.format(engine.instance_name) + k, v, iters) 235 engine.tb_logger.add_scalar('{}_step/'.format(engine.instance_name) + k, v, engine._collector_envstep) 236 # For 'histogram' type variables: log_buffer -> tb_var_dict -> tb_logger 237 tb_var_dict = {} 238 for k in engine.log_buffer['histogram']: 239 new_k = '{}/'.format(engine.instance_name) + k 240 tb_var_dict[new_k] = engine.log_buffer['histogram'][k] 241 for k, v in tb_var_dict.items(): 242 engine.tb_logger.add_histogram(k, v, iters) 243 for k in engine.log_buffer: 244 engine.log_buffer[k].clear() 245 246 247class LogReduceHook(LearnerHook): 248 """ 249 Overview: 250 Hook to reduce the distributed (multi-gpu) logs. 251 Interfaces: 252 __init__, __call__ 253 Property: 254 name, priority, position 255 """ 256 257 def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None: 258 """ 259 Overview: 260 Initialize LogReduceHook. 261 Arguments: 262 - ext_args (:obj:`EasyDict`): Extended arguments, use ext_args.freq to set log_reduce_freq. 263 """ 264 super().__init__(*args, **kwargs) 265 266 def __call__(self, engine: 'BaseLearner') -> None: # noqa 267 """ 268 Overview: 269 Reduce the logs from distributed (multi-gpu) learners. 270 Arguments: 271 - engine (:obj:`BaseLearner`): The BaseLearner. 272 """ 273 274 def aggregate(data): 275 r""" 276 Overview: 277 Aggregate the information from all ranks (usually using sync allreduce). 278 Arguments: 279 - data (:obj:`dict`): Data that needs to be reduced. \ 280 Could be dict, torch.Tensor, numbers.Integral, or numbers.Real. 281 Returns: 282 - new_data (:obj:`dict`): Data after reduction. 283 """ 284 285 def should_reduce(key): 286 # Check if the key starts with the "noreduce_" prefix. 287 # The "noreduce_" prefix is used in the unizero_multitask ddp pipeline 288 # to indicate data that should not be reduced. 289 return not key.startswith("noreduce_") 290 291 cuda_device = torch.cuda.current_device() 292 293 if isinstance(data, dict): 294 new_data = {} 295 for k, v in data.items(): 296 if should_reduce(k): 297 new_data[k] = aggregate(v) # Perform allreduce on data that needs reduction. 298 else: 299 new_data[k] = v # Retain data that does not need reduction. 300 301 elif isinstance(data, list) or isinstance(data, tuple): 302 new_data = [aggregate(t) for t in data] 303 elif isinstance(data, torch.Tensor): 304 new_data = data.clone().detach() 305 if ding.enable_linklink: 306 allreduce(new_data) 307 else: 308 new_data = new_data.to(cuda_device) 309 allreduce(new_data) 310 new_data = new_data.cpu() 311 elif isinstance(data, numbers.Integral) or isinstance(data, numbers.Real): 312 new_data = torch.scalar_tensor(data).reshape([1]) 313 if ding.enable_linklink: 314 allreduce(new_data) 315 else: 316 new_data = new_data.to(cuda_device) 317 allreduce(new_data) 318 new_data = new_data.cpu() 319 new_data = new_data.item() 320 else: 321 raise TypeError("Invalid type in reduce: {}".format(type(data))) 322 return new_data 323 324 engine.log_buffer = aggregate(engine.log_buffer) 325 326 327hook_mapping = { 328 'load_ckpt': LoadCkptHook, 329 'save_ckpt': SaveCkptHook, 330 'log_show': LogShowHook, 331 'log_reduce': LogReduceHook, 332} 333 334 335def register_learner_hook(name: str, hook_type: type) -> None: 336 """ 337 Overview: 338 Add a new LearnerHook class to hook_mapping, so you can build one instance with `build_learner_hook_by_cfg`. 339 Arguments: 340 - name (:obj:`str`): name of the register hook 341 - hook_type (:obj:`type`): the register hook_type you implemented that realize LearnerHook 342 Examples: 343 >>> class HookToRegister(LearnerHook): 344 >>> def __init__(*args, **kargs): 345 >>> ... 346 >>> ... 347 >>> def __call__(*args, **kargs): 348 >>> ... 349 >>> ... 350 >>> ... 351 >>> register_learner_hook('name_of_hook', HookToRegister) 352 >>> ... 353 >>> hooks = build_learner_hook_by_cfg(cfg) 354 """ 355 assert issubclass(hook_type, LearnerHook) 356 hook_mapping[name] = hook_type 357 358 359simplified_hook_mapping = { 360 'log_show_after_iter': lambda freq: hook_mapping['log_show'] 361 ('log_show', 20, position='after_iter', ext_args=EasyDict({'freq': freq})), 362 'load_ckpt_before_run': lambda path: hook_mapping['load_ckpt'] 363 ('load_ckpt', 20, position='before_run', ext_args=EasyDict({'load_path': path})), 364 'save_ckpt_after_iter': lambda freq: hook_mapping['save_ckpt'] 365 ('save_ckpt_after_iter', 20, position='after_iter', ext_args=EasyDict({'freq': freq})), 366 'save_ckpt_after_run': lambda _: hook_mapping['save_ckpt']('save_ckpt_after_run', 20, position='after_run'), 367 'log_reduce_after_iter': lambda _: hook_mapping['log_reduce']('log_reduce_after_iter', 10, position='after_iter'), 368} 369 370 371def find_char(s: str, flag: str, num: int, reverse: bool = False) -> int: 372 assert num > 0, num 373 count = 0 374 iterable_obj = reversed(range(len(s))) if reverse else range(len(s)) 375 for i in iterable_obj: 376 if s[i] == flag: 377 count += 1 378 if count == num: 379 return i 380 return -1 381 382 383def build_learner_hook_by_cfg(cfg: EasyDict) -> Dict[str, List[Hook]]: 384 """ 385 Overview: 386 Build the learner hooks in hook_mapping by config. 387 This function is often used to initialize ``hooks`` according to cfg, 388 while add_learner_hook() is often used to add an existing LearnerHook to `hooks`. 389 Arguments: 390 - cfg (:obj:`EasyDict`): Config dict. Should be like {'hook': xxx}. 391 Returns: 392 - hooks (:obj:`Dict[str, List[Hook]`): Keys should be in ['before_run', 'after_run', 'before_iter', \ 393 'after_iter'], each value should be a list containing all hooks in this position. 394 Note: 395 Lower value means higher priority. 396 """ 397 hooks = {k: [] for k in LearnerHook.positions} 398 for key, value in cfg.items(): 399 if key in simplified_hook_mapping and not isinstance(value, dict): 400 pos = key[find_char(key, '_', 2, reverse=True) + 1:] 401 hook = simplified_hook_mapping[key](value) 402 priority = hook.priority 403 else: 404 priority = value.get('priority', 100) 405 pos = value.position 406 ext_args = value.get('ext_args', {}) 407 hook = hook_mapping[value.type](value.name, priority, position=pos, ext_args=ext_args) 408 idx = 0 409 for i in reversed(range(len(hooks[pos]))): 410 if priority >= hooks[pos][i].priority: 411 idx = i + 1 412 break 413 hooks[pos].insert(idx, hook) 414 return hooks 415 416 417def add_learner_hook(hooks: Dict[str, List[Hook]], hook: LearnerHook) -> None: 418 """ 419 Overview: 420 Add a learner hook(:obj:`LearnerHook`) to hooks(:obj:`Dict[str, List[Hook]`) 421 Arguments: 422 - hooks (:obj:`Dict[str, List[Hook]`): You can refer to ``build_learner_hook_by_cfg``'s return ``hooks``. 423 - hook (:obj:`LearnerHook`): The LearnerHook which will be added to ``hooks``. 424 """ 425 position = hook.position 426 priority = hook.priority 427 idx = 0 428 for i in reversed(range(len(hooks[position]))): 429 if priority >= hooks[position][i].priority: 430 idx = i + 1 431 break 432 assert isinstance(hook, LearnerHook) 433 hooks[position].insert(idx, hook) 434 435 436def merge_hooks(hooks1: Dict[str, List[Hook]], hooks2: Dict[str, List[Hook]]) -> Dict[str, List[Hook]]: 437 """ 438 Overview: 439 Merge two hooks dict, which have the same keys, and each value is sorted by hook priority with stable method. 440 Arguments: 441 - hooks1 (:obj:`Dict[str, List[Hook]`): hooks1 to be merged. 442 - hooks2 (:obj:`Dict[str, List[Hook]`): hooks2 to be merged. 443 Returns: 444 - new_hooks (:obj:`Dict[str, List[Hook]`): New merged hooks dict. 445 Note: 446 This merge function uses stable sort method without disturbing the same priority hook. 447 """ 448 assert set(hooks1.keys()) == set(hooks2.keys()) 449 new_hooks = {} 450 for k in hooks1.keys(): 451 new_hooks[k] = sorted(hooks1[k] + hooks2[k], key=lambda x: x.priority) 452 return new_hooks 453 454 455def show_hooks(hooks: Dict[str, List[Hook]]) -> None: 456 for k in hooks.keys(): 457 print('{}: {}'.format(k, [x.__class__.__name__ for x in hooks[k]]))