Skip to content

ding.policy.cql

ding.policy.cql

CQLPolicy

Bases: SACPolicy

Overview

Policy class of CQL algorithm for continuous control. Paper link: https://arxiv.org/abs/2006.04779.

Config

== ==================== ======== ============= ================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============= ================================= ======================= 1 type str cql | RL policy register name, refer | this arg is optional, | to registry POLICY_REGISTRY | a placeholder 2 cuda bool True | Whether to use cuda for network | 3 | random_ int 10000 | Number of randomly collected | Default to 10000 for | collect_size | training samples in replay | SAC, 25000 for DDPG/ | | buffer when training starts. | TD3. 4 | model.policy_ int 256 | Linear layer size for policy | | embedding_size | network. | 5 | model.soft_q_ int 256 | Linear layer size for soft q | | embedding_size | network. | 6 | model.value_ int 256 | Linear layer size for value | Defalut to None when | embedding_size | network. | model.value_network | | | is False. 7 | learn.learning float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when | _rate_q | network. | model.value_network | | | is True. 8 | learn.learning float 3e-4 | Learning rate for policy | Defalut to 1e-3, when | _rate_policy | network. | model.value_network | | | is True. 9 | learn.learning float 3e-4 | Learning rate for policy | Defalut to None when | _rate_value | network. | model.value_network | | | is False. 10 | learn.alpha float 0.2 | Entropy regularization | alpha is initiali- | | coefficient. | zation for auto | | | alpha, when | | | auto_alpha is True 11 | learn.repara_ bool True | Determine whether to use | | meterization | reparameterization trick. | 12 | learn. bool False | Determine whether to use | Temperature parameter | auto_alpha | auto temperature parameter | determines the | | alpha. | relative importance | | | of the entropy term | | | against the reward. 13 | learn.- bool False | Determine whether to ignore | Use ignore_done only | ignore_done | done flag. | in halfcheetah env. 14 | learn.- float 0.005 | Used for soft update of the | aka. Interpolation | target_theta | target network. | factor in polyak aver | | | aging for target | | | networks. == ==================== ======== ============= ================================= =======================

DiscreteCQLPolicy

Bases: QRDQNPolicy

Overview

Policy class of discrete CQL algorithm in discrete action space environments. Paper link: https://arxiv.org/abs/2006.04779.

Full Source Code

../ding/policy/cql.py

1from typing import List, Dict, Any, Tuple, Union 2import copy 3import numpy as np 4import torch 5import torch.nn.functional as F 6from torch.distributions import Normal, Independent 7 8from ding.torch_utils import Adam, to_device 9from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ 10 qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data 11from ding.model import model_wrap 12from ding.utils import POLICY_REGISTRY 13from ding.utils.data import default_collate, default_decollate 14from .sac import SACPolicy 15from .qrdqn import QRDQNPolicy 16from .common_utils import default_preprocess_learn 17 18 19@POLICY_REGISTRY.register('cql') 20class CQLPolicy(SACPolicy): 21 """ 22 Overview: 23 Policy class of CQL algorithm for continuous control. Paper link: https://arxiv.org/abs/2006.04779. 24 25 Config: 26 == ==================== ======== ============= ================================= ======================= 27 ID Symbol Type Default Value Description Other(Shape) 28 == ==================== ======== ============= ================================= ======================= 29 1 ``type`` str cql | RL policy register name, refer | this arg is optional, 30 | to registry ``POLICY_REGISTRY`` | a placeholder 31 2 ``cuda`` bool True | Whether to use cuda for network | 32 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for 33 | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ 34 | | buffer when training starts. | TD3. 35 4 | ``model.policy_`` int 256 | Linear layer size for policy | 36 | ``embedding_size`` | network. | 37 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | 38 | ``embedding_size`` | network. | 39 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when 40 | ``embedding_size`` | network. | model.value_network 41 | | | is False. 42 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when 43 | ``_rate_q`` | network. | model.value_network 44 | | | is True. 45 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when 46 | ``_rate_policy`` | network. | model.value_network 47 | | | is True. 48 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when 49 | ``_rate_value`` | network. | model.value_network 50 | | | is False. 51 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- 52 | | coefficient. | zation for auto 53 | | | `alpha`, when 54 | | | auto_alpha is True 55 11 | ``learn.repara_`` bool True | Determine whether to use | 56 | ``meterization`` | reparameterization trick. | 57 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter 58 | ``auto_alpha`` | auto temperature parameter | determines the 59 | | `alpha`. | relative importance 60 | | | of the entropy term 61 | | | against the reward. 62 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only 63 | ``ignore_done`` | done flag. | in halfcheetah env. 64 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation 65 | ``target_theta`` | target network. | factor in polyak aver 66 | | | aging for target 67 | | | networks. 68 == ==================== ======== ============= ================================= ======================= 69 """ 70 71 config = dict( 72 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 73 type='cql', 74 # (bool) Whether to use cuda for policy. 75 cuda=False, 76 # (bool) on_policy: Determine whether on-policy or off-policy. 77 # on-policy setting influences the behaviour of buffer. 78 on_policy=False, 79 # (bool) priority: Determine whether to use priority in buffer sample. 80 priority=False, 81 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 82 priority_IS_weight=False, 83 # (int) Number of training samples(randomly collected) in replay buffer when training starts. 84 random_collect_size=10000, 85 model=dict( 86 # (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation. 87 # Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one . 88 # Default to True. 89 twin_critic=True, 90 # (str type) action_space: Use reparameterization trick for continous action 91 action_space='reparameterization', 92 # (int) Hidden size for actor network head. 93 actor_head_hidden_size=256, 94 # (int) Hidden size for critic network head. 95 critic_head_hidden_size=256, 96 ), 97 # learn_mode config 98 learn=dict( 99 # (int) How many updates (iterations) to train after collector's one collection. 100 # Bigger "update_per_collect" means bigger off-policy. 101 update_per_collect=1, 102 # (int) Minibatch size for gradient descent. 103 batch_size=256, 104 # (float) learning_rate_q: Learning rate for soft q network. 105 learning_rate_q=3e-4, 106 # (float) learning_rate_policy: Learning rate for policy network. 107 learning_rate_policy=3e-4, 108 # (float) learning_rate_alpha: Learning rate for auto temperature parameter ``alpha``. 109 learning_rate_alpha=3e-4, 110 # (float) target_theta: Used for soft update of the target network, 111 # aka. Interpolation factor in polyak averaging for target networks. 112 target_theta=0.005, 113 # (float) discount factor for the discounted sum of rewards, aka. gamma. 114 discount_factor=0.99, 115 # (float) alpha: Entropy regularization coefficient. 116 # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. 117 # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. 118 # Default to 0.2. 119 alpha=0.2, 120 # (bool) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . 121 # Temperature parameter determines the relative importance of the entropy term against the reward. 122 # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. 123 # Default to False. 124 # Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`. 125 auto_alpha=True, 126 # (bool) log_space: Determine whether to use auto `\alpha` in log space. 127 log_space=True, 128 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 129 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 130 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 131 # However, interaction with HalfCheetah always gets done with done is False, 132 # Since we inplace done==True with done==False to keep 133 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 134 # when the episode step is greater than max episode step. 135 ignore_done=False, 136 # (float) Weight uniform initialization range in the last output layer. 137 init_w=3e-3, 138 # (int) The numbers of action sample each at every state s from a uniform-at-random. 139 num_actions=10, 140 # (bool) Whether use lagrange multiplier in q value loss. 141 with_lagrange=False, 142 # (float) The threshold for difference in Q-values. 143 lagrange_thresh=-1, 144 # (float) Loss weight for conservative item. 145 min_q_weight=1.0, 146 # (bool) Whether to use entropy in target q. 147 with_q_entropy=False, 148 ), 149 eval=dict(), # for compatibility 150 ) 151 152 def _init_learn(self) -> None: 153 """ 154 Overview: 155 Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ 156 contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange and \ 157 with_q_entropy, main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \ 158 target is also initialized here. 159 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 160 161 .. note:: 162 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 163 and ``_load_state_dict_learn`` methods. 164 165 .. note:: 166 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 167 168 .. note:: 169 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 170 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 171 """ 172 self._priority = self._cfg.priority 173 self._priority_IS_weight = self._cfg.priority_IS_weight 174 self._twin_critic = self._cfg.model.twin_critic 175 self._num_actions = self._cfg.learn.num_actions 176 177 self._min_q_version = 3 178 self._min_q_weight = self._cfg.learn.min_q_weight 179 self._with_lagrange = self._cfg.learn.with_lagrange and (self._lagrange_thresh > 0) 180 self._lagrange_thresh = self._cfg.learn.lagrange_thresh 181 if self._with_lagrange: 182 self.target_action_gap = self._lagrange_thresh 183 self.log_alpha_prime = torch.tensor(0.).to(self._device).requires_grad_() 184 self.alpha_prime_optimizer = Adam( 185 [self.log_alpha_prime], 186 lr=self._cfg.learn.learning_rate_q, 187 ) 188 189 self._with_q_entropy = self._cfg.learn.with_q_entropy 190 191 # Weight Init 192 init_w = self._cfg.learn.init_w 193 self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w) 194 self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w) 195 self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w) 196 self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w) 197 if self._twin_critic: 198 self._model.critic_head[0][-1].last.weight.data.uniform_(-init_w, init_w) 199 self._model.critic_head[0][-1].last.bias.data.uniform_(-init_w, init_w) 200 self._model.critic_head[1][-1].last.weight.data.uniform_(-init_w, init_w) 201 self._model.critic_head[1][-1].last.bias.data.uniform_(-init_w, init_w) 202 else: 203 self._model.critic_head[2].last.weight.data.uniform_(-init_w, init_w) 204 self._model.critic_head[-1].last.bias.data.uniform_(-init_w, init_w) 205 206 # Optimizers 207 self._optimizer_q = Adam( 208 self._model.critic.parameters(), 209 lr=self._cfg.learn.learning_rate_q, 210 ) 211 self._optimizer_policy = Adam( 212 self._model.actor.parameters(), 213 lr=self._cfg.learn.learning_rate_policy, 214 ) 215 216 # Algorithm config 217 self._gamma = self._cfg.learn.discount_factor 218 # Init auto alpha 219 if self._cfg.learn.auto_alpha: 220 if self._cfg.learn.target_entropy is None: 221 assert 'action_shape' in self._cfg.model, "CQL need network model with action_shape variable" 222 self._target_entropy = -np.prod(self._cfg.model.action_shape) 223 else: 224 self._target_entropy = self._cfg.learn.target_entropy 225 if self._cfg.learn.log_space: 226 self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha])) 227 self._log_alpha = self._log_alpha.to(self._device).requires_grad_() 228 self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha) 229 assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad 230 self._alpha = self._log_alpha.detach().exp() 231 self._auto_alpha = True 232 self._log_space = True 233 else: 234 self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_() 235 self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha) 236 self._auto_alpha = True 237 self._log_space = False 238 else: 239 self._alpha = torch.tensor( 240 [self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32 241 ) 242 self._auto_alpha = False 243 244 # Main and target models 245 self._target_model = copy.deepcopy(self._model) 246 self._target_model = model_wrap( 247 self._target_model, 248 wrapper_name='target', 249 update_type='momentum', 250 update_kwargs={'theta': self._cfg.learn.target_theta} 251 ) 252 self._learn_model = model_wrap(self._model, wrapper_name='base') 253 self._learn_model.reset() 254 self._target_model.reset() 255 256 self._forward_learn_cnt = 0 257 258 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 259 """ 260 Overview: 261 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 262 that the policy inputs some training batch data from the offline dataset and then returns the output \ 263 result, including various training information such as loss, action, priority. 264 Arguments: 265 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 266 training samples. For each element in list, the key of the dict is the name of data items and the \ 267 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 268 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 269 dimension by some utility functions such as ``default_preprocess_learn``. \ 270 For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 271 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. 272 Returns: 273 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 274 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 275 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 276 277 .. note:: 278 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 279 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 280 You can implement you own model rather than use the default model. For more information, please raise an \ 281 issue in GitHub repo and we will continue to follow up. 282 """ 283 loss_dict = {} 284 data = default_preprocess_learn( 285 data, 286 use_priority=self._priority, 287 use_priority_IS_weight=self._cfg.priority_IS_weight, 288 ignore_done=self._cfg.learn.ignore_done, 289 use_nstep=False 290 ) 291 if len(data.get('action').shape) == 1: 292 data['action'] = data['action'].reshape(-1, 1) 293 294 if self._cuda: 295 data = to_device(data, self._device) 296 297 self._learn_model.train() 298 self._target_model.train() 299 obs = data['obs'] 300 next_obs = data['next_obs'] 301 reward = data['reward'] 302 done = data['done'] 303 304 # 1. predict q value 305 q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 306 307 # 2. predict target value 308 with torch.no_grad(): 309 (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] 310 311 dist = Independent(Normal(mu, sigma), 1) 312 pred = dist.rsample() 313 next_action = torch.tanh(pred) 314 y = 1 - next_action.pow(2) + 1e-6 315 next_log_prob = dist.log_prob(pred).unsqueeze(-1) 316 next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) 317 318 next_data = {'obs': next_obs, 'action': next_action} 319 target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] 320 # the value of a policy according to the maximum entropy objective 321 if self._twin_critic: 322 # find min one as target q value 323 if self._with_q_entropy: 324 target_q_value = torch.min(target_q_value[0], 325 target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1) 326 else: 327 target_q_value = torch.min(target_q_value[0], target_q_value[1]) 328 else: 329 if self._with_q_entropy: 330 target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1) 331 332 # 3. compute q loss 333 if self._twin_critic: 334 q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight']) 335 loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma) 336 q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight']) 337 loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma) 338 td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2 339 else: 340 q_data = v_1step_td_data(q_value, target_q_value, reward, done, data['weight']) 341 loss_dict['critic_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma) 342 343 # 4. add CQL 344 345 curr_actions_tensor, curr_log_pis = self._get_policy_actions(data, self._num_actions) 346 new_curr_actions_tensor, new_log_pis = self._get_policy_actions({'obs': next_obs}, self._num_actions) 347 348 random_actions_tensor = torch.FloatTensor(curr_actions_tensor.shape).uniform_(-1, 349 1).to(curr_actions_tensor.device) 350 351 obs_repeat = obs.unsqueeze(1).repeat(1, self._num_actions, 352 1).view(obs.shape[0] * self._num_actions, obs.shape[1]) 353 act_repeat = data['action'].unsqueeze(1).repeat(1, self._num_actions, 1).view( 354 data['action'].shape[0] * self._num_actions, data['action'].shape[1] 355 ) 356 357 q_rand = self._get_q_value({'obs': obs_repeat, 'action': random_actions_tensor}) 358 # q2_rand = self._get_q_value(obs, random_actions_tensor, network=self.qf2) 359 q_curr_actions = self._get_q_value({'obs': obs_repeat, 'action': curr_actions_tensor}) 360 # q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2) 361 q_next_actions = self._get_q_value({'obs': obs_repeat, 'action': new_curr_actions_tensor}) 362 # q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2) 363 364 cat_q1 = torch.cat([q_rand[0], q_value[0].reshape(-1, 1, 1), q_next_actions[0], q_curr_actions[0]], 1) 365 cat_q2 = torch.cat([q_rand[1], q_value[1].reshape(-1, 1, 1), q_next_actions[1], q_curr_actions[1]], 1) 366 std_q1 = torch.std(cat_q1, dim=1) 367 std_q2 = torch.std(cat_q2, dim=1) 368 if self._min_q_version == 3: 369 # importance sampled version 370 random_density = np.log(0.5 ** curr_actions_tensor.shape[-1]) 371 cat_q1 = torch.cat( 372 [ 373 q_rand[0] - random_density, q_next_actions[0] - new_log_pis.detach(), 374 q_curr_actions[0] - curr_log_pis.detach() 375 ], 1 376 ) 377 cat_q2 = torch.cat( 378 [ 379 q_rand[1] - random_density, q_next_actions[1] - new_log_pis.detach(), 380 q_curr_actions[1] - curr_log_pis.detach() 381 ], 1 382 ) 383 384 min_qf1_loss = torch.logsumexp(cat_q1, dim=1).mean() * self._min_q_weight 385 min_qf2_loss = torch.logsumexp(cat_q2, dim=1).mean() * self._min_q_weight 386 """Subtract the log likelihood of data""" 387 min_qf1_loss = min_qf1_loss - q_value[0].mean() * self._min_q_weight 388 min_qf2_loss = min_qf2_loss - q_value[1].mean() * self._min_q_weight 389 390 if self._with_lagrange: 391 alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0) 392 min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap) 393 min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap) 394 395 self.alpha_prime_optimizer.zero_grad() 396 alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5 397 alpha_prime_loss.backward(retain_graph=True) 398 self.alpha_prime_optimizer.step() 399 400 loss_dict['critic_loss'] += min_qf1_loss 401 if self._twin_critic: 402 loss_dict['twin_critic_loss'] += min_qf2_loss 403 404 # 5. update q network 405 self._optimizer_q.zero_grad() 406 loss_dict['critic_loss'].backward(retain_graph=True) 407 if self._twin_critic: 408 loss_dict['twin_critic_loss'].backward() 409 self._optimizer_q.step() 410 411 # 6. evaluate to get action distribution 412 (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] 413 dist = Independent(Normal(mu, sigma), 1) 414 pred = dist.rsample() 415 action = torch.tanh(pred) 416 y = 1 - action.pow(2) + 1e-6 417 log_prob = dist.log_prob(pred).unsqueeze(-1) 418 log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) 419 420 eval_data = {'obs': obs, 'action': action} 421 new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value'] 422 if self._twin_critic: 423 new_q_value = torch.min(new_q_value[0], new_q_value[1]) 424 425 # 8. compute policy loss 426 policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean() 427 428 loss_dict['policy_loss'] = policy_loss 429 430 # 9. update policy network 431 self._optimizer_policy.zero_grad() 432 loss_dict['policy_loss'].backward() 433 self._optimizer_policy.step() 434 435 # 10. compute alpha loss 436 if self._auto_alpha: 437 if self._log_space: 438 log_prob = log_prob + self._target_entropy 439 loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean() 440 441 self._alpha_optim.zero_grad() 442 loss_dict['alpha_loss'].backward() 443 self._alpha_optim.step() 444 self._alpha = self._log_alpha.detach().exp() 445 else: 446 log_prob = log_prob + self._target_entropy 447 loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean() 448 449 self._alpha_optim.zero_grad() 450 loss_dict['alpha_loss'].backward() 451 self._alpha_optim.step() 452 self._alpha = max(0, self._alpha) 453 454 loss_dict['total_loss'] = sum(loss_dict.values()) 455 456 # ============= 457 # after update 458 # ============= 459 self._forward_learn_cnt += 1 460 # target update 461 self._target_model.update(self._learn_model.state_dict()) 462 return { 463 'cur_lr_q': self._optimizer_q.defaults['lr'], 464 'cur_lr_p': self._optimizer_policy.defaults['lr'], 465 'priority': td_error_per_sample.abs().tolist(), 466 'td_error': td_error_per_sample.detach().mean().item(), 467 'alpha': self._alpha.item(), 468 'target_q_value': target_q_value.detach().mean().item(), 469 **loss_dict 470 } 471 472 def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List: 473 # evaluate to get action distribution 474 obs = data['obs'] 475 obs = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1]) 476 (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] 477 dist = Independent(Normal(mu, sigma), 1) 478 pred = dist.rsample() 479 action = torch.tanh(pred) 480 481 # evaluate action log prob depending on Jacobi determinant. 482 y = 1 - action.pow(2) + epsilon 483 log_prob = dist.log_prob(pred).unsqueeze(-1) 484 log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) 485 486 return action, log_prob.view(-1, num_actions, 1) 487 488 def _get_q_value(self, data: Dict, keep: bool = True) -> torch.Tensor: 489 new_q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 490 if self._twin_critic: 491 new_q_value = [value.view(-1, self._num_actions, 1) for value in new_q_value] 492 else: 493 new_q_value = new_q_value.view(-1, self._num_actions, 1) 494 if self._twin_critic and not keep: 495 new_q_value = torch.min(new_q_value[0], new_q_value[1]) 496 return new_q_value 497 498 499@POLICY_REGISTRY.register('discrete_cql') 500class DiscreteCQLPolicy(QRDQNPolicy): 501 """ 502 Overview: 503 Policy class of discrete CQL algorithm in discrete action space environments. 504 Paper link: https://arxiv.org/abs/2006.04779. 505 """ 506 507 config = dict( 508 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 509 type='discrete_cql', 510 # (bool) Whether to use cuda for policy. 511 cuda=False, 512 # (bool) Whether the RL algorithm is on-policy or off-policy. 513 on_policy=False, 514 # (bool) Whether use priority(priority sample, IS weight, update priority) 515 priority=False, 516 # (float) Reward's future discount factor, aka. gamma. 517 discount_factor=0.97, 518 # (int) N-step reward for target q_value estimation 519 nstep=1, 520 # learn_mode config 521 learn=dict( 522 # (int) How many updates (iterations) to train after collector's one collection. 523 # Bigger "update_per_collect" means bigger off-policy. 524 update_per_collect=1, 525 # (int) Minibatch size for one gradient descent. 526 batch_size=64, 527 # (float) Learning rate for soft q network. 528 learning_rate=0.001, 529 # (int) Frequence of target network update. 530 target_update_freq=100, 531 # (bool) Whether ignore done(usually for max step termination env). 532 ignore_done=False, 533 # (float) Loss weight for conservative item. 534 min_q_weight=1.0, 535 ), 536 eval=dict(), # for compatibility 537 ) 538 539 def _init_learn(self) -> None: 540 """ 541 Overview: 542 Initialize the learn mode of policy, including related attributes and modules. For DiscreteCQL, it mainly \ 543 contains the optimizer, algorithm-specific arguments such as gamma, nstep and min_q_weight, main and \ 544 target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 545 546 .. note:: 547 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 548 and ``_load_state_dict_learn`` methods. 549 550 .. note:: 551 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 552 553 .. note:: 554 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 555 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 556 """ 557 self._min_q_weight = self._cfg.learn.min_q_weight 558 self._priority = self._cfg.priority 559 # Optimizer 560 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 561 562 self._gamma = self._cfg.discount_factor 563 self._nstep = self._cfg.nstep 564 565 # use wrapper instead of plugin 566 self._target_model = copy.deepcopy(self._model) 567 self._target_model = model_wrap( 568 self._target_model, 569 wrapper_name='target', 570 update_type='assign', 571 update_kwargs={'freq': self._cfg.learn.target_update_freq} 572 ) 573 self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') 574 self._learn_model.reset() 575 self._target_model.reset() 576 577 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 578 """ 579 Overview: 580 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 581 that the policy inputs some training batch data from the offline dataset and then returns the output \ 582 result, including various training information such as loss, action, priority. 583 Arguments: 584 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 585 training samples. For each element in list, the key of the dict is the name of data items and the \ 586 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 587 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 588 dimension by some utility functions such as ``default_preprocess_learn``. \ 589 For DiscreteCQL, each element in list is a dict containing at least the following keys: ``obs``, \ 590 ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight`` \ 591 and ``value_gamma`` for nstep return computation. 592 Returns: 593 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 594 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 595 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 596 597 .. note:: 598 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 599 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 600 You can implement you own model rather than use the default model. For more information, please raise an \ 601 issue in GitHub repo and we will continue to follow up. 602 """ 603 data = default_preprocess_learn( 604 data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True 605 ) 606 if self._cuda: 607 data = to_device(data, self._device) 608 if data['action'].dim() == 2 and data['action'].shape[-1] == 1: 609 data['action'] = data['action'].squeeze(-1) 610 # ==================== 611 # Q-learning forward 612 # ==================== 613 self._learn_model.train() 614 self._target_model.train() 615 # Current q value (main model) 616 ret = self._learn_model.forward(data['obs']) 617 q_value, tau = ret['q'], ret['tau'] 618 # Target q value 619 with torch.no_grad(): 620 target_q_value = self._target_model.forward(data['next_obs'])['q'] 621 # Max q value action (main model) 622 target_q_action = self._learn_model.forward(data['next_obs'])['action'] 623 624 # add CQL 625 # 1. chose action and compute q in dataset. 626 # 2. compute value loss(negative_sampling - dataset_expec) 627 replay_action_one_hot = F.one_hot(data['action'], self._cfg.model.action_shape) 628 replay_chosen_q = (q_value.mean(-1) * replay_action_one_hot).sum(dim=1) 629 630 dataset_expec = replay_chosen_q.mean() 631 632 negative_sampling = torch.logsumexp(q_value.mean(-1), dim=1).mean() 633 634 min_q_loss = negative_sampling - dataset_expec 635 636 data_n = qrdqn_nstep_td_data( 637 q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], tau, data['weight'] 638 ) 639 value_gamma = data.get('value_gamma') 640 loss, td_error_per_sample = qrdqn_nstep_td_error( 641 data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma 642 ) 643 644 loss += self._min_q_weight * min_q_loss 645 646 # ==================== 647 # Q-learning update 648 # ==================== 649 self._optimizer.zero_grad() 650 loss.backward() 651 if self._cfg.multi_gpu: 652 self.sync_gradients(self._learn_model) 653 self._optimizer.step() 654 655 # ============= 656 # after update 657 # ============= 658 self._target_model.update(self._learn_model.state_dict()) 659 return { 660 'cur_lr': self._optimizer.defaults['lr'], 661 'total_loss': loss.item(), 662 'priority': td_error_per_sample.abs().tolist(), 663 'q_target': target_q_value.mean().item(), 664 'q_value': q_value.mean().item(), 665 # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. 666 # '[histogram]action_distribution': data['action'], 667 } 668 669 def _monitor_vars_learn(self) -> List[str]: 670 """ 671 Overview: 672 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 673 as text logger, tensorboard logger, will use these keys to save the corresponding data. 674 Returns: 675 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 676 """ 677 return ['cur_lr', 'total_loss', 'q_target', 'q_value']