Skip to content

ding.policy.coma

ding.policy.coma

COMAPolicy

Bases: Policy

Overview

Policy class of COMA algorithm. COMA is a multi model reinforcement learning algorithm

Interface: _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn\ _init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval\ _reset_eval, _get_train_sample, default_model, _monitor_vars_learn Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str coma | RL policy register name, refer to | this arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | this arg can be diff- | erent from modes 3 on_policy bool True | Whether the RL algorithm is on-policy | or off-policy 4. priority bool False | Whether use priority(PER) | priority sample, | update priority 5 | priority_ bool False | Whether use Importance Sampling | IS weight | IS_weight | Weight to correct biased update. 6 | learn.update int 1 | How many updates(iterations) to train | this args can be vary | _per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 7 | learn.target_ float 0.001 | Target network update momentum | between[0,1] | update_theta | parameter. 8 | learn.discount float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse | _factor | gamma | reward env 9 | learn.td_ float 0.8 | The trade-off factor of td-lambda, | lambda | which balances 1step td and mc 10 | learn.value_ float 1.0 | The loss weight of value network | policy network weight | weight | is set to 1 11 | learn.entropy_ float 0.01 | The loss weight of entropy | policy network weight | weight | regularization | is set to 1 == ==================== ======== ============== ======================================== =======================

default_model()

Overview

Return this algorithm default model setting for demonstration.

Returns: - model_info (:obj:Tuple[str, List[str]]): model name and mode import_names

.. note:: The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For coma, ding.model.coma.coma

Full Source Code

../ding/policy/coma.py

1from typing import List, Dict, Any, Tuple, Union, Optional 2from collections import namedtuple 3import torch 4import copy 5 6from ding.torch_utils import Adam, to_device 7from ding.rl_utils import coma_data, coma_error, get_train_sample 8from ding.model import model_wrap 9from ding.utils import POLICY_REGISTRY 10from ding.utils.data import default_collate, default_decollate, timestep_collate 11from .base_policy import Policy 12 13 14@POLICY_REGISTRY.register('coma') 15class COMAPolicy(Policy): 16 r""" 17 Overview: 18 Policy class of COMA algorithm. COMA is a multi model reinforcement learning algorithm 19 Interface: 20 _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn\ 21 _init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval\ 22 _reset_eval, _get_train_sample, default_model, _monitor_vars_learn 23 Config: 24 == ==================== ======== ============== ======================================== ======================= 25 ID Symbol Type Default Value Description Other(Shape) 26 == ==================== ======== ============== ======================================== ======================= 27 1 ``type`` str coma | RL policy register name, refer to | this arg is optional, 28 | registry ``POLICY_REGISTRY`` | a placeholder 29 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff- 30 | erent from modes 31 3 ``on_policy`` bool True | Whether the RL algorithm is on-policy 32 | or off-policy 33 4. ``priority`` bool False | Whether use priority(PER) | priority sample, 34 | update priority 35 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight 36 | ``IS_weight`` | Weight to correct biased update. 37 6 | ``learn.update`` int 1 | How many updates(iterations) to train | this args can be vary 38 | ``_per_collect`` | after collector's one collection. Only | from envs. Bigger val 39 | valid in serial training | means more off-policy 40 7 | ``learn.target_`` float 0.001 | Target network update momentum | between[0,1] 41 | ``update_theta`` | parameter. 42 8 | ``learn.discount`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse 43 | ``_factor`` | gamma | reward env 44 9 | ``learn.td_`` float 0.8 | The trade-off factor of td-lambda, 45 | ``lambda`` | which balances 1step td and mc 46 10 | ``learn.value_`` float 1.0 | The loss weight of value network | policy network weight 47 | ``weight`` | is set to 1 48 11 | ``learn.entropy_`` float 0.01 | The loss weight of entropy | policy network weight 49 | ``weight`` | regularization | is set to 1 50 == ==================== ======== ============== ======================================== ======================= 51 """ 52 config = dict( 53 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 54 type='coma', 55 # (bool) Whether to use cuda for network. 56 cuda=False, 57 # (bool) Whether the RL algorithm is on-policy or off-policy. 58 on_policy=False, 59 # (bool) Whether use priority(priority sample, IS weight, update priority) 60 priority=False, 61 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 62 priority_IS_weight=False, 63 learn=dict( 64 update_per_collect=20, 65 batch_size=32, 66 learning_rate=0.0005, 67 # ============================================================== 68 # The following configs is algorithm-specific 69 # ============================================================== 70 # (float) target network update weight, theta * new_w + (1 - theta) * old_w, defaults in [0, 0.1] 71 target_update_theta=0.001, 72 # (float) discount factor for future reward, defaults int [0, 1] 73 discount_factor=0.99, 74 # (float) the trade-off factor of td-lambda, which balances 1step td and mc(nstep td in practice) 75 td_lambda=0.8, 76 # (float) the loss weight of policy network network 77 policy_weight=0.001, 78 # (float) the loss weight of value network 79 value_weight=1, 80 # (float) the loss weight of entropy regularization 81 entropy_weight=0.01, 82 ), 83 collect=dict( 84 # (int) collect n_sample data, train model n_iteration time 85 # n_episode=32, 86 # (int) unroll length of a train iteration(gradient update step) 87 unroll_len=20, 88 ), 89 eval=dict(), 90 ) 91 92 def default_model(self) -> Tuple[str, List[str]]: 93 """ 94 Overview: 95 Return this algorithm default model setting for demonstration. 96 Returns: 97 - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names 98 99 .. note:: 100 The user can define and use customized network model but must obey the same inferface definition indicated \ 101 by import_names path. For coma, ``ding.model.coma.coma`` 102 """ 103 return 'coma', ['ding.model.template.coma'] 104 105 def _init_learn(self) -> None: 106 """ 107 Overview: 108 Init the learner model of COMAPolicy 109 110 Arguments: 111 .. note:: 112 113 The _init_learn method takes the argument from the self._cfg.learn in the config file 114 115 - learning_rate (:obj:`float`): The learning rate fo the optimizer 116 - gamma (:obj:`float`): The discount factor 117 - lambda (:obj:`float`): The lambda factor, determining the mix of bootstrapping\ 118 vs further accumulation of multistep returns at each timestep, 119 - value_wight(:obj:`float`): The weight of value loss in total loss 120 - entropy_weight(:obj:`float`): The weight of entropy loss in total loss 121 - agent_num (:obj:`int`): Since this is a multi-agent algorithm, we need to input the agent num. 122 - batch_size (:obj:`int`): Need batch size info to init hidden_state plugins 123 """ 124 self._priority = self._cfg.priority 125 self._priority_IS_weight = self._cfg.priority_IS_weight 126 assert not self._priority, "not implemented priority in COMA" 127 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 128 self._gamma = self._cfg.learn.discount_factor 129 self._lambda = self._cfg.learn.td_lambda 130 self._policy_weight = self._cfg.learn.policy_weight 131 self._value_weight = self._cfg.learn.value_weight 132 self._entropy_weight = self._cfg.learn.entropy_weight 133 134 self._target_model = copy.deepcopy(self._model) 135 self._target_model = model_wrap( 136 self._target_model, 137 wrapper_name='target', 138 update_type='momentum', 139 update_kwargs={'theta': self._cfg.learn.target_update_theta} 140 ) 141 self._target_model = model_wrap( 142 self._target_model, 143 wrapper_name='hidden_state', 144 state_num=self._cfg.learn.batch_size, 145 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 146 ) 147 self._learn_model = model_wrap( 148 self._model, 149 wrapper_name='hidden_state', 150 state_num=self._cfg.learn.batch_size, 151 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 152 ) 153 self._learn_model.reset() 154 self._target_model.reset() 155 156 def _data_preprocess_learn(self, data: List[Any]) -> dict: 157 r""" 158 Overview: 159 Preprocess the data to fit the required data format for learning 160 161 Arguments: 162 - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function, the Dict 163 in data should contain keys including at least ['obs', 'action', 'reward'] 164 165 Returns: 166 - data (:obj:`Dict[str, Any]`): the processed data, including at least \ 167 ['obs', 'action', 'reward', 'done', 'weight'] 168 """ 169 # data preprocess 170 data = timestep_collate(data) 171 assert set(data.keys()) > set(['obs', 'action', 'reward']) 172 if self._cuda: 173 data = to_device(data, self._device) 174 data['weight'] = data.get('weight', None) 175 data['done'] = data['done'].float() 176 return data 177 178 def _forward_learn(self, data: dict) -> Dict[str, Any]: 179 r""" 180 Overview: 181 Forward and backward function of learn mode, acquire the data and calculate the loss and\ 182 optimize learner model 183 Arguments: 184 - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ 185 np.ndarray or dict/list combinations. 186 Returns: 187 - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ 188 recorded in text log and tensorboard, values are python scalar or a list of scalars. 189 ArgumentsKeys: 190 - necessary: ``obs``, ``action``, ``reward``, ``done``, ``weight`` 191 ReturnsKeys: 192 - necessary: ``cur_lr``, ``total_loss``, ``policy_loss``, ``value_loss``, ``entropy_loss`` 193 - cur_lr (:obj:`float`): Current learning rate 194 - total_loss (:obj:`float`): The calculated loss 195 - policy_loss (:obj:`float`): The policy(actor) loss of coma 196 - value_loss (:obj:`float`): The value(critic) loss of coma 197 - entropy_loss (:obj:`float`): The entropy loss 198 """ 199 data = self._data_preprocess_learn(data) 200 # forward 201 self._learn_model.train() 202 self._target_model.train() 203 self._learn_model.reset(state=data['prev_state'][0]) 204 self._target_model.reset(state=data['prev_state'][0]) 205 q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 206 with torch.no_grad(): 207 target_q_value = self._target_model.forward(data, mode='compute_critic')['q_value'] 208 logit = self._learn_model.forward(data, mode='compute_actor')['logit'] 209 logit[data['obs']['action_mask'] == 0.0] = -9999999 210 211 data = coma_data(logit, data['action'], q_value, target_q_value, data['reward'], data['weight']) 212 coma_loss = coma_error(data, self._gamma, self._lambda) 213 total_loss = self._policy_weight * coma_loss.policy_loss + self._value_weight * coma_loss.q_value_loss - \ 214 self._entropy_weight * coma_loss.entropy_loss 215 216 # update 217 self._optimizer.zero_grad() 218 total_loss.backward() 219 self._optimizer.step() 220 # after update 221 self._target_model.update(self._learn_model.state_dict()) 222 return { 223 'cur_lr': self._optimizer.defaults['lr'], 224 'total_loss': total_loss.item(), 225 'policy_loss': coma_loss.policy_loss.item(), 226 'value_loss': coma_loss.q_value_loss.item(), 227 'entropy_loss': coma_loss.entropy_loss.item(), 228 } 229 230 def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: 231 self._learn_model.reset(data_id=data_id) 232 233 def _state_dict_learn(self) -> Dict[str, Any]: 234 return { 235 'model': self._learn_model.state_dict(), 236 'target_model': self._target_model.state_dict(), 237 'optimizer': self._optimizer.state_dict(), 238 } 239 240 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 241 self._learn_model.load_state_dict(state_dict['model']) 242 self._target_model.load_state_dict(state_dict['target_model']) 243 self._optimizer.load_state_dict(state_dict['optimizer']) 244 245 def _init_collect(self) -> None: 246 r""" 247 Overview: 248 Collect mode init moethod. Called by ``self.__init__``. 249 Init traj and unroll length, collect model. 250 Model has eps_greedy_sample wrapper and hidden state wrapper 251 """ 252 self._unroll_len = self._cfg.collect.unroll_len 253 self._collect_model = model_wrap( 254 self._model, 255 wrapper_name='hidden_state', 256 state_num=self._cfg.collect.env_num, 257 save_prev_state=True, 258 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 259 ) 260 self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample') 261 self._collect_model.reset() 262 263 def _forward_collect(self, data: dict, eps: float) -> dict: 264 r""" 265 Overview: 266 Collect output according to eps_greedy plugin 267 268 Arguments: 269 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 270 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 271 - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. 272 Returns: 273 - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. 274 ReturnsKeys 275 - necessary: ``action`` 276 """ 277 data_id = list(data.keys()) 278 data = default_collate(list(data.values())) 279 if self._cuda: 280 data = to_device(data, self._device) 281 data = {'obs': data} 282 self._collect_model.eval() 283 with torch.no_grad(): 284 output = self._collect_model.forward(data, eps=eps, data_id=data_id, mode='compute_actor') 285 if self._cuda: 286 output = to_device(output, 'cpu') 287 output = default_decollate(output) 288 return {i: d for i, d in zip(data_id, output)} 289 290 def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: 291 self._collect_model.reset(data_id=data_id) 292 293 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 294 r""" 295 Overview: 296 Generate dict type transition data from inputs. 297 Arguments: 298 - obs (:obj:`Any`): Env observation 299 - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state'] 300 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ 301 (here 'obs' indicates obs after env step). 302 Returns: 303 - transition (:obj:`dict`): Dict type transition data. 304 """ 305 transition = { 306 'obs': obs, 307 'next_obs': timestep.obs, 308 'prev_state': model_output['prev_state'], 309 'action': model_output['action'], 310 'reward': timestep.reward, 311 'done': timestep.done, 312 } 313 return transition 314 315 def _init_eval(self) -> None: 316 r""" 317 Overview: 318 Evaluate mode init method. Called by ``self.__init__``. 319 Init eval model with argmax strategy and hidden_state plugin. 320 """ 321 self._eval_model = model_wrap( 322 self._model, 323 wrapper_name='hidden_state', 324 state_num=self._cfg.eval.env_num, 325 save_prev_state=True, 326 init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] 327 ) 328 self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample') 329 self._eval_model.reset() 330 331 def _forward_eval(self, data: dict) -> dict: 332 r""" 333 Overview: 334 Forward function of eval mode, similar to ``self._forward_collect``. 335 Arguments: 336 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 337 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 338 Returns: 339 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 340 ReturnsKeys 341 - necessary: ``action`` 342 """ 343 data_id = list(data.keys()) 344 data = default_collate(list(data.values())) 345 if self._cuda: 346 data = to_device(data, self._device) 347 data = {'obs': data} 348 self._eval_model.eval() 349 with torch.no_grad(): 350 output = self._eval_model.forward(data, data_id=data_id, mode='compute_actor') 351 if self._cuda: 352 output = to_device(output, 'cpu') 353 output = default_decollate(output) 354 return {i: d for i, d in zip(data_id, output)} 355 356 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 357 self._eval_model.reset(data_id=data_id) 358 359 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 360 r""" 361 Overview: 362 Get the train sample from trajectory 363 364 Arguments: 365 - data (:obj:`list`): The trajectory's cache 366 367 Returns: 368 - samples (:obj:`dict`): The training samples generated 369 """ 370 return get_train_sample(data, self._unroll_len) 371 372 def _monitor_vars_learn(self) -> List[str]: 373 r""" 374 Overview: 375 Return variables' name if variables are to used in monitor. 376 Returns: 377 - vars (:obj:`List[str]`): Variables' name list. 378 """ 379 return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss']