1from typing import Dict, Any, List 2from functools import partial 3 4import torch 5from torch import Tensor 6from torch import nn 7from torch.distributions import Normal, Independent 8 9from ding.torch_utils import to_device, fold_batch, unfold_batch, unsqueeze_repeat 10from ding.utils import POLICY_REGISTRY 11from ding.policy import SACPolicy 12from ding.rl_utils import generalized_lambda_returns 13from ding.policy.common_utils import default_preprocess_learn 14 15from .utils import q_evaluation 16 17 18@POLICY_REGISTRY.register('mbsac') 19class MBSACPolicy(SACPolicy): 20 """ 21 Overview: 22 Model based SAC with value expansion (arXiv: 1803.00101) 23 and value gradient (arXiv: 1510.09142) w.r.t lambda-return. 24 25 https://arxiv.org/pdf/1803.00101.pdf 26 https://arxiv.org/pdf/1510.09142.pdf 27 28 Config: 29 == ==================== ======== ============= ================================== 30 ID Symbol Type Default Value Description 31 == ==================== ======== ============= ================================== 32 1 ``learn._lambda`` float 0.8 | Lambda for TD-lambda return. 33 2 ``learn.grad_clip` float 100.0 | Max norm of gradients. 34 3 | ``learn.sample`` bool True | Whether to sample states or 35 | ``_state`` | transitions from env buffer. 36 == ==================== ======== ============= ================================== 37 38 .. note:: 39 For other configs, please refer to ding.policy.sac.SACPolicy. 40 """ 41 42 config = dict( 43 learn=dict( 44 # (float) Lambda for TD-lambda return. 45 lambda_=0.8, 46 # (float) Max norm of gradients. 47 grad_clip=100, 48 # (bool) Whether to sample states or transitions from environment buffer. 49 sample_state=True, 50 ) 51 ) 52 53 def _init_learn(self) -> None: 54 super()._init_learn() 55 self._target_model.requires_grad_(False) 56 57 self._lambda = self._cfg.learn.lambda_ 58 self._grad_clip = self._cfg.learn.grad_clip 59 self._sample_state = self._cfg.learn.sample_state 60 self._auto_alpha = self._cfg.learn.auto_alpha 61 # TODO: auto alpha 62 assert not self._auto_alpha, "NotImplemented" 63 64 # TODO: TanhTransform leads to NaN 65 def actor_fn(obs: Tensor): 66 # (mu, sigma) = self._learn_model.forward( 67 # obs, mode='compute_actor')['logit'] 68 # # enforce action bounds 69 # dist = TransformedDistribution( 70 # Independent(Normal(mu, sigma), 1), [TanhTransform()]) 71 # action = dist.rsample() 72 # log_prob = dist.log_prob(action) 73 # return action, -self._alpha.detach() * log_prob 74 (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] 75 dist = Independent(Normal(mu, sigma), 1) 76 pred = dist.rsample() 77 action = torch.tanh(pred) 78 79 log_prob = dist.log_prob( 80 pred 81 ) + 2 * (pred + torch.nn.functional.softplus(-2. * pred) - torch.log(torch.tensor(2.))).sum(-1) 82 return action, -self._alpha.detach() * log_prob 83 84 self._actor_fn = actor_fn 85 86 def critic_fn(obss: Tensor, actions: Tensor, model: nn.Module): 87 eval_data = {'obs': obss, 'action': actions} 88 q_values = model.forward(eval_data, mode='compute_critic')['q_value'] 89 return q_values 90 91 self._critic_fn = critic_fn 92 self._forward_learn_cnt = 0 93 94 def _forward_learn(self, data: dict, world_model, envstep) -> Dict[str, Any]: 95 # preprocess data 96 data = default_preprocess_learn( 97 data, 98 use_priority=self._priority, 99 use_priority_IS_weight=self._cfg.priority_IS_weight, 100 ignore_done=self._cfg.learn.ignore_done, 101 use_nstep=False 102 ) 103 if self._cuda: 104 data = to_device(data, self._device) 105 106 if len(data['action'].shape) == 1: 107 data['action'] = data['action'].unsqueeze(1) 108 109 self._learn_model.train() 110 self._target_model.train() 111 112 # TODO: use treetensor 113 # rollout length is determined by world_model.rollout_length_scheduler 114 if self._sample_state: 115 # data['reward'], ... are not used 116 obss, actions, rewards, aug_rewards, dones = \ 117 world_model.rollout(data['obs'], self._actor_fn, envstep) 118 else: 119 obss, actions, rewards, aug_rewards, dones = \ 120 world_model.rollout(data['next_obs'], self._actor_fn, envstep) 121 obss = torch.cat([data['obs'].unsqueeze(0), obss]) 122 actions = torch.cat([data['action'].unsqueeze(0), actions]) 123 rewards = torch.cat([data['reward'].unsqueeze(0), rewards]) 124 aug_rewards = torch.cat([torch.zeros_like(data['reward']).unsqueeze(0), aug_rewards]) 125 dones = torch.cat([data['done'].unsqueeze(0), dones]) 126 127 dones = torch.cat([torch.zeros_like(data['done']).unsqueeze(0), dones]) 128 129 # (T+1, B) 130 target_q_values = q_evaluation(obss, actions, partial(self._critic_fn, model=self._target_model)) 131 if self._twin_critic: 132 target_q_values = torch.min(target_q_values[0], target_q_values[1]) + aug_rewards 133 else: 134 target_q_values = target_q_values + aug_rewards 135 136 # (T, B) 137 lambda_return = generalized_lambda_returns(target_q_values, rewards, self._gamma, self._lambda, dones[1:]) 138 139 # (T, B) 140 # If S_t terminates, we should not consider loss from t+1,... 141 weight = (1 - dones[:-1].detach()).cumprod(dim=0) 142 143 # (T+1, B) 144 q_values = q_evaluation(obss.detach(), actions.detach(), partial(self._critic_fn, model=self._learn_model)) 145 if self._twin_critic: 146 critic_loss = 0.5 * torch.square(q_values[0][:-1] - lambda_return.detach()) \ 147 + 0.5 * torch.square(q_values[1][:-1] - lambda_return.detach()) 148 else: 149 critic_loss = 0.5 * torch.square(q_values[:-1] - lambda_return.detach()) 150 151 # value expansion loss 152 critic_loss = (critic_loss * weight).mean() 153 154 # value gradient loss 155 policy_loss = -(lambda_return * weight).mean() 156 157 # alpha_loss = None 158 159 loss_dict = { 160 'critic_loss': critic_loss, 161 'policy_loss': policy_loss, 162 # 'alpha_loss': alpha_loss.detach(), 163 } 164 165 norm_dict = self._update(loss_dict) 166 167 # ============= 168 # after update 169 # ============= 170 self._forward_learn_cnt += 1 171 # target update 172 self._target_model.update(self._learn_model.state_dict()) 173 174 return { 175 'cur_lr_q': self._optimizer_q.defaults['lr'], 176 'cur_lr_p': self._optimizer_policy.defaults['lr'], 177 'alpha': self._alpha.item(), 178 'target_q_value': target_q_values.detach().mean().item(), 179 **norm_dict, 180 **loss_dict, 181 } 182 183 def _update(self, loss_dict): 184 # update critic 185 self._optimizer_q.zero_grad() 186 loss_dict['critic_loss'].backward() 187 critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip) 188 self._optimizer_q.step() 189 # update policy 190 self._optimizer_policy.zero_grad() 191 loss_dict['policy_loss'].backward() 192 policy_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip) 193 self._optimizer_policy.step() 194 # update temperature 195 # self._alpha_optim.zero_grad() 196 # loss_dict['alpha_loss'].backward() 197 # self._alpha_optim.step() 198 return {'policy_norm': policy_norm, 'critic_norm': critic_norm} 199 200 def _monitor_vars_learn(self) -> List[str]: 201 r""" 202 Overview: 203 Return variables' name if variables are to used in monitor. 204 Returns: 205 - vars (:obj:`List[str]`): Variables' name list. 206 """ 207 alpha_loss = ['alpha_loss'] if self._auto_alpha else [] 208 return [ 209 'policy_loss', 210 'critic_loss', 211 'policy_norm', 212 'critic_norm', 213 'cur_lr_q', 214 'cur_lr_p', 215 'alpha', 216 'target_q_value', 217 ] + alpha_loss 218 219 220@POLICY_REGISTRY.register('stevesac') 221class STEVESACPolicy(SACPolicy): 222 r""" 223 Overview: 224 Model based SAC with stochastic value expansion (arXiv 1807.01675).\ 225 This implementation also uses value gradient w.r.t the same STEVE target. 226 227 https://arxiv.org/pdf/1807.01675.pdf 228 229 Config: 230 == ==================== ======== ============= ===================================== 231 ID Symbol Type Default Value Description 232 == ==================== ======== ============= ===================================== 233 1 ``learn.grad_clip` float 100.0 | Max norm of gradients. 234 2 ``learn.ensemble_size`` int 1 | The number of ensemble world models. 235 == ==================== ======== ============= ===================================== 236 237 .. note:: 238 For other configs, please refer to ding.policy.sac.SACPolicy. 239 """ 240 241 config = dict( 242 learn=dict( 243 # (float) Max norm of gradients. 244 grad_clip=100, 245 # (int) The number of ensemble world models. 246 ensemble_size=1, 247 ) 248 ) 249 250 def _init_learn(self) -> None: 251 super()._init_learn() 252 self._target_model.requires_grad_(False) 253 254 self._grad_clip = self._cfg.learn.grad_clip 255 self._ensemble_size = self._cfg.learn.ensemble_size 256 self._auto_alpha = self._cfg.learn.auto_alpha 257 # TODO: auto alpha 258 assert not self._auto_alpha, "NotImplemented" 259 260 def actor_fn(obs: Tensor): 261 obs, dim = fold_batch(obs, 1) 262 (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] 263 dist = Independent(Normal(mu, sigma), 1) 264 pred = dist.rsample() 265 action = torch.tanh(pred) 266 267 log_prob = dist.log_prob( 268 pred 269 ) + 2 * (pred + torch.nn.functional.softplus(-2. * pred) - torch.log(torch.tensor(2.))).sum(-1) 270 aug_reward = -self._alpha.detach() * log_prob 271 272 return unfold_batch(action, dim), unfold_batch(aug_reward, dim) 273 274 self._actor_fn = actor_fn 275 276 def critic_fn(obss: Tensor, actions: Tensor, model: nn.Module): 277 eval_data = {'obs': obss, 'action': actions} 278 q_values = model.forward(eval_data, mode='compute_critic')['q_value'] 279 return q_values 280 281 self._critic_fn = critic_fn 282 self._forward_learn_cnt = 0 283 284 def _forward_learn(self, data: dict, world_model, envstep) -> Dict[str, Any]: 285 # preprocess data 286 data = default_preprocess_learn( 287 data, 288 use_priority=self._priority, 289 use_priority_IS_weight=self._cfg.priority_IS_weight, 290 ignore_done=self._cfg.learn.ignore_done, 291 use_nstep=False 292 ) 293 if self._cuda: 294 data = to_device(data, self._device) 295 296 if len(data['action'].shape) == 1: 297 data['action'] = data['action'].unsqueeze(1) 298 299 # [B, D] -> [E, B, D] 300 data['next_obs'] = unsqueeze_repeat(data['next_obs'], self._ensemble_size) 301 data['reward'] = unsqueeze_repeat(data['reward'], self._ensemble_size) 302 data['done'] = unsqueeze_repeat(data['done'], self._ensemble_size) 303 304 self._learn_model.train() 305 self._target_model.train() 306 307 obss, actions, rewards, aug_rewards, dones = \ 308 world_model.rollout(data['next_obs'], self._actor_fn, envstep, keep_ensemble=True) 309 rewards = torch.cat([data['reward'].unsqueeze(0), rewards]) 310 dones = torch.cat([data['done'].unsqueeze(0), dones]) 311 312 # (T, E, B) 313 target_q_values = q_evaluation(obss, actions, partial(self._critic_fn, model=self._target_model)) 314 if self._twin_critic: 315 target_q_values = torch.min(target_q_values[0], target_q_values[1]) + aug_rewards 316 else: 317 target_q_values = target_q_values + aug_rewards 318 319 # (T+1, E, B) 320 discounts = ((1 - dones) * self._gamma).cumprod(dim=0) 321 discounts = torch.cat([torch.ones_like(discounts)[:1], discounts]) 322 # (T, E, B) 323 cum_rewards = (rewards * discounts[:-1]).cumsum(dim=0) 324 discounted_q_values = target_q_values * discounts[1:] 325 steve_return = cum_rewards + discounted_q_values 326 # (T, B) 327 steve_return_mean = steve_return.mean(1) 328 with torch.no_grad(): 329 steve_return_inv_var = 1 / (1e-8 + steve_return.var(1, unbiased=False)) 330 steve_return_weight = steve_return_inv_var / (1e-8 + steve_return_inv_var.sum(dim=0)) 331 # (B, ) 332 steve_return = (steve_return_mean * steve_return_weight).sum(0) 333 334 eval_data = {'obs': data['obs'], 'action': data['action']} 335 q_values = self._learn_model.forward(eval_data, mode='compute_critic')['q_value'] 336 if self._twin_critic: 337 critic_loss = 0.5 * torch.square(q_values[0] - steve_return.detach()) \ 338 + 0.5 * torch.square(q_values[1] - steve_return.detach()) 339 else: 340 critic_loss = 0.5 * torch.square(q_values - steve_return.detach()) 341 342 critic_loss = critic_loss.mean() 343 344 policy_loss = -steve_return.mean() 345 346 # alpha_loss = None 347 348 loss_dict = { 349 'critic_loss': critic_loss, 350 'policy_loss': policy_loss, 351 # 'alpha_loss': alpha_loss.detach(), 352 } 353 354 norm_dict = self._update(loss_dict) 355 356 # ============= 357 # after update 358 # ============= 359 self._forward_learn_cnt += 1 360 # target update 361 self._target_model.update(self._learn_model.state_dict()) 362 363 return { 364 'cur_lr_q': self._optimizer_q.defaults['lr'], 365 'cur_lr_p': self._optimizer_policy.defaults['lr'], 366 'alpha': self._alpha.item(), 367 'target_q_value': target_q_values.detach().mean().item(), 368 **norm_dict, 369 **loss_dict, 370 } 371 372 def _update(self, loss_dict): 373 # update critic 374 self._optimizer_q.zero_grad() 375 loss_dict['critic_loss'].backward() 376 critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip) 377 self._optimizer_q.step() 378 # update policy 379 self._optimizer_policy.zero_grad() 380 loss_dict['policy_loss'].backward() 381 policy_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip) 382 self._optimizer_policy.step() 383 # update temperature 384 # self._alpha_optim.zero_grad() 385 # loss_dict['alpha_loss'].backward() 386 # self._alpha_optim.step() 387 return {'policy_norm': policy_norm, 'critic_norm': critic_norm} 388 389 def _monitor_vars_learn(self) -> List[str]: 390 r""" 391 Overview: 392 Return variables' name if variables are to used in monitor. 393 Returns: 394 - vars (:obj:`List[str]`): Variables' name list. 395 """ 396 alpha_loss = ['alpha_loss'] if self._auto_alpha else [] 397 return [ 398 'policy_loss', 399 'critic_loss', 400 'policy_norm', 401 'critic_norm', 402 'cur_lr_q', 403 'cur_lr_p', 404 'alpha', 405 'target_q_value', 406 ] + alpha_loss