Skip to content

ding.policy.madqn

ding.policy.madqn

MADQNPolicy

Bases: QMIXPolicy

default_model()

Overview

Return this algorithm default model setting for demonstration.

Returns: - model_info (:obj:Tuple[str, List[str]]): model name and mode import_names

Full Source Code

../ding/policy/madqn.py

1from typing import List, Dict, Any, Tuple, Union, Optional 2from collections import namedtuple 3import torch 4import copy 5 6from ding.torch_utils import RMSprop, to_device 7from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ 8 v_nstep_td_data, v_nstep_td_error, get_nstep_return_data 9from ding.model import model_wrap 10from ding.utils import POLICY_REGISTRY 11from ding.utils.data import timestep_collate, default_collate, default_decollate 12from .qmix import QMIXPolicy 13 14 15@POLICY_REGISTRY.register('madqn') 16class MADQNPolicy(QMIXPolicy): 17 config = dict( 18 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 19 type='madqn', 20 # (bool) Whether to use cuda for network. 21 cuda=True, 22 # (bool) Whether the RL algorithm is on-policy or off-policy. 23 on_policy=False, 24 # (bool) Whether use priority(priority sample, IS weight, update priority) 25 priority=False, 26 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 27 priority_IS_weight=False, 28 nstep=3, 29 learn=dict( 30 update_per_collect=20, 31 batch_size=32, 32 learning_rate=0.0005, 33 clip_value=100, 34 # ============================================================== 35 # The following configs is algorithm-specific 36 # ============================================================== 37 # (float) Target network update momentum parameter. 38 # in [0, 1]. 39 target_update_theta=0.008, 40 # (float) The discount factor for future rewards, 41 # in [0, 1]. 42 discount_factor=0.99, 43 # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation) 44 double_q=False, 45 weight_decay=1e-5, 46 ), 47 collect=dict( 48 # (int) Only one of [n_sample, n_episode] shoule be set 49 n_episode=32, 50 # (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps 51 # in each forward when training. In qmix, it is greater than 1 because there is RNN. 52 unroll_len=10, 53 ), 54 eval=dict(), 55 other=dict( 56 eps=dict( 57 # (str) Type of epsilon decay 58 type='exp', 59 # (float) Start value for epsilon decay, in [0, 1]. 60 # 0 means not use epsilon decay. 61 start=1, 62 # (float) Start value for epsilon decay, in [0, 1]. 63 end=0.05, 64 # (int) Decay length(env step) 65 decay=50000, 66 ), 67 replay_buffer=dict( 68 replay_buffer_size=5000, 69 # (int) The maximum reuse times of each data 70 max_reuse=1e+9, 71 max_staleness=1e+9, 72 ), 73 ), 74 ) 75 76 def default_model(self) -> Tuple[str, List[str]]: 77 """ 78 Overview: 79 Return this algorithm default model setting for demonstration. 80 Returns: 81 - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names 82 """ 83 return 'madqn', ['ding.model.template.madqn'] 84 85 def _init_learn(self) -> None: 86 self._priority = self._cfg.priority 87 self._priority_IS_weight = self._cfg.priority_IS_weight 88 assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in QMIX" 89 self._optimizer_current = RMSprop( 90 params=self._model.current.parameters(), 91 lr=self._cfg.learn.learning_rate, 92 alpha=0.99, 93 eps=0.00001, 94 weight_decay=self._cfg.learn.weight_decay 95 ) 96 self._optimizer_cooperation = RMSprop( 97 params=self._model.cooperation.parameters(), 98 lr=self._cfg.learn.learning_rate, 99 alpha=0.99, 100 eps=0.00001, 101 weight_decay=self._cfg.learn.weight_decay 102 ) 103 self._gamma = self._cfg.learn.discount_factor 104 self._nstep = self._cfg.nstep 105 self._target_model = copy.deepcopy(self._model) 106 self._target_model = model_wrap( 107 self._target_model, 108 wrapper_name='target', 109 update_type='momentum', 110 update_kwargs={'theta': self._cfg.learn.target_update_theta} 111 ) 112 self._target_model = model_wrap( 113 self._target_model, 114 wrapper_name='hidden_state', 115 state_num=self._cfg.learn.batch_size, 116 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 117 ) 118 self._learn_model = model_wrap( 119 self._model, 120 wrapper_name='hidden_state', 121 state_num=self._cfg.learn.batch_size, 122 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 123 ) 124 self._learn_model.reset() 125 self._target_model.reset() 126 127 def _data_preprocess_learn(self, data: List[Any]) -> dict: 128 r""" 129 Overview: 130 Preprocess the data to fit the required data format for learning 131 Arguments: 132 - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function 133 Returns: 134 - data (:obj:`Dict[str, Any]`): the processed data, from \ 135 [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])} 136 """ 137 # data preprocess 138 data = timestep_collate(data) 139 if self._cuda: 140 data = to_device(data, self._device) 141 data['weight'] = data.get('weight', None) 142 data['done'] = data['done'].float() 143 return data 144 145 def _forward_learn(self, data: dict) -> Dict[str, Any]: 146 r""" 147 Overview: 148 Forward and backward function of learn mode. 149 Arguments: 150 - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ 151 np.ndarray or dict/list combinations. 152 Returns: 153 - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ 154 recorded in text log and tensorboard, values are python scalar or a list of scalars. 155 ArgumentsKeys: 156 - necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done`` 157 ReturnsKeys: 158 - necessary: ``cur_lr``, ``total_loss`` 159 - cur_lr (:obj:`float`): Current learning rate 160 - total_loss (:obj:`float`): The calculated loss 161 """ 162 data = self._data_preprocess_learn(data) 163 # ==================== 164 # Q-mix forward 165 # ==================== 166 self._learn_model.train() 167 self._target_model.train() 168 # for hidden_state plugin, we need to reset the main model and target model 169 self._learn_model.reset(state=data['prev_state'][0]) 170 self._target_model.reset(state=data['prev_state'][0]) 171 inputs = {'obs': data['obs'], 'action': data['action']} 172 173 total_q = self._learn_model.forward(inputs, single_step=False)['total_q'] 174 175 if self._cfg.learn.double_q: 176 next_inputs = {'obs': data['next_obs']} 177 self._learn_model.reset(state=data['prev_state'][1]) 178 logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach() 179 next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)} 180 else: 181 next_inputs = {'obs': data['next_obs']} 182 with torch.no_grad(): 183 target_total_q = self._target_model.forward(next_inputs, cooperation=True, single_step=False)['total_q'] 184 185 if self._nstep == 1: 186 187 v_data = v_1step_td_data(total_q, target_total_q, data['reward'], data['done'], data['weight']) 188 loss, td_error_per_sample = v_1step_td_error(v_data, self._gamma) 189 # for visualization 190 with torch.no_grad(): 191 if data['done'] is not None: 192 target_v = self._gamma * (1 - data['done']) * target_total_q + data['reward'] 193 else: 194 target_v = self._gamma * target_total_q + data['reward'] 195 else: 196 data['reward'] = data['reward'].permute(0, 2, 1).contiguous() 197 loss = [] 198 td_error_per_sample = [] 199 for t in range(self._cfg.collect.unroll_len): 200 v_data = v_nstep_td_data( 201 total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], None 202 ) 203 # calculate v_nstep_td critic_loss 204 loss_i, td_error_per_sample_i = v_nstep_td_error(v_data, self._gamma, self._nstep) 205 loss.append(loss_i) 206 td_error_per_sample.append(td_error_per_sample_i) 207 loss = sum(loss) / (len(loss) + 1e-8) 208 td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8) 209 210 self._optimizer_current.zero_grad() 211 loss.backward() 212 grad_norm = torch.nn.utils.clip_grad_norm_(self._model.current.parameters(), self._cfg.learn.clip_value) 213 self._optimizer_current.step() 214 215 # cooperation 216 self._learn_model.reset(state=data['prev_state'][0]) 217 self._target_model.reset(state=data['prev_state'][0]) 218 cooperation_total_q = self._learn_model.forward(inputs, cooperation=True, single_step=False)['total_q'] 219 next_inputs = {'obs': data['next_obs']} 220 with torch.no_grad(): 221 cooperation_target_total_q = self._target_model.forward( 222 next_inputs, cooperation=True, single_step=False 223 )['total_q'] 224 225 if self._nstep == 1: 226 v_data = v_1step_td_data( 227 cooperation_total_q, cooperation_target_total_q, data['reward'], data['done'], data['weight'] 228 ) 229 cooperation_loss, _ = v_1step_td_error(v_data, self._gamma) 230 else: 231 cooperation_loss_all = [] 232 for t in range(self._cfg.collect.unroll_len): 233 v_data = v_nstep_td_data( 234 cooperation_total_q[t], 235 cooperation_target_total_q[t], 236 data['reward'][t], 237 data['done'][t], 238 data['weight'], 239 None, 240 ) 241 cooperation_loss, _ = v_nstep_td_error(v_data, self._gamma, self._nstep) 242 cooperation_loss_all.append(cooperation_loss) 243 cooperation_loss = sum(cooperation_loss_all) / (len(cooperation_loss_all) + 1e-8) 244 self._optimizer_cooperation.zero_grad() 245 cooperation_loss.backward() 246 cooperation_grad_norm = torch.nn.utils.clip_grad_norm_( 247 self._model.cooperation.parameters(), self._cfg.learn.clip_value 248 ) 249 self._optimizer_cooperation.step() 250 251 # ============= 252 # after update 253 # ============= 254 self._target_model.update(self._learn_model.state_dict()) 255 return { 256 'cur_lr': self._optimizer_current.defaults['lr'], 257 'total_loss': loss.item(), 258 'total_q': total_q.mean().item() / self._cfg.model.agent_num, 259 'target_total_q': target_total_q.mean().item() / self._cfg.model.agent_num, 260 'grad_norm': grad_norm, 261 'cooperation_grad_norm': cooperation_grad_norm, 262 'cooperation_loss': cooperation_loss.item(), 263 } 264 265 def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: 266 r""" 267 Overview: 268 Reset learn model to the state indicated by data_id 269 Arguments: 270 - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ 271 the model state to the state indicated by data_id 272 """ 273 self._learn_model.reset(data_id=data_id) 274 275 def _state_dict_learn(self) -> Dict[str, Any]: 276 r""" 277 Overview: 278 Return the state_dict of learn mode, usually including model and optimizer. 279 Returns: 280 - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. 281 """ 282 return { 283 'model': self._learn_model.state_dict(), 284 'target_model': self._target_model.state_dict(), 285 'optimizer_current': self._optimizer_current.state_dict(), 286 'optimizer_cooperation': self._optimizer_cooperation.state_dict(), 287 } 288 289 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 290 """ 291 Overview: 292 Load the state_dict variable into policy learn mode. 293 Arguments: 294 - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. 295 296 .. tip:: 297 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 298 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 299 complicated operation. 300 """ 301 self._learn_model.load_state_dict(state_dict['model']) 302 self._target_model.load_state_dict(state_dict['target_model']) 303 self._optimizer_current.load_state_dict(state_dict['optimizer_current']) 304 self._optimizer_cooperation.load_state_dict(state_dict['optimizer_cooperation']) 305 306 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 307 r""" 308 Overview: 309 Generate dict type transition data from inputs. 310 Arguments: 311 - obs (:obj:`Any`): Env observation 312 - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state'] 313 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ 314 (here 'obs' indicates obs after env step). 315 Returns: 316 - transition (:obj:`dict`): Dict type transition data, including 'obs', 'next_obs', 'prev_state',\ 317 'action', 'reward', 'done' 318 """ 319 transition = { 320 'obs': obs, 321 'next_obs': timestep.obs, 322 'prev_state': model_output['prev_state'], 323 'action': model_output['action'], 324 'reward': timestep.reward, 325 'done': timestep.done, 326 } 327 return transition 328 329 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 330 r""" 331 Overview: 332 Get the train sample from trajectory. 333 Arguments: 334 - data (:obj:`list`): The trajectory's cache 335 Returns: 336 - samples (:obj:`dict`): The training samples generated 337 """ 338 if self._cfg.nstep == 1: 339 return get_train_sample(data, self._unroll_len) 340 else: 341 data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) 342 return get_train_sample(data, self._unroll_len) 343 344 def _monitor_vars_learn(self) -> List[str]: 345 r""" 346 Overview: 347 Return variables' name if variables are to used in monitor. 348 Returns: 349 - vars (:obj:`List[str]`): Variables' name list. 350 """ 351 return [ 352 'cur_lr', 'total_loss', 'total_q', 'target_total_q', 'grad_norm', 'target_reward_total_q', 353 'cooperation_grad_norm', 'cooperation_loss' 354 ]