Skip to content

ding.policy.rainbow

ding.policy.rainbow

RainbowDQNPolicy

Bases: DQNPolicy

Overview

Rainbow DQN contain several improvements upon DQN, including: - target network - dueling architecture - prioritized experience replay - n_step return - noise net - distribution net

Therefore, the RainbowDQNPolicy class inherit upon DQNPolicy class

Config

== ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str rainbow | RL policy register name, refer to | this arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | this arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool True | Whether use priority(PER) | priority sample, | update priority 5 model.v_min float -10 | Value of the smallest atom | in the support set. 6 model.v_max float 10 | Value of the largest atom | in the support set. 7 model.n_atom int 51 | Number of atoms in the support set | of the value distribution. 8 | other.eps float 0.05 | Start value for epsilon decay. It's | .start | small because rainbow use noisy net. 9 | other.eps float 0.05 | End value for epsilon decay. | .end 10 | discount_ float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 11 nstep int 3, | N-step reward discount sum for target [3, 5] | q_value estimation 12 | learn.update int 3 | How many updates(iterations) to train | this args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy == ==================== ======== ============== ======================================== =======================

Full Source Code

../ding/policy/rainbow.py

1from typing import List, Dict, Any, Tuple, Union 2import torch 3import copy 4 5from ding.torch_utils import Adam, to_device 6from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, get_train_sample, get_nstep_return_data 7from ding.model import model_wrap 8from ding.utils import POLICY_REGISTRY 9from ding.utils.data import default_collate, default_decollate 10from .dqn import DQNPolicy 11from .common_utils import default_preprocess_learn, set_noise_mode 12 13 14@POLICY_REGISTRY.register('rainbow') 15class RainbowDQNPolicy(DQNPolicy): 16 r""" 17 Overview: 18 Rainbow DQN contain several improvements upon DQN, including: 19 - target network 20 - dueling architecture 21 - prioritized experience replay 22 - n_step return 23 - noise net 24 - distribution net 25 26 Therefore, the RainbowDQNPolicy class inherit upon DQNPolicy class 27 28 Config: 29 == ==================== ======== ============== ======================================== ======================= 30 ID Symbol Type Default Value Description Other(Shape) 31 == ==================== ======== ============== ======================================== ======================= 32 1 ``type`` str rainbow | RL policy register name, refer to | this arg is optional, 33 | registry ``POLICY_REGISTRY`` | a placeholder 34 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff- 35 | erent from modes 36 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 37 | or off-policy 38 4 ``priority`` bool True | Whether use priority(PER) | priority sample, 39 | update priority 40 5 ``model.v_min`` float -10 | Value of the smallest atom 41 | in the support set. 42 6 ``model.v_max`` float 10 | Value of the largest atom 43 | in the support set. 44 7 ``model.n_atom`` int 51 | Number of atoms in the support set 45 | of the value distribution. 46 8 | ``other.eps`` float 0.05 | Start value for epsilon decay. It's 47 | ``.start`` | small because rainbow use noisy net. 48 9 | ``other.eps`` float 0.05 | End value for epsilon decay. 49 | ``.end`` 50 10 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse 51 | ``factor`` [0.95, 0.999] | gamma | reward env 52 11 ``nstep`` int 3, | N-step reward discount sum for target 53 [3, 5] | q_value estimation 54 12 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary 55 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 56 | valid in serial training | means more off-policy 57 == ==================== ======== ============== ======================================== ======================= 58 59 """ 60 61 config = dict( 62 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 63 type='rainbow', 64 # (bool) Whether to use cuda for network. 65 cuda=False, 66 # (bool) Whether the RL algorithm is on-policy or off-policy. 67 on_policy=False, 68 # (bool) Whether use priority(priority sample, IS weight, update priority) 69 priority=True, 70 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 71 priority_IS_weight=True, 72 # (int) Number of training samples(randomly collected) in replay buffer when training starts. 73 # random_collect_size=2000, 74 model=dict( 75 # (float) Value of the smallest atom in the support set. 76 # Default to -10.0. 77 v_min=-10, 78 # (float) Value of the smallest atom in the support set. 79 # Default to 10.0. 80 v_max=10, 81 # (int) Number of atoms in the support set of the 82 # value distribution. Default to 51. 83 n_atom=51, 84 ), 85 # (float) Reward's future discount factor, aka. gamma. 86 discount_factor=0.99, 87 # (int) N-step reward for target q_value estimation 88 nstep=3, 89 # (bool) Whether to use NoisyNet for exploration in both learning and collecting. Default is True. 90 noisy_net=True, 91 learn=dict( 92 # How many updates(iterations) to train after collector's one collection. 93 # Bigger "update_per_collect" means bigger off-policy. 94 # collect data -> update policy-> collect data -> ... 95 update_per_collect=1, 96 batch_size=32, 97 learning_rate=0.001, 98 # ============================================================== 99 # The following configs are algorithm-specific 100 # ============================================================== 101 # (int) Frequence of target network update. 102 target_update_freq=100, 103 # (bool) Whether ignore done(usually for max step termination env) 104 ignore_done=False, 105 ), 106 # collect_mode config 107 collect=dict( 108 # (int) Only one of [n_sample, n_episode] shoule be set 109 # n_sample=32, 110 # (int) Cut trajectories into pieces with length "unroll_len". 111 unroll_len=1, 112 ), 113 eval=dict(), 114 # other config 115 other=dict( 116 # Epsilon greedy with decay. 117 eps=dict( 118 # (str) Decay type. Support ['exp', 'linear']. 119 type='exp', 120 # (float) End value for epsilon decay, in [0, 1]. It's equals to `end` because rainbow uses noisy net. 121 start=0.05, 122 # (float) End value for epsilon decay, in [0, 1]. 123 end=0.05, 124 # (int) Env steps of epsilon decay. 125 decay=100000, 126 ), 127 replay_buffer=dict( 128 # (int) Max size of replay buffer. 129 replay_buffer_size=100000, 130 # (float) Prioritization exponent. 131 alpha=0.6, 132 # (float) Importance sample soft coefficient. 133 # 0 means no correction, while 1 means full correction 134 beta=0.4, 135 # (int) Anneal step for beta: 0 means no annealing. Defaults to 0 136 anneal_step=100000, 137 ) 138 ), 139 ) 140 141 def default_model(self) -> Tuple[str, List[str]]: 142 return 'rainbowdqn', ['ding.model.template.q_learning'] 143 144 def _init_learn(self) -> None: 145 r""" 146 Overview: 147 Init the learner model of RainbowDQNPolicy 148 149 Arguments: 150 - learning_rate (:obj:`float`): the learning rate fo the optimizer 151 - gamma (:obj:`float`): the discount factor 152 - nstep (:obj:`int`): the num of n step return 153 - v_min (:obj:`float`): value distribution minimum value 154 - v_max (:obj:`float`): value distribution maximum value 155 - n_atom (:obj:`int`): the number of atom sample point 156 """ 157 self._priority = self._cfg.priority 158 self._priority_IS_weight = self._cfg.priority_IS_weight 159 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 160 self._gamma = self._cfg.discount_factor 161 self._nstep = self._cfg.nstep 162 self._v_max = self._cfg.model.v_max 163 self._v_min = self._cfg.model.v_min 164 self._n_atom = self._cfg.model.n_atom 165 166 self._target_model = copy.deepcopy(self._model) 167 self._target_model = model_wrap( 168 self._target_model, 169 wrapper_name='target', 170 update_type='assign', 171 update_kwargs={'freq': self._cfg.learn.target_update_freq} 172 ) 173 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 174 self._learn_model.reset() 175 self._target_model.reset() 176 177 def _forward_learn(self, data: dict) -> Dict[str, Any]: 178 """ 179 Overview: 180 Forward and backward function of learn mode, acquire the data and calculate the loss and\ 181 optimize learner model 182 183 Arguments: 184 - data (:obj:`dict`): Dict type data, including at least ['obs', 'next_obs', 'reward', 'action'] 185 186 Returns: 187 - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss 188 - cur_lr (:obj:`float`): current learning rate 189 - total_loss (:obj:`float`): the calculated loss 190 """ 191 data = default_preprocess_learn( 192 data, 193 use_priority=self._priority, 194 use_priority_IS_weight=self._cfg.priority_IS_weight, 195 ignore_done=self._cfg.learn.ignore_done, 196 use_nstep=True 197 ) 198 if self._cuda: 199 data = to_device(data, self._device) 200 # ==================== 201 # Rainbow forward 202 # ==================== 203 self._learn_model.train() 204 self._target_model.train() 205 206 # Set noise mode for NoisyNet for exploration in learning if enabled in config 207 set_noise_mode(self._learn_model, True) 208 set_noise_mode(self._target_model, True) 209 210 # reset noise of noisenet for both main model and target model 211 self._reset_noise(self._learn_model) 212 self._reset_noise(self._target_model) 213 q_dist = self._learn_model.forward(data['obs'])['distribution'] 214 with torch.no_grad(): 215 target_q_dist = self._target_model.forward(data['next_obs'])['distribution'] 216 self._reset_noise(self._learn_model) 217 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 218 value_gamma = data.get('value_gamma', None) 219 data = dist_nstep_td_data( 220 q_dist, target_q_dist, data['action'], target_q_action, data['reward'], data['done'], data['weight'] 221 ) 222 loss, td_error_per_sample = dist_nstep_td_error( 223 data, self._gamma, self._v_min, self._v_max, self._n_atom, nstep=self._nstep, value_gamma=value_gamma 224 ) 225 # ==================== 226 # Rainbow update 227 # ==================== 228 self._optimizer.zero_grad() 229 loss.backward() 230 self._optimizer.step() 231 # ============= 232 # after update 233 # ============= 234 self._target_model.update(self._learn_model.state_dict()) 235 return { 236 'cur_lr': self._optimizer.defaults['lr'], 237 'total_loss': loss.item(), 238 'priority': td_error_per_sample.abs().tolist(), 239 } 240 241 def _init_collect(self) -> None: 242 r""" 243 Overview: 244 Collect mode init moethod. Called by ``self.__init__``. 245 Init traj and unroll length, collect model. 246 247 .. note:: 248 the rainbow dqn enable the eps_greedy_sample, but might not need to use it, \ 249 as the noise_net contain noise that can help exploration 250 """ 251 self._unroll_len = self._cfg.collect.unroll_len 252 self._nstep = self._cfg.nstep 253 self._gamma = self._cfg.discount_factor 254 self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') 255 self._collect_model.reset() 256 257 def _forward_collect(self, data: dict, eps: float) -> dict: 258 r""" 259 Overview: 260 Reset the noise from noise net and collect output according to eps_greedy plugin 261 262 Arguments: 263 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 264 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 265 - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. 266 Returns: 267 - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. 268 ReturnsKeys 269 - necessary: ``action`` 270 """ 271 # Set noise mode for NoisyNet for exploration in collecting if enabled in config 272 # We need to reset set_noise_mode every _forward_xxx because the model is reused across 273 # different phases (learn/collect/eval). 274 set_noise_mode(self._collect_model, True) 275 276 data_id = list(data.keys()) 277 data = default_collate(list(data.values())) 278 if self._cuda: 279 data = to_device(data, self._device) 280 self._collect_model.eval() 281 with torch.no_grad(): 282 output = self._collect_model.forward(data, eps=eps) 283 if self._cuda: 284 output = to_device(output, 'cpu') 285 output = default_decollate(output) 286 return {i: d for i, d in zip(data_id, output)} 287 288 def _get_train_sample(self, traj: list) -> Union[None, List[Any]]: 289 r""" 290 Overview: 291 Get the trajectory and the n step return data, then sample from the n_step return data 292 293 Arguments: 294 - traj (:obj:`list`): The trajactory's buffer list 295 296 Returns: 297 - samples (:obj:`dict`): The training samples generated 298 """ 299 data = get_nstep_return_data(traj, self._nstep, gamma=self._gamma) 300 return get_train_sample(data, self._unroll_len) 301 302 def _reset_noise(self, model: torch.nn.Module): 303 r""" 304 Overview: 305 Reset the noise of model 306 307 Arguments: 308 - model (:obj:`torch.nn.Module`): the model to reset, must contain reset_noise method 309 """ 310 for m in model.modules(): 311 if hasattr(m, 'reset_noise'): 312 m.reset_noise()