Skip to content

ding.policy.iqn

ding.policy.iqn

IQNPolicy

Bases: DQNPolicy

Overview

Policy class of IQN algorithm. Paper link: https://arxiv.org/pdf/1806.06923.pdf. Distrbutional RL is a new direction of RL, which is more stable than the traditional RL algorithm. The core idea of distributional RL is to estimate the distribution of action value instead of the expectation. The difference between IQN and DQN is that IQN uses quantile regression to estimate the quantile value of the action distribution, while DQN uses the expectation of the action distribution.

Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str qrdqn | RL policy register name, refer to | this arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | this arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool True | Whether use priority(PER) | priority sample, | update priority 6 | other.eps float 0.05 | Start value for epsilon decay. It's | .start | small because rainbow use noisy net. 7 | other.eps float 0.05 | End value for epsilon decay. | .end 8 | discount_ float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 9 nstep int 3, | N-step reward discount sum for target [3, 5] | q_value estimation 10 | learn.update int 3 | How many updates(iterations) to train | this args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 11 learn.kappa float / | Threshold of Huber loss == ==================== ======== ============== ======================================== =======================

default_model()

Overview

Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

.. note:: The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For example about IQN, its registered name is iqn and the import_names is ding.model.template.q_learning.

Full Source Code

../ding/policy/iqn.py

1from typing import List, Dict, Any, Tuple, Union 2import copy 3import torch 4 5from ding.torch_utils import Adam, to_device 6from ding.rl_utils import iqn_nstep_td_data, iqn_nstep_td_error, get_train_sample, get_nstep_return_data 7from ding.model import model_wrap 8from ding.utils import POLICY_REGISTRY 9from ding.utils.data import default_collate, default_decollate 10from .dqn import DQNPolicy 11from .common_utils import default_preprocess_learn 12 13 14@POLICY_REGISTRY.register('iqn') 15class IQNPolicy(DQNPolicy): 16 """ 17 Overview: 18 Policy class of IQN algorithm. Paper link: https://arxiv.org/pdf/1806.06923.pdf. \ 19 Distrbutional RL is a new direction of RL, which is more stable than the traditional RL algorithm. \ 20 The core idea of distributional RL is to estimate the distribution of action value instead of the \ 21 expectation. The difference between IQN and DQN is that IQN uses quantile regression to estimate the \ 22 quantile value of the action distribution, while DQN uses the expectation of the action distribution. \ 23 24 Config: 25 == ==================== ======== ============== ======================================== ======================= 26 ID Symbol Type Default Value Description Other(Shape) 27 == ==================== ======== ============== ======================================== ======================= 28 1 ``type`` str qrdqn | RL policy register name, refer to | this arg is optional, 29 | registry ``POLICY_REGISTRY`` | a placeholder 30 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff- 31 | erent from modes 32 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 33 | or off-policy 34 4 ``priority`` bool True | Whether use priority(PER) | priority sample, 35 | update priority 36 6 | ``other.eps`` float 0.05 | Start value for epsilon decay. It's 37 | ``.start`` | small because rainbow use noisy net. 38 7 | ``other.eps`` float 0.05 | End value for epsilon decay. 39 | ``.end`` 40 8 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse 41 | ``factor`` [0.95, 0.999] | gamma | reward env 42 9 ``nstep`` int 3, | N-step reward discount sum for target 43 [3, 5] | q_value estimation 44 10 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary 45 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 46 | valid in serial training | means more off-policy 47 11 ``learn.kappa`` float / | Threshold of Huber loss 48 == ==================== ======== ============== ======================================== ======================= 49 """ 50 51 config = dict( 52 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 53 type='iqn', 54 # (bool) Whether to use cuda for network. 55 cuda=False, 56 # (bool) Whether the RL algorithm is on-policy or off-policy. 57 on_policy=False, 58 # (bool) Whether use priority(priority sample, IS weight, update priority) 59 priority=False, 60 # (float) Reward's future discount factor, aka. gamma. 61 discount_factor=0.97, 62 # (int) N-step reward for target q_value estimation 63 nstep=1, 64 learn=dict( 65 # How many updates(iterations) to train after collector's one collection. 66 # Bigger "update_per_collect" means bigger off-policy. 67 # collect data -> update policy-> collect data -> ... 68 update_per_collect=3, 69 batch_size=64, 70 learning_rate=0.001, 71 # ============================================================== 72 # The following configs are algorithm-specific 73 # ============================================================== 74 # (int) Frequence of target network update. 75 target_update_freq=100, 76 # (float) Threshold of Huber loss. In the IQN paper, this is denoted by kappa. Default to 1.0. 77 kappa=1.0, 78 # (bool) Whether ignore done(usually for max step termination env) 79 ignore_done=False, 80 ), 81 # collect_mode config 82 collect=dict( 83 # (int) Only one of [n_sample, n_step, n_episode] shoule be set 84 # n_sample=8, 85 # (int) Cut trajectories into pieces with length "unroll_len". 86 unroll_len=1, 87 ), 88 eval=dict(), 89 # other config 90 other=dict( 91 # Epsilon greedy with decay. 92 eps=dict( 93 # (str) Decay type. Support ['exp', 'linear']. 94 type='exp', 95 start=0.95, 96 end=0.1, 97 # (int) Decay length(env step) 98 decay=10000, 99 ), 100 replay_buffer=dict(replay_buffer_size=10000, ) 101 ), 102 ) 103 104 def default_model(self) -> Tuple[str, List[str]]: 105 """ 106 Overview: 107 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 108 automatically call this method to get the default model setting and create model. 109 Returns: 110 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 111 112 .. note:: 113 The user can define and use customized network model but must obey the same inferface definition indicated \ 114 by import_names path. For example about IQN, its registered name is ``iqn`` and the import_names is \ 115 ``ding.model.template.q_learning``. 116 """ 117 return 'iqn', ['ding.model.template.q_learning'] 118 119 def _init_learn(self) -> None: 120 """ 121 Overview: 122 Initialize the learn mode of policy, including related attributes and modules. For IQN, it mainly contains \ 123 optimizer, algorithm-specific arguments such as nstep, kappa and gamma, main and target model. 124 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 125 126 .. note:: 127 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 128 and ``_load_state_dict_learn`` methods. 129 130 .. note:: 131 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 132 133 .. note:: 134 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 135 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 136 """ 137 self._priority = self._cfg.priority 138 # Optimizer 139 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 140 141 self._gamma = self._cfg.discount_factor 142 self._nstep = self._cfg.nstep 143 self._kappa = self._cfg.learn.kappa 144 145 # use model_wrapper for specialized demands of different modes 146 self._target_model = copy.deepcopy(self._model) 147 self._target_model = model_wrap( 148 self._target_model, 149 wrapper_name='target', 150 update_type='assign', 151 update_kwargs={'freq': self._cfg.learn.target_update_freq} 152 ) 153 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 154 self._learn_model.reset() 155 self._target_model.reset() 156 157 def _forward_learn(self, data: List[Dict[int, Any]]) -> Dict[str, Any]: 158 """ 159 Overview: 160 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 161 that the policy inputs some training batch data from the replay buffer and then returns the output \ 162 result, including various training information such as loss, priority. 163 Arguments: 164 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 165 training samples. For each element in list, the key of the dict is the name of data items and the \ 166 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 167 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 168 dimension by some utility functions such as ``default_preprocess_learn``. \ 169 For IQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 170 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ 171 and ``value_gamma``. 172 Returns: 173 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 174 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 175 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 176 177 .. note:: 178 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 179 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 180 You can implement you own model rather than use the default model. For more information, please raise an \ 181 issue in GitHub repo and we will continue to follow up. 182 183 .. note:: 184 For more detailed examples, please refer to our unittest for IQNPolicy: ``ding.policy.tests.test_iqn``. 185 """ 186 data = default_preprocess_learn( 187 data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True 188 ) 189 if self._cuda: 190 data = to_device(data, self._device) 191 # ==================== 192 # Q-learning forward 193 # ==================== 194 self._learn_model.train() 195 self._target_model.train() 196 # Current q value (main model) 197 ret = self._learn_model.forward(data['obs']) 198 q_value = ret['q'] 199 replay_quantiles = ret['quantiles'] 200 # Target q value 201 with torch.no_grad(): 202 target_q_value = self._target_model.forward(data['next_obs'])['q'] 203 # Max q value action (main model) 204 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 205 206 data_n = iqn_nstep_td_data( 207 q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], replay_quantiles, 208 data['weight'] 209 ) 210 value_gamma = data.get('value_gamma') 211 loss, td_error_per_sample = iqn_nstep_td_error( 212 data_n, self._gamma, nstep=self._nstep, kappa=self._kappa, value_gamma=value_gamma 213 ) 214 215 # ==================== 216 # Q-learning update 217 # ==================== 218 self._optimizer.zero_grad() 219 loss.backward() 220 if self._cfg.multi_gpu: 221 self.sync_gradients(self._learn_model) 222 self._optimizer.step() 223 224 # ============= 225 # after update 226 # ============= 227 self._target_model.update(self._learn_model.state_dict()) 228 return { 229 'cur_lr': self._optimizer.defaults['lr'], 230 'total_loss': loss.item(), 231 'priority': td_error_per_sample.abs().tolist(), 232 # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. 233 # '[histogram]action_distribution': data['action'], 234 } 235 236 def _state_dict_learn(self) -> Dict[str, Any]: 237 """ 238 Overview: 239 Return the state_dict of learn mode, usually including model, target_model and optimizer. 240 Returns: 241 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 242 """ 243 return { 244 'model': self._learn_model.state_dict(), 245 'target_model': self._target_model.state_dict(), 246 'optimizer': self._optimizer.state_dict(), 247 } 248 249 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 250 """ 251 Overview: 252 Load the state_dict variable into policy learn mode. 253 Arguments: 254 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 255 256 .. tip:: 257 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 258 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 259 complicated operation. 260 """ 261 self._learn_model.load_state_dict(state_dict['model']) 262 self._target_model.load_state_dict(state_dict['target_model']) 263 self._optimizer.load_state_dict(state_dict['optimizer'])