Skip to content

ding.policy.dqn

ding.policy.dqn

DQNPolicy

Bases: Policy

Overview

Policy class of DQN algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD.

Config

== ===================== ======== ============== ======================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ===================== ======== ============== ======================================= ======================= 1 type str dqn | RL policy register name, refer to | This arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | This arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool False | Whether use priority(PER) | Priority sample, | update priority 5 | priority_IS bool False | Whether use Importance Sampling | _weight | Weight to correct biased update. If | True, priority must be True. 6 | discount_ float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 7 nstep int 1, | N-step reward discount sum for target [3, 5] | q_value estimation 8 | model.dueling bool True | dueling head architecture 9 | model.encoder list [32, 64, | Sequence of hidden_size of | default kernel_size | _hidden (int) 64, 128] | subsequent conv layers and the | is [8, 4, 3] | _size_list | final dense layer. | default stride is | [4, 2 ,1] 10 | model.dropout float None | Dropout rate for dropout layers. | [0,1] | If set to None | means no dropout 11 | learn.update int 3 | How many updates(iterations) to train | This args can be vary | per_collect | after collector's one collection. | from envs. Bigger val | Only valid in serial training | means more off-policy 12 | learn.batch_ int 64 | The number of samples of an iteration | size 13 | learn.learning float 0.001 | Gradient step length of an iteration. | _rate 14 | learn.target_ int 100 | Frequence of target network update. | Hard(assign) update | update_freq 15 | learn.target_ float 0.005 | Frequence of target network update. | Soft(assign) update | theta | Only one of [target_update_freq, | | target_theta] should be set 16 | learn.ignore_ bool False | Whether ignore done for target value | Enable it for some | done | calculation. | fake termination env 17 collect.n_sample int [8, 128] | The number of training samples of a | It varies from | call of collector. | different envs 18 collect.n_episode int 8 | The number of training episodes of a | only one of [n_sample | call of collector | ,n_episode] should | | be set 19 | collect.unroll int 1 | unroll length of an iteration | In RNN, unroll_len>1 | _len 20 | other.eps.type str exp | exploration rate decay type | Support ['exp', | 'linear']. 21 | other.eps. float 0.95 | start value of exploration rate | [0,1] | start 22 | other.eps. float 0.1 | end value of exploration rate | [0,1] | end 23 | other.eps. int 10000 | decay length of exploration | greater than 0. set | decay | decay=10000 means | the exploration rate | decay from start | value to end value | during decay length. == ===================== ======== ============== ======================================= =======================

default_model()

Overview

Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

.. note:: The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For example about DQN, its registered name is dqn and the import_names is ding.model.template.q_learning.

calculate_priority(data, update_target_model=False)

Overview

Calculate priority for replay buffer.

Arguments: - data (:obj:Dict[str, Any]): Dict type data, a batch of data for training. - update_target_model (:obj:bool): Whether to update target model. Returns: - priority (:obj:Dict[str, Any]): Dict type priority data, values are python scalar or a list of scalars. ArgumentsKeys: - necessary: obs, action, reward, next_obs, done - optional: value_gamma ReturnsKeys: - necessary: priority

DQNSTDIMPolicy

Bases: DQNPolicy

Overview

Policy class of DQN algorithm, extended by ST-DIM auxiliary objectives. ST-DIM paper link: https://arxiv.org/abs/1906.08226.

Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str dqn_stdim | RL policy register name, refer to | This arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | This arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool False | Whether use priority(PER) | Priority sample, | update priority 5 | priority_IS bool False | Whether use Importance Sampling Weight | _weight | to correct biased update. If True, | priority must be True. 6 | discount_ float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 7 nstep int 1, | N-step reward discount sum for target [3, 5] | q_value estimation 8 | learn.update int 3 | How many updates(iterations) to train | This args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy | _gpu 10 | learn.batch_ int 64 | The number of samples of an iteration | size 11 | learn.learning float 0.001 | Gradient step length of an iteration. | _rate 12 | learn.target_ int 100 | Frequence of target network update. | Hard(assign) update | update_freq 13 | learn.ignore_ bool False | Whether ignore done for target value | Enable it for some | done | calculation. | fake termination env 14 collect.n_sample int [8, 128] | The number of training samples of a | It varies from | call of collector. | different envs 15 | collect.unroll int 1 | unroll length of an iteration | In RNN, unroll_len>1 | _len 16 | other.eps.type str exp | exploration rate decay type | Support ['exp', | 'linear']. 17 | other.eps. float 0.95 | start value of exploration rate | [0,1] | start 18 | other.eps. float 0.1 | end value of exploration rate | [0,1] | end 19 | other.eps. int 10000 | decay length of exploration | greater than 0. set | decay | decay=10000 means | the exploration rate | decay from start | value to end value | during decay length. 20 | aux_loss float 0.001 | the ratio of the auxiliary loss to | any real value, | _weight | the TD loss | typically in | [-0.1, 0.1]. == ==================== ======== ============== ======================================== =======================

Full Source Code

../ding/policy/dqn.py

1from typing import List, Dict, Any, Tuple 2from collections import namedtuple 3import copy 4import torch 5 6from ding.torch_utils import Adam, to_device, ContrastiveLoss 7from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample 8from ding.model import model_wrap 9from ding.utils import POLICY_REGISTRY 10from ding.utils.data import default_collate, default_decollate 11 12from .base_policy import Policy 13from .common_utils import default_preprocess_learn, set_noise_mode 14 15 16@POLICY_REGISTRY.register('dqn') 17class DQNPolicy(Policy): 18 """ 19 Overview: 20 Policy class of DQN algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD. 21 22 Config: 23 == ===================== ======== ============== ======================================= ======================= 24 ID Symbol Type Default Value Description Other(Shape) 25 == ===================== ======== ============== ======================================= ======================= 26 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional, 27 | registry ``POLICY_REGISTRY`` | a placeholder 28 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 29 | erent from modes 30 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 31 | or off-policy 32 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, 33 | update priority 34 5 | ``priority_IS`` bool False | Whether use Importance Sampling 35 | ``_weight`` | Weight to correct biased update. If 36 | True, priority must be True. 37 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse 38 | ``factor`` [0.95, 0.999] | gamma | reward env 39 7 ``nstep`` int 1, | N-step reward discount sum for target 40 [3, 5] | q_value estimation 41 8 | ``model.dueling`` bool True | dueling head architecture 42 9 | ``model.encoder`` list [32, 64, | Sequence of ``hidden_size`` of | default kernel_size 43 | ``_hidden`` (int) 64, 128] | subsequent conv layers and the | is [8, 4, 3] 44 | ``_size_list`` | final dense layer. | default stride is 45 | [4, 2 ,1] 46 10 | ``model.dropout`` float None | Dropout rate for dropout layers. | [0,1] 47 | If set to ``None`` 48 | means no dropout 49 11 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary 50 | ``per_collect`` | after collector's one collection. | from envs. Bigger val 51 | Only valid in serial training | means more off-policy 52 12 | ``learn.batch_`` int 64 | The number of samples of an iteration 53 | ``size`` 54 13 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 55 | ``_rate`` 56 14 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update 57 | ``update_freq`` 58 15 | ``learn.target_`` float 0.005 | Frequence of target network update. | Soft(assign) update 59 | ``theta`` | Only one of [target_update_freq, 60 | | target_theta] should be set 61 16 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 62 | ``done`` | calculation. | fake termination env 63 17 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from 64 | call of collector. | different envs 65 18 ``collect.n_episode`` int 8 | The number of training episodes of a | only one of [n_sample 66 | call of collector | ,n_episode] should 67 | | be set 68 19 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 69 | ``_len`` 70 20 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp', 71 | 'linear']. 72 21 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1] 73 | ``start`` 74 22 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1] 75 | ``end`` 76 23 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set 77 | ``decay`` | decay=10000 means 78 | the exploration rate 79 | decay from start 80 | value to end value 81 | during decay length. 82 == ===================== ======== ============== ======================================= ======================= 83 """ 84 85 config = dict( 86 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 87 type='dqn', 88 # (bool) Whether to use cuda in policy. 89 cuda=False, 90 # (bool) Whether learning policy is the same as collecting data policy(on-policy). 91 on_policy=False, 92 # (bool) Whether to enable priority experience sample. 93 priority=False, 94 # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. 95 priority_IS_weight=False, 96 # (float) Discount factor(gamma) for returns. 97 discount_factor=0.97, 98 # (int) The number of steps for calculating target q_value. 99 nstep=1, 100 # (bool) Whether to use NoisyNet for exploration in both learning and collecting. Default is False. 101 noisy_net=False, 102 model=dict( 103 # (list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer. 104 encoder_hidden_size_list=[128, 128, 64], 105 ), 106 # learn_mode config 107 learn=dict( 108 # (int) How many updates(iterations) to train after collector's one collection. 109 # Bigger "update_per_collect" means bigger off-policy. 110 # collect data -> update policy-> collect data -> ... 111 update_per_collect=3, 112 # (int) How many samples in a training batch. 113 batch_size=64, 114 # (float) The step size of gradient descent. 115 learning_rate=0.001, 116 # (int) Frequency of target network update. 117 # Only one of [target_update_freq, target_theta] should be set. 118 target_update_freq=100, 119 # (float) Used for soft update of the target network. 120 # aka. Interpolation factor in EMA update for target network. 121 # Only one of [target_update_freq, target_theta] should be set. 122 target_theta=0.005, 123 # (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time 124 # limits are disregarded. By default, this is set to False. This setting is particularly useful for tasks 125 # that have a predetermined episode length, such as HalfCheetah and various other MuJoCo environments, 126 # where the maximum length is capped at 1000 steps. When enabled, any 'done' signal triggered by reaching 127 # the maximum episode steps will be overridden to 'False'. This ensures the accurate calculation of the 128 # Temporal Difference (TD) error, using the formula `gamma * (1 - done) * next_v + reward`, 129 # even when the episode surpasses the predefined step limit. 130 ignore_done=False, 131 ), 132 # collect_mode config 133 collect=dict( 134 # (int) How many training samples collected in one collection procedure. 135 # Only one of [n_sample, n_episode] should be set. 136 n_sample=8, 137 # (int) Split episodes or trajectories into pieces with length `unroll_len`. 138 unroll_len=1, 139 ), 140 eval=dict(), # for compatibility 141 # other config 142 other=dict( 143 # Epsilon greedy with decay. 144 eps=dict( 145 # (str) Decay type. Support ['exp', 'linear']. 146 type='exp', 147 # (float) Epsilon start value. 148 start=0.95, 149 # (float) Epsilon end value. 150 end=0.1, 151 # (int) Decay length(env step). 152 decay=10000, 153 ), 154 replay_buffer=dict( 155 # (int) Maximum size of replay buffer. Usually, larger buffer size is better. 156 replay_buffer_size=10000, 157 ), 158 ), 159 ) 160 161 def default_model(self) -> Tuple[str, List[str]]: 162 """ 163 Overview: 164 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 165 automatically call this method to get the default model setting and create model. 166 Returns: 167 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 168 169 .. note:: 170 The user can define and use customized network model but must obey the same interface definition indicated \ 171 by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ 172 ``ding.model.template.q_learning``. 173 """ 174 return 'dqn', ['ding.model.template.q_learning'] 175 176 def _init_learn(self) -> None: 177 """ 178 Overview: 179 Initialize the learn mode of policy, including related attributes and modules. For DQN, it mainly contains \ 180 optimizer, algorithm-specific arguments such as nstep and gamma, main and target model. 181 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 182 183 .. note:: 184 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 185 and ``_load_state_dict_learn`` methods. 186 187 .. note:: 188 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 189 190 .. note:: 191 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 192 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 193 """ 194 self._priority = self._cfg.priority 195 self._priority_IS_weight = self._cfg.priority_IS_weight 196 # Optimizer 197 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 198 199 self._gamma = self._cfg.discount_factor 200 self._nstep = self._cfg.nstep 201 202 # use model_wrapper for specialized demands of different modes 203 self._target_model = copy.deepcopy(self._model) 204 if 'target_update_freq' in self._cfg.learn: 205 self._target_model = model_wrap( 206 self._target_model, 207 wrapper_name='target', 208 update_type='assign', 209 update_kwargs={'freq': self._cfg.learn.target_update_freq} 210 ) 211 elif 'target_theta' in self._cfg.learn: 212 self._target_model = model_wrap( 213 self._target_model, 214 wrapper_name='target', 215 update_type='momentum', 216 update_kwargs={'theta': self._cfg.learn.target_theta} 217 ) 218 else: 219 raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta") 220 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 221 self._learn_model.reset() 222 self._target_model.reset() 223 224 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 225 """ 226 Overview: 227 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 228 that the policy inputs some training batch data from the replay buffer and then returns the output \ 229 result, including various training information such as loss, q value, priority. 230 Arguments: 231 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 232 training samples. For each element in list, the key of the dict is the name of data items and the \ 233 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 234 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 235 dimension by some utility functions such as ``default_preprocess_learn``. \ 236 For DQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 237 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ 238 and ``value_gamma``. 239 Returns: 240 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 241 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 242 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 243 244 .. note:: 245 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 246 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 247 You can implement your own model rather than use the default model. For more information, please raise an \ 248 issue in GitHub repo and we will continue to follow up. 249 250 .. note:: 251 For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. 252 """ 253 # Set noise mode for NoisyNet for exploration in learning if enabled in config 254 # We need to reset set_noise_mode every _forward_xxx because the model is reused across different 255 # phases (learn/collect/eval). 256 if self._cfg.noisy_net: 257 set_noise_mode(self._learn_model, True) 258 set_noise_mode(self._target_model, True) 259 260 # A noisy network agent samples a new set of parameters after every step of optimisation. 261 # Between optimisation steps, the agent acts according to a fixed set of parameters (weights and biases). 262 # This ensures that the agent always acts according to parameters that are drawn from 263 # the current noise distribution. 264 if self._cfg.noisy_net: 265 self._reset_noise(self._learn_model) 266 self._reset_noise(self._target_model) 267 268 # Data preprocessing operations, such as stack data, cpu to cuda device 269 data = default_preprocess_learn( 270 data, 271 use_priority=self._priority, 272 use_priority_IS_weight=self._cfg.priority_IS_weight, 273 ignore_done=self._cfg.learn.ignore_done, 274 use_nstep=True 275 ) 276 if self._cuda: 277 data = to_device(data, self._device) 278 # Q-learning forward 279 self._learn_model.train() 280 self._target_model.train() 281 # Current q value (main model) 282 q_value = self._learn_model.forward(data['obs'])['logit'] 283 # Target q value 284 with torch.no_grad(): 285 target_q_value = self._target_model.forward(data['next_obs'])['logit'] 286 # Max q value action (main model), i.e. Double DQN 287 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 288 289 data_n = q_nstep_td_data( 290 q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight'] 291 ) 292 value_gamma = data.get('value_gamma') 293 loss, td_error_per_sample = q_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma) 294 295 # Update network parameters 296 self._optimizer.zero_grad() 297 loss.backward() 298 if self._cfg.multi_gpu: 299 self.sync_gradients(self._learn_model) 300 self._optimizer.step() 301 302 # Postprocessing operations, such as updating target model, return logged values and priority. 303 self._target_model.update(self._learn_model.state_dict()) 304 return { 305 'cur_lr': self._optimizer.defaults['lr'], 306 'total_loss': loss.item(), 307 'q_value': q_value.mean().item(), 308 'target_q_value': target_q_value.mean().item(), 309 'priority': td_error_per_sample.abs().tolist(), 310 # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. 311 # '[histogram]action_distribution': data['action'], 312 } 313 314 def _monitor_vars_learn(self) -> List[str]: 315 """ 316 Overview: 317 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 318 as text logger, tensorboard logger, will use these keys to save the corresponding data. 319 Returns: 320 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 321 """ 322 return ['cur_lr', 'total_loss', 'q_value', 'target_q_value'] 323 324 def _state_dict_learn(self) -> Dict[str, Any]: 325 """ 326 Overview: 327 Return the state_dict of learn mode, usually including model, target_model and optimizer. 328 Returns: 329 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 330 """ 331 return { 332 'model': self._learn_model.state_dict(), 333 'target_model': self._target_model.state_dict(), 334 'optimizer': self._optimizer.state_dict(), 335 } 336 337 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 338 """ 339 Overview: 340 Load the state_dict variable into policy learn mode. 341 Arguments: 342 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 343 344 .. tip:: 345 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 346 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 347 complicated operation. 348 """ 349 self._learn_model.load_state_dict(state_dict['model']) 350 self._target_model.load_state_dict(state_dict['target_model']) 351 self._optimizer.load_state_dict(state_dict['optimizer']) 352 353 def _init_collect(self) -> None: 354 """ 355 Overview: 356 Initialize the collect mode of policy, including related attributes and modules. For DQN, it contains the \ 357 collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism, and other \ 358 algorithm-specific arguments such as unroll_len and nstep. 359 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 360 361 .. note:: 362 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 363 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 364 365 .. tip:: 366 Some variables need to initialize independently in different modes, such as gamma and nstep in DQN. This \ 367 design is for the convenience of parallel execution of different policy modes. 368 """ 369 self._unroll_len = self._cfg.collect.unroll_len 370 self._gamma = self._cfg.discount_factor # necessary for parallel 371 self._nstep = self._cfg.nstep # necessary for parallel 372 self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') 373 self._collect_model.reset() 374 375 def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: 376 """ 377 Overview: 378 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 379 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 380 data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ 381 exploration, i.e., classic epsilon-greedy exploration strategy. 382 Arguments: 383 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 384 key of the dict is environment id and the value is the corresponding data of the env. 385 - eps (:obj:`float`): The epsilon value for exploration. 386 Returns: 387 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 388 other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ 389 dict is the same as the input data, i.e. environment id. 390 391 .. note:: 392 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 393 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 394 You can implement you own model rather than use the default model. For more information, please raise an \ 395 issue in GitHub repo and we will continue to follow up. 396 397 .. note:: 398 For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. 399 """ 400 # Set noise mode for NoisyNet for exploration in collecting if enabled in config. 401 # We need to reset set_noise_mode every _forward_xxx because the model is reused across different 402 # phases (learn/collect/eval). 403 if self._cfg.noisy_net: 404 set_noise_mode(self._collect_model, True) 405 406 data_id = list(data.keys()) 407 data = default_collate(list(data.values())) 408 if self._cuda: 409 data = to_device(data, self._device) 410 411 self._collect_model.eval() 412 with torch.no_grad(): 413 output = self._collect_model.forward(data, eps=eps) 414 if self._cuda: 415 output = to_device(output, 'cpu') 416 output = default_decollate(output) 417 return {i: d for i, d in zip(data_id, output)} 418 419 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 420 """ 421 Overview: 422 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ 423 can be used for training directly. In DQN with nstep TD, a train sample is a processed transition. \ 424 This method is usually used in collectors to execute necessary \ 425 RL data preprocessing before training, which can help learner amortize relevant time consumption. \ 426 In addition, you can also implement this method as an identity function and do the data processing \ 427 in ``self._forward_learn`` method. 428 Arguments: 429 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ 430 in the same format as the return value of ``self._process_transition`` method. 431 Returns: 432 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is similar in format \ 433 to input transitions, but may contain more data for training, such as nstep reward and target obs. 434 """ 435 transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma) 436 return get_train_sample(transitions, self._unroll_len) 437 438 def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], 439 timestep: namedtuple) -> Dict[str, torch.Tensor]: 440 """ 441 Overview: 442 Process and pack one timestep transition data into a dict, which can be directly used for training and \ 443 saved in replay buffer. For DQN, it contains obs, next_obs, action, reward, done. 444 Arguments: 445 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. 446 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ 447 as input. For DQN, it contains the action and the logit (q_value) of the action. 448 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 449 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 450 reward, done, info, etc. 451 Returns: 452 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 453 """ 454 transition = { 455 'obs': obs, 456 'next_obs': timestep.obs, 457 'action': policy_output['action'], 458 'reward': timestep.reward, 459 'done': timestep.done, 460 } 461 return transition 462 463 def _init_eval(self) -> None: 464 """ 465 Overview: 466 Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \ 467 eval model to greedily select action with argmax q_value mechanism. 468 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 469 470 .. note:: 471 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 472 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 473 """ 474 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 475 self._eval_model.reset() 476 477 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 478 """ 479 Overview: 480 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 481 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 482 action to interact with the envs. 483 Arguments: 484 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 485 key of the dict is environment id and the value is the corresponding data of the env. 486 Returns: 487 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 488 key of the dict is the same as the input data, i.e. environment id. 489 490 .. note:: 491 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 492 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 493 You can implement you own model rather than use the default model. For more information, please raise an \ 494 issue in GitHub repo and we will continue to follow up. 495 496 .. note:: 497 For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. 498 """ 499 # We need to reset set_noise_mode every _forward_xxx because the model is reused across different 500 # phases (learn/collect/eval). 501 # Ensure that in evaluation mode noise is disabled. 502 set_noise_mode(self._eval_model, False) 503 504 data_id = list(data.keys()) 505 data = default_collate(list(data.values())) 506 if self._cuda: 507 data = to_device(data, self._device) 508 509 self._eval_model.eval() 510 with torch.no_grad(): 511 output = self._eval_model.forward(data) 512 if self._cuda: 513 output = to_device(output, 'cpu') 514 output = default_decollate(output) 515 return {i: d for i, d in zip(data_id, output)} 516 517 def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = False) -> Dict[str, Any]: 518 """ 519 Overview: 520 Calculate priority for replay buffer. 521 Arguments: 522 - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training. 523 - update_target_model (:obj:`bool`): Whether to update target model. 524 Returns: 525 - priority (:obj:`Dict[str, Any]`): Dict type priority data, values are python scalar or a list of scalars. 526 ArgumentsKeys: 527 - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` 528 - optional: ``value_gamma`` 529 ReturnsKeys: 530 - necessary: ``priority`` 531 """ 532 533 if update_target_model: 534 self._target_model.load_state_dict(self._learn_model.state_dict()) 535 536 data = default_preprocess_learn( 537 data, 538 use_priority=False, 539 use_priority_IS_weight=False, 540 ignore_done=self._cfg.learn.ignore_done, 541 use_nstep=True 542 ) 543 if self._cuda: 544 data = to_device(data, self._device) 545 # ==================== 546 # Q-learning forward 547 # ==================== 548 self._learn_model.eval() 549 self._target_model.eval() 550 with torch.no_grad(): 551 # Current q value (main model) 552 q_value = self._learn_model.forward(data['obs'])['logit'] 553 # Target q value 554 target_q_value = self._target_model.forward(data['next_obs'])['logit'] 555 # Max q value action (main model), i.e. Double DQN 556 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 557 data_n = q_nstep_td_data( 558 q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight'] 559 ) 560 value_gamma = data.get('value_gamma') 561 loss, td_error_per_sample = q_nstep_td_error( 562 data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma 563 ) 564 return {'priority': td_error_per_sample.abs().tolist()} 565 566 def _reset_noise(self, model: torch.nn.Module): 567 r""" 568 Overview: 569 Reset the noise of model. 570 571 Arguments: 572 - model (:obj:`torch.nn.Module`): the model to reset, must contain reset_noise method 573 """ 574 for m in model.modules(): 575 if hasattr(m, 'reset_noise'): 576 m.reset_noise() 577 578 579@POLICY_REGISTRY.register('dqn_stdim') 580class DQNSTDIMPolicy(DQNPolicy): 581 """ 582 Overview: 583 Policy class of DQN algorithm, extended by ST-DIM auxiliary objectives. 584 ST-DIM paper link: https://arxiv.org/abs/1906.08226. 585 Config: 586 == ==================== ======== ============== ======================================== ======================= 587 ID Symbol Type Default Value Description Other(Shape) 588 == ==================== ======== ============== ======================================== ======================= 589 1 ``type`` str dqn_stdim | RL policy register name, refer to | This arg is optional, 590 | registry ``POLICY_REGISTRY`` | a placeholder 591 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 592 | erent from modes 593 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 594 | or off-policy 595 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, 596 | update priority 597 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight 598 | ``_weight`` | to correct biased update. If True, 599 | priority must be True. 600 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse 601 | ``factor`` [0.95, 0.999] | gamma | reward env 602 7 ``nstep`` int 1, | N-step reward discount sum for target 603 [3, 5] | q_value estimation 604 8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary 605 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 606 | valid in serial training | means more off-policy 607 | ``_gpu`` 608 10 | ``learn.batch_`` int 64 | The number of samples of an iteration 609 | ``size`` 610 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 611 | ``_rate`` 612 12 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update 613 | ``update_freq`` 614 13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 615 | ``done`` | calculation. | fake termination env 616 14 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from 617 | call of collector. | different envs 618 15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 619 | ``_len`` 620 16 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp', 621 | 'linear']. 622 17 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1] 623 | ``start`` 624 18 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1] 625 | ``end`` 626 19 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set 627 | ``decay`` | decay=10000 means 628 | the exploration rate 629 | decay from start 630 | value to end value 631 | during decay length. 632 20 | ``aux_loss`` float 0.001 | the ratio of the auxiliary loss to | any real value, 633 | ``_weight`` | the TD loss | typically in 634 | [-0.1, 0.1]. 635 == ==================== ======== ============== ======================================== ======================= 636 """ 637 638 config = dict( 639 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 640 type='dqn_stdim', 641 # (bool) Whether to use cuda in policy. 642 cuda=False, 643 # (bool) Whether to learning policy is the same as collecting data policy (on-policy). 644 on_policy=False, 645 # (bool) Whether to enable priority experience sample. 646 priority=False, 647 # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. 648 priority_IS_weight=False, 649 # (float) Discount factor(gamma) for returns. 650 discount_factor=0.97, 651 # (int) The number of step for calculating target q_value. 652 nstep=1, 653 # (float) The weight of auxiliary loss to main loss. 654 aux_loss_weight=0.001, 655 # learn_mode config 656 learn=dict( 657 # How many updates(iterations) to train after collector's one collection. 658 # Bigger "update_per_collect" means bigger off-policy. 659 # collect data -> update policy-> collect data -> ... 660 update_per_collect=3, 661 # (int) How many samples in a training batch. 662 batch_size=64, 663 # (float) The step size of gradient descent. 664 learning_rate=0.001, 665 # (int) Frequence of target network update. 666 target_update_freq=100, 667 # (bool) Whether ignore done(usually for max step termination env). 668 ignore_done=False, 669 ), 670 # collect_mode config 671 collect=dict( 672 # (int) How many training samples collected in one collection procedure. 673 # Only one of [n_sample, n_episode] shoule be set. 674 # n_sample=8, 675 # (int) Cut trajectories into pieces with length "unroll_len". 676 unroll_len=1, 677 ), 678 eval=dict(), # for compability 679 # other config 680 other=dict( 681 # Epsilon greedy with decay. 682 eps=dict( 683 # (str) Decay type. Support ['exp', 'linear']. 684 type='exp', 685 # (float) Epsilon start value. 686 start=0.95, 687 # (float) Epsilon end value. 688 end=0.1, 689 # (int) Decay length (env step). 690 decay=10000, 691 ), 692 replay_buffer=dict( 693 # (int) Maximum size of replay buffer. Usually, larger buffer size is better. 694 replay_buffer_size=10000, 695 ), 696 ), 697 ) 698 699 def _init_learn(self) -> None: 700 """ 701 Overview: 702 Initialize the learn mode of policy, including related attributes and modules. For DQNSTDIM, it first \ 703 call super class's ``_init_learn`` method, then initialize extra auxiliary model, its optimizer, and the \ 704 loss weight. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 705 706 .. note:: 707 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 708 and ``_load_state_dict_learn`` methods. 709 710 .. note:: 711 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 712 713 .. note:: 714 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 715 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 716 """ 717 super()._init_learn() 718 x_size, y_size = self._get_encoding_size() 719 self._aux_model = ContrastiveLoss(x_size, y_size, **self._cfg.aux_model) 720 if self._cuda: 721 self._aux_model.cuda() 722 self._aux_optimizer = Adam(self._aux_model.parameters(), lr=self._cfg.learn.learning_rate) 723 self._aux_loss_weight = self._cfg.aux_loss_weight 724 725 def _get_encoding_size(self) -> Tuple[Tuple[int], Tuple[int]]: 726 """ 727 Overview: 728 Get the input encoding size of the ST-DIM axuiliary model. 729 Returns: 730 - info_dict (:obj:`Tuple[Tuple[int], Tuple[int]]`): The encoding size without the first (Batch) dimension. 731 """ 732 obs = self._cfg.model.obs_shape 733 if isinstance(obs, int): 734 obs = [obs] 735 test_data = { 736 "obs": torch.randn(1, *obs), 737 "next_obs": torch.randn(1, *obs), 738 } 739 if self._cuda: 740 test_data = to_device(test_data, self._device) 741 with torch.no_grad(): 742 x, y = self._model_encode(test_data) 743 return x.size()[1:], y.size()[1:] 744 745 def _model_encode(self, data: dict) -> Tuple[torch.Tensor]: 746 """ 747 Overview: 748 Get the encoding of the main model as input for the auxiliary model. 749 Arguments: 750 - data (:obj:`dict`): Dict type data, same as the _forward_learn input. 751 Returns: 752 - (:obj:`Tuple[torch.Tensor]`): the tuple of two tensors to apply contrastive embedding learning. \ 753 In ST-DIM algorithm, these two variables are the dqn encoding of `obs` and `next_obs` respectively. 754 """ 755 assert hasattr(self._model, "encoder") 756 x = self._model.encoder(data["obs"]) 757 y = self._model.encoder(data["next_obs"]) 758 return x, y 759 760 def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: 761 """ 762 Overview: 763 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 764 that the policy inputs some training batch data from the replay buffer and then returns the output \ 765 result, including various training information such as loss, q value, priority, aux_loss. 766 Arguments: 767 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 768 training samples. For each element in list, the key of the dict is the name of data items and the \ 769 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 770 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 771 dimension by some utility functions such as ``default_preprocess_learn``. \ 772 For DQNSTDIM, each element in list is a dict containing at least the following keys: ``obs``, \ 773 ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as \ 774 ``weight`` and ``value_gamma``. 775 Returns: 776 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 777 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 778 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 779 780 .. note:: 781 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 782 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 783 You can implement you own model rather than use the default model. For more information, please raise an \ 784 issue in GitHub repo and we will continue to follow up. 785 """ 786 data = default_preprocess_learn( 787 data, 788 use_priority=self._priority, 789 use_priority_IS_weight=self._cfg.priority_IS_weight, 790 ignore_done=self._cfg.learn.ignore_done, 791 use_nstep=True 792 ) 793 if self._cuda: 794 data = to_device(data, self._device) 795 796 # ====================== 797 # Auxiliary model update 798 # ====================== 799 # RL network encoding 800 # To train the auxiliary network, the gradients of x, y should be 0. 801 with torch.no_grad(): 802 x_no_grad, y_no_grad = self._model_encode(data) 803 # the forward function of the auxiliary network 804 self._aux_model.train() 805 aux_loss_learn = self._aux_model.forward(x_no_grad, y_no_grad) 806 # the BP process of the auxiliary network 807 self._aux_optimizer.zero_grad() 808 aux_loss_learn.backward() 809 if self._cfg.multi_gpu: 810 self.sync_gradients(self._aux_model) 811 self._aux_optimizer.step() 812 813 # ==================== 814 # Q-learning forward 815 # ==================== 816 self._learn_model.train() 817 self._target_model.train() 818 # Current q value (main model) 819 q_value = self._learn_model.forward(data['obs'])['logit'] 820 # Target q value 821 with torch.no_grad(): 822 target_q_value = self._target_model.forward(data['next_obs'])['logit'] 823 # Max q value action (main model) 824 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 825 826 data_n = q_nstep_td_data( 827 q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight'] 828 ) 829 value_gamma = data.get('value_gamma') 830 bellman_loss, td_error_per_sample = q_nstep_td_error( 831 data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma 832 ) 833 834 # ====================== 835 # Compute auxiliary loss 836 # ====================== 837 x, y = self._model_encode(data) 838 self._aux_model.eval() 839 aux_loss_eval = self._aux_model.forward(x, y) * self._aux_loss_weight 840 loss = aux_loss_eval + bellman_loss 841 842 # ==================== 843 # Q-learning update 844 # ==================== 845 self._optimizer.zero_grad() 846 loss.backward() 847 if self._cfg.multi_gpu: 848 self.sync_gradients(self._learn_model) 849 self._optimizer.step() 850 851 # ============= 852 # after update 853 # ============= 854 self._target_model.update(self._learn_model.state_dict()) 855 return { 856 'cur_lr': self._optimizer.defaults['lr'], 857 'bellman_loss': bellman_loss.item(), 858 'aux_loss_learn': aux_loss_learn.item(), 859 'aux_loss_eval': aux_loss_eval.item(), 860 'total_loss': loss.item(), 861 'q_value': q_value.mean().item(), 862 'priority': td_error_per_sample.abs().tolist(), 863 # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. 864 # '[histogram]action_distribution': data['action'], 865 } 866 867 def _monitor_vars_learn(self) -> List[str]: 868 """ 869 Overview: 870 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 871 as text logger, tensorboard logger, will use these keys to save the corresponding data. 872 Returns: 873 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 874 """ 875 return ['cur_lr', 'bellman_loss', 'aux_loss_learn', 'aux_loss_eval', 'total_loss', 'q_value'] 876 877 def _state_dict_learn(self) -> Dict[str, Any]: 878 """ 879 Overview: 880 Return the state_dict of learn mode, usually including model and optimizer. 881 Returns: 882 - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. 883 """ 884 return { 885 'model': self._learn_model.state_dict(), 886 'target_model': self._target_model.state_dict(), 887 'optimizer': self._optimizer.state_dict(), 888 'aux_optimizer': self._aux_optimizer.state_dict(), 889 } 890 891 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 892 """ 893 Overview: 894 Load the state_dict variable into policy learn mode. 895 Arguments: 896 - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. 897 898 .. tip:: 899 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 900 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 901 complicated operation. 902 """ 903 self._learn_model.load_state_dict(state_dict['model']) 904 self._target_model.load_state_dict(state_dict['target_model']) 905 self._optimizer.load_state_dict(state_dict['optimizer']) 906 self._aux_optimizer.load_state_dict(state_dict['aux_optimizer'])