1from typing import List, Dict, Any, Tuple 2from collections import namedtuple 3import copy 4import torch 5 6from ding.torch_utils import Adam, to_device, ContrastiveLoss 7from ding.rl_utils import q_nstep_td_data, bdq_nstep_td_error, get_nstep_return_data, get_train_sample 8from ding.model import model_wrap 9from ding.utils import POLICY_REGISTRY 10from ding.utils.data import default_collate, default_decollate 11 12from .base_policy import Policy 13from .common_utils import default_preprocess_learn 14 15 16@POLICY_REGISTRY.register('bdq') 17class BDQPolicy(Policy): 18 r""" 19 Overview: 20 Policy class of BDQ algorithm, extended by PER/multi-step TD. \ 21 referenced paper Action Branching Architectures for Deep Reinforcement Learning \ 22 <https://arxiv.org/pdf/1711.08946> 23 .. note:: 24 BDQ algorithm contains a neural architecture featuring a shared decision module \ 25 followed by several network branches, one for each action dimension. 26 Config: 27 == ==================== ======== ============== ======================================== ======================= 28 ID Symbol Type Default Value Description Other(Shape) 29 == ==================== ======== ============== ======================================== ======================= 30 1 ``type`` str bdq | RL policy register name, refer to | This arg is optional, 31 | registry ``POLICY_REGISTRY`` | a placeholder 32 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 33 | erent from modes 34 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 35 | or off-policy 36 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, 37 | update priority 38 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight 39 | ``_weight`` | to correct biased update. If True, 40 | priority must be True. 41 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse 42 | ``factor`` [0.95, 0.999] | gamma | reward env 43 7 ``nstep`` int 1, | N-step reward discount sum for target 44 [3, 5] | q_value estimation 45 8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary 46 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 47 | valid in serial training | means more off-policy 48 | ``_gpu`` 49 10 | ``learn.batch_`` int 64 | The number of samples of an iteration 50 | ``size`` 51 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 52 | ``_rate`` 53 12 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update 54 | ``update_freq`` 55 13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 56 | ``done`` | calculation. | fake termination env 57 14 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from 58 | call of collector. | different envs 59 15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 60 | ``_len`` 61 16 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp', 62 | 'linear']. 63 17 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1] 64 | ``start`` 65 18 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1] 66 | ``end`` 67 19 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set 68 | ``decay`` | decay=10000 means 69 | the exploration rate 70 | decay from start 71 | value to end value 72 | during decay length. 73 == ==================== ======== ============== ======================================== ======================= 74 """ 75 76 config = dict( 77 type='bdq', 78 # (bool) Whether use cuda in policy 79 cuda=False, 80 # (bool) Whether learning policy is the same as collecting data policy(on-policy) 81 on_policy=False, 82 # (bool) Whether enable priority experience sample 83 priority=False, 84 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 85 priority_IS_weight=False, 86 # (float) Discount factor(gamma) for returns 87 discount_factor=0.97, 88 # (int) The number of step for calculating target q_value 89 nstep=1, 90 learn=dict( 91 92 # How many updates(iterations) to train after collector's one collection. 93 # Bigger "update_per_collect" means bigger off-policy. 94 # collect data -> update policy-> collect data -> ... 95 update_per_collect=3, 96 # (int) How many samples in a training batch 97 batch_size=64, 98 # (float) The step size of gradient descent 99 learning_rate=0.001, 100 # ============================================================== 101 # The following configs are algorithm-specific 102 # ============================================================== 103 # (int) Frequence of target network update. 104 target_update_freq=100, 105 # (bool) Whether ignore done(usually for max step termination env) 106 ignore_done=False, 107 ), 108 # collect_mode config 109 collect=dict( 110 # (int) Only one of [n_sample, n_episode] shoule be set 111 # n_sample=8, 112 # (int) Cut trajectories into pieces with length "unroll_len". 113 unroll_len=1, 114 ), 115 eval=dict(), 116 # other config 117 other=dict( 118 # Epsilon greedy with decay. 119 eps=dict( 120 # (str) Decay type. Support ['exp', 'linear']. 121 type='exp', 122 # (float) Epsilon start value 123 start=0.95, 124 # (float) Epsilon end value 125 end=0.1, 126 # (int) Decay length(env step) 127 decay=10000, 128 ), 129 replay_buffer=dict(replay_buffer_size=10000, ), 130 ), 131 ) 132 133 def default_model(self) -> Tuple[str, List[str]]: 134 """ 135 Overview: 136 Return this algorithm default model setting for demonstration. 137 Returns: 138 - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names 139 140 .. note:: 141 The user can define and use customized network model but must obey the same inferface definition indicated \ 142 by import_names path. For BDQ, ``ding.model.template.q_learning.BDQ`` 143 """ 144 return 'bdq', ['ding.model.template.q_learning'] 145 146 def _init_learn(self) -> None: 147 """ 148 Overview: 149 Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \ 150 and target model. 151 """ 152 self._priority = self._cfg.priority 153 self._priority_IS_weight = self._cfg.priority_IS_weight 154 # Optimizer 155 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 156 157 self._gamma = self._cfg.discount_factor 158 self._nstep = self._cfg.nstep 159 160 # use model_wrapper for specialized demands of different modes 161 self._target_model = copy.deepcopy(self._model) 162 self._target_model = model_wrap( 163 self._target_model, 164 wrapper_name='target', 165 update_type='assign', 166 update_kwargs={'freq': self._cfg.learn.target_update_freq} 167 ) 168 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 169 self._learn_model.reset() 170 self._target_model.reset() 171 172 def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: 173 """ 174 Overview: 175 Forward computation graph of learn mode(updating policy). 176 Arguments: 177 - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ 178 np.ndarray or dict/list combinations. 179 Returns: 180 - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ 181 recorded in text log and tensorboard, values are python scalar or a list of scalars. 182 ArgumentsKeys: 183 - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` 184 - optional: ``value_gamma``, ``IS`` 185 ReturnsKeys: 186 - necessary: ``cur_lr``, ``total_loss``, ``priority`` 187 - optional: ``action_distribution`` 188 """ 189 data = default_preprocess_learn( 190 data, 191 use_priority=self._priority, 192 use_priority_IS_weight=self._cfg.priority_IS_weight, 193 ignore_done=self._cfg.learn.ignore_done, 194 use_nstep=True 195 ) 196 197 if self._cuda: 198 data = to_device(data, self._device) 199 # ==================== 200 # Q-learning forward 201 # ==================== 202 self._learn_model.train() 203 self._target_model.train() 204 # Current q value (main model) 205 q_value = self._learn_model.forward(data['obs'])['logit'] 206 # Target q value 207 with torch.no_grad(): 208 target_q_value = self._target_model.forward(data['next_obs'])['logit'] 209 # Max q value action (main model) 210 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 211 if data['action'].shape != target_q_action.shape: 212 data['action'] = data['action'].unsqueeze(-1) 213 214 data_n = q_nstep_td_data( 215 q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight'] 216 ) 217 value_gamma = data.get('value_gamma') 218 loss, td_error_per_sample = bdq_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma) 219 220 # ==================== 221 # Q-learning update 222 # ==================== 223 self._optimizer.zero_grad() 224 loss.backward() 225 if self._cfg.multi_gpu: 226 self.sync_gradients(self._learn_model) 227 self._optimizer.step() 228 229 # ============= 230 # after update 231 # ============= 232 self._target_model.update(self._learn_model.state_dict()) 233 update_info = { 234 'cur_lr': self._optimizer.defaults['lr'], 235 'total_loss': loss.item(), 236 'q_value': q_value.mean().item(), 237 'target_q_value': target_q_value.mean().item(), 238 'priority': td_error_per_sample.abs().tolist(), 239 # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. 240 # '[histogram]action_distribution': data['action'], 241 } 242 q_value_per_branch = torch.mean(q_value, 2, keepdim=False) 243 for i in range(self._model.num_branches): 244 update_info['q_value_b_' + str(i)] = q_value_per_branch[:, i].mean().item() 245 return update_info 246 247 def _monitor_vars_learn(self) -> List[str]: 248 return ['cur_lr', 'total_loss', 'q_value'] + ['q_value_b_' + str(i) for i in range(self._model.num_branches)] 249 250 def _state_dict_learn(self) -> Dict[str, Any]: 251 """ 252 Overview: 253 Return the state_dict of learn mode, usually including model and optimizer. 254 Returns: 255 - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. 256 """ 257 return { 258 'model': self._learn_model.state_dict(), 259 'target_model': self._target_model.state_dict(), 260 'optimizer': self._optimizer.state_dict(), 261 } 262 263 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 264 """ 265 Overview: 266 Load the state_dict variable into policy learn mode. 267 Arguments: 268 - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. 269 270 .. tip:: 271 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 272 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 273 complicated operation. 274 """ 275 self._learn_model.load_state_dict(state_dict['model']) 276 self._target_model.load_state_dict(state_dict['target_model']) 277 self._optimizer.load_state_dict(state_dict['optimizer']) 278 279 def _init_collect(self) -> None: 280 """ 281 Overview: 282 Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model, \ 283 enable the eps_greedy_sample for exploration. 284 """ 285 self._unroll_len = self._cfg.collect.unroll_len 286 self._gamma = self._cfg.discount_factor # necessary for parallel 287 self._nstep = self._cfg.nstep # necessary for parallel 288 self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') 289 self._collect_model.reset() 290 291 def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: 292 """ 293 Overview: 294 Forward computation graph of collect mode(collect training data), with eps_greedy for exploration. 295 Arguments: 296 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 297 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 298 - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. 299 Returns: 300 - output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \ 301 env and the constructing of transition. 302 ArgumentsKeys: 303 - necessary: ``obs`` 304 ReturnsKeys 305 - necessary: ``logit``, ``action`` 306 """ 307 data_id = list(data.keys()) 308 data = default_collate(list(data.values())) 309 if self._cuda: 310 data = to_device(data, self._device) 311 self._collect_model.eval() 312 with torch.no_grad(): 313 output = self._collect_model.forward(data, eps=eps) 314 if self._cuda: 315 output = to_device(output, 'cpu') 316 output = default_decollate(output) 317 return {i: d for i, d in zip(data_id, output)} 318 319 def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 320 """ 321 Overview: 322 For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ 323 can be used for training directly. A train sample can be a processed transition(BDQ with nstep TD). 324 Arguments: 325 - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ 326 format as the return value of ``self._process_transition`` method. 327 Returns: 328 - samples (:obj:`dict`): The list of training samples. 329 330 .. note:: 331 We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ 332 And the user can customize the this data processing procecure by overriding this two methods and collector \ 333 itself. 334 """ 335 data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) 336 return get_train_sample(data, self._unroll_len) 337 338 def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: 339 """ 340 Overview: 341 Generate a transition(e.g.: <s, a, s', r, d>) for this algorithm training. 342 Arguments: 343 - obs (:obj:`Any`): Env observation. 344 - policy_output (:obj:`Dict[str, Any]`): The output of policy collect mode(``self._forward_collect``),\ 345 including at least ``action``. 346 - timestep (:obj:`namedtuple`): The output after env step(execute policy output action), including at \ 347 least ``obs``, ``reward``, ``done``, (here obs indicates obs after env step). 348 Returns: 349 - transition (:obj:`dict`): Dict type transition data. 350 """ 351 transition = { 352 'obs': obs, 353 'next_obs': timestep.obs, 354 'action': policy_output['action'], 355 'reward': timestep.reward, 356 'done': timestep.done, 357 } 358 return transition 359 360 def _init_eval(self) -> None: 361 r""" 362 Overview: 363 Evaluate mode init method. Called by ``self.__init__``, initialize eval_model. 364 """ 365 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 366 self._eval_model.reset() 367 368 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 369 """ 370 Overview: 371 Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to \ 372 ``self._forward_collect``. 373 Arguments: 374 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 375 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 376 Returns: 377 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 378 ArgumentsKeys: 379 - necessary: ``obs`` 380 ReturnsKeys 381 - necessary: ``action`` 382 """ 383 data_id = list(data.keys()) 384 data = default_collate(list(data.values())) 385 if self._cuda: 386 data = to_device(data, self._device) 387 self._eval_model.eval() 388 with torch.no_grad(): 389 output = self._eval_model.forward(data) 390 if self._cuda: 391 output = to_device(output, 'cpu') 392 output = default_decollate(output) 393 return {i: d for i, d in zip(data_id, output)}