Skip to content

ding.worker.learner.base_learner

ding.worker.learner.base_learner

BaseLearner

Bases: object

Overview

Base class for policy learning.

Interface: train, call_hook, register_hook, save_checkpoint, start, setup_dataloader, close Property: learn_info, priority_info, last_iter, train_iter, rank, world_size, policy monitor, log_buffer, logger, tb_logger, ckpt_name, exp_name, instance_name

collector_envstep property writable

Overview

Get current collector envstep.

Returns: - collector_envstep (:obj:int): Current collector envstep.

learn_info property

Overview

Get current info dict, which will be sent to commander, e.g. replay buffer priority update, current iteration, hyper-parameter adjustment, whether task is finished, etc.

Returns: - info (:obj:dict): Current learner info dict.

__init__(cfg, policy=None, tb_logger=None, dist_info=None, exp_name='default_experiment', instance_name='learner')

Overview

Initialization method, build common learner components according to cfg, such as hook, wrapper and so on.

Arguments: - cfg (:obj:EasyDict): Learner config, you can refer cls.config for details. It should include is_multitask_pipeline to indicate if the pipeline is multitask, default is False, and only_monitor_rank0 to control whether only rank 0 needs monitor and tb_logger, default is True. - policy (:obj:namedtuple): A collection of policy function of learn mode. And policy can also be initialized when runtime. - tb_logger (:obj:SummaryWriter): Tensorboard summary writer. - dist_info (:obj:Tuple[int, int]): Multi-GPU distributed training information. - exp_name (:obj:str): Experiment name, which is used to indicate output directory. - instance_name (:obj:str): Instance name, which should be unique among different learners. Notes: If you want to debug in sync CUDA mode, please add the following code at the beginning of __init__.

.. code:: python

    os.environ['CUDA_LAUNCH_BLOCKING'] = "1"  # for debug async CUDA

register_hook(hook)

Overview

Add a new learner hook.

Arguments: - hook (:obj:LearnerHook): The hook to be addedr.

train(data, envstep=-1, policy_kwargs=None)

Overview

Given training data, implement network update for one iteration and update related variables. Learner's API for serial entry. Also called in start for each iteration's training.

Arguments: - data (:obj:dict): Training data which is retrieved from repaly buffer.

.. note::

``_policy`` must be set before calling this method.

``_policy.forward`` method contains: forward, backward, grad sync(if in multi-gpu mode) and
parameter update.

``before_iter`` and ``after_iter`` hooks are called at the beginning and ending.

start()

Overview

[Only Used In Parallel Mode] Learner's API for parallel entry. For each iteration, learner will get data through _next_data and call train to train.

.. note::

``before_run`` and ``after_run`` hooks are called at the beginning and ending.

setup_dataloader()

Overview

[Only Used In Parallel Mode] Setup learner's dataloader.

.. note::

Only in parallel mode will we use attributes ``get_data`` and ``_dataloader`` to get data from file system;
Instead, in serial version, we can fetch data from memory directly.

In parallel mode, ``get_data`` is set by ``LearnerCommHelper``, and should be callable.
Users don't need to know the related details if not necessary.

close()

Overview

[Only Used In Parallel Mode] Close the related resources, e.g. dataloader, tensorboard logger, etc.

call_hook(name)

Overview

Call the corresponding hook plugins according to position name.

Arguments: - name (:obj:str): Hooks in which position to call, should be in ['before_run', 'after_run', 'before_iter', 'after_iter'].

info(s)

Overview

Log string info by self._logger.info.

Arguments: - s (:obj:str): The message to add into the logger.

save_checkpoint(ckpt_name=None)

Overview

Directly call save_ckpt_after_run hook to save checkpoint.

Note: Must guarantee that "save_ckpt_after_run" is registered in "after_run" hook. This method is called in:

    - ``auto_checkpoint`` (``torch_utils/checkpoint_helper.py``), which is designed for                     saving checkpoint whenever an exception raises.
    - ``serial_pipeline`` (``entry/serial_entry.py``). Used to save checkpoint when reaching                     new highest episode return.

TickMonitor

Bases: LoggedModel

Overview

TickMonitor is to monitor related info during training. Info includes: cur_lr, time(data, train, forward, backward), loss(total,...) These info variables are firstly recorded in log_buffer, then in LearnerHook will vars in in this monitor be updated bylog_buffer, finally printed to text logger and tensorboard logger.

Interface: init, fixed_time, current_time, freeze, unfreeze, register_attribute_value, getattr Property: time, expire

create_learner(cfg, **kwargs)

Overview

Given the key(learner_name), create a new learner instance if in learner_mapping's values, or raise an KeyError. In other words, a derived learner must first register, then can call create_learner to get the instance.

Arguments: - cfg (:obj:EasyDict): Learner config. Necessary keys: [learner.import_module, learner.learner_type]. Returns: - learner (:obj:BaseLearner): The created new learner, should be an instance of one of learner_mapping's values.

get_simple_monitor_type(properties=[])

Overview

Besides basic training variables provided in TickMonitor, many policies have their own customized ones to record and monitor. This function can return a customized tick monitor. Compared with TickMonitor, SimpleTickMonitor can record extra properties passed in by a policy.

Argumenst: - properties (:obj:List[str]): Customized properties to monitor. Returns: - simple_tick_monitor (:obj:SimpleTickMonitor): A simple customized tick monitor.

Full Source Code

../ding/worker/learner/base_learner.py

1from typing import Any, Union, Callable, List, Dict, Optional, Tuple 2from ditk import logging 3from collections import namedtuple 4from functools import partial 5from easydict import EasyDict 6 7import copy 8 9from ding.torch_utils import CountVar, auto_checkpoint, build_log_buffer 10from ding.utils import build_logger, EasyTimer, import_module, LEARNER_REGISTRY, get_rank, get_world_size 11from ding.utils.autolog import LoggedValue, LoggedModel, TickTime 12from ding.utils.data import AsyncDataLoader 13from .learner_hook import build_learner_hook_by_cfg, add_learner_hook, merge_hooks, LearnerHook 14 15 16@LEARNER_REGISTRY.register('base') 17class BaseLearner(object): 18 r""" 19 Overview: 20 Base class for policy learning. 21 Interface: 22 train, call_hook, register_hook, save_checkpoint, start, setup_dataloader, close 23 Property: 24 learn_info, priority_info, last_iter, train_iter, rank, world_size, policy 25 monitor, log_buffer, logger, tb_logger, ckpt_name, exp_name, instance_name 26 """ 27 28 @classmethod 29 def default_config(cls: type) -> EasyDict: 30 cfg = EasyDict(copy.deepcopy(cls.config)) 31 cfg.cfg_type = cls.__name__ + 'Dict' 32 return cfg 33 34 config = dict( 35 train_iterations=int(1e9), 36 dataloader=dict(num_workers=0, ), 37 log_policy=True, 38 is_multitask_pipeline=False, 39 only_monitor_rank0=True, 40 # --- Hooks --- 41 hook=dict( 42 load_ckpt_before_run='', 43 log_show_after_iter=100, 44 save_ckpt_after_iter=10000, 45 save_ckpt_after_run=True, 46 ), 47 ) 48 49 _name = "BaseLearner" # override this variable for sub-class learner 50 51 def __init__( 52 self, 53 cfg: EasyDict, 54 policy: namedtuple = None, 55 tb_logger: Optional['SummaryWriter'] = None, # noqa 56 dist_info: Tuple[int, int] = None, 57 exp_name: Optional[str] = 'default_experiment', 58 instance_name: Optional[str] = 'learner', 59 ) -> None: 60 """ 61 Overview: 62 Initialization method, build common learner components according to cfg, such as hook, wrapper and so on. 63 Arguments: 64 - cfg (:obj:`EasyDict`): Learner config, you can refer cls.config for details. It should include \ 65 `is_multitask_pipeline` to indicate if the pipeline is multitask, default is False, \ 66 and `only_monitor_rank0` to control whether only rank 0 needs monitor and tb_logger, default is True. 67 - policy (:obj:`namedtuple`): A collection of policy function of learn mode. And policy can also be \ 68 initialized when runtime. 69 - tb_logger (:obj:`SummaryWriter`): Tensorboard summary writer. 70 - dist_info (:obj:`Tuple[int, int]`): Multi-GPU distributed training information. 71 - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. 72 - instance_name (:obj:`str`): Instance name, which should be unique among different learners. 73 Notes: 74 If you want to debug in sync CUDA mode, please add the following code at the beginning of ``__init__``. 75 76 .. code:: python 77 78 os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debug async CUDA 79 """ 80 self._cfg = cfg 81 self._exp_name = exp_name 82 self._instance_name = instance_name 83 self._ckpt_name = None 84 self._timer = EasyTimer() 85 self._is_multitask_pipeline = self._cfg.is_multitask_pipeline 86 self.only_monitor_rank0 = self._cfg.only_monitor_rank0 87 88 # Adjust only_monitor_rank0 based on is_multitask_pipeline 89 if self._is_multitask_pipeline: 90 self.only_monitor_rank0 = False 91 92 # These 2 attributes are only used in parallel mode. 93 self._end_flag = False 94 self._learner_done = False 95 if dist_info is None: 96 self._rank = get_rank() 97 self._world_size = get_world_size() 98 else: 99 # Learner rank. Used to discriminate which GPU it uses. 100 self._rank, self._world_size = dist_info 101 if self._world_size > 1: 102 self._cfg.hook.log_reduce_after_iter = True 103 104 # Logger (Monitor will be initialized in policy setter) 105 # In the multitask pipeline, each rank needs its own tb_logger. 106 # Otherwise, only rank == 0 learner needs monitor and tb_logger, 107 # others only need text_logger to display terminal output. 108 if self._rank == 0 or not self.only_monitor_rank0: 109 if tb_logger is not None: 110 self._logger, _ = build_logger( 111 './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False 112 ) 113 self._tb_logger = tb_logger 114 else: 115 self._logger, self._tb_logger = build_logger( 116 './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name 117 ) 118 else: 119 self._logger, _ = build_logger( 120 './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False 121 ) 122 self._tb_logger = None 123 124 self._log_buffer = { 125 'scalar': build_log_buffer(), 126 'scalars': build_log_buffer(), 127 'histogram': build_log_buffer(), 128 } 129 130 # Setup policy 131 if policy is not None: 132 self.policy = policy 133 134 # Learner hooks. Used to do specific things at specific time point. Will be set in ``_setup_hook`` 135 self._hooks = {'before_run': [], 'before_iter': [], 'after_iter': [], 'after_run': []} 136 # Last iteration. Used to record current iter. 137 self._last_iter = CountVar(init_val=0) 138 # Collector envstep. Used to record current envstep. 139 self._collector_envstep = 0 140 141 # Setup time wrapper and hook. 142 self._setup_wrapper() 143 self._setup_hook() 144 145 def _setup_hook(self) -> None: 146 """ 147 Overview: 148 Setup hook for base_learner. Hook is the way to implement some functions at specific time point 149 in base_learner. You can refer to ``learner_hook.py``. 150 """ 151 if hasattr(self, '_hooks'): 152 self._hooks = merge_hooks(self._hooks, build_learner_hook_by_cfg(self._cfg.hook)) 153 else: 154 self._hooks = build_learner_hook_by_cfg(self._cfg.hook) 155 156 def _setup_wrapper(self) -> None: 157 """ 158 Overview: 159 Use ``_time_wrapper`` to get ``train_time``. 160 Note: 161 ``data_time`` is wrapped in ``setup_dataloader``. 162 """ 163 self._wrapper_timer = EasyTimer() 164 self.train = self._time_wrapper(self.train, 'scalar', 'train_time') 165 166 def _time_wrapper(self, fn: Callable, var_type: str, var_name: str) -> Callable: 167 """ 168 Overview: 169 Wrap a function and record the time it used in ``_log_buffer``. 170 Arguments: 171 - fn (:obj:`Callable`): Function to be time_wrapped. 172 - var_type (:obj:`str`): Variable type, e.g. ['scalar', 'scalars', 'histogram']. 173 - var_name (:obj:`str`): Variable name, e.g. ['cur_lr', 'total_loss']. 174 Returns: 175 - wrapper (:obj:`Callable`): The wrapper to acquire a function's time. 176 """ 177 178 def wrapper(*args, **kwargs) -> Any: 179 with self._wrapper_timer: 180 ret = fn(*args, **kwargs) 181 self._log_buffer[var_type][var_name] = self._wrapper_timer.value 182 return ret 183 184 return wrapper 185 186 def register_hook(self, hook: LearnerHook) -> None: 187 """ 188 Overview: 189 Add a new learner hook. 190 Arguments: 191 - hook (:obj:`LearnerHook`): The hook to be addedr. 192 """ 193 add_learner_hook(self._hooks, hook) 194 195 @property 196 def collector_envstep(self) -> int: 197 """ 198 Overview: 199 Get current collector envstep. 200 Returns: 201 - collector_envstep (:obj:`int`): Current collector envstep. 202 """ 203 return self._collector_envstep 204 205 @collector_envstep.setter 206 def collector_envstep(self, value: int) -> None: 207 """ 208 Overview: 209 Set current collector envstep. 210 Arguments: 211 - value (:obj:`int`): Current collector envstep. 212 """ 213 self._collector_envstep = value 214 215 def train(self, data: dict, envstep: int = -1, policy_kwargs: Optional[dict] = None) -> None: 216 """ 217 Overview: 218 Given training data, implement network update for one iteration and update related variables. 219 Learner's API for serial entry. 220 Also called in ``start`` for each iteration's training. 221 Arguments: 222 - data (:obj:`dict`): Training data which is retrieved from repaly buffer. 223 224 .. note:: 225 226 ``_policy`` must be set before calling this method. 227 228 ``_policy.forward`` method contains: forward, backward, grad sync(if in multi-gpu mode) and 229 parameter update. 230 231 ``before_iter`` and ``after_iter`` hooks are called at the beginning and ending. 232 """ 233 assert hasattr(self, '_policy'), "please set learner policy" 234 self.call_hook('before_iter') 235 236 if policy_kwargs is None: 237 policy_kwargs = {} 238 239 # Forward 240 log_vars = self._policy.forward(data, **policy_kwargs) 241 242 # Update replay buffer's priority info 243 if isinstance(log_vars, dict): 244 priority = log_vars.pop('priority', None) 245 elif isinstance(log_vars, list): 246 priority = log_vars[-1].pop('priority', None) 247 else: 248 raise TypeError("not support type for log_vars: {}".format(type(log_vars))) 249 if priority is not None: 250 replay_buffer_idx = [d.get('replay_buffer_idx', None) for d in data] 251 replay_unique_id = [d.get('replay_unique_id', None) for d in data] 252 self.priority_info = { 253 'priority': priority, 254 'replay_buffer_idx': replay_buffer_idx, 255 'replay_unique_id': replay_unique_id, 256 } 257 # Discriminate vars in scalar, scalars and histogram type 258 # Regard a var as scalar type by default. For scalars and histogram type, must annotate by prefix "[xxx]" 259 self._collector_envstep = envstep 260 if isinstance(log_vars, dict): 261 log_vars = [log_vars] 262 for elem in log_vars: 263 scalars_vars, histogram_vars = {}, {} 264 for k in list(elem.keys()): 265 if "[scalars]" in k: 266 new_k = k.split(']')[-1] 267 scalars_vars[new_k] = elem.pop(k) 268 elif "[histogram]" in k: 269 new_k = k.split(']')[-1] 270 histogram_vars[new_k] = elem.pop(k) 271 # Update log_buffer 272 self._log_buffer['scalar'].update(elem) 273 self._log_buffer['scalars'].update(scalars_vars) 274 self._log_buffer['histogram'].update(histogram_vars) 275 276 self.call_hook('after_iter') 277 self._last_iter.add(1) 278 279 return log_vars 280 281 @auto_checkpoint 282 def start(self) -> None: 283 """ 284 Overview: 285 [Only Used In Parallel Mode] Learner's API for parallel entry. 286 For each iteration, learner will get data through ``_next_data`` and call ``train`` to train. 287 288 .. note:: 289 290 ``before_run`` and ``after_run`` hooks are called at the beginning and ending. 291 """ 292 self._end_flag = False 293 self._learner_done = False 294 # before run hook 295 self.call_hook('before_run') 296 297 for i in range(self._cfg.train_iterations): 298 data = self._next_data() 299 if self._end_flag: 300 break 301 self.train(data) 302 303 self._learner_done = True 304 # after run hook 305 self.call_hook('after_run') 306 307 def setup_dataloader(self) -> None: 308 """ 309 Overview: 310 [Only Used In Parallel Mode] Setup learner's dataloader. 311 312 .. note:: 313 314 Only in parallel mode will we use attributes ``get_data`` and ``_dataloader`` to get data from file system; 315 Instead, in serial version, we can fetch data from memory directly. 316 317 In parallel mode, ``get_data`` is set by ``LearnerCommHelper``, and should be callable. 318 Users don't need to know the related details if not necessary. 319 """ 320 cfg = self._cfg.dataloader 321 batch_size = self._policy.get_attribute('batch_size') 322 device = self._policy.get_attribute('device') 323 chunk_size = cfg.chunk_size if 'chunk_size' in cfg else batch_size 324 self._dataloader = AsyncDataLoader( 325 self.get_data, batch_size, device, chunk_size, collate_fn=lambda x: x, num_workers=cfg.num_workers 326 ) 327 self._next_data = self._time_wrapper(self._next_data, 'scalar', 'data_time') 328 329 def _next_data(self) -> Any: 330 """ 331 Overview: 332 [Only Used In Parallel Mode] Call ``_dataloader``'s ``__next__`` method to return next training data. 333 Returns: 334 - data (:obj:`Any`): Next training data from dataloader. 335 """ 336 return next(self._dataloader) 337 338 def close(self) -> None: 339 """ 340 Overview: 341 [Only Used In Parallel Mode] Close the related resources, e.g. dataloader, tensorboard logger, etc. 342 """ 343 if self._end_flag: 344 return 345 self._end_flag = True 346 if hasattr(self, '_dataloader'): 347 self._dataloader.close() 348 if self._tb_logger: 349 self._tb_logger.flush() 350 self._tb_logger.close() 351 352 def __del__(self) -> None: 353 self.close() 354 355 def call_hook(self, name: str) -> None: 356 """ 357 Overview: 358 Call the corresponding hook plugins according to position name. 359 Arguments: 360 - name (:obj:`str`): Hooks in which position to call, \ 361 should be in ['before_run', 'after_run', 'before_iter', 'after_iter']. 362 """ 363 for hook in self._hooks[name]: 364 hook(self) 365 366 def info(self, s: str) -> None: 367 """ 368 Overview: 369 Log string info by ``self._logger.info``. 370 Arguments: 371 - s (:obj:`str`): The message to add into the logger. 372 """ 373 self._logger.info('[RANK{}]: {}'.format(self._rank, s)) 374 375 def debug(self, s: str) -> None: 376 self._logger.debug('[RANK{}]: {}'.format(self._rank, s)) 377 378 def save_checkpoint(self, ckpt_name: str = None) -> None: 379 """ 380 Overview: 381 Directly call ``save_ckpt_after_run`` hook to save checkpoint. 382 Note: 383 Must guarantee that "save_ckpt_after_run" is registered in "after_run" hook. 384 This method is called in: 385 386 - ``auto_checkpoint`` (``torch_utils/checkpoint_helper.py``), which is designed for \ 387 saving checkpoint whenever an exception raises. 388 - ``serial_pipeline`` (``entry/serial_entry.py``). Used to save checkpoint when reaching \ 389 new highest episode return. 390 """ 391 if ckpt_name is not None: 392 self.ckpt_name = ckpt_name 393 names = [h.name for h in self._hooks['after_run']] 394 assert 'save_ckpt_after_run' in names 395 idx = names.index('save_ckpt_after_run') 396 self._hooks['after_run'][idx](self) 397 self.ckpt_name = None 398 399 @property 400 def learn_info(self) -> dict: 401 """ 402 Overview: 403 Get current info dict, which will be sent to commander, e.g. replay buffer priority update, 404 current iteration, hyper-parameter adjustment, whether task is finished, etc. 405 Returns: 406 - info (:obj:`dict`): Current learner info dict. 407 """ 408 ret = { 409 'learner_step': self._last_iter.val, 410 'priority_info': self.priority_info, 411 'learner_done': self._learner_done, 412 } 413 return ret 414 415 @property 416 def last_iter(self) -> CountVar: 417 return self._last_iter 418 419 @property 420 def train_iter(self) -> int: 421 return self._last_iter.val 422 423 @property 424 def monitor(self) -> 'TickMonitor': # noqa 425 return self._monitor 426 427 @property 428 def log_buffer(self) -> dict: # LogDict 429 return self._log_buffer 430 431 @log_buffer.setter 432 def log_buffer(self, _log_buffer: Dict[str, Dict[str, Any]]) -> None: 433 self._log_buffer = _log_buffer 434 435 @property 436 def logger(self) -> logging.Logger: 437 return self._logger 438 439 @property 440 def tb_logger(self) -> 'TensorBoradLogger': # noqa 441 return self._tb_logger 442 443 @property 444 def exp_name(self) -> str: 445 return self._exp_name 446 447 @property 448 def instance_name(self) -> str: 449 return self._instance_name 450 451 @property 452 def rank(self) -> int: 453 return self._rank 454 455 @property 456 def world_size(self) -> int: 457 return self._world_size 458 459 @property 460 def policy(self) -> 'Policy': # noqa 461 return self._policy 462 463 @policy.setter 464 def policy(self, _policy: 'Policy') -> None: # noqa 465 """ 466 Note: 467 Policy variable monitor is set alongside with policy, because variables are determined by specific policy. 468 """ 469 self._policy = _policy 470 if self._rank == 0 or not self.only_monitor_rank0: 471 self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) 472 if self._cfg.log_policy: 473 self.info(self._policy.info()) 474 475 @property 476 def priority_info(self) -> dict: 477 if not hasattr(self, '_priority_info'): 478 self._priority_info = {} 479 return self._priority_info 480 481 @priority_info.setter 482 def priority_info(self, _priority_info: dict) -> None: 483 self._priority_info = _priority_info 484 485 @property 486 def ckpt_name(self) -> str: 487 return self._ckpt_name 488 489 @ckpt_name.setter 490 def ckpt_name(self, _ckpt_name: str) -> None: 491 self._ckpt_name = _ckpt_name 492 493 494def create_learner(cfg: EasyDict, **kwargs) -> BaseLearner: 495 """ 496 Overview: 497 Given the key(learner_name), create a new learner instance if in learner_mapping's values, 498 or raise an KeyError. In other words, a derived learner must first register, then can call ``create_learner`` 499 to get the instance. 500 Arguments: 501 - cfg (:obj:`EasyDict`): Learner config. Necessary keys: [learner.import_module, learner.learner_type]. 502 Returns: 503 - learner (:obj:`BaseLearner`): The created new learner, should be an instance of one of \ 504 learner_mapping's values. 505 """ 506 import_module(cfg.get('import_names', [])) 507 return LEARNER_REGISTRY.build(cfg.type, cfg=cfg, **kwargs) 508 509 510class TickMonitor(LoggedModel): 511 """ 512 Overview: 513 TickMonitor is to monitor related info during training. 514 Info includes: cur_lr, time(data, train, forward, backward), loss(total,...) 515 These info variables are firstly recorded in ``log_buffer``, then in ``LearnerHook`` will vars in 516 in this monitor be updated by``log_buffer``, finally printed to text logger and tensorboard logger. 517 Interface: 518 __init__, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__ 519 Property: 520 time, expire 521 """ 522 data_time = LoggedValue(float) 523 train_time = LoggedValue(float) 524 total_collect_step = LoggedValue(float) 525 total_step = LoggedValue(float) 526 total_episode = LoggedValue(float) 527 total_sample = LoggedValue(float) 528 total_duration = LoggedValue(float) 529 530 def __init__(self, time_: 'BaseTime', expire: Union[int, float]): # noqa 531 LoggedModel.__init__(self, time_, expire) 532 self.__register() 533 534 def __register(self): 535 536 def __avg_func(prop_name: str) -> float: 537 records = self.range_values[prop_name]() 538 _list = [_value for (_begin_time, _end_time), _value in records] 539 return sum(_list) / len(_list) if len(_list) != 0 else 0 540 541 def __val_func(prop_name: str) -> float: 542 records = self.range_values[prop_name]() 543 return records[-1][1] 544 545 for k in getattr(self, '_LoggedModel__properties'): 546 self.register_attribute_value('avg', k, partial(__avg_func, prop_name=k)) 547 self.register_attribute_value('val', k, partial(__val_func, prop_name=k)) 548 549 550def get_simple_monitor_type(properties: List[str] = []) -> TickMonitor: 551 """ 552 Overview: 553 Besides basic training variables provided in ``TickMonitor``, many policies have their own customized 554 ones to record and monitor. This function can return a customized tick monitor. 555 Compared with ``TickMonitor``, ``SimpleTickMonitor`` can record extra ``properties`` passed in by a policy. 556 Argumenst: 557 - properties (:obj:`List[str]`): Customized properties to monitor. 558 Returns: 559 - simple_tick_monitor (:obj:`SimpleTickMonitor`): A simple customized tick monitor. 560 """ 561 if len(properties) == 0: 562 return TickMonitor 563 else: 564 attrs = {} 565 properties = [ 566 'data_time', 'train_time', 'sample_count', 'total_collect_step', 'total_step', 'total_sample', 567 'total_episode', 'total_duration' 568 ] + properties 569 for p_name in properties: 570 attrs[p_name] = LoggedValue(float) 571 return type('SimpleTickMonitor', (TickMonitor, ), attrs)