1from typing import List, Dict, Any, Tuple 2from collections import namedtuple 3import copy 4import torch 5 6from ding.torch_utils import Adam, to_device 7from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample 8from ding.model import model_wrap 9from ding.utils import POLICY_REGISTRY 10from ding.utils.data import default_collate, default_decollate 11from .base_policy import Policy 12from .common_utils import default_preprocess_learn 13 14 15@POLICY_REGISTRY.register('pdqn') 16class PDQNPolicy(Policy): 17 """ 18 Overview: 19 Policy class of PDQN algorithm, which extends the DQN algorithm on discrete-continuous hybrid action spaces. 20 Paper link: https://arxiv.org/abs/1810.06394. 21 22 Config: 23 == ==================== ======== ============== ======================================== ======================= 24 ID Symbol Type Default Value Description Other(Shape) 25 == ==================== ======== ============== ======================================== ======================= 26 1 ``type`` str pdqn | RL policy register name, refer to | This arg is optional, 27 | registry ``POLICY_REGISTRY`` | a placeholder 28 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 29 | erent from modes 30 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy | This value is always 31 | or off-policy | False for PDQN 32 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, 33 | update priority 34 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight 35 | ``_weight`` | to correct biased update. If True, 36 | priority must be True. 37 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse 38 | ``factor`` [0.95, 0.999] | gamma | reward env 39 40 7 ``nstep`` int 1, | N-step reward discount sum for target 41 [3, 5] | q_value estimation 42 8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary 43 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 44 | valid in serial training | means more off-policy 45 9 | ``learn.batch_`` int 64 | The number of samples of an iteration 46 | ``size`` 47 | ``_gpu`` 48 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 49 | ``_rate`` 50 12 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update 51 | ``update_freq`` 52 13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 53 | ``done`` | calculation. | fake termination env 54 14 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from 55 | call of collector. | different envs 56 15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 57 | ``_len`` 58 16 | ``collect.noise`` float 0.1 | add noise to continuous args 59 | ``_sigma`` | during collection 60 17 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp', 61 | 'linear']. 62 18 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1] 63 | ``start`` 64 19 | ``other.eps.`` float 0.05 | end value of exploration rate | [0,1] 65 | ``end`` 66 20 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set 67 | ``decay`` | decay=10000 means 68 | the exploration rate 69 | decay from start 70 | value to end value 71 | during decay length. 72 == ==================== ======== ============== ======================================== ======================= 73 """ 74 config = dict( 75 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 76 type='pdqn', 77 # (bool) Whether to use cuda in policy. 78 cuda=False, 79 # (bool) Whether learning policy is the same as collecting data policy(on-policy). 80 on_policy=False, 81 # (bool) Whether to enable priority experience sample. 82 priority=False, 83 # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. 84 priority_IS_weight=False, 85 # (float) Discount factor(gamma) for returns. 86 discount_factor=0.97, 87 # (int) The number of step for calculating target q_value. 88 nstep=1, 89 # learn_mode config 90 learn=dict( 91 # (int) How many updates(iterations) to train after collector's one collection. 92 # Bigger "update_per_collect" means bigger off-policy. 93 # collect data -> update policy-> collect data -> ... 94 update_per_collect=3, 95 # (int) How many samples in a training batch. 96 batch_size=64, 97 # (float) The step size of gradient descent. 98 learning_rate=0.001, 99 # (int) Frequence of target network update. 100 target_theta=0.005, 101 # (bool) Whether ignore done(usually for max step termination env). 102 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 103 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 104 # However, interaction with HalfCheetah always gets done with done is False, 105 # Since we inplace done==True with done==False to keep 106 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 107 # when the episode step is greater than max episode step. 108 ignore_done=False, 109 ), 110 # collect_mode config 111 collect=dict( 112 # (int) How many training samples collected in one collection procedure. 113 # Only one of [n_sample, n_episode] shoule be set. 114 # n_sample=8, 115 # (int) Split episodes or trajectories into pieces with length `unroll_len`. 116 unroll_len=1, 117 # (float) It is a must to add noise during collection. So here omits noise and only set ``noise_sigma``. 118 noise_sigma=0.1, 119 ), 120 eval=dict(), # for compatibility 121 # other config 122 other=dict( 123 # Epsilon greedy with decay. 124 eps=dict( 125 # (str) Decay type. Support ['exp', 'linear']. 126 type='exp', 127 # (float) Epsilon start value. 128 start=0.95, 129 # (float) Epsilon end value. 130 end=0.1, 131 # (int) Decay length(env step) 132 decay=10000, 133 ), 134 replay_buffer=dict( 135 # (int) Maximum size of replay buffer. Usually, larger buffer size is better. 136 replay_buffer_size=10000, 137 ), 138 ), 139 ) 140 141 def default_model(self) -> Tuple[str, List[str]]: 142 """ 143 Overview: 144 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 145 automatically call this method to get the default model setting and create model. 146 Returns: 147 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 148 149 .. note:: 150 The user can define and use customized network model but must obey the same inferface definition indicated \ 151 by import_names path. For example about PDQN, its registered name is ``pdqn`` and the import_names is \ 152 ``ding.model.template.pdqn``. 153 """ 154 return 'pdqn', ['ding.model.template.pdqn'] 155 156 def _init_learn(self) -> None: 157 """ 158 Overview: 159 Initialize the learn mode of policy, including related attributes and modules. For PDQN, it mainly \ 160 contains two optimizers, algorithm-specific arguments such as nstep and gamma, main and target model. 161 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 162 163 .. note:: 164 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 165 and ``_load_state_dict_learn`` methods. 166 167 .. note:: 168 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 169 170 .. note:: 171 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 172 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 173 """ 174 self._priority = self._cfg.priority 175 self._priority_IS_weight = self._cfg.priority_IS_weight 176 # Optimizer 177 self._dis_optimizer = Adam( 178 list(self._model.dis_head.parameters()) + list(self._model.cont_encoder.parameters()), 179 # this is very important to put cont_encoder.parameters in here. 180 lr=self._cfg.learn.learning_rate_dis 181 ) 182 self._cont_optimizer = Adam(list(self._model.cont_head.parameters()), lr=self._cfg.learn.learning_rate_cont) 183 184 self._gamma = self._cfg.discount_factor 185 self._nstep = self._cfg.nstep 186 187 # use model_wrapper for specialized demands of different modes 188 self._target_model = copy.deepcopy(self._model) 189 self._target_model = model_wrap( 190 self._target_model, 191 wrapper_name='target', 192 update_type='momentum', 193 update_kwargs={'theta': self._cfg.learn.target_theta} 194 ) 195 self._learn_model = model_wrap(self._model, wrapper_name='hybrid_argmax_sample') 196 self._learn_model.reset() 197 self._target_model.reset() 198 self.cont_train_cnt = 0 199 self.disc_train_cnt = 0 200 self.train_cnt = 0 201 202 def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: 203 """ 204 Overview: 205 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 206 that the policy inputs some training batch data from the replay buffer and then returns the output \ 207 result, including various training information such as loss, q value, target_q_value, priority. 208 Arguments: 209 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 210 training samples. For each element in list, the key of the dict is the name of data items and the \ 211 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 212 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 213 dimension by some utility functions such as ``default_preprocess_learn``. \ 214 For PDQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 215 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ 216 and ``value_gamma``. 217 Returns: 218 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 219 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 220 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 221 222 .. note:: 223 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 224 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 225 You can implement you own model rather than use the default model. For more information, please raise an \ 226 issue in GitHub repo and we will continue to follow up. 227 228 .. note:: 229 For more detailed examples, please refer to our unittest for PDQNPolicy: ``ding.policy.tests.test_pdqn``. 230 """ 231 data = default_preprocess_learn( 232 data, 233 use_priority=self._priority, 234 use_priority_IS_weight=self._cfg.priority_IS_weight, 235 ignore_done=self._cfg.learn.ignore_done, 236 use_nstep=True 237 ) 238 if self._cuda: 239 data = to_device(data, self._device) 240 241 self.train_cnt += 1 242 # ================================ 243 # Continuous args network forward 244 # ================================ 245 if self.train_cnt == 1 or self.train_cnt % self._cfg.learn.update_circle in range(5, 10): 246 dis_loss = torch.Tensor([0]) 247 td_error_per_sample = torch.Tensor([0]) 248 target_q_value = torch.Tensor([0]) 249 250 action_args = self._learn_model.forward(data['obs'], mode='compute_continuous')['action_args'] 251 252 # Current q value (main model) for cont loss 253 discrete_inputs = {'state': data['obs'], 'action_args': action_args} 254 # with torch.no_grad(): 255 q_pi_action_value = self._learn_model.forward(discrete_inputs, mode='compute_discrete')['logit'] 256 cont_loss = -q_pi_action_value.sum(dim=-1).mean() 257 258 # ================================ 259 # Continuous args network update 260 # ================================ 261 self._cont_optimizer.zero_grad() 262 cont_loss.backward() 263 self._cont_optimizer.step() 264 265 # ==================== 266 # Q-learning forward 267 # ==================== 268 if self.train_cnt == 1 or self.train_cnt % self._cfg.learn.update_circle in range(0, 5): 269 cont_loss = torch.Tensor([0]) 270 q_pi_action_value = torch.Tensor([0]) 271 self._learn_model.train() 272 self._target_model.train() 273 # Current q value (main model) 274 discrete_inputs = {'state': data['obs'], 'action_args': data['action']['action_args']} 275 q_data_action_args_value = self._learn_model.forward(discrete_inputs, mode='compute_discrete')['logit'] 276 277 # Target q value 278 with torch.no_grad(): 279 next_action_args = self._learn_model.forward(data['next_obs'], mode='compute_continuous')['action_args'] 280 next_action_args_cp = next_action_args.clone().detach() 281 next_discrete_inputs = {'state': data['next_obs'], 'action_args': next_action_args_cp} 282 target_q_value = self._target_model.forward(next_discrete_inputs, mode='compute_discrete')['logit'] 283 # Max q value action (main model) 284 target_q_discrete_action = self._learn_model.forward( 285 next_discrete_inputs, mode='compute_discrete' 286 )['action']['action_type'] 287 288 data_n = q_nstep_td_data( 289 q_data_action_args_value, target_q_value, data['action']['action_type'], target_q_discrete_action, 290 data['reward'], data['done'], data['weight'] 291 ) 292 value_gamma = data.get('value_gamma') 293 dis_loss, td_error_per_sample = q_nstep_td_error( 294 data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma 295 ) 296 297 # ==================== 298 # Q-learning update 299 # ==================== 300 self._dis_optimizer.zero_grad() 301 dis_loss.backward() 302 self._dis_optimizer.step() 303 304 # ============= 305 # after update 306 # ============= 307 self._target_model.update(self._learn_model.state_dict()) 308 309 return { 310 'cur_lr': self._dis_optimizer.defaults['lr'], 311 'q_loss': dis_loss.item(), 312 'total_loss': cont_loss.item() + dis_loss.item(), 313 'continuous_loss': cont_loss.item(), 314 'q_value': q_pi_action_value.mean().item(), 315 'priority': td_error_per_sample.abs().tolist(), 316 'reward': data['reward'].mean().item(), 317 'target_q_value': target_q_value.mean().item(), 318 } 319 320 def _state_dict_learn(self) -> Dict[str, Any]: 321 """ 322 Overview: 323 Return the state_dict of learn mode, usually including model, target model, discrete part optimizer, and \ 324 continuous part optimizer. 325 Returns: 326 - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. 327 """ 328 return { 329 'model': self._learn_model.state_dict(), 330 'target_model': self._target_model.state_dict(), 331 'dis_optimizer': self._dis_optimizer.state_dict(), 332 'cont_optimizer': self._cont_optimizer.state_dict() 333 } 334 335 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 336 """ 337 Overview: 338 Load the state_dict variable into policy learn mode. 339 Arguments: 340 - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. 341 342 .. tip:: 343 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 344 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 345 complicated operation. 346 """ 347 self._learn_model.load_state_dict(state_dict['model']) 348 self._target_model.load_state_dict(state_dict['target_model']) 349 self._dis_optimizer.load_state_dict(state_dict['dis_optimizer']) 350 self._cont_optimizer.load_state_dict(state_dict['cont_optimizer']) 351 352 def _init_collect(self) -> None: 353 """ 354 Overview: 355 Initialize the collect mode of policy, including related attributes and modules. For PDQN, it contains the \ 356 collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism and \ 357 continuous action mechanism, besides, other algorithm-specific arguments such as unroll_len and nstep are \ 358 also initialized here. 359 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 360 361 .. note:: 362 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 363 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 364 365 .. tip:: 366 Some variables need to initialize independently in different modes, such as gamma and nstep in PDQN. This \ 367 design is for the convenience of parallel execution of different policy modes. 368 """ 369 self._unroll_len = self._cfg.collect.unroll_len 370 self._gamma = self._cfg.discount_factor # necessary for parallel 371 self._nstep = self._cfg.nstep # necessary for parallel 372 self._collect_model = model_wrap( 373 self._model, 374 wrapper_name='action_noise', 375 noise_type='gauss', 376 noise_kwargs={ 377 'mu': 0.0, 378 'sigma': self._cfg.collect.noise_sigma 379 }, 380 noise_range=None 381 ) 382 self._collect_model = model_wrap(self._collect_model, wrapper_name='hybrid_eps_greedy_multinomial_sample') 383 self._collect_model.reset() 384 385 def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: 386 """ 387 Overview: 388 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 389 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 390 data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ 391 exploration, i.e., classic epsilon-greedy exploration strategy. 392 Arguments: 393 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 394 key of the dict is environment id and the value is the corresponding data of the env. 395 - eps (:obj:`float`): The epsilon value for exploration. 396 Returns: 397 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 398 other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ 399 dict is the same as the input data, i.e. environment id. 400 401 .. note:: 402 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 403 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 404 You can implement you own model rather than use the default model. For more information, please raise an \ 405 issue in GitHub repo and we will continue to follow up. 406 407 .. note:: 408 For more detailed examples, please refer to our unittest for PDQNPolicy: ``ding.policy.tests.test_pdqn``. 409 """ 410 data_id = list(data.keys()) 411 data = default_collate(list(data.values())) 412 if self._cuda: 413 data = to_device(data, self._device) 414 self._collect_model.eval() 415 with torch.no_grad(): 416 action_args = self._collect_model.forward(data, 'compute_continuous', eps=eps)['action_args'] 417 inputs = {'state': data, 'action_args': action_args.clone().detach()} 418 output = self._collect_model.forward(inputs, 'compute_discrete', eps=eps) 419 if self._cuda: 420 output = to_device(output, 'cpu') 421 output = default_decollate(output) 422 return {i: d for i, d in zip(data_id, output)} 423 424 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 425 """ 426 Overview: 427 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ 428 can be used for training directly. In PDQN, a train sample is a processed transition. \ 429 This method is usually used in collectors to execute necessary \ 430 RL data preprocessing before training, which can help learner amortize revelant time consumption. \ 431 In addition, you can also implement this method as an identity function and do the data processing \ 432 in ``self._forward_learn`` method. 433 Arguments: 434 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ 435 the same format as the return value of ``self._process_transition`` method. 436 Returns: 437 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ 438 as input transitions, but may contain more data for training, such as nstep reward and target obs. 439 """ 440 transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma) 441 return get_train_sample(transitions, self._unroll_len) 442 443 def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], 444 timestep: namedtuple) -> Dict[str, torch.Tensor]: 445 """ 446 Overview: 447 Process and pack one timestep transition data into a dict, which can be directly used for training and \ 448 saved in replay buffer. For PDQN, it contains obs, next_obs, action, reward, done and logit. 449 Arguments: 450 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. 451 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ 452 as input. For PDQN, it contains the hybrid action and the logit (discrete part q_value) of the action. 453 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 454 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 455 reward, done, info, etc. 456 Returns: 457 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 458 """ 459 transition = { 460 'obs': obs, 461 'next_obs': timestep.obs, 462 'action': policy_output['action'], 463 'logit': policy_output['logit'], 464 'reward': timestep.reward, 465 'done': timestep.done, 466 } 467 return transition 468 469 def _init_eval(self) -> None: 470 """ 471 Overview: 472 Initialize the eval mode of policy, including related attributes and modules. For PDQN, it contains the \ 473 eval model to greedily select action with argmax q_value mechanism. 474 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 475 476 .. note:: 477 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 478 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 479 """ 480 self._eval_model = model_wrap(self._model, wrapper_name='hybrid_argmax_sample') 481 self._eval_model.reset() 482 483 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 484 """ 485 Overview: 486 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 487 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 488 action to interact with the envs. 489 Arguments: 490 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 491 key of the dict is environment id and the value is the corresponding data of the env. 492 Returns: 493 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 494 key of the dict is the same as the input data, i.e. environment id. 495 496 .. note:: 497 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 498 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 499 You can implement you own model rather than use the default model. For more information, please raise an \ 500 issue in GitHub repo and we will continue to follow up. 501 502 .. note:: 503 For more detailed examples, please refer to our unittest for PDQNPolicy: ``ding.policy.tests.test_pdqn``. 504 """ 505 data_id = list(data.keys()) 506 data = default_collate(list(data.values())) 507 if self._cuda: 508 data = to_device(data, self._device) 509 self._eval_model.eval() 510 with torch.no_grad(): 511 action_args = self._eval_model.forward(data, mode='compute_continuous')['action_args'] 512 inputs = {'state': data, 'action_args': action_args.clone().detach()} 513 output = self._eval_model.forward(inputs, mode='compute_discrete') 514 if self._cuda: 515 output = to_device(output, 'cpu') 516 output = default_decollate(output) 517 return {i: d for i, d in zip(data_id, output)} 518 519 def _monitor_vars_learn(self) -> List[str]: 520 """ 521 Overview: 522 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 523 as text logger, tensorboard logger, will use these keys to save the corresponding data. 524 Returns: 525 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 526 """ 527 return ['cur_lr', 'total_loss', 'q_loss', 'continuous_loss', 'q_value', 'reward', 'target_q_value']