Skip to content

ding.policy.mdqn

ding.policy.mdqn

MDQNPolicy

Bases: DQNPolicy

Overview

Policy class of Munchausen DQN algorithm, extended by auxiliary objectives. Paper link: https://arxiv.org/abs/2007.14430.

Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str mdqn | RL policy register name, refer to | This arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | This arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool False | Whether use priority(PER) | Priority sample, | update priority 5 | priority_IS bool False | Whether use Importance Sampling Weight | _weight | to correct biased update. If True, | priority must be True. 6 | discount_ float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 7 nstep int 1, | N-step reward discount sum for target [3, 5] | q_value estimation 8 | learn.update int 1 | How many updates(iterations) to train | This args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy | _gpu 10 | learn.batch_ int 32 | The number of samples of an iteration | size 11 | learn.learning float 0.001 | Gradient step length of an iteration. | _rate 12 | learn.target_ int 2000 | Frequence of target network update. | Hard(assign) update | update_freq 13 | learn.ignore_ bool False | Whether ignore done for target value | Enable it for some | done | calculation. | fake termination env 14 collect.n_sample int 4 | The number of training samples of a | It varies from | call of collector. | different envs 15 | collect.unroll int 1 | unroll length of an iteration | In RNN, unroll_len>1 | _len 16 | other.eps.type str exp | exploration rate decay type | Support ['exp', | 'linear']. 17 | other.eps. float 0.01 | start value of exploration rate | [0,1] | start 18 | other.eps. float 0.001 | end value of exploration rate | [0,1] | end 19 | other.eps. int 250000 | decay length of exploration | greater than 0. set | decay | decay=250000 means | the exploration rate | decay from start | value to end value | during decay length. 20 | entropy_tau float 0.003 | the ration of entropy in TD loss 21 | alpha float 0.9 | the ration of Munchausen term to the | TD loss == ==================== ======== ============== ======================================== =======================

Full Source Code

../ding/policy/mdqn.py

1from typing import List, Dict, Any 2import copy 3import torch 4 5from ding.torch_utils import Adam, to_device 6from ding.rl_utils import m_q_1step_td_data, m_q_1step_td_error 7from ding.model import model_wrap 8from ding.utils import POLICY_REGISTRY 9 10from .dqn import DQNPolicy 11from .common_utils import default_preprocess_learn 12 13 14@POLICY_REGISTRY.register('mdqn') 15class MDQNPolicy(DQNPolicy): 16 """ 17 Overview: 18 Policy class of Munchausen DQN algorithm, extended by auxiliary objectives. 19 Paper link: https://arxiv.org/abs/2007.14430. 20 Config: 21 == ==================== ======== ============== ======================================== ======================= 22 ID Symbol Type Default Value Description Other(Shape) 23 == ==================== ======== ============== ======================================== ======================= 24 1 ``type`` str mdqn | RL policy register name, refer to | This arg is optional, 25 | registry ``POLICY_REGISTRY`` | a placeholder 26 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 27 | erent from modes 28 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 29 | or off-policy 30 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, 31 | update priority 32 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight 33 | ``_weight`` | to correct biased update. If True, 34 | priority must be True. 35 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse 36 | ``factor`` [0.95, 0.999] | gamma | reward env 37 7 ``nstep`` int 1, | N-step reward discount sum for target 38 [3, 5] | q_value estimation 39 8 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary 40 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 41 | valid in serial training | means more off-policy 42 | ``_gpu`` 43 10 | ``learn.batch_`` int 32 | The number of samples of an iteration 44 | ``size`` 45 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 46 | ``_rate`` 47 12 | ``learn.target_`` int 2000 | Frequence of target network update. | Hard(assign) update 48 | ``update_freq`` 49 13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 50 | ``done`` | calculation. | fake termination env 51 14 ``collect.n_sample`` int 4 | The number of training samples of a | It varies from 52 | call of collector. | different envs 53 15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 54 | ``_len`` 55 16 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp', 56 | 'linear']. 57 17 | ``other.eps.`` float 0.01 | start value of exploration rate | [0,1] 58 | ``start`` 59 18 | ``other.eps.`` float 0.001 | end value of exploration rate | [0,1] 60 | ``end`` 61 19 | ``other.eps.`` int 250000 | decay length of exploration | greater than 0. set 62 | ``decay`` | decay=250000 means 63 | the exploration rate 64 | decay from start 65 | value to end value 66 | during decay length. 67 20 | ``entropy_tau`` float 0.003 | the ration of entropy in TD loss 68 21 | ``alpha`` float 0.9 | the ration of Munchausen term to the 69 | TD loss 70 == ==================== ======== ============== ======================================== ======================= 71 """ 72 config = dict( 73 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 74 type='mdqn', 75 # (bool) Whether to use cuda in policy. 76 cuda=False, 77 # (bool) Whether learning policy is the same as collecting data policy(on-policy). 78 on_policy=False, 79 # (bool) Whether to enable priority experience sample. 80 priority=False, 81 # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. 82 priority_IS_weight=False, 83 # (float) Discount factor(gamma) for returns. 84 discount_factor=0.97, 85 # (float) Entropy factor (tau) for Munchausen DQN. 86 entropy_tau=0.03, 87 # (float) Discount factor (alpha) for Munchausen term. 88 m_alpha=0.9, 89 # (int) The number of step for calculating target q_value. 90 nstep=1, 91 # learn_mode config 92 learn=dict( 93 # (int) How many updates(iterations) to train after collector's one collection. 94 # Bigger "update_per_collect" means bigger off-policy. 95 # collect data -> update policy-> collect data -> ... 96 update_per_collect=3, 97 # (int) How many samples in a training batch 98 batch_size=64, 99 # (float) The step size of gradient descent 100 learning_rate=0.001, 101 # (int) Frequence of target network update. 102 target_update_freq=100, 103 # (bool) Whether ignore done(usually for max step termination env). 104 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 105 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 106 # However, interaction with HalfCheetah always gets done with done is False, 107 # Since we inplace done==True with done==False to keep 108 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 109 # when the episode step is greater than max episode step. 110 ignore_done=False, 111 ), 112 # collect_mode config 113 collect=dict( 114 # (int) How many training samples collected in one collection procedure. 115 # Only one of [n_sample, n_episode] shoule be set. 116 n_sample=4, 117 # (int) Split episodes or trajectories into pieces with length `unroll_len`. 118 unroll_len=1, 119 ), 120 eval=dict(), # for compability 121 # other config 122 other=dict( 123 # Epsilon greedy with decay. 124 eps=dict( 125 # (str) Decay type. Support ['exp', 'linear']. 126 type='exp', 127 # (float) Epsilon start value. 128 start=0.95, 129 # (float) Epsilon end value. 130 end=0.1, 131 # (int) Decay length(env step). 132 decay=10000, 133 ), 134 replay_buffer=dict( 135 # (int) Maximum size of replay buffer. Usually, larger buffer size is better. 136 replay_buffer_size=10000, 137 ), 138 ), 139 ) 140 141 def _init_learn(self) -> None: 142 """ 143 Overview: 144 Initialize the learn mode of policy, including related attributes and modules. For MDQN, it contains \ 145 optimizer, algorithm-specific arguments such as entropy_tau, m_alpha and nstep, main and target model. 146 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 147 148 .. note:: 149 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 150 and ``_load_state_dict_learn`` methods. 151 152 .. note:: 153 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 154 155 .. note:: 156 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 157 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 158 """ 159 self._priority = self._cfg.priority 160 self._priority_IS_weight = self._cfg.priority_IS_weight 161 # Optimizer 162 # set eps in order to consistent with the original paper implementation 163 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, eps=0.0003125) 164 165 self._gamma = self._cfg.discount_factor 166 self._nstep = self._cfg.nstep 167 self._entropy_tau = self._cfg.entropy_tau 168 self._m_alpha = self._cfg.m_alpha 169 170 # use model_wrapper for specialized demands of different modes 171 self._target_model = copy.deepcopy(self._model) 172 if 'target_update_freq' in self._cfg.learn: 173 self._target_model = model_wrap( 174 self._target_model, 175 wrapper_name='target', 176 update_type='assign', 177 update_kwargs={'freq': self._cfg.learn.target_update_freq} 178 ) 179 elif 'target_theta' in self._cfg.learn: 180 self._target_model = model_wrap( 181 self._target_model, 182 wrapper_name='target', 183 update_type='momentum', 184 update_kwargs={'theta': self._cfg.learn.target_theta} 185 ) 186 else: 187 raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta") 188 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 189 self._learn_model.reset() 190 self._target_model.reset() 191 192 def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: 193 """ 194 Overview: 195 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 196 that the policy inputs some training batch data from the replay buffer and then returns the output \ 197 result, including various training information such as loss, action_gap, clip_frac, priority. 198 Arguments: 199 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 200 training samples. For each element in list, the key of the dict is the name of data items and the \ 201 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 202 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 203 dimension by some utility functions such as ``default_preprocess_learn``. \ 204 For MDQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 205 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ 206 and ``value_gamma``. 207 Returns: 208 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 209 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 210 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 211 212 .. note:: 213 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 214 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 215 You can implement you own model rather than use the default model. For more information, please raise an \ 216 issue in GitHub repo and we will continue to follow up. 217 218 .. note:: 219 For more detailed examples, please refer to our unittest for MDQNPolicy: ``ding.policy.tests.test_mdqn``. 220 """ 221 data = default_preprocess_learn( 222 data, 223 use_priority=self._priority, 224 use_priority_IS_weight=self._cfg.priority_IS_weight, 225 ignore_done=self._cfg.learn.ignore_done, 226 use_nstep=True 227 ) 228 if self._cuda: 229 data = to_device(data, self._device) 230 # ==================== 231 # Q-learning forward 232 # ==================== 233 self._learn_model.train() 234 self._target_model.train() 235 # Current q value (main model) 236 q_value = self._learn_model.forward(data['obs'])['logit'] 237 # Target q value 238 with torch.no_grad(): 239 target_q_value_current = self._target_model.forward(data['obs'])['logit'] 240 target_q_value = self._target_model.forward(data['next_obs'])['logit'] 241 242 data_m = m_q_1step_td_data( 243 q_value, target_q_value_current, target_q_value, data['action'], data['reward'].squeeze(0), data['done'], 244 data['weight'] 245 ) 246 247 loss, td_error_per_sample, action_gap, clipfrac = m_q_1step_td_error( 248 data_m, self._gamma, self._entropy_tau, self._m_alpha 249 ) 250 # ==================== 251 # Q-learning update 252 # ==================== 253 self._optimizer.zero_grad() 254 loss.backward() 255 if self._cfg.multi_gpu: 256 self.sync_gradients(self._learn_model) 257 self._optimizer.step() 258 259 # ============= 260 # after update 261 # ============= 262 self._target_model.update(self._learn_model.state_dict()) 263 return { 264 'cur_lr': self._optimizer.defaults['lr'], 265 'total_loss': loss.item(), 266 'q_value': q_value.mean().item(), 267 'target_q_value': target_q_value.mean().item(), 268 'priority': td_error_per_sample.abs().tolist(), 269 'action_gap': action_gap.item(), 270 'clip_frac': clipfrac.mean().item(), 271 } 272 273 def _monitor_vars_learn(self) -> List[str]: 274 """ 275 Overview: 276 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 277 as text logger, tensorboard logger, will use these keys to save the corresponding data. 278 Returns: 279 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 280 """ 281 return ['cur_lr', 'total_loss', 'q_value', 'action_gap', 'clip_frac']