Skip to content

ding.policy.iql

ding.policy.iql

IQLPolicy

Bases: Policy

Overview

Policy class of Implicit Q-Learning (IQL) algorithm for continuous control. Paper link: https://arxiv.org/abs/2110.06169.

Config

== ==================== ======== ============= ================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============= ================================= ======================= 1 type str iql | RL policy register name, refer | this arg is optional, | to registry POLICY_REGISTRY | a placeholder 2 cuda bool True | Whether to use cuda for network | 3 | random_ int 10000 | Number of randomly collected | Default to 10000 for | collect_size | training samples in replay | SAC, 25000 for DDPG/ | | buffer when training starts. | TD3. 4 | model.policy_ int 256 | Linear layer size for policy | | embedding_size | network. | 5 | model.soft_q_ int 256 | Linear layer size for soft q | | embedding_size | network. | 6 | model.value_ int 256 | Linear layer size for value | Defalut to None when | embedding_size | network. | model.value_network | | | is False. 7 | learn.learning float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when | _rate_q | network. | model.value_network | | | is True. 8 | learn.learning float 3e-4 | Learning rate for policy | Defalut to 1e-3, when | _rate_policy | network. | model.value_network | | | is True. 9 | learn.learning float 3e-4 | Learning rate for policy | Defalut to None when | _rate_value | network. | model.value_network | | | is False. 10 | learn.alpha float 0.2 | Entropy regularization | alpha is initiali- | | coefficient. | zation for auto | | | alpha, when | | | auto_alpha is True 11 | learn.repara_ bool True | Determine whether to use | | meterization | reparameterization trick. | 12 | learn. bool False | Determine whether to use | Temperature parameter | auto_alpha | auto temperature parameter | determines the | | alpha. | relative importance | | | of the entropy term | | | against the reward. 13 | learn.- bool False | Determine whether to ignore | Use ignore_done only | ignore_done | done flag. | in halfcheetah env. 14 | learn.- float 0.005 | Used for soft update of the | aka. Interpolation | target_theta | target network. | factor in polyak aver | | | aging for target | | | networks. == ==================== ======== ============= ================================= =======================

default_model()

Overview

Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

Full Source Code

../ding/policy/iql.py

1from typing import List, Dict, Any, Tuple, Union 2import copy 3from collections import namedtuple 4import numpy as np 5import torch 6import torch.nn.functional as F 7from torch.distributions import Normal, Independent, TransformedDistribution 8from torch.distributions.transforms import TanhTransform, AffineTransform 9 10from ding.torch_utils import Adam, to_device 11from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ 12 qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data 13from ding.model import model_wrap 14from ding.utils import POLICY_REGISTRY 15from ding.utils.data import default_collate, default_decollate 16from .base_policy import Policy 17from .common_utils import default_preprocess_learn 18 19 20def asymmetric_l2_loss(u, tau): 21 return torch.mean(torch.abs(tau - (u < 0).float()) * u ** 2) 22 23 24@POLICY_REGISTRY.register('iql') 25class IQLPolicy(Policy): 26 """ 27 Overview: 28 Policy class of Implicit Q-Learning (IQL) algorithm for continuous control. 29 Paper link: https://arxiv.org/abs/2110.06169. 30 31 Config: 32 == ==================== ======== ============= ================================= ======================= 33 ID Symbol Type Default Value Description Other(Shape) 34 == ==================== ======== ============= ================================= ======================= 35 1 ``type`` str iql | RL policy register name, refer | this arg is optional, 36 | to registry ``POLICY_REGISTRY`` | a placeholder 37 2 ``cuda`` bool True | Whether to use cuda for network | 38 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for 39 | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ 40 | | buffer when training starts. | TD3. 41 4 | ``model.policy_`` int 256 | Linear layer size for policy | 42 | ``embedding_size`` | network. | 43 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | 44 | ``embedding_size`` | network. | 45 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when 46 | ``embedding_size`` | network. | model.value_network 47 | | | is False. 48 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when 49 | ``_rate_q`` | network. | model.value_network 50 | | | is True. 51 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when 52 | ``_rate_policy`` | network. | model.value_network 53 | | | is True. 54 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when 55 | ``_rate_value`` | network. | model.value_network 56 | | | is False. 57 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- 58 | | coefficient. | zation for auto 59 | | | `alpha`, when 60 | | | auto_alpha is True 61 11 | ``learn.repara_`` bool True | Determine whether to use | 62 | ``meterization`` | reparameterization trick. | 63 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter 64 | ``auto_alpha`` | auto temperature parameter | determines the 65 | | `alpha`. | relative importance 66 | | | of the entropy term 67 | | | against the reward. 68 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only 69 | ``ignore_done`` | done flag. | in halfcheetah env. 70 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation 71 | ``target_theta`` | target network. | factor in polyak aver 72 | | | aging for target 73 | | | networks. 74 == ==================== ======== ============= ================================= ======================= 75 """ 76 77 config = dict( 78 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 79 type='iql', 80 # (bool) Whether to use cuda for policy. 81 cuda=False, 82 # (bool) on_policy: Determine whether on-policy or off-policy. 83 # on-policy setting influences the behaviour of buffer. 84 on_policy=False, 85 # (bool) priority: Determine whether to use priority in buffer sample. 86 priority=False, 87 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 88 priority_IS_weight=False, 89 # (int) Number of training samples(randomly collected) in replay buffer when training starts. 90 random_collect_size=10000, 91 model=dict( 92 # (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation. 93 # Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one . 94 # Default to True. 95 twin_critic=True, 96 # (str type) action_space: Use reparameterization trick for continous action 97 action_space='reparameterization', 98 # (int) Hidden size for actor network head. 99 actor_head_hidden_size=512, 100 actor_head_layer_num=3, 101 # (int) Hidden size for critic network head. 102 critic_head_hidden_size=512, 103 critic_head_layer_num=2, 104 ), 105 # learn_mode config 106 learn=dict( 107 # (int) How many updates (iterations) to train after collector's one collection. 108 # Bigger "update_per_collect" means bigger off-policy. 109 update_per_collect=1, 110 # (int) Minibatch size for gradient descent. 111 batch_size=256, 112 # (float) learning_rate_q: Learning rate for soft q network. 113 learning_rate_q=3e-4, 114 # (float) learning_rate_policy: Learning rate for policy network. 115 learning_rate_policy=3e-4, 116 # (float) learning_rate_alpha: Learning rate for auto temperature parameter ``alpha``. 117 learning_rate_alpha=3e-4, 118 # (float) target_theta: Used for soft update of the target network, 119 # aka. Interpolation factor in polyak averaging for target networks. 120 target_theta=0.005, 121 # (float) discount factor for the discounted sum of rewards, aka. gamma. 122 discount_factor=0.99, 123 # (float) alpha: Entropy regularization coefficient. 124 # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. 125 # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. 126 # Default to 0.2. 127 alpha=0.2, 128 # (bool) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . 129 # Temperature parameter determines the relative importance of the entropy term against the reward. 130 # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. 131 # Default to False. 132 # Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`. 133 auto_alpha=True, 134 # (bool) log_space: Determine whether to use auto `\alpha` in log space. 135 log_space=True, 136 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 137 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 138 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 139 # However, interaction with HalfCheetah always gets done with done is False, 140 # Since we inplace done==True with done==False to keep 141 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 142 # when the episode step is greater than max episode step. 143 ignore_done=False, 144 # (float) Weight uniform initialization range in the last output layer. 145 init_w=3e-3, 146 # (int) The numbers of action sample each at every state s from a uniform-at-random. 147 num_actions=10, 148 # (bool) Whether use lagrange multiplier in q value loss. 149 with_lagrange=False, 150 # (float) The threshold for difference in Q-values. 151 lagrange_thresh=-1, 152 # (float) Loss weight for conservative item. 153 min_q_weight=1.0, 154 # (float) coefficient for the asymmetric loss, range from [0.5, 1.0], default to 0.70. 155 tau=0.7, 156 # (float) temperature coefficient for Advantage Weighted Regression loss, default to 1.0. 157 beta=1.0, 158 ), 159 eval=dict(), # for compatibility 160 ) 161 162 def default_model(self) -> Tuple[str, List[str]]: 163 """ 164 Overview: 165 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 166 automatically call this method to get the default model setting and create model. 167 Returns: 168 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 169 """ 170 171 return 'continuous_qvac', ['ding.model.template.qvac'] 172 173 def _init_learn(self) -> None: 174 """ 175 Overview: 176 Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ 177 contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange, \ 178 main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \ 179 target is also initialized here. 180 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 181 182 .. note:: 183 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 184 and ``_load_state_dict_learn`` methods. 185 186 .. note:: 187 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 188 189 .. note:: 190 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 191 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 192 """ 193 self._priority = self._cfg.priority 194 self._priority_IS_weight = self._cfg.priority_IS_weight 195 self._twin_critic = self._cfg.model.twin_critic 196 self._num_actions = self._cfg.learn.num_actions 197 198 self._min_q_version = 3 199 self._min_q_weight = self._cfg.learn.min_q_weight 200 self._with_lagrange = self._cfg.learn.with_lagrange and (self._lagrange_thresh > 0) 201 self._lagrange_thresh = self._cfg.learn.lagrange_thresh 202 if self._with_lagrange: 203 self.target_action_gap = self._lagrange_thresh 204 self.log_alpha_prime = torch.tensor(0.).to(self._device).requires_grad_() 205 self.alpha_prime_optimizer = Adam( 206 [self.log_alpha_prime], 207 lr=self._cfg.learn.learning_rate_q, 208 ) 209 210 # Weight Init 211 init_w = self._cfg.learn.init_w 212 self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w) 213 self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w) 214 # self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w) 215 # self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w) 216 if self._twin_critic: 217 self._model.critic_q_head[0][-1].last.weight.data.uniform_(-init_w, init_w) 218 self._model.critic_q_head[0][-1].last.bias.data.uniform_(-init_w, init_w) 219 self._model.critic_q_head[1][-1].last.weight.data.uniform_(-init_w, init_w) 220 self._model.critic_q_head[1][-1].last.bias.data.uniform_(-init_w, init_w) 221 else: 222 self._model.critic_q_head[2].last.weight.data.uniform_(-init_w, init_w) 223 self._model.critic_q_head[-1].last.bias.data.uniform_(-init_w, init_w) 224 self._model.critic_v_head[2].last.weight.data.uniform_(-init_w, init_w) 225 self._model.critic_v_head[-1].last.bias.data.uniform_(-init_w, init_w) 226 227 # Optimizers 228 self._optimizer_q = Adam( 229 self._model.critic.parameters(), 230 lr=self._cfg.learn.learning_rate_q, 231 ) 232 self._optimizer_policy = Adam( 233 self._model.actor.parameters(), 234 lr=self._cfg.learn.learning_rate_policy, 235 ) 236 237 # Algorithm config 238 self._gamma = self._cfg.learn.discount_factor 239 240 self._learn_model = model_wrap(self._model, wrapper_name='base') 241 self._learn_model.reset() 242 243 self._forward_learn_cnt = 0 244 245 self._tau = self._cfg.learn.tau 246 self._beta = self._cfg.learn.beta 247 self._policy_start_training_counter = 10000 # 300000 248 249 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 250 """ 251 Overview: 252 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 253 that the policy inputs some training batch data from the offline dataset and then returns the output \ 254 result, including various training information such as loss, action, priority. 255 Arguments: 256 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 257 training samples. For each element in list, the key of the dict is the name of data items and the \ 258 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 259 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 260 dimension by some utility functions such as ``default_preprocess_learn``. \ 261 For IQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 262 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. 263 Returns: 264 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 265 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 266 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 267 268 .. note:: 269 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 270 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 271 You can implement you own model rather than use the default model. For more information, please raise an \ 272 issue in GitHub repo and we will continue to follow up. 273 """ 274 loss_dict = {} 275 data = default_preprocess_learn( 276 data, 277 use_priority=self._priority, 278 use_priority_IS_weight=self._cfg.priority_IS_weight, 279 ignore_done=self._cfg.learn.ignore_done, 280 use_nstep=False 281 ) 282 if len(data.get('action').shape) == 1: 283 data['action'] = data['action'].reshape(-1, 1) 284 285 if self._cuda: 286 data = to_device(data, self._device) 287 288 self._learn_model.train() 289 obs = data['obs'] 290 next_obs = data['next_obs'] 291 reward = data['reward'] 292 done = data['done'] 293 294 # 1. predict q and v value 295 value = self._learn_model.forward(data, mode='compute_critic') 296 q_value, v_value = value['q_value'], value['v_value'] 297 298 # 2. predict target value 299 with torch.no_grad(): 300 (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] 301 302 next_obs_dist = TransformedDistribution( 303 Independent(Normal(mu, sigma), 1), 304 transforms=[TanhTransform(cache_size=1), 305 AffineTransform(loc=0.0, scale=1.05)] 306 ) 307 next_action = next_obs_dist.rsample() 308 next_log_prob = next_obs_dist.log_prob(next_action) 309 310 next_data = {'obs': next_obs, 'action': next_action} 311 next_value = self._learn_model.forward(next_data, mode='compute_critic') 312 next_q_value, next_v_value = next_value['q_value'], next_value['v_value'] 313 314 # the value of a policy according to the maximum entropy objective 315 if self._twin_critic: 316 next_q_value = torch.min(next_q_value[0], next_q_value[1]) 317 318 # 3. compute v loss 319 if self._twin_critic: 320 q_value_min = torch.min(q_value[0], q_value[1]).detach() 321 v_loss = asymmetric_l2_loss(q_value_min - v_value, self._tau) 322 else: 323 advantage = q_value.detach() - v_value 324 v_loss = asymmetric_l2_loss(advantage, self._tau) 325 326 # 4. compute q loss 327 if self._twin_critic: 328 q_data0 = v_1step_td_data(q_value[0], next_v_value, reward, done, data['weight']) 329 loss_dict['critic_q_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma) 330 q_data1 = v_1step_td_data(q_value[1], next_v_value, reward, done, data['weight']) 331 loss_dict['twin_critic_q_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma) 332 q_loss = (loss_dict['critic_q_loss'] + loss_dict['twin_critic_q_loss']) / 2 333 else: 334 q_data = v_1step_td_data(q_value, next_v_value, reward, done, data['weight']) 335 loss_dict['critic_q_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma) 336 q_loss = loss_dict['critic_q_loss'] 337 338 # 5. update q and v network 339 self._optimizer_q.zero_grad() 340 v_loss.backward() 341 q_loss.backward() 342 self._optimizer_q.step() 343 344 # 6. evaluate to get action distribution 345 (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] 346 347 dist = TransformedDistribution( 348 Independent(Normal(mu, sigma), 1), 349 transforms=[TanhTransform(cache_size=1), AffineTransform(loc=0.0, scale=1.05)] 350 ) 351 action = data['action'] 352 log_prob = dist.log_prob(action) 353 354 eval_data = {'obs': obs, 'action': action} 355 new_value = self._learn_model.forward(eval_data, mode='compute_critic') 356 new_q_value, new_v_value = new_value['q_value'], new_value['v_value'] 357 if self._twin_critic: 358 new_q_value = torch.min(new_q_value[0], new_q_value[1]) 359 new_advantage = new_q_value - new_v_value 360 361 # 8. compute policy loss 362 policy_loss = (-log_prob * torch.exp(new_advantage.detach() / self._beta).clamp(max=20.0)).mean() 363 self._policy_start_training_counter -= 1 364 365 loss_dict['policy_loss'] = policy_loss 366 367 # 9. update policy network 368 self._optimizer_policy.zero_grad() 369 policy_loss.backward() 370 policy_grad_norm = torch.nn.utils.clip_grad_norm_(self._model.actor.parameters(), 1) 371 self._optimizer_policy.step() 372 373 loss_dict['total_loss'] = sum(loss_dict.values()) 374 375 # ============= 376 # after update 377 # ============= 378 self._forward_learn_cnt += 1 379 380 return { 381 'cur_lr_q': self._optimizer_q.defaults['lr'], 382 'cur_lr_p': self._optimizer_policy.defaults['lr'], 383 'priority': q_loss.abs().tolist(), 384 'q_loss': q_loss.detach().mean().item(), 385 'v_loss': v_loss.detach().mean().item(), 386 'log_prob': log_prob.detach().mean().item(), 387 'next_q_value': next_q_value.detach().mean().item(), 388 'next_v_value': next_v_value.detach().mean().item(), 389 'policy_loss': policy_loss.detach().mean().item(), 390 'total_loss': loss_dict['total_loss'].detach().item(), 391 'advantage_max': new_advantage.max().detach().item(), 392 'new_q_value': new_q_value.detach().mean().item(), 393 'new_v_value': new_v_value.detach().mean().item(), 394 'policy_grad_norm': policy_grad_norm, 395 } 396 397 def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List: 398 # evaluate to get action distribution 399 obs = data['obs'] 400 obs = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1]) 401 (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] 402 dist = Independent(Normal(mu, sigma), 1) 403 pred = dist.rsample() 404 action = torch.tanh(pred) 405 406 # evaluate action log prob depending on Jacobi determinant. 407 y = 1 - action.pow(2) + epsilon 408 log_prob = dist.log_prob(pred).unsqueeze(-1) 409 log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) 410 411 return action, log_prob.view(-1, num_actions, 1) 412 413 def _get_q_value(self, data: Dict, keep: bool = True) -> torch.Tensor: 414 new_q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 415 if self._twin_critic: 416 new_q_value = [value.view(-1, self._num_actions, 1) for value in new_q_value] 417 else: 418 new_q_value = new_q_value.view(-1, self._num_actions, 1) 419 if self._twin_critic and not keep: 420 new_q_value = torch.min(new_q_value[0], new_q_value[1]) 421 return new_q_value 422 423 def _get_v_value(self, data: Dict, keep: bool = True) -> torch.Tensor: 424 new_v_value = self._learn_model.forward(data, mode='compute_critic')['v_value'] 425 if self._twin_critic: 426 new_v_value = [value.view(-1, self._num_actions, 1) for value in new_v_value] 427 else: 428 new_v_value = new_v_value.view(-1, self._num_actions, 1) 429 if self._twin_critic and not keep: 430 new_v_value = torch.min(new_v_value[0], new_v_value[1]) 431 return new_v_value 432 433 def _init_collect(self) -> None: 434 """ 435 Overview: 436 Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ 437 collect_model other algorithm-specific arguments such as unroll_len. \ 438 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 439 440 .. note:: 441 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 442 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 443 """ 444 self._unroll_len = self._cfg.collect.unroll_len 445 self._collect_model = model_wrap(self._model, wrapper_name='base') 446 self._collect_model.reset() 447 448 def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: 449 """ 450 Overview: 451 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 452 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 453 data, such as the action to interact with the envs. 454 Arguments: 455 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 456 key of the dict is environment id and the value is the corresponding data of the env. 457 Returns: 458 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 459 other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ 460 dict is the same as the input data, i.e. environment id. 461 462 .. note:: 463 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 464 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 465 You can implement you own model rather than use the default model. For more information, please raise an \ 466 issue in GitHub repo and we will continue to follow up. 467 468 .. note:: 469 ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. 470 471 .. note:: 472 For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. 473 """ 474 data_id = list(data.keys()) 475 data = default_collate(list(data.values())) 476 if self._cuda: 477 data = to_device(data, self._device) 478 self._collect_model.eval() 479 with torch.no_grad(): 480 (mu, sigma) = self._collect_model.forward(data, mode='compute_actor')['logit'] 481 dist = Independent(Normal(mu, sigma), 1) 482 action = torch.tanh(dist.rsample()) 483 output = {'logit': (mu, sigma), 'action': action} 484 if self._cuda: 485 output = to_device(output, 'cpu') 486 output = default_decollate(output) 487 return {i: d for i, d in zip(data_id, output)} 488 489 def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], 490 timestep: namedtuple) -> Dict[str, torch.Tensor]: 491 """ 492 Overview: 493 Process and pack one timestep transition data into a dict, which can be directly used for training and \ 494 saved in replay buffer. For continuous SAC, it contains obs, next_obs, action, reward, done. The logit \ 495 will be also added when ``collector_logit`` is True. 496 Arguments: 497 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. 498 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ 499 as input. For continuous SAC, it contains the action and the logit (mu and sigma) of the action. 500 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 501 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 502 reward, done, info, etc. 503 Returns: 504 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 505 """ 506 if self._cfg.collect.collector_logit: 507 transition = { 508 'obs': obs, 509 'next_obs': timestep.obs, 510 'logit': policy_output['logit'], 511 'action': policy_output['action'], 512 'reward': timestep.reward, 513 'done': timestep.done, 514 } 515 else: 516 transition = { 517 'obs': obs, 518 'next_obs': timestep.obs, 519 'action': policy_output['action'], 520 'reward': timestep.reward, 521 'done': timestep.done, 522 } 523 return transition 524 525 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 526 """ 527 Overview: 528 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ 529 can be used for training directly. In continuous SAC, a train sample is a processed transition \ 530 (unroll_len=1). 531 Arguments: 532 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ 533 the same format as the return value of ``self._process_transition`` method. 534 Returns: 535 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ 536 as input transitions, but may contain more data for training. 537 """ 538 return get_train_sample(transitions, self._unroll_len) 539 540 def _init_eval(self) -> None: 541 """ 542 Overview: 543 Initialize the eval mode of policy, including related attributes and modules. For SAC, it contains the \ 544 eval model, which is equipped with ``base`` model wrapper to ensure compability. 545 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 546 547 .. note:: 548 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 549 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 550 """ 551 self._eval_model = model_wrap(self._model, wrapper_name='base') 552 self._eval_model.reset() 553 554 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 555 """ 556 Overview: 557 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 558 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 559 action to interact with the envs. 560 Arguments: 561 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 562 key of the dict is environment id and the value is the corresponding data of the env. 563 Returns: 564 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 565 key of the dict is the same as the input data, i.e. environment id. 566 567 .. note:: 568 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 569 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 570 You can implement you own model rather than use the default model. For more information, please raise an \ 571 issue in GitHub repo and we will continue to follow up. 572 573 .. note:: 574 ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. 575 576 .. note:: 577 For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. 578 """ 579 data_id = list(data.keys()) 580 data = default_collate(list(data.values())) 581 if self._cuda: 582 data = to_device(data, self._device) 583 self._eval_model.eval() 584 with torch.no_grad(): 585 (mu, sigma) = self._eval_model.forward(data, mode='compute_actor')['logit'] 586 action = torch.tanh(mu) / 1.05 # deterministic_eval 587 output = {'action': action} 588 if self._cuda: 589 output = to_device(output, 'cpu') 590 output = default_decollate(output) 591 return {i: d for i, d in zip(data_id, output)} 592 593 def _monitor_vars_learn(self) -> List[str]: 594 """ 595 Overview: 596 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 597 as text logger, tensorboard logger, will use these keys to save the corresponding data. 598 Returns: 599 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 600 """ 601 twin_critic = ['twin_critic_loss'] if self._twin_critic else [] 602 return [ 603 'cur_lr_q', 604 'cur_lr_p', 605 'value_loss' 606 'policy_loss', 607 'q_loss', 608 'v_loss', 609 'policy_loss', 610 'log_prob', 611 'total_loss', 612 'advantage_max', 613 'next_q_value', 614 'next_v_value', 615 'new_q_value', 616 'new_v_value', 617 'policy_grad_norm', 618 ] + twin_critic 619 620 def _state_dict_learn(self) -> Dict[str, Any]: 621 """ 622 Overview: 623 Return the state_dict of learn mode, usually including model and optimizer. 624 Returns: 625 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 626 """ 627 return { 628 'model': self._learn_model.state_dict(), 629 'optimizer_q': self._optimizer_q.state_dict(), 630 'optimizer_policy': self._optimizer_policy.state_dict(), 631 } 632 633 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 634 """ 635 Overview: 636 Load the state_dict variable into policy learn mode. 637 Arguments: 638 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 639 640 .. tip:: 641 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 642 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 643 complicated operation. 644 """ 645 self._learn_model.load_state_dict(state_dict['model']) 646 self._optimizer_q.load_state_dict(state_dict['optimizer_q']) 647 self._optimizer_policy.load_state_dict(state_dict['optimizer_policy'])