Skip to content

ding.policy.td3_vae

ding.policy.td3_vae

TD3VAEPolicy

Bases: DDPGPolicy

Overview

Policy class of TD3 algorithm.

Since DDPG and TD3 share many common things, we can easily derive this TD3 class from DDPG class by changing _actor_update_freq, _twin_critic and noise in model wrapper.

https://arxiv.org/pdf/1802.09477.pdf

Property

learn_mode, collect_mode, eval_mode

Config:

== ==================== ======== ================== ================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ================== ================================= ======================= 1 type str td3 | RL policy register name, refer | this arg is optional, | to registry POLICY_REGISTRY | a placeholder 2 cuda bool True | Whether to use cuda for network | 3 | random_ int 25000 | Number of randomly collected | Default to 25000 for | collect_size | training samples in replay | DDPG/TD3, 10000 for | | buffer when training starts. | sac. 4 | model.twin_ bool True | Whether to use two critic | Default True for TD3, | critic | networks or only one. | Clipped Double | | | Q-learning method in | | | TD3 paper. 5 | learn.learning float 1e-3 | Learning rate for actor | | _rate_actor | network(aka. policy). | 6 | learn.learning float 1e-3 | Learning rates for critic | | _rate_critic | network (aka. Q-network). | 7 | learn.actor_ int 2 | When critic network updates | Default 2 for TD3, 1 | update_freq | once, how many times will actor | for DDPG. Delayed | | network update. | Policy Updates method | | | in TD3 paper. 8 | learn.noise bool True | Whether to add noise on target | Default True for TD3, | | network's action. | False for DDPG. | | | Target Policy Smoo- | | | thing Regularization | | | in TD3 paper. 9 | learn.noise_ dict | dict(min=-0.5, | Limit for range of target | | range | max=0.5,) | policy smoothing noise, | | | | aka. noise_clip. | 10 | learn.- bool False | Determine whether to ignore | Use ignore_done only | ignore_done | done flag. | in halfcheetah env. 11 | learn.- float 0.005 | Used for soft update of the | aka. Interpolation | target_theta | target network. | factor in polyak aver | | | aging for target | | | networks. 12 | collect.- float 0.1 | Used for add noise during co- | Sample noise from dis | noise_sigma | llection, through controlling | tribution, Ornstein- | | the sigma of distribution | Uhlenbeck process in | | | DDPG paper, Guassian | | | process in ours. == ==================== ======== ================== ================================= =======================

Full Source Code

../ding/policy/td3_vae.py

1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import torch 4import copy 5 6from ding.torch_utils import Adam, to_device 7from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample 8from ding.model import model_wrap 9from ding.utils import POLICY_REGISTRY 10from ding.utils.data import default_collate, default_decollate 11from .base_policy import Policy 12from .common_utils import default_preprocess_learn 13from .ddpg import DDPGPolicy 14from ding.model.template.vae import VanillaVAE 15from ding.utils import RunningMeanStd 16from torch.nn import functional as F 17 18 19@POLICY_REGISTRY.register('td3-vae') 20class TD3VAEPolicy(DDPGPolicy): 21 r""" 22 Overview: 23 Policy class of TD3 algorithm. 24 25 Since DDPG and TD3 share many common things, we can easily derive this TD3 26 class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper. 27 28 https://arxiv.org/pdf/1802.09477.pdf 29 30 Property: 31 learn_mode, collect_mode, eval_mode 32 33 Config: 34 35 == ==================== ======== ================== ================================= ======================= 36 ID Symbol Type Default Value Description Other(Shape) 37 == ==================== ======== ================== ================================= ======================= 38 1 ``type`` str td3 | RL policy register name, refer | this arg is optional, 39 | to registry ``POLICY_REGISTRY`` | a placeholder 40 2 ``cuda`` bool True | Whether to use cuda for network | 41 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for 42 | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for 43 | | buffer when training starts. | sac. 44 4 | ``model.twin_`` bool True | Whether to use two critic | Default True for TD3, 45 | ``critic`` | networks or only one. | Clipped Double 46 | | | Q-learning method in 47 | | | TD3 paper. 48 5 | ``learn.learning`` float 1e-3 | Learning rate for actor | 49 | ``_rate_actor`` | network(aka. policy). | 50 6 | ``learn.learning`` float 1e-3 | Learning rates for critic | 51 | ``_rate_critic`` | network (aka. Q-network). | 52 7 | ``learn.actor_`` int 2 | When critic network updates | Default 2 for TD3, 1 53 | ``update_freq`` | once, how many times will actor | for DDPG. Delayed 54 | | network update. | Policy Updates method 55 | | | in TD3 paper. 56 8 | ``learn.noise`` bool True | Whether to add noise on target | Default True for TD3, 57 | | network's action. | False for DDPG. 58 | | | Target Policy Smoo- 59 | | | thing Regularization 60 | | | in TD3 paper. 61 9 | ``learn.noise_`` dict | dict(min=-0.5, | Limit for range of target | 62 | ``range`` | max=0.5,) | policy smoothing noise, | 63 | | | aka. noise_clip. | 64 10 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only 65 | ``ignore_done`` | done flag. | in halfcheetah env. 66 11 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation 67 | ``target_theta`` | target network. | factor in polyak aver 68 | | | aging for target 69 | | | networks. 70 12 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis 71 | ``noise_sigma`` | llection, through controlling | tribution, Ornstein- 72 | | the sigma of distribution | Uhlenbeck process in 73 | | | DDPG paper, Guassian 74 | | | process in ours. 75 == ==================== ======== ================== ================================= ======================= 76 """ 77 78 # You can refer to DDPG's default config for more details. 79 config = dict( 80 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 81 type='td3', 82 # (bool) Whether to use cuda for network. 83 cuda=False, 84 # (bool type) on_policy: Determine whether on-policy or off-policy. 85 # on-policy setting influences the behaviour of buffer. 86 # Default False in TD3. 87 on_policy=False, 88 # (bool) Whether use priority(priority sample, IS weight, update priority) 89 # Default False in TD3. 90 priority=False, 91 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 92 priority_IS_weight=False, 93 # (int) Number of training samples(randomly collected) in replay buffer when training starts. 94 # Default 25000 in DDPG/TD3. 95 random_collect_size=25000, 96 # (str) Action space type 97 action_space='continuous', # ['continuous', 'hybrid'] 98 # (bool) Whether use batch normalization for reward 99 reward_batch_norm=False, 100 original_action_shape=2, 101 model=dict( 102 # (bool) Whether to use two critic networks or only one. 103 # Clipped Double Q-Learning for Actor-Critic in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). 104 # Default True for TD3, False for DDPG. 105 twin_critic=True, 106 ), 107 learn=dict( 108 109 # How many updates(iterations) to train after collector's one collection. 110 # Bigger "update_per_collect" means bigger off-policy. 111 # collect data -> update policy-> collect data -> ... 112 update_per_collect=1, 113 # (int) Minibatch size for gradient descent. 114 batch_size=256, 115 # (float) Learning rates for actor network(aka. policy). 116 learning_rate_actor=1e-3, 117 # (float) Learning rates for critic network(aka. Q-network). 118 learning_rate_critic=1e-3, 119 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 120 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 121 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 122 # However, interaction with HalfCheetah always gets done with False, 123 # Since we inplace done==True with done==False to keep 124 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 125 # when the episode step is greater than max episode step. 126 ignore_done=False, 127 # (float type) target_theta: Used for soft update of the target network, 128 # aka. Interpolation factor in polyak averaging for target networks. 129 # Default to 0.005. 130 target_theta=0.005, 131 # (float) discount factor for the discounted sum of rewards, aka. gamma. 132 discount_factor=0.99, 133 # (int) When critic network updates once, how many times will actor network update. 134 # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). 135 # Default 1 for DDPG, 2 for TD3. 136 actor_update_freq=2, 137 # (bool) Whether to add noise on target network's action. 138 # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). 139 # Default True for TD3, False for DDPG. 140 noise=True, 141 # (float) Sigma for smoothing noise added to target policy. 142 noise_sigma=0.2, 143 # (dict) Limit for range of target policy smoothing noise, aka. noise_clip. 144 noise_range=dict( 145 min=-0.5, 146 max=0.5, 147 ), 148 ), 149 collect=dict( 150 # n_sample=1, 151 # (int) Cut trajectories into pieces with length "unroll_len". 152 unroll_len=1, 153 # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma". 154 noise_sigma=0.1, 155 ), 156 eval=dict( 157 evaluator=dict( 158 # (int) Evaluate every "eval_freq" training iterations. 159 eval_freq=5000, 160 ), 161 ), 162 other=dict( 163 replay_buffer=dict( 164 # (int) Maximum size of replay buffer. 165 replay_buffer_size=100000, 166 ), 167 ), 168 ) 169 170 def default_model(self) -> Tuple[str, List[str]]: 171 return 'continuous_qac', ['ding.model.template.qac'] 172 173 def _init_learn(self) -> None: 174 r""" 175 Overview: 176 Learn mode init method. Called by ``self.__init__``. 177 Init actor and critic optimizers, algorithm config, main and target models. 178 """ 179 self._priority = self._cfg.priority 180 self._priority_IS_weight = self._cfg.priority_IS_weight 181 # actor and critic optimizer 182 self._optimizer_actor = Adam( 183 self._model.actor.parameters(), 184 lr=self._cfg.learn.learning_rate_actor, 185 ) 186 self._optimizer_critic = Adam( 187 self._model.critic.parameters(), 188 lr=self._cfg.learn.learning_rate_critic, 189 ) 190 self._reward_batch_norm = self._cfg.reward_batch_norm 191 192 self._gamma = self._cfg.learn.discount_factor 193 self._actor_update_freq = self._cfg.learn.actor_update_freq 194 self._twin_critic = self._cfg.model.twin_critic # True for TD3, False for DDPG 195 196 # main and target models 197 self._target_model = copy.deepcopy(self._model) 198 if self._cfg.action_space == 'hybrid': 199 self._target_model = model_wrap(self._target_model, wrapper_name='hybrid_argmax_sample') 200 self._target_model = model_wrap( 201 self._target_model, 202 wrapper_name='target', 203 update_type='momentum', 204 update_kwargs={'theta': self._cfg.learn.target_theta} 205 ) 206 if self._cfg.learn.noise: 207 self._target_model = model_wrap( 208 self._target_model, 209 wrapper_name='action_noise', 210 noise_type='gauss', 211 noise_kwargs={ 212 'mu': 0.0, 213 'sigma': self._cfg.learn.noise_sigma 214 }, 215 noise_range=self._cfg.learn.noise_range 216 ) 217 self._learn_model = model_wrap(self._model, wrapper_name='base') 218 if self._cfg.action_space == 'hybrid': 219 self._learn_model = model_wrap(self._learn_model, wrapper_name='hybrid_argmax_sample') 220 self._learn_model.reset() 221 self._target_model.reset() 222 223 self._forward_learn_cnt = 0 # count iterations 224 # action_shape, obs_shape, latent_action_dim, hidden_size_list 225 self._vae_model = VanillaVAE( 226 self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256, 256] 227 ) 228 # self._vae_model = VanillaVAE(2, 8, 6, [256, 256]) 229 230 self._optimizer_vae = Adam( 231 self._vae_model.parameters(), 232 lr=self._cfg.learn.learning_rate_vae, 233 ) 234 self._running_mean_std_predict_loss = RunningMeanStd(epsilon=1e-4) 235 self.c_percentage_bound_lower = -1 * torch.ones([6]) 236 self.c_percentage_bound_upper = torch.ones([6]) 237 238 def _forward_learn(self, data: dict) -> Dict[str, Any]: 239 r""" 240 Overview: 241 Forward and backward function of learn mode. 242 Arguments: 243 - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] 244 Returns: 245 - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses. 246 """ 247 # warmup phase 248 if 'warm_up' in data[0].keys() and data[0]['warm_up'] is True: 249 loss_dict = {} 250 data = default_preprocess_learn( 251 data, 252 use_priority=self._cfg.priority, 253 use_priority_IS_weight=self._cfg.priority_IS_weight, 254 ignore_done=self._cfg.learn.ignore_done, 255 use_nstep=False 256 ) 257 if self._cuda: 258 data = to_device(data, self._device) 259 260 # ==================== 261 # train vae 262 # ==================== 263 result = self._vae_model({'action': data['action'], 'obs': data['obs']}) 264 265 result['original_action'] = data['action'] 266 result['true_residual'] = data['next_obs'] - data['obs'] 267 268 vae_loss = self._vae_model.loss_function(result, kld_weight=0.01, predict_weight=0.01) # TODO(pu): weight 269 270 loss_dict['vae_loss'] = vae_loss['loss'].item() 271 loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss'].item() 272 loss_dict['kld_loss'] = vae_loss['kld_loss'].item() 273 loss_dict['predict_loss'] = vae_loss['predict_loss'].item() 274 self._running_mean_std_predict_loss.update(vae_loss['predict_loss'].unsqueeze(-1).cpu().detach().numpy()) 275 276 # vae update 277 self._optimizer_vae.zero_grad() 278 vae_loss['loss'].backward() 279 self._optimizer_vae.step() 280 # For compatibility 281 loss_dict['actor_loss'] = torch.Tensor([0]).item() 282 loss_dict['critic_loss'] = torch.Tensor([0]).item() 283 loss_dict['critic_twin_loss'] = torch.Tensor([0]).item() 284 loss_dict['total_loss'] = torch.Tensor([0]).item() 285 q_value_dict = {} 286 q_value_dict['q_value'] = torch.Tensor([0]).item() 287 q_value_dict['q_value_twin'] = torch.Tensor([0]).item() 288 return { 289 'cur_lr_actor': self._optimizer_actor.defaults['lr'], 290 'cur_lr_critic': self._optimizer_critic.defaults['lr'], 291 'action': torch.Tensor([0]).item(), 292 'priority': torch.Tensor([0]).item(), 293 'td_error': torch.Tensor([0]).item(), 294 **loss_dict, 295 **q_value_dict, 296 } 297 else: 298 self._forward_learn_cnt += 1 299 loss_dict = {} 300 q_value_dict = {} 301 data = default_preprocess_learn( 302 data, 303 use_priority=self._cfg.priority, 304 use_priority_IS_weight=self._cfg.priority_IS_weight, 305 ignore_done=self._cfg.learn.ignore_done, 306 use_nstep=False 307 ) 308 if data['vae_phase'][0].item() is True: 309 if self._cuda: 310 data = to_device(data, self._device) 311 312 # ==================== 313 # train vae 314 # ==================== 315 result = self._vae_model({'action': data['action'], 'obs': data['obs']}) 316 317 result['original_action'] = data['action'] 318 result['true_residual'] = data['next_obs'] - data['obs'] 319 320 # latent space constraint (LSC) 321 # NOTE: using tanh is important, update latent_action using z, shape (128,6) 322 data['latent_action'] = torch.tanh(result['z'].clone().detach()) # NOTE: tanh 323 # data['latent_action'] = result['z'].clone().detach() 324 self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int( 325 result['recons_action'].shape[0] * 0.02 326 ), :] # values, indices 327 self.c_percentage_bound_upper = data['latent_action'].sort( 328 dim=0 329 )[0][int(result['recons_action'].shape[0] * 0.98), :] 330 331 vae_loss = self._vae_model.loss_function( 332 result, kld_weight=0.01, predict_weight=0.01 333 ) # TODO(pu): weight 334 335 loss_dict['vae_loss'] = vae_loss['loss'] 336 loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss'] 337 loss_dict['kld_loss'] = vae_loss['kld_loss'] 338 loss_dict['predict_loss'] = vae_loss['predict_loss'] 339 340 # vae update 341 self._optimizer_vae.zero_grad() 342 vae_loss['loss'].backward() 343 self._optimizer_vae.step() 344 345 return { 346 'cur_lr_actor': self._optimizer_actor.defaults['lr'], 347 'cur_lr_critic': self._optimizer_critic.defaults['lr'], 348 # 'q_value': np.array(q_value).mean(), 349 'action': torch.Tensor([0]).item(), 350 'priority': torch.Tensor([0]).item(), 351 'td_error': torch.Tensor([0]).item(), 352 **loss_dict, 353 **q_value_dict, 354 } 355 356 else: 357 # ==================== 358 # critic learn forward 359 # ==================== 360 self._learn_model.train() 361 self._target_model.train() 362 next_obs = data['next_obs'] 363 reward = data['reward'] 364 365 # ==================== 366 # relabel latent action 367 # ==================== 368 if self._cuda: 369 data = to_device(data, self._device) 370 result = self._vae_model({'action': data['action'], 'obs': data['obs']}) 371 true_residual = data['next_obs'] - data['obs'] 372 373 # Representation shift correction (RSC) 374 for i in range(result['recons_action'].shape[0]): 375 if F.mse_loss(result['prediction_residual'][i], 376 true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean: 377 # NOTE: using tanh is important, update latent_action using z 378 data['latent_action'][i] = torch.tanh(result['z'][i].clone().detach()) # NOTE: tanh 379 # data['latent_action'][i] = result['z'][i].clone().detach() 380 381 # update all latent action 382 # data['latent_action'] = torch.tanh(result['z'].clone().detach()) 383 384 if self._reward_batch_norm: 385 reward = (reward - reward.mean()) / (reward.std() + 1e-8) 386 387 # current q value 388 q_value = self._learn_model.forward( 389 { 390 'obs': data['obs'], 391 'action': data['latent_action'] 392 }, mode='compute_critic' 393 )['q_value'] 394 q_value_dict = {} 395 if self._twin_critic: 396 q_value_dict['q_value'] = q_value[0].mean() 397 q_value_dict['q_value_twin'] = q_value[1].mean() 398 else: 399 q_value_dict['q_value'] = q_value.mean() 400 # target q value. 401 with torch.no_grad(): 402 # NOTE: here next_actor_data['action'] is latent action 403 next_actor_data = self._target_model.forward(next_obs, mode='compute_actor') 404 next_actor_data['obs'] = next_obs 405 target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value'] 406 if self._twin_critic: 407 # TD3: two critic networks 408 target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value 409 # critic network1 410 td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight']) 411 critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma) 412 loss_dict['critic_loss'] = critic_loss 413 # critic network2(twin network) 414 td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight']) 415 critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma) 416 loss_dict['critic_twin_loss'] = critic_twin_loss 417 td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2 418 else: 419 # DDPG: single critic network 420 td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight']) 421 critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma) 422 loss_dict['critic_loss'] = critic_loss 423 # ================ 424 # critic update 425 # ================ 426 self._optimizer_critic.zero_grad() 427 for k in loss_dict: 428 if 'critic' in k: 429 loss_dict[k].backward() 430 self._optimizer_critic.step() 431 # =============================== 432 # actor learn forward and update 433 # =============================== 434 # actor updates every ``self._actor_update_freq`` iters 435 if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0: 436 # NOTE: actor_data['action] is latent action 437 actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') 438 actor_data['obs'] = data['obs'] 439 if self._twin_critic: 440 actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean() 441 else: 442 actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean() 443 444 loss_dict['actor_loss'] = actor_loss 445 # actor update 446 self._optimizer_actor.zero_grad() 447 actor_loss.backward() 448 self._optimizer_actor.step() 449 # ============= 450 # after update 451 # ============= 452 loss_dict['total_loss'] = sum(loss_dict.values()) 453 # self._forward_learn_cnt += 1 454 self._target_model.update(self._learn_model.state_dict()) 455 if self._cfg.action_space == 'hybrid': 456 action_log_value = -1. # TODO(nyz) better way to viz hybrid action 457 else: 458 action_log_value = data['action'].mean() 459 460 return { 461 'cur_lr_actor': self._optimizer_actor.defaults['lr'], 462 'cur_lr_critic': self._optimizer_critic.defaults['lr'], 463 'action': action_log_value, 464 'priority': td_error_per_sample.abs().tolist(), 465 'td_error': td_error_per_sample.abs().mean(), 466 **loss_dict, 467 **q_value_dict, 468 } 469 470 def _state_dict_learn(self) -> Dict[str, Any]: 471 return { 472 'model': self._learn_model.state_dict(), 473 'target_model': self._target_model.state_dict(), 474 'optimizer_actor': self._optimizer_actor.state_dict(), 475 'optimizer_critic': self._optimizer_critic.state_dict(), 476 'vae_model': self._vae_model.state_dict(), 477 } 478 479 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 480 self._learn_model.load_state_dict(state_dict['model']) 481 self._target_model.load_state_dict(state_dict['target_model']) 482 self._optimizer_actor.load_state_dict(state_dict['optimizer_actor']) 483 self._optimizer_critic.load_state_dict(state_dict['optimizer_critic']) 484 self._vae_model.load_state_dict(state_dict['vae_model']) 485 486 def _init_collect(self) -> None: 487 r""" 488 Overview: 489 Collect mode init method. Called by ``self.__init__``. 490 Init traj and unroll length, collect model. 491 """ 492 self._unroll_len = self._cfg.collect.unroll_len 493 # collect model 494 self._collect_model = model_wrap( 495 self._model, 496 wrapper_name='action_noise', 497 noise_type='gauss', 498 noise_kwargs={ 499 'mu': 0.0, 500 'sigma': self._cfg.collect.noise_sigma 501 }, 502 noise_range=None 503 ) 504 if self._cfg.action_space == 'hybrid': 505 self._collect_model = model_wrap(self._collect_model, wrapper_name='hybrid_eps_greedy_multinomial_sample') 506 self._collect_model.reset() 507 508 def _forward_collect(self, data: dict, **kwargs) -> dict: 509 r""" 510 Overview: 511 Forward function of collect mode. 512 Arguments: 513 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 514 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 515 Returns: 516 - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. 517 ReturnsKeys 518 - necessary: ``action`` 519 - optional: ``logit`` 520 """ 521 data_id = list(data.keys()) 522 data = default_collate(list(data.values())) 523 if self._cuda: 524 data = to_device(data, self._device) 525 self._collect_model.eval() 526 with torch.no_grad(): 527 output = self._collect_model.forward(data, mode='compute_actor', **kwargs) 528 output['latent_action'] = output['action'] 529 530 # latent space constraint (LSC) 531 for i in range(output['action'].shape[-1]): 532 output['action'][:, i].clamp_( 533 self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item() 534 ) 535 536 # TODO(pu): decode into original hybrid actions, here data is obs 537 # this is very important to generate self.obs_encoding using in decode phase 538 output['action'] = self._vae_model.decode_with_obs(output['action'], data)['reconstruction_action'] 539 540 # NOTE: add noise in the original actions 541 from ding.rl_utils.exploration import GaussianNoise 542 action = output['action'] 543 gaussian_noise = GaussianNoise(mu=0.0, sigma=0.1) 544 noise = gaussian_noise(output['action'].shape, output['action'].device) 545 if self._cfg.learn.noise_range is not None: 546 noise = noise.clamp(self._cfg.learn.noise_range['min'], self._cfg.learn.noise_range['max']) 547 action += noise 548 self.action_range = {'min': -1, 'max': 1} 549 if self.action_range is not None: 550 action = action.clamp(self.action_range['min'], self.action_range['max']) 551 output['action'] = action 552 553 if self._cuda: 554 output = to_device(output, 'cpu') 555 output = default_decollate(output) 556 return {i: d for i, d in zip(data_id, output)} 557 558 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> Dict[str, Any]: 559 r""" 560 Overview: 561 Generate dict type transition data from inputs. 562 Arguments: 563 - obs (:obj:`Any`): Env observation 564 - model_output (:obj:`dict`): Output of collect model, including at least ['action'] 565 - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ 566 (here 'obs' indicates obs after env step, i.e. next_obs). 567 Return: 568 - transition (:obj:`Dict[str, Any]`): Dict type transition data. 569 """ 570 if 'latent_action' in model_output.keys(): 571 transition = { 572 'obs': obs, 573 'next_obs': timestep.obs, 574 'action': model_output['action'], 575 'latent_action': model_output['latent_action'], 576 'reward': timestep.reward, 577 'done': timestep.done, 578 } 579 else: # if random collect at fist 580 transition = { 581 'obs': obs, 582 'next_obs': timestep.obs, 583 'action': model_output['action'], 584 'latent_action': 999, 585 'reward': timestep.reward, 586 'done': timestep.done, 587 } 588 if self._cfg.action_space == 'hybrid': 589 transition['logit'] = model_output['logit'] 590 return transition 591 592 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 593 return get_train_sample(data, self._unroll_len) 594 595 def _init_eval(self) -> None: 596 r""" 597 Overview: 598 Evaluate mode init method. Called by ``self.__init__``. 599 Init eval model. Unlike learn and collect model, eval model does not need noise. 600 """ 601 self._eval_model = model_wrap(self._model, wrapper_name='base') 602 if self._cfg.action_space == 'hybrid': 603 self._eval_model = model_wrap(self._eval_model, wrapper_name='hybrid_argmax_sample') 604 self._eval_model.reset() 605 606 def _forward_eval(self, data: dict) -> dict: 607 r""" 608 Overview: 609 Forward function of eval mode, similar to ``self._forward_collect``. 610 Arguments: 611 - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ 612 values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. 613 Returns: 614 - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. 615 ReturnsKeys 616 - necessary: ``action`` 617 - optional: ``logit`` 618 """ 619 data_id = list(data.keys()) 620 data = default_collate(list(data.values())) 621 if self._cuda: 622 data = to_device(data, self._device) 623 self._eval_model.eval() 624 with torch.no_grad(): 625 output = self._eval_model.forward(data, mode='compute_actor') 626 output['latent_action'] = output['action'] 627 628 # latent space constraint (LSC) 629 for i in range(output['action'].shape[-1]): 630 output['action'][:, i].clamp_( 631 self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item() 632 ) 633 634 # TODO(pu): decode into original hybrid actions, here data is obs 635 # this is very important to generate self.obs_encoding using in decode phase 636 output['action'] = self._vae_model.decode_with_obs(output['action'], data)['reconstruction_action'] 637 if self._cuda: 638 output = to_device(output, 'cpu') 639 output = default_decollate(output) 640 return {i: d for i, d in zip(data_id, output)} 641 642 def _monitor_vars_learn(self) -> List[str]: 643 r""" 644 Overview: 645 Return variables' names if variables are to used in monitor. 646 Returns: 647 - vars (:obj:`List[str]`): Variables' name list. 648 """ 649 ret = [ 650 'cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'q_value_twin', 651 'action', 'td_error', 'vae_loss', 'reconstruction_loss', 'kld_loss', 'predict_loss' 652 ] 653 if self._twin_critic: 654 ret += ['critic_twin_loss'] 655 return ret