Skip to content

ding.policy.sac

ding.policy.sac

DiscreteSACPolicy

Bases: Policy

Overview

Policy class of discrete SAC algorithm. Paper link: https://arxiv.org/abs/1910.07207.

default_model()

Overview

Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

SACPolicy

Bases: Policy

Overview

Policy class of continuous SAC algorithm. Paper link: https://arxiv.org/pdf/1801.01290.pdf

Config

== ==================== ======== ============= ================================= ======================= ID Symbol Type Default Value Description Other == ==================== ======== ============= ================================= ======================= 1 type str sac | RL policy register name, refer | this arg is optional, | to registry POLICY_REGISTRY | a placeholder 2 cuda bool True | Whether to use cuda for network | 3 on_policy bool False | SAC is an off-policy | | algorithm. | 4 priority bool False | Whether to use priority | | sampling in buffer. | 5 | priority_IS_ bool False | Whether use Importance Sampling | | weight | weight to correct biased update | 6 | random_ int 10000 | Number of randomly collected | Default to 10000 for | collect_size | training samples in replay | SAC, 25000 for DDPG/ | | buffer when training starts. | TD3. 7 | learn.learning float 3e-4 | Learning rate for soft q | Defalut to 1e-3 | _rate_q | network. | 8 | learn.learning float 3e-4 | Learning rate for policy | Defalut to 1e-3 | _rate_policy | network. | 9 | learn.alpha float 0.2 | Entropy regularization | alpha is initiali- | | coefficient. | zation for auto | | | alpha, when | | | auto_alpha is True 10 | learn. bool False | Determine whether to use | Temperature parameter | auto_alpha | auto temperature parameter | determines the | | alpha. | relative importance | | | of the entropy term | | | against the reward. 11 | learn.- bool False | Determine whether to ignore | Use ignore_done only | ignore_done | done flag. | in env like Pendulum 12 | learn.- float 0.005 | Used for soft update of the | aka. Interpolation | target_theta | target network. | factor in polyak aver | | | aging for target | | | networks. == ==================== ======== ============= ================================= =======================

default_model()

Overview

Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

SQILSACPolicy

Bases: SACPolicy

Overview

Policy class of continuous SAC algorithm with SQIL extension. SAC paper link: https://arxiv.org/pdf/1801.01290.pdf SQIL paper link: https://arxiv.org/abs/1905.11108

Full Source Code

../ding/policy/sac.py

1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import copy 4import numpy as np 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8from torch.distributions import Normal, Independent 9 10from ding.torch_utils import Adam, to_device 11from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, q_v_1step_td_error, q_v_1step_td_data 12from ding.model import model_wrap 13from ding.utils import POLICY_REGISTRY 14from ding.utils.data import default_collate, default_decollate 15from .base_policy import Policy 16from .common_utils import default_preprocess_learn 17 18 19@POLICY_REGISTRY.register('discrete_sac') 20class DiscreteSACPolicy(Policy): 21 """ 22 Overview: 23 Policy class of discrete SAC algorithm. Paper link: https://arxiv.org/abs/1910.07207. 24 """ 25 26 config = dict( 27 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 28 type='discrete_sac', 29 # (bool) Whether to use cuda for network and loss computation. 30 cuda=False, 31 # (bool) Whether to belong to on-policy or off-policy algorithm, DiscreteSAC is an off-policy algorithm. 32 on_policy=False, 33 # (bool) Whether to use priority sampling in buffer. Default to False in DiscreteSAC. 34 priority=False, 35 # (bool) Whether use Importance Sampling weight to correct biased update. If True, priority must be True. 36 priority_IS_weight=False, 37 # (int) Number of training samples (randomly collected) in replay buffer when training starts. 38 random_collect_size=10000, 39 # (bool) Whether to need policy-specific data in process transition. 40 transition_with_policy_data=True, 41 # (bool) Whether to enable multi-agent training setting. 42 multi_agent=False, 43 model=dict( 44 # (bool) Whether to use double-soft-q-net for target q computation. 45 # For more details, please refer to TD3 about Clipped Double-Q Learning trick. 46 twin_critic=True, 47 ), 48 # learn_mode config 49 learn=dict( 50 # (int) How many updates (iterations) to train after collector's one collection. 51 # Bigger "update_per_collect" means bigger off-policy. 52 update_per_collect=1, 53 # (int) Minibatch size for one gradient descent. 54 batch_size=256, 55 # (float) Learning rate for soft q network. 56 learning_rate_q=3e-4, 57 # (float) Learning rate for policy network. 58 learning_rate_policy=3e-4, 59 # (float) Learning rate for auto temperature parameter `\alpha`. 60 learning_rate_alpha=3e-4, 61 # (float) Used for soft update of the target network, 62 # aka. Interpolation factor in EMA update for target network. 63 target_theta=0.005, 64 # (float) Discount factor for the discounted sum of rewards, aka. gamma. 65 discount_factor=0.99, 66 # (float) Entropy regularization coefficient in SAC. 67 # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. 68 # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. 69 alpha=0.2, 70 # (bool) Whether to use auto temperature parameter `\alpha` . 71 # Temperature parameter `\alpha` determines the relative importance of the entropy term against the reward. 72 # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. 73 # Note that: Using auto alpha needs to set the above `learning_rate_alpha`. 74 auto_alpha=True, 75 # (bool) Whether to use auto `\alpha` in log space. 76 log_space=True, 77 # (float) Target policy entropy value for auto temperature (alpha) adjustment. 78 target_entropy=None, 79 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 80 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 81 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 82 # However, interaction with HalfCheetah always gets done with done is False, 83 # Since we inplace done==True with done==False to keep 84 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 85 # when the episode step is greater than max episode step. 86 ignore_done=False, 87 # (float) Weight uniform initialization max range in the last output layer 88 init_w=3e-3, 89 ), 90 # collect_mode config 91 collect=dict( 92 # (int) How many training samples collected in one collection procedure. 93 # Only one of [n_sample, n_episode] shoule be set. 94 n_sample=1, 95 # (int) Split episodes or trajectories into pieces with length `unroll_len`. 96 unroll_len=1, 97 # (bool) Whether to collect logit in `process_transition`. 98 # In some algorithm like guided cost learning, we need to use logit to train the reward model. 99 collector_logit=False, 100 ), 101 eval=dict(), # for compability 102 other=dict( 103 replay_buffer=dict( 104 # (int) Maximum size of replay buffer. Usually, larger buffer size is good 105 # for SAC but cost more storage. 106 replay_buffer_size=1000000, 107 ), 108 ), 109 ) 110 111 def default_model(self) -> Tuple[str, List[str]]: 112 """ 113 Overview: 114 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 115 automatically call this method to get the default model setting and create model. 116 Returns: 117 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 118 """ 119 if self._cfg.multi_agent: 120 return 'discrete_maqac', ['ding.model.template.maqac'] 121 else: 122 return 'discrete_qac', ['ding.model.template.qac'] 123 124 def _init_learn(self) -> None: 125 """ 126 Overview: 127 Initialize the learn mode of policy, including related attributes and modules. For DiscreteSAC, it mainly \ 128 contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \ 129 model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here. 130 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 131 132 .. note:: 133 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 134 and ``_load_state_dict_learn`` methods. 135 136 .. note:: 137 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 138 139 .. note:: 140 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 141 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 142 """ 143 self._priority = self._cfg.priority 144 self._priority_IS_weight = self._cfg.priority_IS_weight 145 self._twin_critic = self._cfg.model.twin_critic 146 147 self._optimizer_q = Adam( 148 self._model.critic.parameters(), 149 lr=self._cfg.learn.learning_rate_q, 150 ) 151 self._optimizer_policy = Adam( 152 self._model.actor.parameters(), 153 lr=self._cfg.learn.learning_rate_policy, 154 ) 155 156 # Algorithm-Specific Config 157 self._gamma = self._cfg.learn.discount_factor 158 if self._cfg.learn.auto_alpha: 159 if self._cfg.learn.target_entropy is None: 160 assert 'action_shape' in self._cfg.model, "DiscreteSAC need network model with action_shape variable" 161 self._target_entropy = -np.prod(self._cfg.model.action_shape) 162 else: 163 self._target_entropy = self._cfg.learn.target_entropy 164 if self._cfg.learn.log_space: 165 self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha])) 166 self._log_alpha = self._log_alpha.to(self._device).requires_grad_() 167 self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha) 168 assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad 169 self._alpha = self._log_alpha.detach().exp() 170 self._auto_alpha = True 171 self._log_space = True 172 else: 173 self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_() 174 self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha) 175 self._auto_alpha = True 176 self._log_space = False 177 else: 178 self._alpha = torch.tensor( 179 [self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32 180 ) 181 self._auto_alpha = False 182 183 # Main and target models 184 self._target_model = copy.deepcopy(self._model) 185 self._target_model = model_wrap( 186 self._target_model, 187 wrapper_name='target', 188 update_type='momentum', 189 update_kwargs={'theta': self._cfg.learn.target_theta} 190 ) 191 self._learn_model = model_wrap(self._model, wrapper_name='base') 192 self._learn_model.reset() 193 self._target_model.reset() 194 195 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 196 """ 197 Overview: 198 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 199 that the policy inputs some training batch data from the replay buffer and then returns the output \ 200 result, including various training information such as loss, action, priority. 201 Arguments: 202 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 203 training samples. For each element in list, the key of the dict is the name of data items and the \ 204 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 205 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 206 dimension by some utility functions such as ``default_preprocess_learn``. \ 207 For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 208 ``logit``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight``. 209 Returns: 210 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 211 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 212 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 213 214 .. note:: 215 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 216 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 217 You can implement you own model rather than use the default model. For more information, please raise an \ 218 issue in GitHub repo and we will continue to follow up. 219 220 .. note:: 221 For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \ 222 ``ding.policy.tests.test_discrete_sac``. 223 """ 224 loss_dict = {} 225 data = default_preprocess_learn( 226 data, 227 use_priority=self._priority, 228 use_priority_IS_weight=self._cfg.priority_IS_weight, 229 ignore_done=self._cfg.learn.ignore_done, 230 use_nstep=False 231 ) 232 if self._cuda: 233 data = to_device(data, self._device) 234 235 self._learn_model.train() 236 self._target_model.train() 237 obs = data['obs'] 238 next_obs = data['next_obs'] 239 reward = data['reward'] 240 done = data['done'] 241 logit = data['logit'] 242 action = data['action'] 243 244 # 1. predict q value 245 q_value = self._learn_model.forward(obs, mode='compute_critic')['q_value'] 246 dist = torch.distributions.categorical.Categorical(logits=logit) 247 dist_entropy = dist.entropy() 248 entropy = dist_entropy.mean() 249 250 # 2. predict target value 251 252 # target q value. SARSA: first predict next action, then calculate next q value 253 with torch.no_grad(): 254 policy_output_next = self._learn_model.forward(next_obs, mode='compute_actor') 255 if self._cfg.multi_agent: 256 policy_output_next['logit'][policy_output_next['action_mask'] == 0.0] = -1e8 257 prob = F.softmax(policy_output_next['logit'], dim=-1) 258 log_prob = torch.log(prob + 1e-8) 259 target_q_value = self._target_model.forward(next_obs, mode='compute_critic')['q_value'] 260 # the value of a policy according to the maximum entropy objective 261 if self._twin_critic: 262 # find min one as target q value 263 target_value = ( 264 prob * (torch.min(target_q_value[0], target_q_value[1]) - self._alpha * log_prob.squeeze(-1)) 265 ).sum(dim=-1) 266 else: 267 target_value = (prob * (target_q_value - self._alpha * log_prob.squeeze(-1))).sum(dim=-1) 268 269 # 3. compute q loss 270 if self._twin_critic: 271 q_data0 = q_v_1step_td_data(q_value[0], target_value, action, reward, done, data['weight']) 272 loss_dict['critic_loss'], td_error_per_sample0 = q_v_1step_td_error(q_data0, self._gamma) 273 q_data1 = q_v_1step_td_data(q_value[1], target_value, action, reward, done, data['weight']) 274 loss_dict['twin_critic_loss'], td_error_per_sample1 = q_v_1step_td_error(q_data1, self._gamma) 275 td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2 276 else: 277 q_data = q_v_1step_td_data(q_value, target_value, action, reward, done, data['weight']) 278 loss_dict['critic_loss'], td_error_per_sample = q_v_1step_td_error(q_data, self._gamma) 279 280 # 4. update q network 281 self._optimizer_q.zero_grad() 282 loss_dict['critic_loss'].backward() 283 if self._twin_critic: 284 loss_dict['twin_critic_loss'].backward() 285 self._optimizer_q.step() 286 287 # 5. evaluate to get action distribution 288 policy_output = self._learn_model.forward(obs, mode='compute_actor') 289 # 6. apply discrete action mask in multi_agent setting 290 if self._cfg.multi_agent: 291 policy_output['logit'][policy_output['action_mask'] == 0.0] = -1e8 292 logit = policy_output['logit'] 293 prob = F.softmax(logit, dim=-1) 294 log_prob = F.log_softmax(logit, dim=-1) 295 296 with torch.no_grad(): 297 new_q_value = self._learn_model.forward(obs, mode='compute_critic')['q_value'] 298 if self._twin_critic: 299 new_q_value = torch.min(new_q_value[0], new_q_value[1]) 300 # 7. compute policy loss 301 # we need to sum different actions' policy loss and calculate the average value of a batch 302 policy_loss = (prob * (self._alpha * log_prob - new_q_value)).sum(dim=-1).mean() 303 304 loss_dict['policy_loss'] = policy_loss 305 306 # 8. update policy network 307 self._optimizer_policy.zero_grad() 308 loss_dict['policy_loss'].backward() 309 self._optimizer_policy.step() 310 311 # 9. compute alpha loss 312 if self._auto_alpha: 313 if self._log_space: 314 log_prob = log_prob + self._target_entropy 315 loss_dict['alpha_loss'] = (-prob.detach() * (self._log_alpha * log_prob.detach())).sum(dim=-1).mean() 316 317 self._alpha_optim.zero_grad() 318 loss_dict['alpha_loss'].backward() 319 self._alpha_optim.step() 320 self._alpha = self._log_alpha.detach().exp() 321 else: 322 log_prob = log_prob + self._target_entropy 323 loss_dict['alpha_loss'] = (-prob.detach() * (self._alpha * log_prob.detach())).sum(dim=-1).mean() 324 325 self._alpha_optim.zero_grad() 326 loss_dict['alpha_loss'].backward() 327 self._alpha_optim.step() 328 self._alpha.data = torch.where(self._alpha > 0, self._alpha, 329 torch.zeros_like(self._alpha)).requires_grad_() 330 loss_dict['total_loss'] = sum(loss_dict.values()) 331 332 # target update 333 self._target_model.update(self._learn_model.state_dict()) 334 return { 335 'total_loss': loss_dict['total_loss'].item(), 336 'policy_loss': loss_dict['policy_loss'].item(), 337 'critic_loss': loss_dict['critic_loss'].item(), 338 'cur_lr_q': self._optimizer_q.defaults['lr'], 339 'cur_lr_p': self._optimizer_policy.defaults['lr'], 340 'priority': td_error_per_sample.abs().tolist(), 341 'td_error': td_error_per_sample.detach().mean().item(), 342 'alpha': self._alpha.item(), 343 'q_value_1': target_q_value[0].detach().mean().item(), 344 'q_value_2': target_q_value[1].detach().mean().item(), 345 'target_value': target_value.detach().mean().item(), 346 'entropy': entropy.item(), 347 } 348 349 def _state_dict_learn(self) -> Dict[str, Any]: 350 """ 351 Overview: 352 Return the state_dict of learn mode, usually including model, target_model and optimizers. 353 Returns: 354 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 355 """ 356 ret = { 357 'model': self._learn_model.state_dict(), 358 'target_model': self._target_model.state_dict(), 359 'optimizer_q': self._optimizer_q.state_dict(), 360 'optimizer_policy': self._optimizer_policy.state_dict(), 361 } 362 if self._auto_alpha: 363 ret.update({'optimizer_alpha': self._alpha_optim.state_dict()}) 364 return ret 365 366 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 367 """ 368 Overview: 369 Load the state_dict variable into policy learn mode. 370 Arguments: 371 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 372 373 .. tip:: 374 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 375 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 376 complicated operation. 377 """ 378 self._learn_model.load_state_dict(state_dict['model']) 379 self._target_model.load_state_dict(state_dict['target_model']) 380 self._optimizer_q.load_state_dict(state_dict['optimizer_q']) 381 self._optimizer_policy.load_state_dict(state_dict['optimizer_policy']) 382 if self._auto_alpha: 383 self._alpha_optim.load_state_dict(state_dict['optimizer_alpha']) 384 385 def _init_collect(self) -> None: 386 """ 387 Overview: 388 Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ 389 collect_model to balance the exploration and exploitation with the epsilon and multinomial sample \ 390 mechanism, and other algorithm-specific arguments such as unroll_len. \ 391 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 392 393 .. note:: 394 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 395 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 396 """ 397 self._unroll_len = self._cfg.collect.unroll_len 398 # Empirically, we found that eps_greedy_multinomial_sample works better than multinomial_sample 399 # and eps_greedy_sample, and we don't divide logit by alpha, 400 # for the details please refer to ding/model/wrapper/model_wrappers 401 self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample') 402 self._collect_model.reset() 403 404 def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: 405 """ 406 Overview: 407 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 408 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 409 data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ 410 exploration, i.e., classic epsilon-greedy exploration strategy. 411 Arguments: 412 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 413 key of the dict is environment id and the value is the corresponding data of the env. 414 - eps (:obj:`float`): The epsilon value for exploration. 415 Returns: 416 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 417 other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ 418 dict is the same as the input data, i.e. environment id. 419 420 .. note:: 421 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 422 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 423 You can implement you own model rather than use the default model. For more information, please raise an \ 424 issue in GitHub repo and we will continue to follow up. 425 426 .. note:: 427 For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \ 428 ``ding.policy.tests.test_discrete_sac``. 429 """ 430 data_id = list(data.keys()) 431 data = default_collate(list(data.values())) 432 if self._cuda: 433 data = to_device(data, self._device) 434 self._collect_model.eval() 435 with torch.no_grad(): 436 output = self._collect_model.forward(data, mode='compute_actor', eps=eps) 437 if self._cuda: 438 output = to_device(output, 'cpu') 439 output = default_decollate(output) 440 return {i: d for i, d in zip(data_id, output)} 441 442 def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], 443 timestep: namedtuple) -> Dict[str, torch.Tensor]: 444 """ 445 Overview: 446 Process and pack one timestep transition data into a dict, which can be directly used for training and \ 447 saved in replay buffer. For discrete SAC, it contains obs, next_obs, logit, action, reward, done. 448 Arguments: 449 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. 450 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ 451 as input. For discrete SAC, it contains the action and the logit of the action. 452 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 453 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 454 reward, done, info, etc. 455 Returns: 456 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 457 """ 458 transition = { 459 'obs': obs, 460 'next_obs': timestep.obs, 461 'action': policy_output['action'], 462 'logit': policy_output['logit'], 463 'reward': timestep.reward, 464 'done': timestep.done, 465 } 466 return transition 467 468 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 469 """ 470 Overview: 471 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ 472 can be used for training directly. In discrete SAC, a train sample is a processed transition (unroll_len=1). 473 Arguments: 474 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ 475 the same format as the return value of ``self._process_transition`` method. 476 Returns: 477 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ 478 as input transitions, but may contain more data for training. 479 """ 480 return get_train_sample(transitions, self._unroll_len) 481 482 def _init_eval(self) -> None: 483 """ 484 Overview: 485 Initialize the eval mode of policy, including related attributes and modules. For DiscreteSAC, it contains \ 486 the eval model to greedily select action type with argmax q_value mechanism. 487 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 488 489 .. note:: 490 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 491 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 492 """ 493 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 494 self._eval_model.reset() 495 496 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 497 """ 498 Overview: 499 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 500 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 501 action to interact with the envs. 502 Arguments: 503 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 504 key of the dict is environment id and the value is the corresponding data of the env. 505 Returns: 506 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 507 key of the dict is the same as the input data, i.e. environment id. 508 509 .. note:: 510 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 511 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 512 You can implement you own model rather than use the default model. For more information, please raise an \ 513 issue in GitHub repo and we will continue to follow up. 514 515 .. note:: 516 For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \ 517 ``ding.policy.tests.test_discrete_sac``. 518 """ 519 data_id = list(data.keys()) 520 data = default_collate(list(data.values())) 521 if self._cuda: 522 data = to_device(data, self._device) 523 self._eval_model.eval() 524 with torch.no_grad(): 525 output = self._eval_model.forward(data, mode='compute_actor') 526 if self._cuda: 527 output = to_device(output, 'cpu') 528 output = default_decollate(output) 529 return {i: d for i, d in zip(data_id, output)} 530 531 def _monitor_vars_learn(self) -> List[str]: 532 """ 533 Overview: 534 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 535 as text logger, tensorboard logger, will use these keys to save the corresponding data. 536 Returns: 537 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 538 """ 539 twin_critic = ['twin_critic_loss'] if self._twin_critic else [] 540 if self._auto_alpha: 541 return super()._monitor_vars_learn() + [ 542 'alpha_loss', 'policy_loss', 'critic_loss', 'cur_lr_q', 'cur_lr_p', 'target_q_value', 'q_value_1', 543 'q_value_2', 'alpha', 'td_error', 'target_value', 'entropy' 544 ] + twin_critic 545 else: 546 return super()._monitor_vars_learn() + [ 547 'policy_loss', 'critic_loss', 'cur_lr_q', 'cur_lr_p', 'target_q_value', 'q_value_1', 'q_value_2', 548 'alpha', 'td_error', 'target_value', 'entropy' 549 ] + twin_critic 550 551 552@POLICY_REGISTRY.register('sac') 553class SACPolicy(Policy): 554 """ 555 Overview: 556 Policy class of continuous SAC algorithm. Paper link: https://arxiv.org/pdf/1801.01290.pdf 557 558 Config: 559 == ==================== ======== ============= ================================= ======================= 560 ID Symbol Type Default Value Description Other 561 == ==================== ======== ============= ================================= ======================= 562 1 ``type`` str sac | RL policy register name, refer | this arg is optional, 563 | to registry ``POLICY_REGISTRY`` | a placeholder 564 2 ``cuda`` bool True | Whether to use cuda for network | 565 3 ``on_policy`` bool False | SAC is an off-policy | 566 | algorithm. | 567 4 ``priority`` bool False | Whether to use priority | 568 | sampling in buffer. | 569 5 | ``priority_IS_`` bool False | Whether use Importance Sampling | 570 | ``weight`` | weight to correct biased update | 571 6 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for 572 | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ 573 | | buffer when training starts. | TD3. 574 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3 575 | ``_rate_q`` | network. | 576 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3 577 | ``_rate_policy`` | network. | 578 9 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- 579 | | coefficient. | zation for auto 580 | | | alpha, when 581 | | | auto_alpha is True 582 10 | ``learn.`` bool False | Determine whether to use | Temperature parameter 583 | ``auto_alpha`` | auto temperature parameter | determines the 584 | | alpha. | relative importance 585 | | | of the entropy term 586 | | | against the reward. 587 11 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only 588 | ``ignore_done`` | done flag. | in env like Pendulum 589 12 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation 590 | ``target_theta`` | target network. | factor in polyak aver 591 | | | aging for target 592 | | | networks. 593 == ==================== ======== ============= ================================= ======================= 594 """ 595 596 config = dict( 597 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 598 type='sac', 599 # (bool) Whether to use cuda for network and loss computation. 600 cuda=False, 601 # (bool) Whether to belong to on-policy or off-policy algorithm, SAC is an off-policy algorithm. 602 on_policy=False, 603 # (bool) Whether to use priority sampling in buffer. Default to False in SAC. 604 priority=False, 605 # (bool) Whether use Importance Sampling weight to correct biased update. If True, priority must be True. 606 priority_IS_weight=False, 607 # (int) Number of training samples (randomly collected) in replay buffer when training starts. 608 random_collect_size=10000, 609 # (bool) Whether to need policy-specific data in process transition. 610 transition_with_policy_data=True, 611 # (bool) Whether to enable multi-agent training setting. 612 multi_agent=False, 613 model=dict( 614 # (bool) Whether to use double-soft-q-net for target q computation. 615 # For more details, please refer to TD3 about Clipped Double-Q Learning trick. 616 twin_critic=True, 617 # (str) Use reparameterization trick for continous action. 618 action_space='reparameterization', 619 ), 620 # learn_mode config 621 learn=dict( 622 # (int) How many updates (iterations) to train after collector's one collection. 623 # Bigger "update_per_collect" means bigger off-policy. 624 update_per_collect=1, 625 # (int) Minibatch size for one gradient descent. 626 batch_size=256, 627 # (float) Learning rate for soft q network. 628 learning_rate_q=3e-4, 629 # (float) Learning rate for policy network. 630 learning_rate_policy=3e-4, 631 # (float) Learning rate for auto temperature parameter `\alpha`. 632 learning_rate_alpha=3e-4, 633 # (float) Used for soft update of the target network, 634 # aka. Interpolation factor in EMA update for target network. 635 target_theta=0.005, 636 # (float) discount factor for the discounted sum of rewards, aka. gamma. 637 discount_factor=0.99, 638 # (float) Entropy regularization coefficient in SAC. 639 # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. 640 # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. 641 alpha=0.2, 642 # (bool) Whether to use auto temperature parameter `\alpha` . 643 # Temperature parameter `\alpha` determines the relative importance of the entropy term against the reward. 644 # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. 645 # Note that: Using auto alpha needs to set the above `learning_rate_alpha`. 646 auto_alpha=True, 647 # (bool) Whether to use auto `\alpha` in log space. 648 log_space=True, 649 # (float) Target policy entropy value for auto temperature (alpha) adjustment. 650 target_entropy=None, 651 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 652 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 653 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 654 # However, interaction with HalfCheetah always gets done with False, 655 # Since we inplace done==True with done==False to keep 656 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 657 # when the episode step is greater than max episode step. 658 ignore_done=False, 659 # (float) Weight uniform initialization max range in the last output layer. 660 init_w=3e-3, 661 ), 662 # collect_mode config 663 collect=dict( 664 # (int) How many training samples collected in one collection procedure. 665 n_sample=1, 666 # (int) Split episodes or trajectories into pieces with length `unroll_len`. 667 unroll_len=1, 668 # (bool) Whether to collect logit in `process_transition`. 669 # In some algorithm like guided cost learning, we need to use logit to train the reward model. 670 collector_logit=False, 671 ), 672 eval=dict(), # for compability 673 other=dict( 674 replay_buffer=dict( 675 # (int) Maximum size of replay buffer. Usually, larger buffer size is good 676 # for SAC but cost more storage. 677 replay_buffer_size=1000000, 678 ), 679 ), 680 ) 681 682 def default_model(self) -> Tuple[str, List[str]]: 683 """ 684 Overview: 685 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 686 automatically call this method to get the default model setting and create model. 687 Returns: 688 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 689 """ 690 if self._cfg.multi_agent: 691 return 'continuous_maqac', ['ding.model.template.maqac'] 692 else: 693 return 'continuous_qac', ['ding.model.template.qac'] 694 695 def _init_learn(self) -> None: 696 """ 697 Overview: 698 Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ 699 contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \ 700 model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here. 701 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 702 703 .. note:: 704 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 705 and ``_load_state_dict_learn`` methods. 706 707 .. note:: 708 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 709 710 .. note:: 711 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 712 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 713 """ 714 self._priority = self._cfg.priority 715 self._priority_IS_weight = self._cfg.priority_IS_weight 716 self._twin_critic = self._cfg.model.twin_critic 717 718 # Weight Init for the last output layer 719 if hasattr(self._model, 'actor_head'): # keep compatibility 720 init_w = self._cfg.learn.init_w 721 self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w) 722 self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w) 723 self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w) 724 self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w) 725 726 self._optimizer_q = Adam( 727 self._model.critic.parameters(), 728 lr=self._cfg.learn.learning_rate_q, 729 ) 730 self._optimizer_policy = Adam( 731 self._model.actor.parameters(), 732 lr=self._cfg.learn.learning_rate_policy, 733 ) 734 735 # Algorithm-Specific Config 736 self._gamma = self._cfg.learn.discount_factor 737 if self._cfg.learn.auto_alpha: 738 if self._cfg.learn.target_entropy is None: 739 assert 'action_shape' in self._cfg.model, "SAC need network model with action_shape variable" 740 self._target_entropy = -np.prod(self._cfg.model.action_shape) 741 else: 742 self._target_entropy = self._cfg.learn.target_entropy 743 if self._cfg.learn.log_space: 744 self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha])) 745 self._log_alpha = self._log_alpha.to(self._device).requires_grad_() 746 self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha) 747 assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad 748 self._alpha = self._log_alpha.detach().exp() 749 self._auto_alpha = True 750 self._log_space = True 751 else: 752 self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_() 753 self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha) 754 self._auto_alpha = True 755 self._log_space = False 756 else: 757 self._alpha = torch.tensor( 758 [self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32 759 ) 760 self._auto_alpha = False 761 762 # Main and target models 763 self._target_model = copy.deepcopy(self._model) 764 self._target_model = model_wrap( 765 self._target_model, 766 wrapper_name='target', 767 update_type='momentum', 768 update_kwargs={'theta': self._cfg.learn.target_theta} 769 ) 770 self._learn_model = model_wrap(self._model, wrapper_name='base') 771 self._learn_model.reset() 772 self._target_model.reset() 773 774 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 775 """ 776 Overview: 777 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 778 that the policy inputs some training batch data from the replay buffer and then returns the output \ 779 result, including various training information such as loss, action, priority. 780 Arguments: 781 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 782 training samples. For each element in list, the key of the dict is the name of data items and the \ 783 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 784 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 785 dimension by some utility functions such as ``default_preprocess_learn``. \ 786 For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 787 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. 788 Returns: 789 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 790 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 791 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 792 793 .. note:: 794 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 795 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 796 You can implement you own model rather than use the default model. For more information, please raise an \ 797 issue in GitHub repo and we will continue to follow up. 798 799 .. note:: 800 For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. 801 """ 802 loss_dict = {} 803 data = default_preprocess_learn( 804 data, 805 use_priority=self._priority, 806 use_priority_IS_weight=self._cfg.priority_IS_weight, 807 ignore_done=self._cfg.learn.ignore_done, 808 use_nstep=False 809 ) 810 if self._cuda: 811 data = to_device(data, self._device) 812 813 self._learn_model.train() 814 self._target_model.train() 815 obs = data['obs'] 816 next_obs = data['next_obs'] 817 reward = data['reward'] 818 done = data['done'] 819 820 # 1. predict q value 821 q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 822 823 # 2. predict target value 824 with torch.no_grad(): 825 (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] 826 827 dist = Independent(Normal(mu, sigma), 1) 828 pred = dist.rsample() 829 next_action = torch.tanh(pred) 830 y = 1 - next_action.pow(2) + 1e-6 831 # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum) 832 next_log_prob = dist.log_prob(pred).unsqueeze(-1) 833 next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) 834 835 next_data = {'obs': next_obs, 'action': next_action} 836 target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] 837 # the value of a policy according to the maximum entropy objective 838 if self._twin_critic: 839 # find min one as target q value 840 target_q_value = torch.min(target_q_value[0], 841 target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1) 842 else: 843 target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1) 844 845 # 3. compute q loss 846 if self._twin_critic: 847 q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight']) 848 loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma) 849 q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight']) 850 loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma) 851 td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2 852 else: 853 q_data = v_1step_td_data(q_value, target_q_value, reward, done, data['weight']) 854 loss_dict['critic_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma) 855 856 # 4. update q network 857 self._optimizer_q.zero_grad() 858 if self._twin_critic: 859 (loss_dict['critic_loss'] + loss_dict['twin_critic_loss']).backward() 860 else: 861 loss_dict['critic_loss'].backward() 862 self._optimizer_q.step() 863 864 # 5. evaluate to get action distribution 865 (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] 866 dist = Independent(Normal(mu, sigma), 1) 867 pred = dist.rsample() 868 action = torch.tanh(pred) 869 y = 1 - action.pow(2) + 1e-6 870 # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum) 871 log_prob = dist.log_prob(pred).unsqueeze(-1) 872 log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) 873 874 eval_data = {'obs': obs, 'action': action} 875 new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value'] 876 if self._twin_critic: 877 new_q_value = torch.min(new_q_value[0], new_q_value[1]) 878 879 # 6. compute policy loss 880 policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean() 881 882 loss_dict['policy_loss'] = policy_loss 883 884 # 7. update policy network 885 self._optimizer_policy.zero_grad() 886 loss_dict['policy_loss'].backward() 887 self._optimizer_policy.step() 888 889 # 8. compute alpha loss 890 if self._auto_alpha: 891 if self._log_space: 892 log_prob = log_prob + self._target_entropy 893 loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean() 894 895 self._alpha_optim.zero_grad() 896 loss_dict['alpha_loss'].backward() 897 self._alpha_optim.step() 898 self._alpha = self._log_alpha.detach().exp() 899 else: 900 log_prob = log_prob + self._target_entropy 901 loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean() 902 903 self._alpha_optim.zero_grad() 904 loss_dict['alpha_loss'].backward() 905 self._alpha_optim.step() 906 self._alpha = max(0, self._alpha) 907 908 loss_dict['total_loss'] = sum(loss_dict.values()) 909 910 # target update 911 self._target_model.update(self._learn_model.state_dict()) 912 return { 913 'cur_lr_q': self._optimizer_q.defaults['lr'], 914 'cur_lr_p': self._optimizer_policy.defaults['lr'], 915 'priority': td_error_per_sample.abs().tolist(), 916 'td_error': td_error_per_sample.detach().mean().item(), 917 'alpha': self._alpha.item(), 918 'target_q_value': target_q_value.detach().mean().item(), 919 'transformed_log_prob': log_prob.mean().item(), 920 **loss_dict 921 } 922 923 def _state_dict_learn(self) -> Dict[str, Any]: 924 """ 925 Overview: 926 Return the state_dict of learn mode, usually including model, target_model and optimizers. 927 Returns: 928 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 929 """ 930 ret = { 931 'model': self._learn_model.state_dict(), 932 'target_model': self._target_model.state_dict(), 933 'optimizer_q': self._optimizer_q.state_dict(), 934 'optimizer_policy': self._optimizer_policy.state_dict(), 935 } 936 if self._auto_alpha: 937 ret.update({'optimizer_alpha': self._alpha_optim.state_dict()}) 938 return ret 939 940 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 941 """ 942 Overview: 943 Load the state_dict variable into policy learn mode. 944 Arguments: 945 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 946 947 .. tip:: 948 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 949 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 950 complicated operation. 951 """ 952 self._learn_model.load_state_dict(state_dict['model']) 953 self._target_model.load_state_dict(state_dict['target_model']) 954 self._optimizer_q.load_state_dict(state_dict['optimizer_q']) 955 self._optimizer_policy.load_state_dict(state_dict['optimizer_policy']) 956 if self._auto_alpha: 957 self._alpha_optim.load_state_dict(state_dict['optimizer_alpha']) 958 959 def _init_collect(self) -> None: 960 """ 961 Overview: 962 Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ 963 collect_model other algorithm-specific arguments such as unroll_len. \ 964 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 965 966 .. note:: 967 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 968 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 969 """ 970 self._unroll_len = self._cfg.collect.unroll_len 971 self._collect_model = model_wrap(self._model, wrapper_name='base') 972 self._collect_model.reset() 973 974 def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: 975 """ 976 Overview: 977 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 978 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 979 data, such as the action to interact with the envs. 980 Arguments: 981 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 982 key of the dict is environment id and the value is the corresponding data of the env. 983 Returns: 984 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 985 other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ 986 dict is the same as the input data, i.e. environment id. 987 988 .. note:: 989 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 990 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 991 You can implement you own model rather than use the default model. For more information, please raise an \ 992 issue in GitHub repo and we will continue to follow up. 993 994 .. note:: 995 ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. 996 997 .. note:: 998 For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. 999 """1000 data_id = list(data.keys())1001 data = default_collate(list(data.values()))1002 if self._cuda:1003 data = to_device(data, self._device)1004 self._collect_model.eval()1005 with torch.no_grad():1006 (mu, sigma) = self._collect_model.forward(data, mode='compute_actor')['logit']1007 dist = Independent(Normal(mu, sigma), 1)1008 action = torch.tanh(dist.rsample())1009 output = {'logit': (mu, sigma), 'action': action}1010 if self._cuda:1011 output = to_device(output, 'cpu')1012 output = default_decollate(output)1013 return {i: d for i, d in zip(data_id, output)}10141015 def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],1016 timestep: namedtuple) -> Dict[str, torch.Tensor]:1017 """1018 Overview:1019 Process and pack one timestep transition data into a dict, which can be directly used for training and \1020 saved in replay buffer. For continuous SAC, it contains obs, next_obs, action, reward, done. The logit \1021 will be also added when ``collector_logit`` is True.1022 Arguments:1023 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.1024 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \1025 as input. For continuous SAC, it contains the action and the logit (mu and sigma) of the action.1026 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \1027 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \1028 reward, done, info, etc.1029 Returns:1030 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.1031 """1032 if self._cfg.collect.collector_logit:1033 transition = {1034 'obs': obs,1035 'next_obs': timestep.obs,1036 'logit': policy_output['logit'],1037 'action': policy_output['action'],1038 'reward': timestep.reward,1039 'done': timestep.done,1040 }1041 else:1042 transition = {1043 'obs': obs,1044 'next_obs': timestep.obs,1045 'action': policy_output['action'],1046 'reward': timestep.reward,1047 'done': timestep.done,1048 }1049 return transition10501051 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:1052 """1053 Overview:1054 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \1055 can be used for training directly. In continuous SAC, a train sample is a processed transition \1056 (unroll_len=1).1057 Arguments:1058 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \1059 the same format as the return value of ``self._process_transition`` method.1060 Returns:1061 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \1062 as input transitions, but may contain more data for training.1063 """1064 return get_train_sample(transitions, self._unroll_len)10651066 def _init_eval(self) -> None:1067 """1068 Overview:1069 Initialize the eval mode of policy, including related attributes and modules. For SAC, it contains the \1070 eval model, which is equipped with ``base`` model wrapper to ensure compability.1071 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.10721073 .. note::1074 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \1075 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.1076 """1077 self._eval_model = model_wrap(self._model, wrapper_name='base')1078 self._eval_model.reset()10791080 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:1081 """1082 Overview:1083 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \1084 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \1085 action to interact with the envs.1086 Arguments:1087 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \1088 key of the dict is environment id and the value is the corresponding data of the env.1089 Returns:1090 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \1091 key of the dict is the same as the input data, i.e. environment id.10921093 .. note::1094 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \1095 For the data type that not supported, the main reason is that the corresponding model does not support it. \1096 You can implement you own model rather than use the default model. For more information, please raise an \1097 issue in GitHub repo and we will continue to follow up.10981099 .. note::1100 ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency.11011102 .. note::1103 For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``.1104 """1105 data_id = list(data.keys())1106 data = default_collate(list(data.values()))1107 if self._cuda:1108 data = to_device(data, self._device)1109 self._eval_model.eval()1110 with torch.no_grad():1111 (mu, sigma) = self._eval_model.forward(data, mode='compute_actor')['logit']1112 action = torch.tanh(mu) # deterministic_eval1113 output = {'action': action}1114 if self._cuda:1115 output = to_device(output, 'cpu')1116 output = default_decollate(output)1117 return {i: d for i, d in zip(data_id, output)}11181119 def _monitor_vars_learn(self) -> List[str]:1120 """1121 Overview:1122 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \1123 as text logger, tensorboard logger, will use these keys to save the corresponding data.1124 Returns:1125 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.1126 """1127 twin_critic = ['twin_critic_loss'] if self._twin_critic else []1128 alpha_loss = ['alpha_loss'] if self._auto_alpha else []1129 return [1130 'value_loss'1131 'alpha_loss',1132 'policy_loss',1133 'critic_loss',1134 'cur_lr_q',1135 'cur_lr_p',1136 'target_q_value',1137 'alpha',1138 'td_error',1139 'transformed_log_prob',1140 ] + twin_critic + alpha_loss114111421143@POLICY_REGISTRY.register('sqil_sac')1144class SQILSACPolicy(SACPolicy):1145 """1146 Overview:1147 Policy class of continuous SAC algorithm with SQIL extension.1148 SAC paper link: https://arxiv.org/pdf/1801.01290.pdf1149 SQIL paper link: https://arxiv.org/abs/1905.111081150 """11511152 def _init_learn(self) -> None:1153 """1154 Overview:1155 Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \1156 contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \1157 model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here.1158 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.11591160 .. note::1161 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \1162 and ``_load_state_dict_learn`` methods.11631164 .. note::1165 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.11661167 .. note::1168 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \1169 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.1170 """1171 self._priority = self._cfg.priority1172 self._priority_IS_weight = self._cfg.priority_IS_weight1173 self._twin_critic = self._cfg.model.twin_critic11741175 # Weight Init for the last output layer1176 init_w = self._cfg.learn.init_w1177 self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w)1178 self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w)1179 self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w)1180 self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w)11811182 self._optimizer_q = Adam(1183 self._model.critic.parameters(),1184 lr=self._cfg.learn.learning_rate_q,1185 )1186 self._optimizer_policy = Adam(1187 self._model.actor.parameters(),1188 lr=self._cfg.learn.learning_rate_policy,1189 )11901191 # Algorithm-Specific Config1192 self._gamma = self._cfg.learn.discount_factor1193 if self._cfg.learn.auto_alpha:1194 if self._cfg.learn.target_entropy is None:1195 assert 'action_shape' in self._cfg.model, "SQILSACPolicy need network model with action_shape variable"1196 self._target_entropy = -np.prod(self._cfg.model.action_shape)1197 else:1198 self._target_entropy = self._cfg.learn.target_entropy1199 if self._cfg.learn.log_space:1200 self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha]))1201 self._log_alpha = self._log_alpha.to(self._device).requires_grad_()1202 self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha)1203 assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad1204 self._alpha = self._log_alpha.detach().exp()1205 self._auto_alpha = True1206 self._log_space = True1207 else:1208 self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_()1209 self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha)1210 self._auto_alpha = True1211 self._log_space = False1212 else:1213 self._alpha = torch.tensor(1214 [self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float321215 )1216 self._auto_alpha = False12171218 # Main and target models1219 self._target_model = copy.deepcopy(self._model)1220 self._target_model = model_wrap(1221 self._target_model,1222 wrapper_name='target',1223 update_type='momentum',1224 update_kwargs={'theta': self._cfg.learn.target_theta}1225 )1226 self._learn_model = model_wrap(self._model, wrapper_name='base')1227 self._learn_model.reset()1228 self._target_model.reset()12291230 # monitor cossimilarity and entropy switch1231 self._monitor_cos = True1232 self._monitor_entropy = True12331234 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:1235 """1236 Overview:1237 Policy forward function of learn mode (training policy and updating parameters). Forward means \1238 that the policy inputs some training batch data from the replay buffer and then returns the output \1239 result, including various training information such as loss, action, priority.1240 Arguments:1241 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \1242 training samples. For each element in list, the key of the dict is the name of data items and the \1243 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \1244 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \1245 dimension by some utility functions such as ``default_preprocess_learn``. \1246 For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \1247 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``.1248 Returns:1249 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \1250 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \1251 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.12521253 .. note::1254 For SQIL + SAC, input data is composed of two parts with the same size: agent data and expert data. \1255 Both of them are relabelled with new reward according to SQIL algorithm.12561257 .. note::1258 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \1259 For the data type that not supported, the main reason is that the corresponding model does not support it. \1260 You can implement you own model rather than use the default model. For more information, please raise an \1261 issue in GitHub repo and we will continue to follow up.12621263 .. note::1264 For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``.1265 """1266 loss_dict = {}1267 if self._monitor_cos:1268 agent_data = default_preprocess_learn(1269 data[0:len(data) // 2],1270 use_priority=self._priority,1271 use_priority_IS_weight=self._cfg.priority_IS_weight,1272 ignore_done=self._cfg.learn.ignore_done,1273 use_nstep=False1274 )12751276 expert_data = default_preprocess_learn(1277 data[len(data) // 2:],1278 use_priority=self._priority,1279 use_priority_IS_weight=self._cfg.priority_IS_weight,1280 ignore_done=self._cfg.learn.ignore_done,1281 use_nstep=False1282 )1283 if self._cuda:1284 agent_data = to_device(agent_data, self._device)1285 expert_data = to_device(expert_data, self._device)12861287 data = default_preprocess_learn(1288 data,1289 use_priority=self._priority,1290 use_priority_IS_weight=self._cfg.priority_IS_weight,1291 ignore_done=self._cfg.learn.ignore_done,1292 use_nstep=False1293 )1294 if self._cuda:1295 data = to_device(data, self._device)12961297 self._learn_model.train()1298 self._target_model.train()1299 obs = data['obs']1300 next_obs = data['next_obs']1301 reward = data['reward']1302 done = data['done']13031304 # 1. predict q value1305 q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']13061307 # 2. predict target value1308 with torch.no_grad():1309 (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit']1310 dist = Independent(Normal(mu, sigma), 1)1311 pred = dist.rsample()1312 next_action = torch.tanh(pred)1313 y = 1 - next_action.pow(2) + 1e-61314 # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)1315 next_log_prob = dist.log_prob(pred).unsqueeze(-1)1316 next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True)13171318 next_data = {'obs': next_obs, 'action': next_action}1319 target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']1320 # the value of a policy according to the maximum entropy objective1321 if self._twin_critic:1322 # find min one as target q value1323 target_q_value = torch.min(target_q_value[0],1324 target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1)1325 else:1326 target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1)13271328 # 3. compute q loss1329 if self._twin_critic:1330 q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight'])1331 loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma)1332 q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight'])1333 loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma)1334 td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 21335 else:1336 q_data = v_1step_td_data(q_value, target_q_value, reward, done, data['weight'])1337 loss_dict['critic_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma)13381339 # 4. update q network1340 self._optimizer_q.zero_grad()1341 if self._twin_critic:1342 (loss_dict['critic_loss'] + loss_dict['twin_critic_loss']).backward()1343 else:1344 loss_dict['critic_loss'].backward()1345 self._optimizer_q.step()13461347 # 5. evaluate to get action distribution1348 if self._monitor_cos:1349 # agent1350 (mu, sigma) = self._learn_model.forward(agent_data['obs'], mode='compute_actor')['logit']1351 dist = Independent(Normal(mu, sigma), 1)1352 pred = dist.rsample()1353 action = torch.tanh(pred)1354 y = 1 - action.pow(2) + 1e-61355 # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)1356 agent_log_prob = dist.log_prob(pred).unsqueeze(-1)1357 agent_log_prob = agent_log_prob - torch.log(y).sum(-1, keepdim=True)13581359 eval_data = {'obs': agent_data['obs'], 'action': action}1360 agent_new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']1361 if self._twin_critic:1362 agent_new_q_value = torch.min(agent_new_q_value[0], agent_new_q_value[1])1363 # expert1364 (mu, sigma) = self._learn_model.forward(expert_data['obs'], mode='compute_actor')['logit']1365 dist = Independent(Normal(mu, sigma), 1)1366 pred = dist.rsample()1367 action = torch.tanh(pred)1368 y = 1 - action.pow(2) + 1e-61369 # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)1370 expert_log_prob = dist.log_prob(pred).unsqueeze(-1)1371 expert_log_prob = expert_log_prob - torch.log(y).sum(-1, keepdim=True)13721373 eval_data = {'obs': expert_data['obs'], 'action': action}1374 expert_new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']1375 if self._twin_critic:1376 expert_new_q_value = torch.min(expert_new_q_value[0], expert_new_q_value[1])13771378 (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit']1379 dist = Independent(Normal(mu, sigma), 1)1380 # for monitor the entropy of policy1381 if self._monitor_entropy:1382 dist_entropy = dist.entropy()1383 entropy = dist_entropy.mean()13841385 pred = dist.rsample()1386 action = torch.tanh(pred)1387 y = 1 - action.pow(2) + 1e-61388 # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)1389 log_prob = dist.log_prob(pred).unsqueeze(-1)1390 log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)13911392 eval_data = {'obs': obs, 'action': action}1393 new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']1394 if self._twin_critic:1395 new_q_value = torch.min(new_q_value[0], new_q_value[1])13961397 # 6. compute policy loss1398 policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean()1399 loss_dict['policy_loss'] = policy_loss14001401 # 7. update policy network1402 if self._monitor_cos:1403 agent_policy_loss = (self._alpha * agent_log_prob - agent_new_q_value.unsqueeze(-1)).mean()1404 expert_policy_loss = (self._alpha * expert_log_prob - expert_new_q_value.unsqueeze(-1)).mean()1405 loss_dict['agent_policy_loss'] = agent_policy_loss1406 loss_dict['expert_policy_loss'] = expert_policy_loss1407 self._optimizer_policy.zero_grad()1408 loss_dict['agent_policy_loss'].backward()1409 agent_grad = (list(list(self._learn_model.actor.children())[-1].children())[-1].weight.grad).mean()1410 self._optimizer_policy.zero_grad()1411 loss_dict['expert_policy_loss'].backward()1412 expert_grad = (list(list(self._learn_model.actor.children())[-1].children())[-1].weight.grad).mean()1413 cos = nn.CosineSimilarity(dim=0)1414 cos_similarity = cos(agent_grad, expert_grad)1415 self._optimizer_policy.zero_grad()1416 loss_dict['policy_loss'].backward()1417 self._optimizer_policy.step()14181419 # 8. compute alpha loss1420 if self._auto_alpha:1421 if self._log_space:1422 log_prob = log_prob + self._target_entropy1423 loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean()14241425 self._alpha_optim.zero_grad()1426 loss_dict['alpha_loss'].backward()1427 self._alpha_optim.step()1428 self._alpha = self._log_alpha.detach().exp()1429 else:1430 log_prob = log_prob + self._target_entropy1431 loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean()14321433 self._alpha_optim.zero_grad()1434 loss_dict['alpha_loss'].backward()1435 self._alpha_optim.step()1436 self._alpha = max(0, self._alpha)14371438 loss_dict['total_loss'] = sum(loss_dict.values())14391440 # target update1441 self._target_model.update(self._learn_model.state_dict())1442 var_monitor = {1443 'cur_lr_q': self._optimizer_q.defaults['lr'],1444 'cur_lr_p': self._optimizer_policy.defaults['lr'],1445 'priority': td_error_per_sample.abs().tolist(),1446 'td_error': td_error_per_sample.detach().mean().item(),1447 'agent_td_error': td_error_per_sample.detach().chunk(2, dim=0)[0].mean().item(),1448 'expert_td_error': td_error_per_sample.detach().chunk(2, dim=0)[1].mean().item(),1449 'alpha': self._alpha.item(),1450 'target_q_value': target_q_value.detach().mean().item(),1451 'mu': mu.detach().mean().item(),1452 'sigma': sigma.detach().mean().item(),1453 'q_value0': new_q_value[0].detach().mean().item(),1454 'q_value1': new_q_value[1].detach().mean().item(),1455 **loss_dict,1456 }1457 if self._monitor_cos:1458 var_monitor['cos_similarity'] = cos_similarity.item()1459 if self._monitor_entropy:1460 var_monitor['entropy'] = entropy.item()1461 return var_monitor14621463 def _monitor_vars_learn(self) -> List[str]:1464 """1465 Overview:1466 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \1467 as text logger, tensorboard logger, will use these keys to save the corresponding data.1468 Returns:1469 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.1470 """1471 twin_critic = ['twin_critic_loss'] if self._twin_critic else []1472 alpha_loss = ['alpha_loss'] if self._auto_alpha else []1473 cos_similarity = ['cos_similarity'] if self._monitor_cos else []1474 entropy = ['entropy'] if self._monitor_entropy else []1475 return [1476 'value_loss'1477 'alpha_loss',1478 'policy_loss',1479 'critic_loss',1480 'cur_lr_q',1481 'cur_lr_p',1482 'target_q_value',1483 'alpha',1484 'td_error',1485 'agent_td_error',1486 'expert_td_error',1487 'mu',1488 'sigma',1489 'q_value0',1490 'q_value1',1491 ] + twin_critic + alpha_loss + cos_similarity + entropy