1from typing import List, Dict, Any, Tuple, Union 2import copy 3import numpy as np 4import torch 5import torch.nn as nn 6from torch.distributions import Normal, Independent 7 8from ding.torch_utils import Adam, to_device 9from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ 10 qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data 11from ding.model import model_wrap 12from ding.utils import POLICY_REGISTRY 13from ding.utils.data import default_collate, default_decollate 14from .sac import SACPolicy 15from .dqn import DQNPolicy 16from .common_utils import default_preprocess_learn 17 18 19@POLICY_REGISTRY.register('edac') 20class EDACPolicy(SACPolicy): 21 """ 22 Overview: 23 Policy class of EDAC algorithm. Paper link: https://arxiv.org/pdf/2110.01548.pdf 24 25 Config: 26 == ==================== ======== ============= ================================= ======================= 27 ID Symbol Type Default Value Description Other(Shape) 28 == ==================== ======== ============= ================================= ======================= 29 1 ``type`` str td3 | RL policy register name, refer | this arg is optional, 30 | to registry ``POLICY_REGISTRY`` | a placeholder 31 2 ``cuda`` bool True | Whether to use cuda for network | 32 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for 33 | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ 34 | | buffer when training starts. | TD3. 35 4 | ``model.policy_`` int 256 | Linear layer size for policy | 36 | ``embedding_size`` | network. | 37 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | 38 | ``embedding_size`` | network. | 39 6 | ``model.emsemble`` int 10 | Number of Q-ensemble network | 40 | ``_num`` | | 41 | | | is False. 42 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when 43 | ``_rate_q`` | network. | model.value_network 44 | | | is True. 45 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when 46 | ``_rate_policy`` | network. | model.value_network 47 | | | is True. 48 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when 49 | ``_rate_value`` | network. | model.value_network 50 | | | is False. 51 10 | ``learn.alpha`` float 1.0 | Entropy regularization | alpha is initiali- 52 | | coefficient. | zation for auto 53 | | | `alpha`, when 54 | | | auto_alpha is True 55 11 | ``learn.eta`` bool True | Parameter of EDAC algorithm | Defalut to 1.0 56 12 | ``learn.`` bool True | Determine whether to use | Temperature parameter 57 | ``auto_alpha`` | auto temperature parameter | determines the 58 | | `alpha`. | relative importance 59 | | | of the entropy term 60 | | | against the reward. 61 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only 62 | ``ignore_done`` | done flag. | in halfcheetah env. 63 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation 64 | ``target_theta`` | target network. | factor in polyak aver 65 | | | aging for target 66 | | | networks. 67 == ==================== ======== ============= ================================= ======================= 68 """ 69 config = dict( 70 # (str) RL policy register name 71 type='edac', 72 cuda=False, 73 on_policy=False, 74 multi_agent=False, 75 priority=False, 76 priority_IS_weight=False, 77 random_collect_size=10000, 78 model=dict( 79 # (bool type) ensemble_num:num of Q-network. 80 ensemble_num=10, 81 # (bool type) value_network: Determine whether to use value network as the 82 # original EDAC paper (arXiv 2110.01548). 83 # using value_network needs to set learning_rate_value, learning_rate_q, 84 # and learning_rate_policy in `cfg.policy.learn`. 85 # Default to False. 86 # value_network=False, 87 88 # (int) Hidden size for actor network head. 89 actor_head_hidden_size=256, 90 91 # (int) Hidden size for critic network head. 92 critic_head_hidden_size=256, 93 ), 94 learn=dict( 95 multi_gpu=False, 96 update_per_collect=1, 97 batch_size=256, 98 learning_rate_q=3e-4, 99 learning_rate_policy=3e-4, 100 learning_rate_value=3e-4, 101 learning_rate_alpha=3e-4, 102 target_theta=0.005, 103 discount_factor=0.99, 104 alpha=1, 105 auto_alpha=True, 106 # (bool type) log_space: Determine whether to use auto `\alpha` in log space. 107 log_space=True, 108 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 109 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 110 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 111 # However, interaction with HalfCheetah always gets done with done is False, 112 # Since we inplace done==True with done==False to keep 113 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 114 # when the episode step is greater than max episode step. 115 ignore_done=False, 116 # (float) Weight uniform initialization range in the last output layer 117 init_w=3e-3, 118 # (float) Loss weight for conservative item. 119 min_q_weight=1.0, 120 # (bool) Whether to use entropy in target q. 121 with_q_entropy=False, 122 eta=0.1, 123 ), 124 collect=dict( 125 # (int) Cut trajectories into pieces with length "unroll_len". 126 unroll_len=1, 127 ), 128 eval=dict(), 129 other=dict( 130 replay_buffer=dict( 131 # (int type) replay_buffer_size: Max size of replay buffer. 132 replay_buffer_size=1000000, 133 # (int type) max_use: Max use times of one data in the buffer. 134 # Data will be removed once used for too many times. 135 # Default to infinite. 136 # max_use=256, 137 ), 138 ), 139 ) 140 141 def default_model(self) -> Tuple[str, List[str]]: 142 """ 143 Overview: 144 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 145 automatically call this method to get the default model setting and create model. 146 Returns: 147 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 148 """ 149 return 'edac', ['ding.model.template.edac'] 150 151 def _init_learn(self) -> None: 152 """ 153 Overview: 154 Initialize the learn mode of policy, including related attributes and modules. For EDAC, in addition \ 155 to the things that need to be initialized in SAC, it is also necessary to additionally define \ 156 eta/with_q_entropy/forward_learn_cnt. \ 157 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 158 159 .. note:: 160 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 161 and ``_load_state_dict_learn`` methods. 162 163 .. note:: 164 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 165 166 .. note:: 167 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 168 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 169 """ 170 super()._init_learn() 171 # EDAC special implementation 172 self._eta = self._cfg.learn.eta 173 self._with_q_entropy = self._cfg.learn.with_q_entropy 174 self._forward_learn_cnt = 0 175 176 def _forward_learn(self, data: List[Dict[int, Any]]) -> Dict[str, Any]: 177 """ 178 Overview: 179 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 180 that the policy inputs some training batch data from the replay buffer and then returns the output \ 181 result, including various training information such as loss, action, priority. 182 Arguments: 183 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 184 training samples. For each element in list, the key of the dict is the name of data items and the \ 185 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 186 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 187 dimension by some utility functions such as ``default_preprocess_learn``. \ 188 For EDAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 189 ``logit``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight``. 190 Returns: 191 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 192 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 193 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 194 195 .. note:: 196 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 197 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 198 You can implement you own model rather than use the default model. For more information, please raise an \ 199 issue in GitHub repo and we will continue to follow up. 200 201 .. note:: 202 For more detailed examples, please refer to our unittest for EDACPolicy: \ 203 ``ding.policy.tests.test_edac``. 204 """ 205 loss_dict = {} 206 data = default_preprocess_learn( 207 data, 208 use_priority=self._priority, 209 use_priority_IS_weight=self._cfg.priority_IS_weight, 210 ignore_done=self._cfg.learn.ignore_done, 211 use_nstep=False 212 ) 213 if len(data.get('action').shape) == 1: 214 data['action'] = data['action'].reshape(-1, 1) 215 216 if self._cuda: 217 data = to_device(data, self._device) 218 219 self._learn_model.train() 220 self._target_model.train() 221 obs = data['obs'] 222 next_obs = data['next_obs'] 223 reward = data['reward'] 224 done = data['done'] 225 acs = data['action'] 226 227 # 1. predict q value 228 q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 229 with torch.no_grad(): 230 (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] 231 232 dist = Independent(Normal(mu, sigma), 1) 233 pred = dist.rsample() 234 next_action = torch.tanh(pred) 235 y = 1 - next_action.pow(2) + 1e-6 236 next_log_prob = dist.log_prob(pred).unsqueeze(-1) 237 next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) 238 239 next_data = {'obs': next_obs, 'action': next_action} 240 target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] 241 # the value of a policy according to the maximum entropy objective 242 243 target_q_value, _ = torch.min(target_q_value, dim=0) 244 if self._with_q_entropy: 245 target_q_value -= self._alpha * next_log_prob.squeeze(-1) 246 target_q_value = self._gamma * (1 - done) * target_q_value + reward 247 248 weight = data['weight'] 249 if weight is None: 250 weight = torch.ones_like(q_value) 251 td_error_per_sample = nn.MSELoss(reduction='none')(q_value, target_q_value).mean(dim=1).sum() 252 loss_dict['critic_loss'] = (td_error_per_sample * weight).mean() 253 254 # penalty term of EDAC 255 if self._eta > 0: 256 # [batch_size,dim] -> [Ensemble_num,batch_size,dim] 257 pre_obs = obs.unsqueeze(0).repeat_interleave(self._cfg.model.ensemble_num, dim=0) 258 pre_acs = acs.unsqueeze(0).repeat_interleave(self._cfg.model.ensemble_num, dim=0).requires_grad_(True) 259 260 # [Ensemble_num,batch_size] 261 q_pred_tile = self._learn_model.forward({ 262 'obs': pre_obs, 263 'action': pre_acs 264 }, mode='compute_critic')['q_value'].requires_grad_(True) 265 266 q_pred_grads = torch.autograd.grad(q_pred_tile.sum(), pre_acs, retain_graph=True, create_graph=True)[0] 267 q_pred_grads = q_pred_grads / (torch.norm(q_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10) 268 # [Ensemble_num,batch_size,act_dim] -> [batch_size,Ensemble_num,act_dim] 269 q_pred_grads = q_pred_grads.transpose(0, 1) 270 271 q_pred_grads = q_pred_grads @ q_pred_grads.permute(0, 2, 1) 272 masks = torch.eye( 273 self._cfg.model.ensemble_num, device=obs.device 274 ).unsqueeze(dim=0).repeat(q_pred_grads.size(0), 1, 1) 275 q_pred_grads = (1 - masks) * q_pred_grads 276 grad_loss = torch.mean(torch.sum(q_pred_grads, dim=(1, 2))) / (self._cfg.model.ensemble_num - 1) 277 loss_dict['critic_loss'] += grad_loss * self._eta 278 279 self._optimizer_q.zero_grad() 280 loss_dict['critic_loss'].backward() 281 self._optimizer_q.step() 282 283 (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] 284 dist = Independent(Normal(mu, sigma), 1) 285 pred = dist.rsample() 286 action = torch.tanh(pred) 287 y = 1 - action.pow(2) + 1e-6 288 log_prob = dist.log_prob(pred).unsqueeze(-1) 289 log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) 290 291 eval_data = {'obs': obs, 'action': action} 292 new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value'] 293 new_q_value, _ = torch.min(new_q_value, dim=0) 294 295 # 8. compute policy loss 296 policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean() 297 298 loss_dict['policy_loss'] = policy_loss 299 300 # 9. update policy network 301 self._optimizer_policy.zero_grad() 302 loss_dict['policy_loss'].backward() 303 self._optimizer_policy.step() 304 305 # 10. compute alpha loss 306 if self._auto_alpha: 307 if self._log_space: 308 log_prob = log_prob + self._target_entropy 309 loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean() 310 311 self._alpha_optim.zero_grad() 312 loss_dict['alpha_loss'].backward() 313 self._alpha_optim.step() 314 self._alpha = self._log_alpha.detach().exp() 315 else: 316 log_prob = log_prob + self._target_entropy 317 loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean() 318 319 self._alpha_optim.zero_grad() 320 loss_dict['alpha_loss'].backward() 321 self._alpha_optim.step() 322 self._alpha = max(0, self._alpha) 323 324 loss_dict['total_loss'] = sum(loss_dict.values()) 325 326 # ============= 327 # after update 328 # ============= 329 self._forward_learn_cnt += 1 330 # target update 331 self._target_model.update(self._learn_model.state_dict()) 332 return { 333 'cur_lr_q': self._optimizer_q.defaults['lr'], 334 'cur_lr_p': self._optimizer_policy.defaults['lr'], 335 'priority': td_error_per_sample.abs().tolist(), 336 'td_error': td_error_per_sample.detach().mean().item(), 337 'alpha': self._alpha.item(), 338 'target_q_value': target_q_value.detach().mean().item(), 339 **loss_dict 340 }