Skip to content

ding.policy.collaq

ding.policy.collaq

CollaQPolicy

Bases: Policy

Overview

Policy class of CollaQ algorithm. CollaQ is a multi-agent reinforcement learning algorithm

Interface: _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn\ _init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval\ _reset_eval, _get_train_sample, default_model Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str collaq | RL policy register name, refer to | this arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool True | Whether to use cuda for network | this arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4. priority bool False | Whether use priority(PER) | priority sample, | update priority 5 | priority_ bool False | Whether use Importance Sampling | IS weight | IS_weight | Weight to correct biased update. 6 | learn.update_ int 20 | How many updates(iterations) to train | this args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 7 | learn.target_ float 0.001 | Target network update momentum | between[0,1] | update_theta | parameter. 8 | learn.discount float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse | _factor | gamma | reward env 9 | learn.collaq float 1.0 | The weight of collaq MARA loss | _loss_weight == ==================== ======== ============== ======================================== =======================

default_model()

Overview

Return this algorithm default model setting for demonstration.

Returns: - model_info (:obj:Tuple[str, List[str]]): model name and mode import_names

.. note:: The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For collaq, ding.model.collaq.CollaQ .

Full Source Code

../ding/policy/collaq.py

1from typing import List, Dict, Any, Tuple, Union, Optional 2from collections import namedtuple 3import torch 4import copy 5 6from ding.torch_utils import to_device, RMSprop 7from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample 8from ding.model import model_wrap 9from ding.utils import POLICY_REGISTRY 10from ding.utils.data import timestep_collate, default_collate, default_decollate 11from .base_policy import Policy 12from .common_utils import default_preprocess_learn 13 14 15@POLICY_REGISTRY.register('collaq') 16class CollaQPolicy(Policy): 17 r""" 18 Overview: 19 Policy class of CollaQ algorithm. CollaQ is a multi-agent reinforcement learning algorithm 20 Interface: 21 _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn\ 22 _init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval\ 23 _reset_eval, _get_train_sample, default_model 24 Config: 25 == ==================== ======== ============== ======================================== ======================= 26 ID Symbol Type Default Value Description Other(Shape) 27 == ==================== ======== ============== ======================================== ======================= 28 1 ``type`` str collaq | RL policy register name, refer to | this arg is optional, 29 | registry ``POLICY_REGISTRY`` | a placeholder 30 2 ``cuda`` bool True | Whether to use cuda for network | this arg can be diff- 31 | erent from modes 32 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 33 | or off-policy 34 4. ``priority`` bool False | Whether use priority(PER) | priority sample, 35 | update priority 36 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight 37 | ``IS_weight`` | Weight to correct biased update. 38 6 | ``learn.update_`` int 20 | How many updates(iterations) to train | this args can be vary 39 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 40 | valid in serial training | means more off-policy 41 7 | ``learn.target_`` float 0.001 | Target network update momentum | between[0,1] 42 | ``update_theta`` | parameter. 43 8 | ``learn.discount`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse 44 | ``_factor`` | gamma | reward env 45 9 | ``learn.collaq`` float 1.0 | The weight of collaq MARA loss 46 | ``_loss_weight`` 47 == ==================== ======== ============== ======================================== ======================= 48 """ 49 config = dict( 50 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 51 type='collaq', 52 # (bool) Whether to use cuda for network. 53 cuda=True, 54 # (bool) Whether the RL algorithm is on-policy or off-policy. 55 on_policy=False, 56 # (bool) Whether use priority(priority sample, IS weight, update priority) 57 priority=False, 58 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 59 priority_IS_weight=False, 60 learn=dict( 61 62 # (int) Collect n_episode data, update_model n_iteration times 63 update_per_collect=20, 64 # (int) The number of data for a train iteration 65 batch_size=32, 66 # (float) Gradient-descent step size 67 learning_rate=0.0005, 68 # ============================================================== 69 # The following configs is algorithm-specific 70 # ============================================================== 71 # (float) Target network update weight, theta * new_w + (1 - theta) * old_w, defaults in [0, 0.1] 72 target_update_theta=0.001, 73 # (float) Discount factor for future reward, defaults int [0, 1] 74 discount_factor=0.99, 75 # (float) The weight of collaq MARA loss 76 collaq_loss_weight=1.0, 77 # (float) 78 clip_value=100, 79 # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation) 80 double_q=False, 81 ), 82 collect=dict( 83 # (int) Only one of [n_sample, n_episode] shoule be set 84 # n_episode=32, 85 # (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps 86 # in each forward when training. In qmix, it is greater than 1 because there is RNN. 87 unroll_len=10, 88 ), 89 eval=dict(), 90 other=dict( 91 eps=dict( 92 # (str) Type of epsilon decay 93 type='exp', 94 # (float) Start value for epsilon decay, in [0, 1]. 95 # 0 means not use epsilon decay. 96 start=1, 97 # (float) Start value for epsilon decay, in [0, 1]. 98 end=0.05, 99 # (int) Decay length(env step) 100 decay=200000, 101 ), 102 replay_buffer=dict( 103 # (int) max size of replay buffer 104 replay_buffer_size=5000, 105 max_reuse=10, 106 ), 107 ), 108 ) 109 110 def default_model(self) -> Tuple[str, List[str]]: 111 """ 112 Overview: 113 Return this algorithm default model setting for demonstration. 114 Returns: 115 - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names 116 117 .. note:: 118 The user can define and use customized network model but must obey the same inferface definition indicated \ 119 by import_names path. For collaq, ``ding.model.collaq.CollaQ`` . 120 """ 121 return 'collaq', ['ding.model.template.collaq'] 122 123 def _init_learn(self) -> None: 124 """ 125 Overview: 126 Learn mode init method. Called by ``self.__init__``. 127 Init the learner model of CollaQPolicy 128 Arguments: 129 .. note:: 130 131 The _init_learn method takes the argument from the self._cfg.learn in the config file 132 133 - learning_rate (:obj:`float`): The learning rate fo the optimizer 134 - gamma (:obj:`float`): The discount factor 135 - alpha (:obj:`float`): The collaQ loss factor, the weight for calculating MARL loss 136 - agent_num (:obj:`int`): Since this is a multi-agent algorithm, we need to input the agent num. 137 - batch_size (:obj:`int`): Need batch size info to init hidden_state plugins 138 """ 139 self._priority = self._cfg.priority 140 self._priority_IS_weight = self._cfg.priority_IS_weight 141 self._optimizer = RMSprop( 142 params=self._model.parameters(), lr=self._cfg.learn.learning_rate, alpha=0.99, eps=0.00001 143 ) 144 self._gamma = self._cfg.learn.discount_factor 145 self._alpha = self._cfg.learn.collaq_loss_weight 146 147 self._target_model = copy.deepcopy(self._model) 148 self._target_model = model_wrap( 149 self._target_model, 150 wrapper_name='target', 151 update_type='momentum', 152 update_kwargs={'theta': self._cfg.learn.target_update_theta} 153 ) 154 self._target_model = model_wrap( 155 self._target_model, 156 wrapper_name='hidden_state', 157 state_num=self._cfg.learn.batch_size, 158 init_fn=lambda: [[None for _ in range(self._cfg.model.agent_num)] for _ in range(3)] 159 ) 160 self._learn_model = model_wrap( 161 self._model, 162 wrapper_name='hidden_state', 163 state_num=self._cfg.learn.batch_size, 164 init_fn=lambda: [[None for _ in range(self._cfg.model.agent_num)] for _ in range(3)] 165 ) 166 self._learn_model.reset() 167 self._target_model.reset() 168 169 def _data_preprocess_learn( 170 self, 171 data: List[Any], 172 use_priority_IS_weight: bool = False, 173 use_priority: bool = False, 174 ) -> dict: 175 r""" 176 Overview: 177 Preprocess the data to fit the required data format for learning 178 179 Arguments: 180 - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function 181 182 Returns: 183 - data (:obj:`Dict[str, Any]`): the processed data, from \ 184 [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])} 185 """ 186 # data preprocess 187 data = timestep_collate(data) 188 if self._cuda: 189 data = to_device(data, self._device) 190 if use_priority_IS_weight: 191 assert use_priority, "Use IS Weight correction, but Priority is not used." 192 if use_priority and use_priority_IS_weight: 193 if 'priority_IS' in data: 194 data['weight'] = data['priority_IS'] 195 else: # for compability 196 data['weight'] = data['IS'] 197 else: 198 data['weight'] = data.get('weight', None) 199 data['done'] = data['done'].float() 200 return data 201 202 def _forward_learn(self, data: dict) -> Dict[str, Any]: 203 r""" 204 Overview: 205 Forward and backward function of learn mode. 206 Arguments: 207 - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ 208 np.ndarray or dict/list combinations. 209 Returns: 210 - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ 211 recorded in text log and tensorboard, values are python scalar or a list of scalars. 212 ArgumentsKeys: 213 - necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done`` 214 ReturnsKeys: 215 - necessary: ``cur_lr``, ``total_loss`` 216 - cur_lr (:obj:`float`): Current learning rate 217 - total_loss (:obj:`float`): The calculated loss 218 """ 219 data = self._data_preprocess_learn(data, self.cfg.priority_IS_weight, self.cfg.priority) 220 # ==================== 221 # CollaQ forward 222 # ==================== 223 self._learn_model.train() 224 self._target_model.train() 225 # for hidden_state plugin, we need to reset the main model and target model 226 self._learn_model.reset(state=data['prev_state'][0]) 227 self._target_model.reset(state=data['prev_state'][0]) 228 inputs = {'obs': data['obs'], 'action': data['action']} 229 ret = self._learn_model.forward(inputs, single_step=False) 230 total_q = ret['total_q'] 231 agent_colla_alone_q = ret['agent_colla_alone_q'].sum(-1).sum(-1) 232 233 if self._cfg.learn.double_q: 234 next_inputs = {'obs': data['next_obs']} 235 logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach() 236 next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)} 237 else: 238 next_inputs = {'obs': data['next_obs']} 239 with torch.no_grad(): 240 target_total_q = self._target_model.forward(next_inputs, single_step=False)['total_q'] 241 242 # td_loss calculation 243 td_data = v_1step_td_data(total_q, target_total_q, data['reward'], data['done'], data['weight']) 244 td_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma) 245 # collaQ loss calculation 246 colla_loss = (agent_colla_alone_q ** 2).mean() 247 # combine loss with factor 248 loss = colla_loss * self._alpha + td_loss 249 # ==================== 250 # CollaQ update 251 # ==================== 252 self._optimizer.zero_grad() 253 loss.backward() 254 grad_norm = torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.learn.clip_value) 255 self._optimizer.step() 256 # ============= 257 # after update 258 # ============= 259 self._target_model.update(self._learn_model.state_dict()) 260 return { 261 'cur_lr': self._optimizer.defaults['lr'], 262 'total_loss': loss.item(), 263 'colla_loss': colla_loss.item(), 264 'td_loss': td_loss.item(), 265 'grad_norm': grad_norm, 266 'priority': torch.mean(td_error_per_sample.abs(), dim=0).tolist(), 267 } 268 269 def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: 270 r""" 271 Overview: 272 Reset learn model to the state indicated by data_id 273 Arguments: 274 - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ 275 the model state to the state indicated by data_id 276 """ 277 self._learn_model.reset(data_id=data_id) 278 279 def _state_dict_learn(self) -> Dict[str, Any]: 280 r""" 281 Overview: 282 Return the state_dict of learn mode, usually including model and optimizer. 283 Returns: 284 - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. 285 """ 286 return { 287 'model': self._learn_model.state_dict(), 288 'target_model': self._target_model.state_dict(), 289 'optimizer': self._optimizer.state_dict(), 290 } 291 292 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 293 r""" 294 Overview: 295 Load the state_dict variable into policy learn mode. 296 Arguments: 297 - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. 298 299 .. tip:: 300 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 301 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 302 complicated operation. 303 """ 304 self._learn_model.load_state_dict(state_dict['model']) 305 self._target_model.load_state_dict(state_dict['target_model']) 306 self._optimizer.load_state_dict(state_dict['optimizer']) 307 308 def _init_collect(self) -> None: 309 r""" 310 Overview: 311 Collect mode init method. Called by ``self.__init__``. 312 Init traj and unroll length, collect model. 313 Enable the eps_greedy_sample and the hidden_state plugin. 314 """ 315 self._unroll_len = self._cfg.collect.unroll_len 316 self._collect_model = model_wrap( 317 self._model, 318 wrapper_name='hidden_state', 319 state_num=self._cfg.collect.env_num, 320 save_prev_state=True, 321 init_fn=lambda: [[None for _ in range(self._cfg.model.agent_num)] for _ in range(3)] 322 ) 323 self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample') 324 self._collect_model.reset() 325 326 def _forward_collect(self, data: dict, eps: float) -> dict: 327 r""" 328 Overview: 329 Forward function for collect mode with eps_greedy 330 Arguments: 331 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 332 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 333 - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. 334 Returns: 335 - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. 336 ReturnsKeys 337 - necessary: ``action`` 338 """ 339 data_id = list(data.keys()) 340 data = default_collate(list(data.values())) 341 if self._cuda: 342 data = to_device(data, self._device) 343 data = {'obs': data} 344 self._collect_model.eval() 345 with torch.no_grad(): 346 output = self._collect_model.forward(data, eps=eps, data_id=data_id) 347 if self._cuda: 348 output = to_device(output, 'cpu') 349 output = default_decollate(output) 350 return {i: d for i, d in zip(data_id, output)} 351 352 def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: 353 r""" 354 Overview: 355 Reset collect model to the state indicated by data_id 356 Arguments: 357 - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ 358 the model state to the state indicated by data_id 359 """ 360 self._collect_model.reset(data_id=data_id) 361 362 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 363 r""" 364 Overview: 365 Generate dict type transition data from inputs. 366 Arguments: 367 - obs (:obj:`Any`): Env observation 368 - model_output (:obj:`dict`): Output of collect model, including at least \ 369 ['action', 'prev_state', 'agent_colla_alone_q'] 370 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ 371 (here 'obs' indicates obs after env step). 372 Returns: 373 - transition (:obj:`dict`): Dict type transition data. 374 """ 375 transition = { 376 'obs': obs, 377 'next_obs': timestep.obs, 378 'prev_state': model_output['prev_state'], 379 'action': model_output['action'], 380 'agent_colla_alone_q': model_output['agent_colla_alone_q'], 381 'reward': timestep.reward, 382 'done': timestep.done, 383 } 384 return transition 385 386 def _init_eval(self) -> None: 387 r""" 388 Overview: 389 Evaluate mode init method. Called by ``self.__init__``. 390 Init eval model with argmax strategy and the hidden_state plugin. 391 """ 392 self._eval_model = model_wrap( 393 self._model, 394 wrapper_name='hidden_state', 395 state_num=self._cfg.eval.env_num, 396 save_prev_state=True, 397 init_fn=lambda: [[None for _ in range(self._cfg.model.agent_num)] for _ in range(3)] 398 ) 399 self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample') 400 self._eval_model.reset() 401 402 def _forward_eval(self, data: dict) -> dict: 403 r""" 404 Overview: 405 Forward function for eval mode, similar to ``self._forward_collect``. 406 Arguments: 407 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 408 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 409 Returns: 410 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 411 ReturnsKeys 412 - necessary: ``action`` 413 """ 414 data_id = list(data.keys()) 415 data = default_collate(list(data.values())) 416 if self._cuda: 417 data = to_device(data, self._device) 418 data = {'obs': data} 419 self._eval_model.eval() 420 with torch.no_grad(): 421 output = self._eval_model.forward(data, data_id=data_id) 422 if self._cuda: 423 output = to_device(output, 'cpu') 424 output = default_decollate(output) 425 return {i: d for i, d in zip(data_id, output)} 426 427 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 428 r""" 429 Overview: 430 Reset eval model to the state indicated by data_id 431 Arguments: 432 - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ 433 the model state to the state indicated by data_id 434 """ 435 self._eval_model.reset(data_id=data_id) 436 437 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 438 r""" 439 Overview: 440 Get the train sample from trajectory. 441 Arguments: 442 - data (:obj:`list`): The trajectory's cache 443 Returns: 444 - samples (:obj:`dict`): The training samples generated 445 """ 446 return get_train_sample(data, self._unroll_len) 447 448 def _monitor_vars_learn(self) -> List[str]: 449 r""" 450 Overview: 451 Return variables' name if variables are to used in monitor. 452 Returns: 453 - vars (:obj:`List[str]`): Variables' name list. 454 """ 455 return ['cur_lr', 'total_loss', 'colla_loss', 'td_loss', 'grad_norm']