Skip to content

ding.policy.qtran

ding.policy.qtran

QTRANPolicy

Bases: Policy

Overview

Policy class of QTRAN algorithm. QTRAN is a multi model reinforcement learning algorithm, you can view the paper in the following link https://arxiv.org/abs/1803.11485

Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str qtran | RL policy register name, refer to | this arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool True | Whether to use cuda for network | this arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4. priority bool False | Whether use priority(PER) | priority sample, | update priority 5 | priority_ bool False | Whether use Importance Sampling | IS weight | IS_weight | Weight to correct biased update. 6 | learn.update_ int 20 | How many updates(iterations) to train | this args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 7 | learn.target_ float 0.001 | Target network update momentum | between[0,1] | update_theta | parameter. 8 | learn.discount float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse | _factor | gamma | reward env == ==================== ======== ============== ======================================== =======================

default_model()

Overview

Return this algorithm default model setting for demonstration.

Returns: - model_info (:obj:Tuple[str, List[str]]): model name and mode import_names .. note:: The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For QTRAN, ding.model.qtran.qtran

Full Source Code

../ding/policy/qtran.py

1from typing import List, Dict, Any, Tuple, Union, Optional 2from collections import namedtuple 3import torch 4import torch.nn.functional as F 5import copy 6from easydict import EasyDict 7 8from ding.torch_utils import Adam, RMSprop, to_device 9from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_epsilon_greedy_fn, get_train_sample 10from ding.model import model_wrap 11from ding.utils import POLICY_REGISTRY 12from ding.utils.data import timestep_collate, default_collate, default_decollate 13from .base_policy import Policy 14 15 16@POLICY_REGISTRY.register('qtran') 17class QTRANPolicy(Policy): 18 """ 19 Overview: 20 Policy class of QTRAN algorithm. QTRAN is a multi model reinforcement learning algorithm, \ 21 you can view the paper in the following link https://arxiv.org/abs/1803.11485 22 Config: 23 == ==================== ======== ============== ======================================== ======================= 24 ID Symbol Type Default Value Description Other(Shape) 25 == ==================== ======== ============== ======================================== ======================= 26 1 ``type`` str qtran | RL policy register name, refer to | this arg is optional, 27 | registry ``POLICY_REGISTRY`` | a placeholder 28 2 ``cuda`` bool True | Whether to use cuda for network | this arg can be diff- 29 | erent from modes 30 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 31 | or off-policy 32 4. ``priority`` bool False | Whether use priority(PER) | priority sample, 33 | update priority 34 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight 35 | ``IS_weight`` | Weight to correct biased update. 36 6 | ``learn.update_`` int 20 | How many updates(iterations) to train | this args can be vary 37 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 38 | valid in serial training | means more off-policy 39 7 | ``learn.target_`` float 0.001 | Target network update momentum | between[0,1] 40 | ``update_theta`` | parameter. 41 8 | ``learn.discount`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse 42 | ``_factor`` | gamma | reward env 43 == ==================== ======== ============== ======================================== ======================= 44 """ 45 config = dict( 46 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 47 type='qtran', 48 # (bool) Whether to use cuda for network. 49 cuda=True, 50 # (bool) Whether the RL algorithm is on-policy or off-policy. 51 on_policy=False, 52 # (bool) Whether use priority(priority sample, IS weight, update priority) 53 priority=False, 54 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 55 priority_IS_weight=False, 56 learn=dict( 57 update_per_collect=20, 58 batch_size=32, 59 learning_rate=0.0005, 60 clip_value=1.5, 61 # ============================================================== 62 # The following configs is algorithm-specific 63 # ============================================================== 64 # (float) Target network update momentum parameter. 65 # in [0, 1]. 66 target_update_theta=0.008, 67 # (float) The discount factor for future rewards, 68 # in [0, 1]. 69 discount_factor=0.99, 70 # (float) the loss weight of TD-error 71 td_weight=1, 72 # (float) the loss weight of Opt Loss 73 opt_weight=0.01, 74 # (float) the loss weight of Nopt Loss 75 nopt_min_weight=0.0001, 76 # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation) 77 double_q=True, 78 ), 79 collect=dict( 80 # (int) Only one of [n_sample, n_episode] shoule be set 81 # n_sample=32 * 16, 82 # (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps 83 # in each forward when training. In qtran, it is greater than 1 because there is RNN. 84 unroll_len=10, 85 ), 86 eval=dict(), 87 other=dict( 88 eps=dict( 89 # (str) Type of epsilon decay 90 type='exp', 91 # (float) Start value for epsilon decay, in [0, 1]. 92 # 0 means not use epsilon decay. 93 start=1, 94 # (float) Start value for epsilon decay, in [0, 1]. 95 end=0.05, 96 # (int) Decay length(env step) 97 decay=50000, 98 ), 99 replay_buffer=dict( 100 replay_buffer_size=5000, 101 # (int) The maximum reuse times of each data 102 max_reuse=1e+9, 103 max_staleness=1e+9, 104 ), 105 ), 106 ) 107 108 def default_model(self) -> Tuple[str, List[str]]: 109 """ 110 Overview: 111 Return this algorithm default model setting for demonstration. 112 Returns: 113 - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names 114 .. note:: 115 The user can define and use customized network model but must obey the same inferface definition indicated \ 116 by import_names path. For QTRAN, ``ding.model.qtran.qtran`` 117 """ 118 return 'qtran', ['ding.model.template.qtran'] 119 120 def _init_learn(self) -> None: 121 """ 122 Overview: 123 Learn mode init method. Called by ``self.__init__``. 124 Init the learner model of QTRANPolicy 125 Arguments: 126 .. note:: 127 128 The _init_learn method takes the argument from the self._cfg.learn in the config file 129 130 - learning_rate (:obj:`float`): The learning rate fo the optimizer 131 - gamma (:obj:`float`): The discount factor 132 - agent_num (:obj:`int`): This is a multi-agent algorithm, we need to input agent num. 133 - batch_size (:obj:`int`): Need batch size info to init hidden_state plugins 134 """ 135 self._priority = self._cfg.priority 136 self._priority_IS_weight = self._cfg.priority_IS_weight 137 assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in QTRAN" 138 self._optimizer = RMSprop( 139 params=self._model.parameters(), lr=self._cfg.learn.learning_rate, alpha=0.99, eps=0.00001 140 ) 141 self._gamma = self._cfg.learn.discount_factor 142 self._td_weight = self._cfg.learn.td_weight 143 self._opt_weight = self._cfg.learn.opt_weight 144 self._nopt_min_weight = self._cfg.learn.nopt_min_weight 145 146 self._target_model = copy.deepcopy(self._model) 147 self._target_model = model_wrap( 148 self._target_model, 149 wrapper_name='target', 150 update_type='momentum', 151 update_kwargs={'theta': self._cfg.learn.target_update_theta} 152 ) 153 self._target_model = model_wrap( 154 self._target_model, 155 wrapper_name='hidden_state', 156 state_num=self._cfg.learn.batch_size, 157 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 158 ) 159 self._learn_model = model_wrap( 160 self._model, 161 wrapper_name='hidden_state', 162 state_num=self._cfg.learn.batch_size, 163 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 164 ) 165 self._learn_model.reset() 166 self._target_model.reset() 167 168 def _data_preprocess_learn(self, data: List[Any]) -> dict: 169 r""" 170 Overview: 171 Preprocess the data to fit the required data format for learning 172 Arguments: 173 - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function 174 Returns: 175 - data (:obj:`Dict[str, Any]`): the processed data, from \ 176 [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])} 177 """ 178 # data preprocess 179 data = timestep_collate(data) 180 if self._cuda: 181 data = to_device(data, self._device) 182 data['weight'] = data.get('weight', None) 183 data['done'] = data['done'].float() 184 return data 185 186 def _forward_learn(self, data: dict) -> Dict[str, Any]: 187 r""" 188 Overview: 189 Forward and backward function of learn mode. 190 Arguments: 191 - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ 192 np.ndarray or dict/list combinations. 193 Returns: 194 - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ 195 recorded in text log and tensorboard, values are python scalar or a list of scalars. 196 ArgumentsKeys: 197 - necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done`` 198 ReturnsKeys: 199 - necessary: ``cur_lr``, ``total_loss`` 200 - cur_lr (:obj:`float`): Current learning rate 201 - total_loss (:obj:`float`): The calculated loss 202 """ 203 data = self._data_preprocess_learn(data) 204 # ==================== 205 # Q-mix forward 206 # ==================== 207 self._learn_model.train() 208 self._target_model.train() 209 # for hidden_state plugin, we need to reset the main model and target model 210 self._learn_model.reset(state=data['prev_state'][0]) 211 self._target_model.reset(state=data['prev_state'][0]) 212 inputs = {'obs': data['obs'], 'action': data['action']} 213 learn_ret = self._learn_model.forward(inputs, single_step=False) 214 total_q = learn_ret['total_q'] 215 vs = learn_ret['vs'] 216 agent_q_act = learn_ret['agent_q_act'] 217 logit_detach = learn_ret['logit'].clone() 218 logit_detach[data['obs']['action_mask'] == 0.0] = -9999999 219 logit_q, logit_action = logit_detach.max(dim=-1, keepdim=False) 220 221 if self._cfg.learn.double_q: 222 next_inputs = {'obs': data['next_obs']} 223 double_q_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach() 224 _, double_q_action = double_q_detach.max(dim=-1, keepdim=False) 225 next_inputs = {'obs': data['next_obs'], 'action': double_q_action} 226 else: 227 next_inputs = {'obs': data['next_obs']} 228 with torch.no_grad(): 229 target_total_q = self._target_model.forward(next_inputs, single_step=False)['total_q'] 230 231 # -- TD Loss -- 232 td_data = v_1step_td_data(total_q, target_total_q.detach(), data['reward'], data['done'], data['weight']) 233 td_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma) 234 # -- TD Loss -- 235 236 # -- Opt Loss -- 237 if data['weight'] is None: 238 weight = torch.ones_like(data['reward']) 239 opt_inputs = {'obs': data['obs'], 'action': logit_action} 240 max_q = self._learn_model.forward(opt_inputs, single_step=False)['total_q'] 241 opt_error = logit_q.sum(dim=2) - max_q.detach() + vs 242 opt_loss = (opt_error ** 2 * weight).mean() 243 # -- Opt Loss -- 244 245 # -- Nopt Loss -- 246 nopt_values = agent_q_act.sum(dim=2) - total_q.detach() + vs 247 nopt_error = nopt_values.clamp(max=0) 248 nopt_min_loss = (nopt_error ** 2 * weight).mean() 249 # -- Nopt Loss -- 250 251 total_loss = self._td_weight * td_loss + self._opt_weight * opt_loss + self._nopt_min_weight * nopt_min_loss 252 # ==================== 253 # Q-mix update 254 # ==================== 255 self._optimizer.zero_grad() 256 total_loss.backward() 257 # just get grad_norm 258 grad_norm = torch.nn.utils.clip_grad_norm_(self._model.parameters(), 10000000) 259 self._optimizer.step() 260 # ============= 261 # after update 262 # ============= 263 self._target_model.update(self._learn_model.state_dict()) 264 return { 265 'cur_lr': self._optimizer.defaults['lr'], 266 'total_loss': total_loss.item(), 267 'td_loss': td_loss.item(), 268 'opt_loss': opt_loss.item(), 269 'nopt_loss': nopt_min_loss.item(), 270 'grad_norm': grad_norm, 271 } 272 273 def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: 274 r""" 275 Overview: 276 Reset learn model to the state indicated by data_id 277 Arguments: 278 - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ 279 the model state to the state indicated by data_id 280 """ 281 self._learn_model.reset(data_id=data_id) 282 283 def _state_dict_learn(self) -> Dict[str, Any]: 284 r""" 285 Overview: 286 Return the state_dict of learn mode, usually including model and optimizer. 287 Returns: 288 - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. 289 """ 290 return { 291 'model': self._learn_model.state_dict(), 292 'target_model': self._target_model.state_dict(), 293 'optimizer': self._optimizer.state_dict(), 294 } 295 296 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 297 r""" 298 Overview: 299 Load the state_dict variable into policy learn mode. 300 Arguments: 301 - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. 302 .. tip:: 303 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 304 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 305 complicated operation. 306 """ 307 self._learn_model.load_state_dict(state_dict['model']) 308 self._target_model.load_state_dict(state_dict['target_model']) 309 self._optimizer.load_state_dict(state_dict['optimizer']) 310 311 def _init_collect(self) -> None: 312 r""" 313 Overview: 314 Collect mode init method. Called by ``self.__init__``. 315 Init traj and unroll length, collect model. 316 Enable the eps_greedy_sample and the hidden_state plugin. 317 """ 318 self._unroll_len = self._cfg.collect.unroll_len 319 self._collect_model = model_wrap( 320 self._model, 321 wrapper_name='hidden_state', 322 state_num=self._cfg.collect.env_num, 323 save_prev_state=True, 324 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 325 ) 326 self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample') 327 self._collect_model.reset() 328 329 def _forward_collect(self, data: dict, eps: float) -> dict: 330 r""" 331 Overview: 332 Forward function for collect mode with eps_greedy 333 Arguments: 334 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 335 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 336 - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. 337 Returns: 338 - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. 339 ReturnsKeys 340 - necessary: ``action`` 341 """ 342 data_id = list(data.keys()) 343 data = default_collate(list(data.values())) 344 if self._cuda: 345 data = to_device(data, self._device) 346 data = {'obs': data} 347 self._collect_model.eval() 348 with torch.no_grad(): 349 output = self._collect_model.forward(data, eps=eps, data_id=data_id) 350 if self._cuda: 351 output = to_device(output, 'cpu') 352 output = default_decollate(output) 353 return {i: d for i, d in zip(data_id, output)} 354 355 def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: 356 r""" 357 Overview: 358 Reset collect model to the state indicated by data_id 359 Arguments: 360 - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ 361 the model state to the state indicated by data_id 362 """ 363 self._collect_model.reset(data_id=data_id) 364 365 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 366 r""" 367 Overview: 368 Generate dict type transition data from inputs. 369 Arguments: 370 - obs (:obj:`Any`): Env observation 371 - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state'] 372 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ 373 (here 'obs' indicates obs after env step). 374 Returns: 375 - transition (:obj:`dict`): Dict type transition data, including 'obs', 'next_obs', 'prev_state',\ 376 'action', 'reward', 'done' 377 """ 378 transition = { 379 'obs': obs, 380 'next_obs': timestep.obs, 381 'prev_state': model_output['prev_state'], 382 'action': model_output['action'], 383 'reward': timestep.reward, 384 'done': timestep.done, 385 } 386 return transition 387 388 def _init_eval(self) -> None: 389 r""" 390 Overview: 391 Evaluate mode init method. Called by ``self.__init__``. 392 Init eval model with argmax strategy and the hidden_state plugin. 393 """ 394 self._eval_model = model_wrap( 395 self._model, 396 wrapper_name='hidden_state', 397 state_num=self._cfg.eval.env_num, 398 save_prev_state=True, 399 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 400 ) 401 self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample') 402 self._eval_model.reset() 403 404 def _forward_eval(self, data: dict) -> dict: 405 r""" 406 Overview: 407 Forward function of eval mode, similar to ``self._forward_collect``. 408 Arguments: 409 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 410 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 411 Returns: 412 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 413 ReturnsKeys 414 - necessary: ``action`` 415 """ 416 data_id = list(data.keys()) 417 data = default_collate(list(data.values())) 418 if self._cuda: 419 data = to_device(data, self._device) 420 data = {'obs': data} 421 self._eval_model.eval() 422 with torch.no_grad(): 423 output = self._eval_model.forward(data, data_id=data_id) 424 if self._cuda: 425 output = to_device(output, 'cpu') 426 output = default_decollate(output) 427 return {i: d for i, d in zip(data_id, output)} 428 429 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 430 r""" 431 Overview: 432 Reset eval model to the state indicated by data_id 433 Arguments: 434 - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ 435 the model state to the state indicated by data_id 436 """ 437 self._eval_model.reset(data_id=data_id) 438 439 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 440 r""" 441 Overview: 442 Get the train sample from trajectory. 443 Arguments: 444 - data (:obj:`list`): The trajectory's cache 445 Returns: 446 - samples (:obj:`dict`): The training samples generated 447 """ 448 return get_train_sample(data, self._unroll_len) 449 450 def _monitor_vars_learn(self) -> List[str]: 451 r""" 452 Overview: 453 Return variables' name if variables are to used in monitor. 454 Returns: 455 - vars (:obj:`List[str]`): Variables' name list. 456 """ 457 return ['cur_lr', 'total_loss', 'td_loss', 'opt_loss', 'nopt_loss', 'grad_norm']