Skip to content

ding.policy.base_policy

ding.policy.base_policy

Policy

Bases: ABC

Overview

The basic class of Reinforcement Learning (RL) and Imitation Learning (IL) policy in DI-engine.

Property: cfg, learn_mode, collect_mode, eval_mode

learn_mode property

Overview

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns: - interfaces (:obj:Policy.learn_function): The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods. Examples: >>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()

collect_mode property

Overview

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns: - interfaces (:obj:Policy.collect_function): The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods. Examples: >>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)

eval_mode property

Overview

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns: - interfaces (:obj:Policy.eval_function): The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods. Examples: >>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)

default_config() classmethod

Overview

Get the default config of policy. This method is used to create the default config of policy.

Returns: - cfg (:obj:EasyDict): The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

.. tip:: This method will deepcopy the config attribute of the class and return the result. So users don't need to worry about the modification of the returned config.

__init__(cfg, model=None, enable_field=None)

Overview

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Arguments: - cfg (:obj:EasyDict): The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class. - model (:obj:torch.nn.Module): The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller. - enable_field (:obj:Optional[List[str]]): The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

.. note:: For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

__repr__()

Overview

Get the string representation of the policy.

Returns: - repr (:obj:str): The string representation of the policy.

sync_gradients(model)

Overview

Synchronize (allreduce) gradients of model parameters in data-parallel multi-GPU training. For parameters that did not participate in the forward/backward pass in some GPUs, assign a zero gradient with an indicator of 0. This ensures that only GPUs which contributed to the gradient computation are considered when averaging, thereby avoiding an incorrect division by the total number of GPUs.

Arguments: - model (:obj:torch.nn.Module): The model to synchronize gradients.

.. note:: This method is only used in multi-gpu training, and it should be called after the backward method and before the step method. The user can also use the bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

default_model()

Overview

Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

.. note:: The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For example about DQN, its registered name is dqn and the import_names is ding.model.template.q_learning.DQN

CommandModePolicy

Bases: Policy

Overview

Policy with command mode, which can be used in old version of DI-engine pipeline: serial_pipeline. CommandModePolicy uses _get_setting_learn, _get_setting_collect, _get_setting_eval methods to exchange information between different workers.

Interface

_init_command, _get_setting_learn, _get_setting_collect, _get_setting_eval

Property: command_mode

command_mode property

Overview

Return the interfaces of command mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own command mode.

Returns: - interfaces (:obj:Policy.command_function): The interfaces of command mode, it is a namedtuple whose values of distinct fields are different internal methods. Examples: >>> policy = CommandModePolicy(cfg, model) >>> policy_command = policy.command_mode >>> settings = policy_command.get_setting_learn(command_info)

create_policy(cfg, **kwargs)

Overview

Create a policy instance according to cfg and other kwargs.

Arguments: - cfg (:obj:EasyDict): Final merged policy config. ArgumentsKeys: - type (:obj:str): Policy type set in POLICY_REGISTRY.register method , such as dqn . - import_names (:obj:List[str]): A list of module names (paths) to import before creating policy, such as ding.policy.dqn . Returns: - policy (:obj:Policy): The created policy instance.

.. tip:: kwargs contains other arguments that need to be passed to the policy constructor. You can refer to the __init__ method of the corresponding policy class for details.

.. note:: For more details about how to merge config, please refer to the system document of DI-engine (en link <../03_system/config.html>_).

get_policy_cls(cfg)

Overview

Get policy class according to cfg, which is used to access related class variables/methods.

Arguments: - cfg (:obj:EasyDict): Final merged policy config. ArgumentsKeys: - type (:obj:str): Policy type set in POLICY_REGISTRY.register method , such as dqn . - import_names (:obj:List[str]): A list of module names (paths) to import before creating policy, such as ding.policy.dqn . Returns: - policy (:obj:type): The policy class.

Full Source Code

../ding/policy/base_policy.py

1from typing import Optional, List, Dict, Any, Tuple, Union 2from abc import ABC, abstractmethod 3from collections import namedtuple 4from easydict import EasyDict 5 6import copy 7import torch 8 9from ding.model import create_model 10from ding.utils import import_module, allreduce, allreduce_with_indicator, broadcast, get_rank, allreduce_async, \ 11 synchronize, deep_merge_dicts, POLICY_REGISTRY 12 13 14class Policy(ABC): 15 """ 16 Overview: 17 The basic class of Reinforcement Learning (RL) and Imitation Learning (IL) policy in DI-engine. 18 Property: 19 ``cfg``, ``learn_mode``, ``collect_mode``, ``eval_mode`` 20 """ 21 22 @classmethod 23 def default_config(cls: type) -> EasyDict: 24 """ 25 Overview: 26 Get the default config of policy. This method is used to create the default config of policy. 27 Returns: 28 - cfg (:obj:`EasyDict`): The default config of corresponding policy. For the derived policy class, \ 29 it will recursively merge the default config of base class and its own default config. 30 31 .. tip:: 32 This method will deepcopy the ``config`` attribute of the class and return the result. So users don't need \ 33 to worry about the modification of the returned config. 34 """ 35 if cls == Policy: 36 raise RuntimeError("Basic class Policy doesn't have completed default_config") 37 38 base_cls = cls.__base__ 39 if base_cls == Policy: 40 base_policy_cfg = EasyDict(copy.deepcopy(Policy.config)) 41 else: 42 base_policy_cfg = copy.deepcopy(base_cls.default_config()) 43 cfg = EasyDict(copy.deepcopy(cls.config)) 44 cfg = deep_merge_dicts(base_policy_cfg, cfg) 45 cfg.cfg_type = cls.__name__ + 'Dict' 46 return cfg 47 48 learn_function = namedtuple( 49 'learn_function', [ 50 'forward', 51 'reset', 52 'info', 53 'monitor_vars', 54 'get_attribute', 55 'set_attribute', 56 'state_dict', 57 'load_state_dict', 58 ] 59 ) 60 collect_function = namedtuple( 61 'collect_function', [ 62 'forward', 63 'process_transition', 64 'get_train_sample', 65 'reset', 66 'get_attribute', 67 'set_attribute', 68 'state_dict', 69 'load_state_dict', 70 ] 71 ) 72 eval_function = namedtuple( 73 'eval_function', [ 74 'forward', 75 'reset', 76 'get_attribute', 77 'set_attribute', 78 'state_dict', 79 'load_state_dict', 80 ] 81 ) 82 total_field = set(['learn', 'collect', 'eval']) 83 config = dict( 84 # (bool) Whether the learning policy is the same as the collecting data policy (on-policy). 85 on_policy=False, 86 # (bool) Whether to use cuda in policy. 87 cuda=False, 88 # (bool) Whether to use data parallel multi-gpu mode in policy. 89 multi_gpu=False, 90 # (bool) Whether to synchronize update the model parameters after allreduce the gradients of model parameters. 91 bp_update_sync=True, 92 # (bool) Whether to enable infinite trajectory length in data collecting. 93 traj_len_inf=False, 94 # neural network model config 95 model=dict(), 96 # If resume_training is True, the environment step count (collector.envstep) and training iteration (train_iter) 97 # will be loaded from the pretrained checkpoint, allowing training to resume seamlessly 98 # from where the ckpt left off. 99 learn=dict(resume_training=False), 100 ) 101 102 def __init__( 103 self, 104 cfg: EasyDict, 105 model: Optional[torch.nn.Module] = None, 106 enable_field: Optional[List[str]] = None 107 ) -> None: 108 """ 109 Overview: 110 Initialize policy instance according to input configures and model. This method will initialize differnent \ 111 fields in policy, including ``learn``, ``collect``, ``eval``. The ``learn`` field is used to train the \ 112 policy, the ``collect`` field is used to collect data for training, and the ``eval`` field is used to \ 113 evaluate the policy. The ``enable_field`` is used to specify which field to initialize, if it is None, \ 114 then all fields will be initialized. 115 Arguments: 116 - cfg (:obj:`EasyDict`): The final merged config used to initialize policy. For the default config, \ 117 see the ``config`` attribute and its comments of policy class. 118 - model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. If it \ 119 is None, then the model will be created according to ``default_model`` method and ``cfg.model`` field. \ 120 Otherwise, the model will be set to the ``model`` instance created by outside caller. 121 - enable_field (:obj:`Optional[List[str]]`): The field list to initialize. If it is None, then all fields \ 122 will be initialized. Otherwise, only the fields in ``enable_field`` will be initialized, which is \ 123 beneficial to save resources. 124 125 .. note:: 126 For the derived policy class, it should implement the ``_init_learn``, ``_init_collect``, ``_init_eval`` \ 127 method to initialize the corresponding field. 128 """ 129 self._cfg = cfg 130 self._on_policy = self._cfg.on_policy 131 if enable_field is None: 132 self._enable_field = self.total_field 133 else: 134 self._enable_field = enable_field 135 assert set(self._enable_field).issubset(self.total_field), self._enable_field 136 137 if len(set(self._enable_field).intersection(set(['learn', 'collect', 'eval']))) > 0: 138 model = self._create_model(cfg, model) 139 self._cuda = cfg.cuda and torch.cuda.is_available() 140 # now only support multi-gpu for only enable learn mode 141 if len(set(self._enable_field).intersection(set(['learn']))) > 0: 142 multi_gpu = self._cfg.multi_gpu 143 self._rank = get_rank() if multi_gpu else 0 144 if self._cuda: 145 # model.cuda() is an in-place operation. 146 model.cuda() 147 if multi_gpu: 148 bp_update_sync = self._cfg.bp_update_sync 149 self._bp_update_sync = bp_update_sync 150 self._init_multi_gpu_setting(model, bp_update_sync) 151 else: 152 self._rank = 0 153 if self._cuda: 154 # model.cuda() is an in-place operation. 155 model.cuda() 156 self._model = model 157 self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu' 158 else: 159 self._cuda = False 160 self._rank = 0 161 self._device = 'cpu' 162 163 # call the initialization method of different modes, such as ``_init_learn``, ``_init_collect``, ``_init_eval`` 164 for field in self._enable_field: 165 getattr(self, '_init_' + field)() 166 167 def _init_multi_gpu_setting(self, model: torch.nn.Module, bp_update_sync: bool) -> None: 168 """ 169 Overview: 170 Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning \ 171 of the training, and prepare the hook function to allreduce the gradients of model parameters. 172 Arguments: 173 - model (:obj:`torch.nn.Module`): The neural network model to be trained. 174 - bp_update_sync (:obj:`bool`): Whether to synchronize update the model parameters after allreduce the \ 175 gradients of model parameters. Async update can be parallel in different network layers like pipeline \ 176 so that it can save time. 177 """ 178 for name, param in model.state_dict().items(): 179 assert isinstance(param.data, torch.Tensor), type(param.data) 180 broadcast(param.data, 0) 181 # here we manually set the gradient to zero tensor at the beginning of the training, which is necessary for 182 # the case that different GPUs have different computation graph. 183 for name, param in model.named_parameters(): 184 setattr(param, 'grad', torch.zeros_like(param)) 185 if not bp_update_sync: 186 187 def make_hook(name, p): 188 189 def hook(*ignore): 190 allreduce_async(name, p.grad.data) 191 192 return hook 193 194 for i, (name, p) in enumerate(model.named_parameters()): 195 if p.requires_grad: 196 p_tmp = p.expand_as(p) 197 grad_acc = p_tmp.grad_fn.next_functions[0][0] 198 grad_acc.register_hook(make_hook(name, p)) 199 200 def _create_model(self, cfg: EasyDict, model: Optional[torch.nn.Module] = None) -> torch.nn.Module: 201 """ 202 Overview: 203 Create or validate the neural network model according to the input configuration and model. \ 204 If the input model is None, then the model will be created according to ``default_model`` \ 205 method and ``cfg.model`` field. Otherwise, the model will be verified as an instance of \ 206 ``torch.nn.Module`` and set to the ``model`` instance created by outside caller. 207 Arguments: 208 - cfg (:obj:`EasyDict`): The final merged config used to initialize policy. 209 - model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. User can refer to \ 210 the default model defined in the corresponding policy to customize its own model. 211 Returns: 212 - model (:obj:`torch.nn.Module`): The created neural network model. The different modes of policy will \ 213 add distinct wrappers and plugins to the model, which is used to train, collect and evaluate. 214 Raises: 215 - RuntimeError: If the input model is not None and is not an instance of ``torch.nn.Module``. 216 """ 217 if model is None: 218 model_cfg = cfg.model 219 if 'type' not in model_cfg: 220 m_type, import_names = self.default_model() 221 model_cfg.type = m_type 222 model_cfg.import_names = import_names 223 return create_model(model_cfg) 224 else: 225 if isinstance(model, torch.nn.Module): 226 return model 227 else: 228 raise RuntimeError("invalid model: {}".format(type(model))) 229 230 @property 231 def cfg(self) -> EasyDict: 232 return self._cfg 233 234 @abstractmethod 235 def _init_learn(self) -> None: 236 """ 237 Overview: 238 Initialize the learn mode of policy, including related attributes and modules. This method will be \ 239 called in ``__init__`` method if ``learn`` field is in ``enable_field``. Almost different policies have \ 240 its own learn mode, so this method must be overrided in subclass. 241 242 .. note:: 243 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 244 and ``_load_state_dict_learn`` methods. 245 246 .. note:: 247 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 248 249 .. note:: 250 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 251 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 252 """ 253 raise NotImplementedError 254 255 @abstractmethod 256 def _init_collect(self) -> None: 257 """ 258 Overview: 259 Initialize the collect mode of policy, including related attributes and modules. This method will be \ 260 called in ``__init__`` method if ``collect`` field is in ``enable_field``. Almost different policies have \ 261 its own collect mode, so this method must be overrided in subclass. 262 263 .. note:: 264 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_collect`` \ 265 and ``_load_state_dict_collect`` methods. 266 267 .. note:: 268 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 269 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 270 """ 271 raise NotImplementedError 272 273 @abstractmethod 274 def _init_eval(self) -> None: 275 """ 276 Overview: 277 Initialize the eval mode of policy, including related attributes and modules. This method will be \ 278 called in ``__init__`` method if ``eval`` field is in ``enable_field``. Almost different policies have \ 279 its own eval mode, so this method must be override in subclass. 280 281 .. note:: 282 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_eval`` \ 283 and ``_load_state_dict_eval`` methods. 284 285 .. note:: 286 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 287 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 288 """ 289 raise NotImplementedError 290 291 @property 292 def learn_mode(self) -> 'Policy.learn_function': # noqa 293 """ 294 Overview: 295 Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple \ 296 to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived \ 297 subclass can override the interfaces to customize its own learn mode. 298 Returns: 299 - interfaces (:obj:`Policy.learn_function`): The interfaces of learn mode of policy, it is a namedtuple \ 300 whose values of distinct fields are different internal methods. 301 Examples: 302 >>> policy = Policy(cfg, model) 303 >>> policy_learn = policy.learn_mode 304 >>> train_output = policy_learn.forward(data) 305 >>> state_dict = policy_learn.state_dict() 306 """ 307 return Policy.learn_function( 308 self._forward_learn, 309 self._reset_learn, 310 self.__repr__, 311 self._monitor_vars_learn, 312 self._get_attribute, 313 self._set_attribute, 314 self._state_dict_learn, 315 self._load_state_dict_learn, 316 ) 317 318 @property 319 def collect_mode(self) -> 'Policy.collect_function': # noqa 320 """ 321 Overview: 322 Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple \ 323 to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived \ 324 subclass can override the interfaces to customize its own collect mode. 325 Returns: 326 - interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a \ 327 namedtuple whose values of distinct fields are different internal methods. 328 Examples: 329 >>> policy = Policy(cfg, model) 330 >>> policy_collect = policy.collect_mode 331 >>> obs = env_manager.ready_obs 332 >>> inference_output = policy_collect.forward(obs) 333 >>> next_obs, rew, done, info = env_manager.step(inference_output.action) 334 """ 335 return Policy.collect_function( 336 self._forward_collect, 337 self._process_transition, 338 self._get_train_sample, 339 self._reset_collect, 340 self._get_attribute, 341 self._set_attribute, 342 self._state_dict_collect, 343 self._load_state_dict_collect, 344 ) 345 346 @property 347 def eval_mode(self) -> 'Policy.eval_function': # noqa 348 """ 349 Overview: 350 Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple \ 351 to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ 352 subclass can override the interfaces to customize its own eval mode. 353 Returns: 354 - interfaces (:obj:`Policy.eval_function`): The interfaces of eval mode of policy, it is a namedtuple \ 355 whose values of distinct fields are different internal methods. 356 Examples: 357 >>> policy = Policy(cfg, model) 358 >>> policy_eval = policy.eval_mode 359 >>> obs = env_manager.ready_obs 360 >>> inference_output = policy_eval.forward(obs) 361 >>> next_obs, rew, done, info = env_manager.step(inference_output.action) 362 """ 363 return Policy.eval_function( 364 self._forward_eval, 365 self._reset_eval, 366 self._get_attribute, 367 self._set_attribute, 368 self._state_dict_eval, 369 self._load_state_dict_eval, 370 ) 371 372 def _set_attribute(self, name: str, value: Any) -> None: 373 """ 374 Overview: 375 In order to control the access of the policy attributes, we expose different modes to outside rather than \ 376 directly use the policy instance. And we also provide a method to set the attribute of the policy in \ 377 different modes. And the new attribute will name as ``_{name}``. 378 Arguments: 379 - name (:obj:`str`): The name of the attribute. 380 - value (:obj:`Any`): The value of the attribute. 381 """ 382 setattr(self, '_' + name, value) 383 384 def _get_attribute(self, name: str) -> Any: 385 """ 386 Overview: 387 In order to control the access of the policy attributes, we expose different modes to outside rather than \ 388 directly use the policy instance. And we also provide a method to get the attribute of the policy in \ 389 different modes. 390 Arguments: 391 - name (:obj:`str`): The name of the attribute. 392 Returns: 393 - value (:obj:`Any`): The value of the attribute. 394 395 .. note:: 396 DI-engine's policy will first try to access `_get_{name}` method, and then try to access `_{name}` \ 397 attribute. If both of them are not found, it will raise a ``NotImplementedError``. 398 """ 399 if hasattr(self, '_get_' + name): 400 return getattr(self, '_get_' + name)() 401 elif hasattr(self, '_' + name): 402 return getattr(self, '_' + name) 403 else: 404 raise NotImplementedError 405 406 def __repr__(self) -> str: 407 """ 408 Overview: 409 Get the string representation of the policy. 410 Returns: 411 - repr (:obj:`str`): The string representation of the policy. 412 """ 413 return "DI-engine DRL Policy\n{}".format(repr(self._model)) 414 415 def sync_gradients(self, model: torch.nn.Module) -> None: 416 """ 417 Overview: 418 Synchronize (allreduce) gradients of model parameters in data-parallel multi-GPU training. 419 For parameters that did not participate in the forward/backward pass in some GPUs, 420 assign a zero gradient with an indicator of 0. This ensures that only GPUs which contributed 421 to the gradient computation are considered when averaging, thereby avoiding an incorrect 422 division by the total number of GPUs. 423 Arguments: 424 - model (:obj:`torch.nn.Module`): The model to synchronize gradients. 425 426 .. note:: 427 This method is only used in multi-gpu training, and it should be called after the ``backward`` method and \ 428 before the ``step`` method. The user can also use the ``bp_update_sync`` config to control whether to \ 429 synchronize gradients allreduce and optimizer updates. 430 """ 431 if self._bp_update_sync: 432 for name, param in model.named_parameters(): 433 if param.requires_grad: 434 # Create an indicator tensor on the same device as the parameter (or its gradient) 435 if param.grad is not None: 436 # If the gradient exists, extract its data and set indicator to 1. 437 grad_tensor = param.grad.data 438 indicator = torch.tensor(1.0, device=grad_tensor.device) 439 else: 440 # If the parameter did not participate in the computation (grad is None), 441 # create a zero tensor for the gradient and set the indicator to 0. 442 grad_tensor = torch.zeros_like(param.data) 443 indicator = torch.tensor(0.0, device=grad_tensor.device) 444 445 # Assign the zero gradient to param.grad to ensure that all GPUs 446 # participate in the subsequent allreduce call (avoiding deadlock). 447 param.grad = grad_tensor 448 449 # Use the custom allreduce function to reduce the gradient using the indicator. 450 allreduce_with_indicator(param.grad, indicator) 451 else: 452 synchronize() 453 454 # don't need to implement default_model method by force 455 def default_model(self) -> Tuple[str, List[str]]: 456 """ 457 Overview: 458 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 459 automatically call this method to get the default model setting and create model. 460 Returns: 461 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 462 463 .. note:: 464 The user can define and use customized network model but must obey the same inferface definition indicated \ 465 by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ 466 ``ding.model.template.q_learning.DQN`` 467 """ 468 raise NotImplementedError 469 470 # *************************************** learn function ************************************ 471 472 @abstractmethod 473 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 474 """ 475 Overview: 476 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 477 that the policy inputs some training batch data from the replay buffer and then returns the output \ 478 result, including various training information such as loss value, policy entropy, q value, priority, \ 479 and so on. This method is left to be implemented by the subclass, and more arguments can be added in \ 480 ``data`` item if necessary. 481 Arguments: 482 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 483 training samples. For each element in list, the key of the dict is the name of data items and the \ 484 value is the corresponding data. Usually, in the ``_forward_learn`` method, data should be stacked in \ 485 the batch dimension by some utility functions such as ``default_preprocess_learn``. 486 Returns: 487 - output (:obj:`Dict[int, Any]`): The training information of policy forward, including some metrics for \ 488 monitoring training such as loss, priority, q value, policy entropy, and some data for next step \ 489 training such as priority. Note the output data item should be Python native scalar rather than \ 490 PyTorch tensor, which is convenient for the outside to use. 491 """ 492 raise NotImplementedError 493 494 # don't need to implement _reset_learn method by force 495 def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: 496 """ 497 Overview: 498 Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the \ 499 memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ 500 varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ 501 different trajectories in ``data_id`` will have different hidden state in RNN. 502 Arguments: 503 - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ 504 specified by ``data_id``. 505 506 .. note:: 507 This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary. 508 """ 509 pass 510 511 def _monitor_vars_learn(self) -> List[str]: 512 """ 513 Overview: 514 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 515 as text logger, tensorboard logger, will use these keys to save the corresponding data. 516 Returns: 517 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 518 519 .. tip:: 520 The default implementation is ``['cur_lr', 'total_loss']``. Other derived classes can overwrite this \ 521 method to add their own keys if necessary. 522 """ 523 return ['cur_lr', 'total_loss'] 524 525 def _state_dict_learn(self) -> Dict[str, Any]: 526 """ 527 Overview: 528 Return the state_dict of learn mode, usually including model and optimizer. 529 Returns: 530 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 531 """ 532 return { 533 'model': self._learn_model.state_dict(), 534 'optimizer': self._optimizer.state_dict(), 535 } 536 537 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 538 """ 539 Overview: 540 Load the state_dict variable into policy learn mode. 541 Arguments: 542 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 543 544 .. tip:: 545 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 546 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 547 complicated operation. 548 """ 549 self._learn_model.load_state_dict(state_dict['model']) 550 self._optimizer.load_state_dict(state_dict['optimizer']) 551 552 def _get_batch_size(self) -> Union[int, Dict[str, int]]: 553 # some specifial algorithms use different batch size for different optimization parts. 554 if 'batch_size' in self._cfg: 555 return self._cfg.batch_size 556 else: # for compatibility 557 return self._cfg.learn.batch_size 558 559 # *************************************** collect function ************************************ 560 561 @abstractmethod 562 def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: 563 """ 564 Overview: 565 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 566 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 567 data, such as the action to interact with the envs, or the action logits to calculate the loss in learn \ 568 mode. This method is left to be implemented by the subclass, and more arguments can be added in ``kwargs`` \ 569 part if necessary. 570 Arguments: 571 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 572 key of the dict is environment id and the value is the corresponding data of the env. 573 Returns: 574 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 575 other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ 576 dict is the same as the input data, i.e. environment id. 577 """ 578 raise NotImplementedError 579 580 @abstractmethod 581 def _process_transition( 582 self, obs: Union[torch.Tensor, Dict[str, torch.Tensor]], policy_output: Dict[str, torch.Tensor], 583 timestep: namedtuple 584 ) -> Dict[str, torch.Tensor]: 585 """ 586 Overview: 587 Process and pack one timestep transition data into a dict, such as <s, a, r, s', done>. Some policies \ 588 need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), \ 589 so this method is left to be implemented by the subclass. 590 Arguments: 591 - obs (:obj:`Union[torch.Tensor, Dict[str, torch.Tensor]]`): The observation of the current timestep. 592 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ 593 as input. Usually, it contains the action and the logit of the action. 594 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 595 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 596 reward, done, info, etc. 597 Returns: 598 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 599 """ 600 raise NotImplementedError 601 602 @abstractmethod 603 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 604 """ 605 Overview: 606 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ 607 can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) \ 608 or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary \ 609 RL data preprocessing before training, which can help learner amortize revelant time consumption. \ 610 In addition, you can also implement this method as an identity function and do the data processing \ 611 in ``self._forward_learn`` method. 612 Arguments: 613 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ 614 the same format as the return value of ``self._process_transition`` method. 615 Returns: 616 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ 617 as input transitions, but may contain more data for training, such as nstep reward, advantage, etc. 618 619 .. note:: 620 We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ 621 And the user can customize the this data processing procecure by overriding this two methods and collector \ 622 itself 623 """ 624 raise NotImplementedError 625 626 # don't need to implement _reset_collect method by force 627 def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: 628 """ 629 Overview: 630 Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the \ 631 memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ 632 varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ 633 different environments/episodes in collecting in ``data_id`` will have different hidden state in RNN. 634 Arguments: 635 - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ 636 specified by ``data_id``. 637 638 .. note:: 639 This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary. 640 """ 641 pass 642 643 def _state_dict_collect(self) -> Dict[str, Any]: 644 """ 645 Overview: 646 Return the state_dict of collect mode, only including model in usual, which is necessary for distributed \ 647 training scenarios to auto-recover collectors. 648 Returns: 649 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy collect state, for saving and restoring. 650 651 .. tip:: 652 Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed \ 653 collector and renew a new one. 654 """ 655 return {'model': self._collect_model.state_dict()} 656 657 def _load_state_dict_collect(self, state_dict: Dict[str, Any]) -> None: 658 """ 659 Overview: 660 Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover \ 661 checkpoint, or model replica from learner in distributed training scenarios. 662 Arguments: 663 - state_dict (:obj:`Dict[str, Any]`): The dict of policy collect state saved before. 664 665 .. tip:: 666 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 667 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 668 complicated operation. 669 """ 670 self._collect_model.load_state_dict(state_dict['model'], strict=True) 671 672 def _get_n_sample(self) -> Union[int, None]: 673 if 'n_sample' in self._cfg: 674 return self._cfg.n_sample 675 else: # for compatibility 676 return self._cfg.collect.get('n_sample', None) # for some adpative collecting data case 677 678 def _get_n_episode(self) -> Union[int, None]: 679 if 'n_episode' in self._cfg: 680 return self._cfg.n_episode 681 else: # for compatibility 682 return self._cfg.collect.get('n_episode', None) # for some adpative collecting data case 683 684 # *************************************** eval function ************************************ 685 686 @abstractmethod 687 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 688 """ 689 Overview: 690 Policy forward function of eval mode (evaluation policy performance, such as interacting with envs or \ 691 computing metrics on validation dataset). Forward means that the policy gets some necessary data (mainly \ 692 observation) from the envs and then returns the output data, such as the action to interact with the envs. \ 693 This method is left to be implemented by the subclass. 694 Arguments: 695 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 696 key of the dict is environment id and the value is the corresponding data of the env. 697 Returns: 698 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 699 key of the dict is the same as the input data, i.e. environment id. 700 """ 701 raise NotImplementedError 702 703 # don't need to implement _reset_eval method by force 704 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 705 """ 706 Overview: 707 Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \ 708 memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ 709 varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ 710 different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN. 711 Arguments: 712 - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ 713 specified by ``data_id``. 714 715 .. note:: 716 This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary. 717 """ 718 pass 719 720 def _state_dict_eval(self) -> Dict[str, Any]: 721 """ 722 Overview: 723 Return the state_dict of eval mode, only including model in usual, which is necessary for distributed \ 724 training scenarios to auto-recover evaluators. 725 Returns: 726 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy eval state, for saving and restoring. 727 728 .. tip:: 729 Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed \ 730 evaluator and renew a new one. 731 """ 732 return {'model': self._eval_model.state_dict()} 733 734 def _load_state_dict_eval(self, state_dict: Dict[str, Any]) -> None: 735 """ 736 Overview: 737 Load the state_dict variable into policy eval mode, such as load auto-recover \ 738 checkpoint, or model replica from learner in distributed training scenarios. 739 Arguments: 740 - state_dict (:obj:`Dict[str, Any]`): The dict of policy eval state saved before. 741 742 .. tip:: 743 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 744 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 745 complicated operation. 746 """ 747 self._eval_model.load_state_dict(state_dict['model'], strict=True) 748 749 750class CommandModePolicy(Policy): 751 """ 752 Overview: 753 Policy with command mode, which can be used in old version of DI-engine pipeline: ``serial_pipeline``. \ 754 ``CommandModePolicy`` uses ``_get_setting_learn``, ``_get_setting_collect``, ``_get_setting_eval`` methods \ 755 to exchange information between different workers. 756 757 Interface: 758 ``_init_command``, ``_get_setting_learn``, ``_get_setting_collect``, ``_get_setting_eval`` 759 Property: 760 ``command_mode`` 761 """ 762 command_function = namedtuple('command_function', ['get_setting_learn', 'get_setting_collect', 'get_setting_eval']) 763 total_field = set(['learn', 'collect', 'eval', 'command']) 764 765 @property 766 def command_mode(self) -> 'Policy.command_function': # noqa 767 """ 768 Overview: 769 Return the interfaces of command mode of policy, which is used to train the model. Here we use namedtuple \ 770 to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ 771 subclass can override the interfaces to customize its own command mode. 772 Returns: 773 - interfaces (:obj:`Policy.command_function`): The interfaces of command mode, it is a namedtuple \ 774 whose values of distinct fields are different internal methods. 775 Examples: 776 >>> policy = CommandModePolicy(cfg, model) 777 >>> policy_command = policy.command_mode 778 >>> settings = policy_command.get_setting_learn(command_info) 779 """ 780 return CommandModePolicy.command_function( 781 self._get_setting_learn, self._get_setting_collect, self._get_setting_eval 782 ) 783 784 @abstractmethod 785 def _init_command(self) -> None: 786 """ 787 Overview: 788 Initialize the command mode of policy, including related attributes and modules. This method will be \ 789 called in ``__init__`` method if ``command`` field is in ``enable_field``. Almost different policies have \ 790 its own command mode, so this method must be overrided in subclass. 791 792 .. note:: 793 If you want to set some spacial member variables in ``_init_command`` method, you'd better name them \ 794 with prefix ``_command_`` to avoid conflict with other modes, such as ``self._command_attr1``. 795 """ 796 raise NotImplementedError 797 798 # *************************************** command function ************************************ 799 @abstractmethod 800 def _get_setting_learn(self, command_info: Dict[str, Any]) -> Dict[str, Any]: 801 """ 802 Overview: 803 Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \ 804 step, evaluation results, etc.), return the setting of learn mode, which contains dynamically changed \ 805 hyperparameters for learn mode, such as ``batch_size``, ``learning_rate``, etc. 806 Arguments: 807 - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``. 808 Returns: 809 - setting (:obj:`Dict[str, Any]`): The latest setting of learn mode, which is usually used as extra \ 810 arguments of the ``policy._forward_learn`` method. 811 """ 812 raise NotImplementedError 813 814 @abstractmethod 815 def _get_setting_collect(self, command_info: Dict[str, Any]) -> Dict[str, Any]: 816 """ 817 Overview: 818 Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \ 819 step, evaluation results, etc.), return the setting of collect mode, which contains dynamically changed \ 820 hyperparameters for collect mode, such as ``eps``, ``temperature``, etc. 821 Arguments: 822 - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``. 823 Returns: 824 - setting (:obj:`Dict[str, Any]`): The latest setting of collect mode, which is usually used as extra \ 825 arguments of the ``policy._forward_collect`` method. 826 """ 827 raise NotImplementedError 828 829 @abstractmethod 830 def _get_setting_eval(self, command_info: Dict[str, Any]) -> Dict[str, Any]: 831 """ 832 Overview: 833 Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \ 834 step, evaluation results, etc.), return the setting of eval mode, which contains dynamically changed \ 835 hyperparameters for eval mode, such as ``temperature``, etc. 836 Arguments: 837 - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``. 838 Returns: 839 - setting (:obj:`Dict[str, Any]`): The latest setting of eval mode, which is usually used as extra \ 840 arguments of the ``policy._forward_eval`` method. 841 """ 842 raise NotImplementedError 843 844 845def create_policy(cfg: EasyDict, **kwargs) -> Policy: 846 """ 847 Overview: 848 Create a policy instance according to ``cfg`` and other kwargs. 849 Arguments: 850 - cfg (:obj:`EasyDict`): Final merged policy config. 851 ArgumentsKeys: 852 - type (:obj:`str`): Policy type set in ``POLICY_REGISTRY.register`` method , such as ``dqn`` . 853 - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating policy, such \ 854 as ``ding.policy.dqn`` . 855 Returns: 856 - policy (:obj:`Policy`): The created policy instance. 857 858 .. tip:: 859 ``kwargs`` contains other arguments that need to be passed to the policy constructor. You can refer to \ 860 the ``__init__`` method of the corresponding policy class for details. 861 862 .. note:: 863 For more details about how to merge config, please refer to the system document of DI-engine \ 864 (`en link <../03_system/config.html>`_). 865 """ 866 import_module(cfg.get('import_names', [])) 867 return POLICY_REGISTRY.build(cfg.type, cfg=cfg, **kwargs) 868 869 870def get_policy_cls(cfg: EasyDict) -> type: 871 """ 872 Overview: 873 Get policy class according to ``cfg``, which is used to access related class variables/methods. 874 Arguments: 875 - cfg (:obj:`EasyDict`): Final merged policy config. 876 ArgumentsKeys: 877 - type (:obj:`str`): Policy type set in ``POLICY_REGISTRY.register`` method , such as ``dqn`` . 878 - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating policy, such \ 879 as ``ding.policy.dqn`` . 880 Returns: 881 - policy (:obj:`type`): The policy class. 882 """ 883 import_module(cfg.get('import_names', [])) 884 return POLICY_REGISTRY.get(cfg.type)