1from collections import namedtuple 2from typing import List, Dict, Any, Tuple 3import copy 4 5import torch 6 7from ding.model import model_wrap 8from ding.rl_utils import get_train_sample, compute_q_retraces, acer_policy_error,\ 9 acer_value_error, acer_trust_region_update 10from ding.torch_utils import Adam, RMSprop, to_device 11from ding.utils import POLICY_REGISTRY 12from ding.utils.data import default_collate, default_decollate 13from ding.policy.base_policy import Policy 14 15EPS = 1e-8 16 17 18@POLICY_REGISTRY.register('acer') 19class ACERPolicy(Policy): 20 r""" 21 Overview: 22 Policy class of ACER algorithm. 23 24 Config: 25 == ======================= ======== ============== ===================================== ======================= 26 ID Symbol Type Default Value Description Other(Shape) 27 == ======================= ======== ============== ===================================== ======================= 28 1 ``type`` str acer | RL policy register name, refer to | this arg is optional, 29 | registry ``POLICY_REGISTRY`` | a placeholder 30 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff- 31 | | erent from modes 32 3 ``on_policy`` bool False | Whether the RL algorithm is 33 | on-policy or off-policy 34 4 ``trust_region`` bool True | Whether the RL algorithm use trust | 35 | region constraint | 36 5 ``trust_region_value`` float 1.0 | maximum range of the trust region | 37 6 ``unroll_len`` int 32 | trajectory length to calculate 38 | Q retrace target 39 7 ``learn.update`` int 4 | How many updates(iterations) to | this args can be vary 40 ``per_collect`` | train after collector's one | from envs. Bigger val 41 | collection. Only | 42 | valid in serial training | means more off-policy 43 8 ``c_clip_ratio`` float 1.0 | clip ratio of importance weights | 44 == ======================= ======== ============== ===================================== ======================= 45 """ 46 unroll_len = 32 47 config = dict( 48 type='acer', 49 cuda=False, 50 # (bool) whether to use on-policy training pipeline (behaviour policy and training policy are the same) 51 # here we follow ppo serial pipeline, the original is False 52 on_policy=False, 53 priority=False, 54 # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. 55 priority_IS_weight=False, 56 learn=dict( 57 # (str) the type of gradient clip method 58 grad_clip_type=None, 59 # (float) max value when ACER use gradient clip 60 clip_value=None, 61 62 # (int) collect n_sample data, train model update_per_collect times 63 # here we follow ppo serial pipeline 64 update_per_collect=4, 65 # (int) the number of data for a train iteration 66 batch_size=16, 67 # (float) loss weight of the value network, the weight of policy network is set to 1 68 value_weight=0.5, 69 # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 70 entropy_weight=0.0001, 71 # (float) discount factor for future reward, defaults int [0, 1] 72 discount_factor=0.9, 73 # (float) additional discounting parameter 74 lambda_=0.95, 75 # (int) the trajectory length to calculate v-trace target 76 unroll_len=unroll_len, 77 # (float) clip ratio of importance weights 78 c_clip_ratio=10, 79 trust_region=True, 80 trust_region_value=1.0, 81 learning_rate_actor=0.0005, 82 learning_rate_critic=0.0005, 83 target_theta=0.01 84 ), 85 collect=dict( 86 # (int) collect n_sample data, train model n_iteration times 87 # n_sample=16, 88 # (int) the trajectory length to calculate v-trace target 89 unroll_len=unroll_len, 90 # (float) discount factor for future reward, defaults int [0, 1] 91 discount_factor=0.9, 92 gae_lambda=0.95, 93 collector=dict( 94 type='sample', 95 collect_print_freq=1000, 96 ), 97 ), 98 eval=dict(evaluator=dict(eval_freq=200, ), ), 99 other=dict(replay_buffer=dict( 100 replay_buffer_size=1000, 101 max_use=16, 102 ), ), 103 ) 104 105 def default_model(self) -> Tuple[str, List[str]]: 106 return 'acer', ['ding.model.template.acer'] 107 108 def _init_learn(self) -> None: 109 r""" 110 Overview: 111 Learn mode init method. Called by ``self.__init__``. 112 Initialize the optimizer, algorithm config and main model. 113 """ 114 # Optimizer 115 self._optimizer_actor = Adam( 116 self._model.actor.parameters(), 117 lr=self._cfg.learn.learning_rate_actor, 118 grad_clip_type=self._cfg.learn.grad_clip_type, 119 clip_value=self._cfg.learn.clip_value 120 ) 121 self._optimizer_critic = Adam( 122 self._model.critic.parameters(), 123 lr=self._cfg.learn.learning_rate_critic, 124 ) 125 self._target_model = copy.deepcopy(self._model) 126 self._target_model = model_wrap( 127 self._target_model, 128 wrapper_name='target', 129 update_type='momentum', 130 update_kwargs={'theta': self._cfg.learn.target_theta} 131 ) 132 self._learn_model = model_wrap(self._model, wrapper_name='base') 133 134 self._action_shape = self._cfg.model.action_shape 135 self._unroll_len = self._cfg.learn.unroll_len 136 137 # Algorithm config 138 self._priority = self._cfg.priority 139 self._priority_IS_weight = self._cfg.priority_IS_weight 140 self._value_weight = self._cfg.learn.value_weight 141 self._entropy_weight = self._cfg.learn.entropy_weight 142 self._gamma = self._cfg.learn.discount_factor 143 # self._rho_clip_ratio = self._cfg.learn.rho_clip_ratio 144 self._c_clip_ratio = self._cfg.learn.c_clip_ratio 145 # self._rho_pg_clip_ratio = self._cfg.learn.rho_pg_clip_ratio 146 self._use_trust_region = self._cfg.learn.trust_region 147 self._trust_region_value = self._cfg.learn.trust_region_value 148 # Main model 149 self._learn_model.reset() 150 self._target_model.reset() 151 152 def _data_preprocess_learn(self, data: List[Dict[str, Any]]): 153 """ 154 Overview: 155 Data preprocess function of learn mode. 156 Convert list trajectory data to to trajectory data, which is a dict of tensors. 157 Arguments: 158 - data (:obj:`List[Dict[str, Any]]`): List type data, a list of data for training. Each list element is a \ 159 dict, whose values are torch.Tensor or np.ndarray or dict/list combinations, keys include at least 'obs',\ 160 'next_obs', 'logit', 'action', 'reward', 'done' 161 Returns: 162 - data (:obj:`dict`): Dict type data. Values are torch.Tensor or np.ndarray or dict/list combinations. \ 163 ReturnsKeys: 164 - necessary: 'logit', 'action', 'reward', 'done', 'weight', 'obs_plus_1'. 165 - optional and not used in later computation: 'obs', 'next_obs'.'IS', 'collect_iter', 'replay_unique_id', \ 166 'replay_buffer_idx', 'priority', 'staleness', 'use'. 167 ReturnsShapes: 168 - obs_plus_1 (:obj:`torch.FloatTensor`): :math:`(T * B, obs_shape)`, where T is timestep, B is batch size \ 169 and obs_shape is the shape of single env observation 170 - logit (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where N is action dim 171 - action (:obj:`torch.LongTensor`): :math:`(T, B)` 172 - reward (:obj:`torch.FloatTensor`): :math:`(T+1, B)` 173 - done (:obj:`torch.FloatTensor`): :math:`(T, B)` 174 - weight (:obj:`torch.FloatTensor`): :math:`(T, B)` 175 """ 176 data = default_collate(data) 177 if self._cuda: 178 data = to_device(data, self._device) 179 data['weight'] = data.get('weight', None) 180 # shape (T+1)*B,env_obs_shape 181 data['obs_plus_1'] = torch.cat((data['obs'] + data['next_obs'][-1:]), dim=0) 182 data['logit'] = torch.cat( 183 data['logit'], dim=0 184 ).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape 185 data['action'] = torch.cat(data['action'], dim=0).reshape(self._unroll_len, -1) # shape T,B, 186 data['done'] = torch.cat(data['done'], dim=0).reshape(self._unroll_len, -1).float() # shape T,B, 187 data['reward'] = torch.cat(data['reward'], dim=0).reshape(self._unroll_len, -1) # shape T,B, 188 data['weight'] = torch.cat( 189 data['weight'], dim=0 190 ).reshape(self._unroll_len, -1) if data['weight'] else None # shape T,B 191 return data 192 193 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 194 r""" 195 Overview: 196 Forward computation graph of learn mode(updating policy). 197 Arguments: 198 - data (:obj:`List[Dict[str, Any]]`): List type data, a list of data for training. Each list element is a \ 199 dict, whose values are torch.Tensor or np.ndarray or dict/list combinations, keys include at least 'obs',\ 200 'next_obs', 'logit', 'action', 'reward', 'done' 201 Returns: 202 - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ 203 recorded in text log and tensorboard, values are python scalar or a list of scalars. 204 ArgumentsKeys: 205 - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` 206 - optional: 'collect_iter', 'replay_unique_id', 'replay_buffer_idx', 'priority', 'staleness', 'use', 'IS' 207 ReturnsKeys: 208 - necessary: ``cur_lr_actor``, ``cur_lr_critic``, ``actor_loss`,``bc_loss``,``policy_loss``,\ 209 ``critic_loss``,``entropy_loss`` 210 """ 211 data = self._data_preprocess_learn(data) 212 self._learn_model.train() 213 action_data = self._learn_model.forward(data['obs_plus_1'], mode='compute_actor') 214 q_value_data = self._learn_model.forward(data['obs_plus_1'], mode='compute_critic') 215 avg_action_data = self._target_model.forward(data['obs_plus_1'], mode='compute_actor') 216 217 target_logit, behaviour_logit, avg_logit, actions, q_values, rewards, weights = self._reshape_data( 218 action_data, avg_action_data, q_value_data, data 219 ) 220 # shape (T+1),B,env_action_shape 221 target_logit = torch.log_softmax(target_logit, dim=-1) 222 # shape T,B,env_action_shape 223 behaviour_logit = torch.log_softmax(behaviour_logit, dim=-1) 224 # shape (T+1),B,env_action_shape 225 avg_logit = torch.log_softmax(avg_logit, dim=-1) 226 with torch.no_grad(): 227 # shape T,B,env_action_shape 228 ratio = torch.exp(target_logit[0:-1] - behaviour_logit) 229 # shape (T+1),B,1 230 v_pred = (q_values * torch.exp(target_logit)).sum(-1).unsqueeze(-1) 231 # Calculate retrace 232 q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio, self._gamma) 233 234 # the terminal states' weights are 0. it needs to be shift to count valid state 235 weights_ext = torch.ones_like(weights) 236 weights_ext[1:] = weights[0:-1] 237 weights = weights_ext 238 q_retraces = q_retraces[0:-1] # shape T,B,1 239 q_values = q_values[0:-1] # shape T,B,env_action_shape 240 v_pred = v_pred[0:-1] # shape T,B,1 241 target_logit = target_logit[0:-1] # shape T,B,env_action_shape 242 avg_logit = avg_logit[0:-1] # shape T,B,env_action_shape 243 total_valid = weights.sum() # 1 244 # ==================== 245 # policy update 246 # ==================== 247 actor_loss, bc_loss = acer_policy_error( 248 q_values, q_retraces, v_pred, target_logit, actions, ratio, self._c_clip_ratio 249 ) 250 actor_loss = actor_loss * weights.unsqueeze(-1) 251 bc_loss = bc_loss * weights.unsqueeze(-1) 252 dist_new = torch.distributions.categorical.Categorical(logits=target_logit) 253 entropy_loss = (dist_new.entropy() * weights).unsqueeze(-1) # shape T,B,1 254 total_actor_loss = (actor_loss + bc_loss + self._entropy_weight * entropy_loss).sum() / total_valid 255 self._optimizer_actor.zero_grad() 256 actor_gradients = torch.autograd.grad(-total_actor_loss, target_logit, retain_graph=True) 257 if self._use_trust_region: 258 actor_gradients = acer_trust_region_update( 259 actor_gradients, target_logit, avg_logit, self._trust_region_value 260 ) 261 target_logit.backward(actor_gradients) 262 self._optimizer_actor.step() 263 264 # ==================== 265 # critic update 266 # ==================== 267 critic_loss = (acer_value_error(q_values, q_retraces, actions) * weights.unsqueeze(-1)).sum() / total_valid 268 self._optimizer_critic.zero_grad() 269 critic_loss.backward() 270 self._optimizer_critic.step() 271 self._target_model.update(self._learn_model.state_dict()) 272 273 with torch.no_grad(): 274 kl_div = torch.exp(avg_logit) * (avg_logit - target_logit) 275 kl_div = (kl_div.sum(-1) * weights).sum() / total_valid 276 277 return { 278 'cur_actor_lr': self._optimizer_actor.defaults['lr'], 279 'cur_critic_lr': self._optimizer_critic.defaults['lr'], 280 'actor_loss': (actor_loss.sum() / total_valid).item(), 281 'bc_loss': (bc_loss.sum() / total_valid).item(), 282 'policy_loss': total_actor_loss.item(), 283 'critic_loss': critic_loss.item(), 284 'entropy_loss': (entropy_loss.sum() / total_valid).item(), 285 'kl_div': kl_div.item() 286 } 287 288 def _reshape_data( 289 self, action_data: Dict[str, Any], avg_action_data: Dict[str, Any], q_value_data: Dict[str, Any], 290 data: Dict[str, Any] 291 ) -> Tuple[Any, Any, Any, Any, Any, Any]: 292 r""" 293 Overview: 294 Obtain weights for loss calculating, where should be 0 for done positions 295 Update values and rewards with the weight 296 Arguments: 297 - output (:obj:`Dict[int, Any]`): Dict type data, output of learn_model forward. \ 298 Values are torch.Tensor or np.ndarray or dict/list combinations, keys are value, logit. 299 - data (:obj:`Dict[int, Any]`): Dict type data, input of policy._forward_learn \ 300 Values are torch.Tensor or np.ndarray or dict/list combinations. Keys includes at \ 301 least ['logit', 'action', 'reward', 'done',] 302 Returns: 303 - data (:obj:`Tuple[Any]`): Tuple of target_logit, behaviour_logit, actions, \ 304 values, rewards, weights 305 ReturnsShapes: 306 - target_logit (:obj:`torch.FloatTensor`): :math:`((T+1), B, Obs_Shape)`, where T is timestep,\ 307 B is batch size and Obs_Shape is the shape of single env observation. 308 - behaviour_logit (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where N is action dim. 309 - avg_action_logit (:obj:`torch.FloatTensor`): :math: `(T+1, B, N)`, where N is action dim. 310 - actions (:obj:`torch.LongTensor`): :math:`(T, B)` 311 - values (:obj:`torch.FloatTensor`): :math:`(T+1, B)` 312 - rewards (:obj:`torch.FloatTensor`): :math:`(T, B)` 313 - weights (:obj:`torch.FloatTensor`): :math:`(T, B)` 314 """ 315 target_logit = action_data['logit'].reshape( 316 self._unroll_len + 1, -1, self._action_shape 317 ) # shape (T+1),B,env_action_shape 318 behaviour_logit = data['logit'] # shape T,B,env_action_shape 319 avg_action_logit = avg_action_data['logit'].reshape( 320 self._unroll_len + 1, -1, self._action_shape 321 ) # shape (T+1),B,env_action_shape 322 actions = data['action'] # shape T,B 323 values = q_value_data['q_value'].reshape( 324 self._unroll_len + 1, -1, self._action_shape 325 ) # shape (T+1),B,env_action_shape 326 rewards = data['reward'] # shape T,B 327 weights_ = 1 - data['done'] # shape T,B 328 weights = torch.ones_like(rewards) # shape T,B 329 weights = weights_ 330 return target_logit, behaviour_logit, avg_action_logit, actions, values, rewards, weights 331 332 def _state_dict_learn(self) -> Dict[str, Any]: 333 r""" 334 Overview: 335 Return the state_dict of learn mode, usually including model and optimizer. 336 Returns: 337 - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. 338 """ 339 return { 340 'model': self._learn_model.state_dict(), 341 'target_model': self._target_model.state_dict(), 342 'actor_optimizer': self._optimizer_actor.state_dict(), 343 'critic_optimizer': self._optimizer_critic.state_dict(), 344 } 345 346 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 347 r""" 348 Overview: 349 Load the state_dict variable into policy learn mode. 350 Arguments: 351 - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. 352 .. tip:: 353 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 354 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 355 complicated operation. 356 """ 357 self._learn_model.load_state_dict(state_dict['model']) 358 self._target_model.load_state_dict(state_dict['target_model']) 359 self._optimizer_actor.load_state_dict(state_dict['actor_optimizer']) 360 self._optimizer_critic.load_state_dict(state_dict['critic_optimizer']) 361 362 def _init_collect(self) -> None: 363 r""" 364 Overview: 365 Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model. 366 Use multinomial_sample to choose action. 367 """ 368 self._collect_unroll_len = self._cfg.collect.unroll_len 369 self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') 370 self._collect_model.reset() 371 372 def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Dict[str, Any]]: 373 r""" 374 Overview: 375 Forward computation graph of collect mode(collect training data). 376 Arguments: 377 - data (:obj:`Dict[int, Any]`): Dict type data, stacked env data for predicting \ 378 action, values are torch.Tensor or np.ndarray or dict/list combinations,keys \ 379 are env_id indicated by integer. 380 Returns: 381 - output (:obj:`Dict[int, Dict[str, Any]]`): Dict of predicting policy_output(logit, action) for each env. 382 ReturnsKeys 383 - necessary: ``logit``, ``action`` 384 """ 385 data_id = list(data.keys()) 386 data = default_collate(list(data.values())) 387 if self._cuda: 388 data = to_device(data, self._device) 389 self._collect_model.eval() 390 with torch.no_grad(): 391 output = self._collect_model.forward(data, mode='compute_actor') 392 if self._cuda: 393 output = to_device(output, 'cpu') 394 output = default_decollate(output) 395 output = {i: d for i, d in zip(data_id, output)} 396 return output 397 398 def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 399 r""" 400 Overview: 401 For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ 402 can be used for training directly. 403 Arguments: 404 - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ 405 format as the return value of ``self._process_transition`` method. 406 Returns: 407 - samples (:obj:`dict`): List of training samples. 408 .. note:: 409 We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ 410 And the user can customize the this data processing procedure by overriding this two methods and collector \ 411 itself. 412 """ 413 return get_train_sample(data, self._unroll_len) 414 415 def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: 416 r""" 417 Overview: 418 Generate dict type transition data from inputs. 419 Arguments: 420 - obs (:obj:`Any`): Env observation,can be torch.Tensor or np.ndarray or dict/list combinations. 421 - model_output (:obj:`dict`): Output of collect model, including ['logit','action'] 422 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ 423 (here 'obs' indicates obs after env step). 424 Returns: 425 - transition (:obj:`dict`): Dict type transition data, including at least ['obs','next_obs', 'logit',\ 426 'action','reward', 'done'] 427 """ 428 transition = { 429 'obs': obs, 430 'next_obs': timestep.obs, 431 'logit': policy_output['logit'], 432 'action': policy_output['action'], 433 'reward': timestep.reward, 434 'done': timestep.done, 435 } 436 return transition 437 438 def _init_eval(self) -> None: 439 r""" 440 Overview: 441 Evaluate mode init method. Called by ``self.__init__``, initialize eval_model, 442 and use argmax_sample to choose action. 443 """ 444 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 445 self._eval_model.reset() 446 447 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 448 r""" 449 Overview: 450 Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to \ 451 ``self._forward_collect``. 452 Arguments: 453 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 454 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 455 Returns: 456 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 457 ReturnsKeys 458 - necessary: ``action`` 459 - optional: ``logit`` 460 461 """ 462 data_id = list(data.keys()) 463 data = default_collate(list(data.values())) 464 if self._cuda: 465 data = to_device(data, self._device) 466 self._eval_model.eval() 467 with torch.no_grad(): 468 output = self._eval_model.forward(data, mode='compute_actor') 469 if self._cuda: 470 output = to_device(output, 'cpu') 471 output = default_decollate(output) 472 output = {i: d for i, d in zip(data_id, output)} 473 return output 474 475 def _monitor_vars_learn(self) -> List[str]: 476 r""" 477 Overview: 478 Return this algorithm default model setting for demonstration. 479 Returns: 480 - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names 481 .. note:: 482 The user can define and use a customized network model but must obey the same interface definition \ 483 indicated by import_names path. For IMPALA, ``ding.model.interface.IMPALA`` 484 """ 485 return ['actor_loss', 'bc_loss', 'policy_loss', 'critic_loss', 'entropy_loss', 'kl_div']