ding.policy.atoc¶
ding.policy.atoc
¶
Full Source Code
../ding/policy/atoc.py
1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import copy 4import torch 5 6from ding.torch_utils import Adam, to_device 7from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample 8from ding.model import model_wrap 9from ding.utils import POLICY_REGISTRY 10from ding.utils.data import default_collate, default_decollate 11from .base_policy import Policy 12from .common_utils import default_preprocess_learn 13 14 15@POLICY_REGISTRY.register('atoc') 16class ATOCPolicy(Policy): 17 r""" 18 Overview: 19 Policy class of ATOC algorithm. 20 Interface: 21 __init__, set_setting, __repr__, state_dict_handle 22 Property: 23 learn_mode, collect_mode, eval_mode 24 """ 25 26 config = dict( 27 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 28 type='atoc', 29 # (bool) Whether to use cuda for network. 30 cuda=False, 31 # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same) 32 on_policy=False, 33 # (bool) Whether use priority(priority sample, IS weight, update priority) 34 priority=False, 35 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 36 priority_IS_weight=False, 37 model=dict( 38 # (bool) Whether to use communication module in ATOC, if not, it is a multi-agent DDPG 39 communication=True, 40 # (int) The number of thought size 41 thought_size=8, 42 # (int) The number of agent for each communication group 43 agent_per_group=2, 44 ), 45 learn=dict( 46 # (int) Collect n_sample data, update model n_iteration time 47 update_per_collect=5, 48 # (int) The number of data for a train iteration 49 batch_size=64, 50 # (float) Gradient-descent step size of actor 51 learning_rate_actor=0.001, 52 # (float) Gradient-descent step size of critic 53 learning_rate_critic=0.001, 54 # ============================================================== 55 # The following configs is algorithm-specific 56 # ============================================================== 57 # (float) Target network update weight, theta * new_w + (1 - theta) * old_w, defaults in [0, 0.1] 58 target_theta=0.005, 59 # (float) Discount factor for future reward, defaults int [0, 1] 60 discount_factor=0.99, 61 # (bool) Whether to use communication module in ATOC, if not, it is a multi-agent DDPG 62 communication=True, 63 # (int) The frequency of actor update, each critic update 64 actor_update_freq=1, 65 # (bool) Whether use noise in action output when learning 66 noise=True, 67 # (float) The std of noise distribution for target policy smooth 68 noise_sigma=0.15, 69 # (float, float) The minimum and maximum value of noise 70 noise_range=dict( 71 min=-0.5, 72 max=0.5, 73 ), 74 # (bool) Whether to use reward batch norm in the total batch 75 reward_batch_norm=False, 76 ignore_done=False, 77 ), 78 collect=dict( 79 # (int) Collect n_sample data, update model n_iteration time 80 # n_sample=64, 81 # (int) Unroll length of a train iteration(gradient update step) 82 unroll_len=1, 83 # ============================================================== 84 # The following configs is algorithm-specific 85 # ============================================================== 86 # (float) The std of noise distribution for exploration 87 noise_sigma=0.4, 88 ), 89 eval=dict(), 90 other=dict( 91 replay_buffer=dict( 92 # (int) The max size of replay buffer 93 replay_buffer_size=100000, 94 # (int) The max use count of data, if count is bigger than this value, the data will be removed 95 max_use=10, 96 ), 97 ), 98 ) 99 100 def default_model(self) -> Tuple[str, List[str]]: 101 return 'atoc', ['ding.model.template.atoc'] 102 103 def _init_learn(self) -> None: 104 r""" 105 Overview: 106 Learn mode init method. Called by ``self.__init__``. 107 Init actor and critic optimizers, algorithm config, main and target models. 108 """ 109 self._priority = self._cfg.priority 110 self._priority_IS_weight = self._cfg.priority_IS_weight 111 assert not self._priority and not self._priority_IS_weight 112 # algorithm config 113 self._communication = self._cfg.learn.communication 114 self._gamma = self._cfg.learn.discount_factor 115 self._actor_update_freq = self._cfg.learn.actor_update_freq 116 # actor and critic optimizer 117 self._optimizer_actor = Adam( 118 self._model.actor.parameters(), 119 lr=self._cfg.learn.learning_rate_actor, 120 ) 121 self._optimizer_critic = Adam( 122 self._model.critic.parameters(), 123 lr=self._cfg.learn.learning_rate_critic, 124 ) 125 if self._communication: 126 self._optimizer_actor_attention = Adam( 127 self._model.actor.attention.parameters(), 128 lr=self._cfg.learn.learning_rate_actor, 129 ) 130 self._reward_batch_norm = self._cfg.learn.reward_batch_norm 131 132 # main and target models 133 self._target_model = copy.deepcopy(self._model) 134 self._target_model = model_wrap( 135 self._target_model, 136 wrapper_name='target', 137 update_type='momentum', 138 update_kwargs={'theta': self._cfg.learn.target_theta} 139 ) 140 if self._cfg.learn.noise: 141 self._target_model = model_wrap( 142 self._target_model, 143 wrapper_name='action_noise', 144 noise_type='gauss', 145 noise_kwargs={ 146 'mu': 0.0, 147 'sigma': self._cfg.learn.noise_sigma 148 }, 149 noise_range=self._cfg.learn.noise_range 150 ) 151 self._learn_model = model_wrap(self._model, wrapper_name='base') 152 self._learn_model.reset() 153 self._target_model.reset() 154 155 self._forward_learn_cnt = 0 # count iterations 156 157 def _forward_learn(self, data: dict) -> Dict[str, Any]: 158 r""" 159 Overview: 160 Forward and backward function of learn mode. 161 Arguments: 162 - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] 163 Returns: 164 - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses. 165 """ 166 loss_dict = {} 167 data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) 168 if self._cuda: 169 data = to_device(data, self._device) 170 # ==================== 171 # critic learn forward 172 # ==================== 173 self._learn_model.train() 174 self._target_model.train() 175 next_obs = data['next_obs'] 176 reward = data['reward'] 177 if self._reward_batch_norm: 178 reward = (reward - reward.mean()) / (reward.std() + 1e-8) 179 # current q value 180 q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 181 # target q value. 182 with torch.no_grad(): 183 next_action = self._target_model.forward(next_obs, mode='compute_actor')['action'] 184 next_data = {'obs': next_obs, 'action': next_action} 185 target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] 186 td_data = v_1step_td_data(q_value.mean(-1), target_q_value.mean(-1), reward, data['done'], data['weight']) 187 critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma) 188 loss_dict['critic_loss'] = critic_loss 189 # ================ 190 # critic update 191 # ================ 192 self._optimizer_critic.zero_grad() 193 critic_loss.backward() 194 self._optimizer_critic.step() 195 # =============================== 196 # actor learn forward and update 197 # =============================== 198 # actor updates every ``self._actor_update_freq`` iters 199 if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0: 200 if self._communication: 201 output = self._learn_model.forward(data['obs'], mode='compute_actor', get_delta_q=False) 202 output['delta_q'] = data['delta_q'] 203 attention_loss = self._learn_model.forward(output, mode='optimize_actor_attention')['loss'] 204 loss_dict['attention_loss'] = attention_loss 205 self._optimizer_actor_attention.zero_grad() 206 attention_loss.backward() 207 self._optimizer_actor_attention.step() 208 209 output = self._learn_model.forward(data['obs'], mode='compute_actor', get_delta_q=False) 210 211 critic_input = {'obs': data['obs'], 'action': output['action']} 212 actor_loss = -self._learn_model.forward(critic_input, mode='compute_critic')['q_value'].mean() 213 loss_dict['actor_loss'] = actor_loss 214 # actor update 215 self._optimizer_actor.zero_grad() 216 actor_loss.backward() 217 self._optimizer_actor.step() 218 # ============= 219 # after update 220 # ============= 221 loss_dict['total_loss'] = sum(loss_dict.values()) 222 self._forward_learn_cnt += 1 223 self._target_model.update(self._learn_model.state_dict()) 224 return { 225 'cur_lr_actor': self._optimizer_actor.defaults['lr'], 226 'cur_lr_critic': self._optimizer_critic.defaults['lr'], 227 'priority': td_error_per_sample.abs().tolist(), 228 'q_value': q_value.mean().item(), 229 **loss_dict, 230 } 231 232 def _state_dict_learn(self) -> Dict[str, Any]: 233 return { 234 'model': self._learn_model.state_dict(), 235 'target_model': self._target_model.state_dict(), 236 'optimizer_actor': self._optimizer_actor.state_dict(), 237 'optimizer_critic': self._optimizer_critic.state_dict(), 238 'optimize_actor_attention': self._optimizer_actor_attention.state_dict(), 239 } 240 241 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 242 self._learn_model.load_state_dict(state_dict['model']) 243 self._target_model.load_state_dict(state_dict['target_model']) 244 self._optimizer_actor.load_state_dict(state_dict['optimizer_actor']) 245 self._optimizer_critic.load_state_dict(state_dict['optimizer_critic']) 246 self._optimizer_actor_attention.load_state_dict(state_dict['optimize_actor_attention']) 247 248 def _init_collect(self) -> None: 249 r""" 250 Overview: 251 Collect mode init method. Called by ``self.__init__``. 252 Init traj and unroll length, collect model. 253 """ 254 self._unroll_len = self._cfg.collect.unroll_len 255 # collect model 256 self._collect_model = model_wrap( 257 self._model, 258 wrapper_name='action_noise', 259 noise_type='gauss', 260 noise_kwargs={ 261 'mu': 0.0, 262 'sigma': self._cfg.collect.noise_sigma 263 }, 264 noise_range=None, # no noise clip in actor 265 ) 266 self._collect_model.reset() 267 268 def _forward_collect(self, data: dict) -> dict: 269 r""" 270 Overview: 271 Forward function of collect mode. 272 Arguments: 273 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 274 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 275 Returns: 276 - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. 277 ReturnsKeys 278 - necessary: ``action`` 279 """ 280 data_id = list(data.keys()) 281 data = default_collate(list(data.values())) 282 if self._cuda: 283 data = to_device(data, self._device) 284 self._collect_model.eval() 285 with torch.no_grad(): 286 output = self._collect_model.forward(data, mode='compute_actor', get_delta_q=True) 287 if self._cuda: 288 output = to_device(output, 'cpu') 289 output = default_decollate(output) 290 return {i: d for i, d in zip(data_id, output)} 291 292 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> Dict[str, Any]: 293 r""" 294 Overview: 295 Generate dict type transition data from inputs. 296 Arguments: 297 - obs (:obj:`Any`): Env observation 298 - model_output (:obj:`dict`): Output of collect model, including at least ['action'] 299 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ 300 (here 'obs' indicates obs after env step, i.e. next_obs). 301 Return: 302 - transition (:obj:`Dict[str, Any]`): Dict type transition data. 303 """ 304 if self._communication: 305 transition = { 306 'obs': obs, 307 'next_obs': timestep.obs, 308 'action': model_output['action'], 309 'delta_q': model_output['delta_q'], 310 'reward': timestep.reward, 311 'done': timestep.done, 312 } 313 else: 314 transition = { 315 'obs': obs, 316 'next_obs': timestep.obs, 317 'action': model_output['action'], 318 'reward': timestep.reward, 319 'done': timestep.done, 320 } 321 return transition 322 323 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 324 if self._communication: 325 delta_q_batch = [d['delta_q'] for d in data] 326 delta_min = torch.stack(delta_q_batch).min() 327 delta_max = torch.stack(delta_q_batch).max() 328 for i in range(len(data)): 329 data[i]['delta_q'] = (data[i]['delta_q'] - delta_min) / (delta_max - delta_min + 1e-8) 330 return get_train_sample(data, self._unroll_len) 331 332 def _init_eval(self) -> None: 333 r""" 334 Overview: 335 Evaluate mode init method. Called by ``self.__init__``. 336 Init eval model. Unlike learn and collect model, eval model does not need noise. 337 """ 338 self._eval_model = model_wrap(self._model, wrapper_name='base') 339 self._eval_model.reset() 340 341 def _forward_eval(self, data: dict) -> dict: 342 r""" 343 Overview: 344 Forward function of eval mode, similar to ``self._forward_collect``. 345 Arguments: 346 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 347 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 348 Returns: 349 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 350 ReturnsKeys 351 - necessary: ``action`` 352 """ 353 data_id = list(data.keys()) 354 data = default_collate(list(data.values())) 355 if self._cuda: 356 data = to_device(data, self._device) 357 self._eval_model.eval() 358 with torch.no_grad(): 359 output = self._eval_model.forward(data, mode='compute_actor') 360 if self._cuda: 361 output = to_device(output, 'cpu') 362 output = default_decollate(output) 363 return {i: d for i, d in zip(data_id, output)} 364 365 def _monitor_vars_learn(self) -> List[str]: 366 r""" 367 Overview: 368 Return variables' name if variables are to used in monitor. 369 Returns: 370 - vars (:obj:`List[str]`): Variables' name list. 371 """ 372 return [ 373 'cur_lr_actor', 374 'cur_lr_critic', 375 'critic_loss', 376 'actor_loss', 377 'attention_loss', 378 'total_loss', 379 'q_value', 380 ]