Skip to content

ding.policy.c51

ding.policy.c51

C51Policy

Bases: DQNPolicy

Overview

Policy class of C51 algorithm.

Config

== ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str c51 | RL policy register name, refer to | this arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | this arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool False | Whether use priority(PER) | priority sample, | update priority 5 model.v_min float -10 | Value of the smallest atom | in the support set. 6 model.v_max float 10 | Value of the largest atom | in the support set. 7 model.n_atom int 51 | Number of atoms in the support set | of the value distribution. 8 | other.eps float 0.95 | Start value for epsilon decay. | .start | 9 | other.eps float 0.1 | End value for epsilon decay. | .end 10 | discount_ float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 11 nstep int 1, | N-step reward discount sum for target | q_value estimation 12 | learn.update int 3 | How many updates(iterations) to train | this args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy == ==================== ======== ============== ======================================== =======================

Full Source Code

../ding/policy/c51.py

1from typing import List, Dict, Any, Tuple, Union 2import copy 3import torch 4 5from ding.torch_utils import Adam, to_device 6from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, get_train_sample, get_nstep_return_data 7from ding.model import model_wrap 8from ding.utils import POLICY_REGISTRY 9from ding.utils.data import default_collate, default_decollate 10from .dqn import DQNPolicy 11from .common_utils import default_preprocess_learn 12 13 14@POLICY_REGISTRY.register('c51') 15class C51Policy(DQNPolicy): 16 r""" 17 Overview: 18 Policy class of C51 algorithm. 19 20 Config: 21 == ==================== ======== ============== ======================================== ======================= 22 ID Symbol Type Default Value Description Other(Shape) 23 == ==================== ======== ============== ======================================== ======================= 24 1 ``type`` str c51 | RL policy register name, refer to | this arg is optional, 25 | registry ``POLICY_REGISTRY`` | a placeholder 26 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff- 27 | erent from modes 28 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 29 | or off-policy 30 4 ``priority`` bool False | Whether use priority(PER) | priority sample, 31 | update priority 32 5 ``model.v_min`` float -10 | Value of the smallest atom 33 | in the support set. 34 6 ``model.v_max`` float 10 | Value of the largest atom 35 | in the support set. 36 7 ``model.n_atom`` int 51 | Number of atoms in the support set 37 | of the value distribution. 38 8 | ``other.eps`` float 0.95 | Start value for epsilon decay. 39 | ``.start`` | 40 9 | ``other.eps`` float 0.1 | End value for epsilon decay. 41 | ``.end`` 42 10 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse 43 | ``factor`` [0.95, 0.999] | gamma | reward env 44 11 ``nstep`` int 1, | N-step reward discount sum for target 45 | q_value estimation 46 12 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary 47 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 48 | valid in serial training | means more off-policy 49 == ==================== ======== ============== ======================================== ======================= 50 """ 51 52 config = dict( 53 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 54 type='c51', 55 # (bool) Whether to use cuda for network. 56 cuda=False, 57 # (bool) Whether the RL algorithm is on-policy or off-policy. 58 on_policy=False, 59 # (bool) Whether use priority(priority sample, IS weight, update priority) 60 priority=False, 61 # (float) Reward's future discount factor, aka. gamma. 62 discount_factor=0.97, 63 # (int) N-step reward for target q_value estimation 64 nstep=1, 65 model=dict( 66 v_min=-10, 67 v_max=10, 68 n_atom=51, 69 ), 70 learn=dict( 71 72 # How many updates(iterations) to train after collector's one collection. 73 # Bigger "update_per_collect" means bigger off-policy. 74 # collect data -> update policy-> collect data -> ... 75 update_per_collect=3, 76 batch_size=64, 77 learning_rate=0.001, 78 # ============================================================== 79 # The following configs are algorithm-specific 80 # ============================================================== 81 # (int) Frequence of target network update. 82 target_update_freq=100, 83 # (bool) Whether ignore done(usually for max step termination env) 84 ignore_done=False, 85 ), 86 # collect_mode config 87 collect=dict( 88 # (int) Only one of [n_sample, n_step, n_episode] shoule be set 89 # n_sample=8, 90 # (int) Cut trajectories into pieces with length "unroll_len". 91 unroll_len=1, 92 ), 93 eval=dict(), 94 # other config 95 other=dict( 96 # Epsilon greedy with decay. 97 eps=dict( 98 # (str) Decay type. Support ['exp', 'linear']. 99 type='exp', 100 start=0.95, 101 end=0.1, 102 # (int) Decay length(env step) 103 decay=10000, 104 ), 105 replay_buffer=dict(replay_buffer_size=10000, ) 106 ), 107 ) 108 109 def default_model(self) -> Tuple[str, List[str]]: 110 return 'c51dqn', ['ding.model.template.q_learning'] 111 112 def _init_learn(self) -> None: 113 r""" 114 Overview: 115 Learn mode init method. Called by ``self.__init__``. 116 Init the optimizer, algorithm config, main and target models. 117 """ 118 self._priority = self._cfg.priority 119 # Optimizer 120 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 121 122 self._gamma = self._cfg.discount_factor 123 self._nstep = self._cfg.nstep 124 self._v_max = self._cfg.model.v_max 125 self._v_min = self._cfg.model.v_min 126 self._n_atom = self._cfg.model.n_atom 127 128 # use wrapper instead of plugin 129 self._target_model = copy.deepcopy(self._model) 130 self._target_model = model_wrap( 131 self._target_model, 132 wrapper_name='target', 133 update_type='assign', 134 update_kwargs={'freq': self._cfg.learn.target_update_freq} 135 ) 136 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 137 self._learn_model.reset() 138 self._target_model.reset() 139 140 def _forward_learn(self, data: dict) -> Dict[str, Any]: 141 r""" 142 Overview: 143 Forward and backward function of learn mode. 144 Arguments: 145 - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] 146 Returns: 147 - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. 148 """ 149 data = default_preprocess_learn( 150 data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True 151 ) 152 if self._cuda: 153 data = to_device(data, self._device) 154 # ==================== 155 # Q-learning forward 156 # ==================== 157 self._learn_model.train() 158 self._target_model.train() 159 # Current q value (main model) 160 output = self._learn_model.forward(data['obs']) 161 q_value = output['logit'] 162 q_value_dist = output['distribution'] 163 # Target q value 164 with torch.no_grad(): 165 target_output = self._target_model.forward(data['next_obs']) 166 target_q_value_dist = target_output['distribution'] 167 target_q_value = target_output['logit'] 168 # Max q value action (main model) 169 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 170 171 data_n = dist_nstep_td_data( 172 q_value_dist, target_q_value_dist, data['action'], target_q_action, data['reward'], data['done'], 173 data['weight'] 174 ) 175 value_gamma = data.get('value_gamma') 176 loss, td_error_per_sample = dist_nstep_td_error( 177 data_n, self._gamma, self._v_min, self._v_max, self._n_atom, nstep=self._nstep, value_gamma=value_gamma 178 ) 179 180 # ==================== 181 # Q-learning update 182 # ==================== 183 self._optimizer.zero_grad() 184 loss.backward() 185 if self._cfg.multi_gpu: 186 self.sync_gradients(self._learn_model) 187 self._optimizer.step() 188 189 # ============= 190 # after update 191 # ============= 192 self._target_model.update(self._learn_model.state_dict()) 193 return { 194 'cur_lr': self._optimizer.defaults['lr'], 195 'total_loss': loss.item(), 196 'q_value': q_value.mean().item(), 197 'target_q_value': target_q_value.mean().item(), 198 'priority': td_error_per_sample.abs().tolist(), 199 # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. 200 # '[histogram]action_distribution': data['action'], 201 } 202 203 def _monitor_vars_learn(self) -> List[str]: 204 return ['cur_lr', 'total_loss', 'q_value', 'target_q_value'] 205 206 def _state_dict_learn(self) -> Dict[str, Any]: 207 return { 208 'model': self._learn_model.state_dict(), 209 'target_model': self._target_model.state_dict(), 210 'optimizer': self._optimizer.state_dict(), 211 } 212 213 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 214 self._learn_model.load_state_dict(state_dict['model']) 215 self._target_model.load_state_dict(state_dict['target_model']) 216 self._optimizer.load_state_dict(state_dict['optimizer']) 217 218 def _init_collect(self) -> None: 219 """ 220 Overview: 221 Collect mode init method. Called by ``self.__init__``. Initialize necessary arguments for nstep return \ 222 calculation and collect_model for exploration (eps_greedy_sample). 223 """ 224 self._unroll_len = self._cfg.collect.unroll_len 225 self._gamma = self._cfg.discount_factor # necessary for parallel 226 self._nstep = self._cfg.nstep # necessary for parallel 227 self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') 228 self._collect_model.reset() 229 230 def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: 231 """ 232 Overview: 233 Forward computation graph of collect mode(collect training data), with eps_greedy for exploration. 234 Arguments: 235 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 236 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 237 - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. 238 Returns: 239 - output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \ 240 env and the constructing of transition. 241 ArgumentsKeys: 242 - necessary: ``obs`` 243 ReturnsKeys 244 - necessary: ``logit``, ``action`` 245 """ 246 data_id = list(data.keys()) 247 data = default_collate(list(data.values())) 248 if self._cuda: 249 data = to_device(data, self._device) 250 self._collect_model.eval() 251 with torch.no_grad(): 252 output = self._collect_model.forward(data, eps=eps) 253 if self._cuda: 254 output = to_device(output, 'cpu') 255 output = default_decollate(output) 256 return {i: d for i, d in zip(data_id, output)} 257 258 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 259 """ 260 Overview: 261 Calculate nstep return data and transform a trajectory into many train samples. 262 Arguments: 263 - data (:obj:`list`): The collected data of a trajectory, which is a list that contains dict elements. 264 Returns: 265 - samples (:obj:`dict`): The training samples generated. 266 """ 267 data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) 268 return get_train_sample(data, self._unroll_len)