1from typing import List, Dict, Any, Tuple, Union 2from collections import namedtuple 3import torch 4import copy 5 6from ding.torch_utils import Adam, to_device 7from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample 8from ding.model import model_wrap 9from ding.utils import POLICY_REGISTRY 10from ding.utils.data import default_collate, default_decollate 11from .base_policy import Policy 12from .common_utils import default_preprocess_learn 13 14 15@POLICY_REGISTRY.register('ddpg') 16class DDPGPolicy(Policy): 17 """ 18 Overview: 19 Policy class of DDPG algorithm. Paper link: https://arxiv.org/abs/1509.02971. 20 21 Config: 22 == ==================== ======== ============= ================================= ======================= 23 ID Symbol Type Default Value Description Other(Shape) 24 == ==================== ======== ============= ================================= ======================= 25 1 | ``type`` str ddpg | RL policy register name, refer | this arg is optional, 26 | | to registry ``POLICY_REGISTRY`` | a placeholder 27 2 | ``cuda`` bool False | Whether to use cuda for network | 28 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for 29 | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for 30 | | buffer when training starts. | sac. 31 4 | ``model.twin_`` bool False | Whether to use two critic | Default False for 32 | ``critic`` | networks or only one. | DDPG, Clipped Double 33 | | | Q-learning method in 34 | | | TD3 paper. 35 5 | ``learn.learning`` float 1e-3 | Learning rate for actor | 36 | ``_rate_actor`` | network(aka. policy). | 37 6 | ``learn.learning`` float 1e-3 | Learning rates for critic | 38 | ``_rate_critic`` | network (aka. Q-network). | 39 7 | ``learn.actor_`` int 2 | When critic network updates | Default 1 for DDPG, 40 | ``update_freq`` | once, how many times will actor | 2 for TD3. Delayed 41 | | network update. | Policy Updates method 42 | | | in TD3 paper. 43 8 | ``learn.noise`` bool False | Whether to add noise on target | Default False for 44 | | network's action. | DDPG, True for TD3. 45 | | | Target Policy Smoo- 46 | | | thing Regularization 47 | | | in TD3 paper. 48 9 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only 49 | ``ignore_done`` | done flag. | in halfcheetah env. 50 10 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation 51 | ``target_theta`` | target network. | factor in polyak aver- 52 | | | aging for target 53 | | | networks. 54 11 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis- 55 | ``noise_sigma`` | llection, through controlling | tribution, Ornstein- 56 | | the sigma of distribution | Uhlenbeck process in 57 | | | DDPG paper, Gaussian 58 | | | process in ours. 59 == ==================== ======== ============= ================================= ======================= 60 """ 61 62 config = dict( 63 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 64 type='ddpg', 65 # (bool) Whether to use cuda in policy. 66 cuda=False, 67 # (bool) Whether learning policy is the same as collecting data policy(on-policy). Default False in DDPG. 68 on_policy=False, 69 # (bool) Whether to enable priority experience sample. 70 priority=False, 71 # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. 72 priority_IS_weight=False, 73 # (int) Number of training samples(randomly collected) in replay buffer when training starts. 74 # Default 25000 in DDPG/TD3. 75 random_collect_size=25000, 76 # (bool) Whether to need policy data in process transition. 77 transition_with_policy_data=False, 78 # (str) Action space type, including ['continuous', 'hybrid']. 79 action_space='continuous', 80 # (bool) Whether use batch normalization for reward. 81 reward_batch_norm=False, 82 # (bool) Whether to enable multi-agent training setting. 83 multi_agent=False, 84 # learn_mode config 85 learn=dict( 86 # (int) How many updates(iterations) to train after collector's one collection. 87 # Bigger "update_per_collect" means bigger off-policy. 88 # collect data -> update policy-> collect data -> ... 89 update_per_collect=1, 90 # (int) Minibatch size for gradient descent. 91 batch_size=256, 92 # (float) Learning rates for actor network(aka. policy). 93 learning_rate_actor=1e-3, 94 # (float) Learning rates for critic network(aka. Q-network). 95 learning_rate_critic=1e-3, 96 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 97 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 98 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 99 # However, interaction with HalfCheetah always gets done with False, 100 # Since we inplace done==True with done==False to keep 101 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 102 # when the episode step is greater than max episode step. 103 ignore_done=False, 104 # (float) target_theta: Used for soft update of the target network, 105 # aka. Interpolation factor in polyak averaging for target networks. 106 # Default to 0.005. 107 target_theta=0.005, 108 # (float) discount factor for the discounted sum of rewards, aka. gamma. 109 discount_factor=0.99, 110 # (int) When critic network updates once, how many times will actor network update. 111 # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). 112 # Default 1 for DDPG, 2 for TD3. 113 actor_update_freq=1, 114 # (bool) Whether to add noise on target network's action. 115 # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf). 116 # Default True for TD3, False for DDPG. 117 noise=False, 118 ), 119 # collect_mode config 120 collect=dict( 121 # (int) How many training samples collected in one collection procedure. 122 # Only one of [n_sample, n_episode] shoule be set. 123 # n_sample=1, 124 # (int) Split episodes or trajectories into pieces with length `unroll_len`. 125 unroll_len=1, 126 # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma". 127 noise_sigma=0.1, 128 ), 129 eval=dict(), # for compability 130 other=dict( 131 replay_buffer=dict( 132 # (int) Maximum size of replay buffer. Usually, larger buffer size is better. 133 replay_buffer_size=100000, 134 ), 135 ), 136 ) 137 138 def default_model(self) -> Tuple[str, List[str]]: 139 """ 140 Overview: 141 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 142 automatically call this method to get the default model setting and create model. 143 Returns: 144 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 145 """ 146 if self._cfg.multi_agent: 147 return 'continuous_maqac', ['ding.model.template.maqac'] 148 else: 149 return 'continuous_qac', ['ding.model.template.qac'] 150 151 def _init_learn(self) -> None: 152 """ 153 Overview: 154 Initialize the learn mode of policy, including related attributes and modules. For DDPG, it mainly \ 155 contains two optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target model. 156 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 157 158 .. note:: 159 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 160 and ``_load_state_dict_learn`` methods. 161 162 .. note:: 163 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 164 165 .. note:: 166 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 167 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 168 """ 169 self._priority = self._cfg.priority 170 self._priority_IS_weight = self._cfg.priority_IS_weight 171 # actor and critic optimizer 172 self._optimizer_actor = Adam( 173 self._model.actor.parameters(), 174 lr=self._cfg.learn.learning_rate_actor, 175 ) 176 self._optimizer_critic = Adam( 177 self._model.critic.parameters(), 178 lr=self._cfg.learn.learning_rate_critic, 179 ) 180 self._reward_batch_norm = self._cfg.reward_batch_norm 181 182 self._gamma = self._cfg.learn.discount_factor 183 self._actor_update_freq = self._cfg.learn.actor_update_freq 184 self._twin_critic = self._cfg.model.twin_critic # True for TD3, False for DDPG 185 186 # main and target models 187 self._target_model = copy.deepcopy(self._model) 188 self._learn_model = model_wrap(self._model, wrapper_name='base') 189 if self._cfg.action_space == 'hybrid': 190 self._learn_model = model_wrap(self._learn_model, wrapper_name='hybrid_argmax_sample') 191 self._target_model = model_wrap(self._target_model, wrapper_name='hybrid_argmax_sample') 192 self._target_model = model_wrap( 193 self._target_model, 194 wrapper_name='target', 195 update_type='momentum', 196 update_kwargs={'theta': self._cfg.learn.target_theta} 197 ) 198 if self._cfg.learn.noise: 199 self._target_model = model_wrap( 200 self._target_model, 201 wrapper_name='action_noise', 202 noise_type='gauss', 203 noise_kwargs={ 204 'mu': 0.0, 205 'sigma': self._cfg.learn.noise_sigma 206 }, 207 noise_range=self._cfg.learn.noise_range 208 ) 209 self._learn_model.reset() 210 self._target_model.reset() 211 212 self._forward_learn_cnt = 0 # count iterations 213 214 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 215 """ 216 Overview: 217 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 218 that the policy inputs some training batch data from the replay buffer and then returns the output \ 219 result, including various training information such as loss, action, priority. 220 Arguments: 221 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 222 training samples. For each element in list, the key of the dict is the name of data items and the \ 223 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 224 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 225 dimension by some utility functions such as ``default_preprocess_learn``. \ 226 For DDPG, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ 227 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ 228 and ``logit`` which is used for hybrid action space. 229 Returns: 230 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 231 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 232 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 233 234 .. note:: 235 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 236 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 237 You can implement you own model rather than use the default model. For more information, please raise an \ 238 issue in GitHub repo and we will continue to follow up. 239 240 .. note:: 241 For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``. 242 """ 243 loss_dict = {} 244 data = default_preprocess_learn( 245 data, 246 use_priority=self._cfg.priority, 247 use_priority_IS_weight=self._cfg.priority_IS_weight, 248 ignore_done=self._cfg.learn.ignore_done, 249 use_nstep=False 250 ) 251 if self._cuda: 252 data = to_device(data, self._device) 253 # ==================== 254 # critic learn forward 255 # ==================== 256 self._learn_model.train() 257 self._target_model.train() 258 next_obs = data['next_obs'] 259 reward = data['reward'] 260 if self._reward_batch_norm: 261 reward = (reward - reward.mean()) / (reward.std() + 1e-8) 262 # current q value 263 q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] 264 265 # target q value. 266 with torch.no_grad(): 267 next_actor_data = self._target_model.forward(next_obs, mode='compute_actor') 268 next_actor_data['obs'] = next_obs 269 target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value'] 270 271 q_value_dict = {} 272 target_q_value_dict = {} 273 274 if self._twin_critic: 275 # TD3: two critic networks 276 target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value 277 q_value_dict['q_value'] = q_value[0].mean().data.item() 278 q_value_dict['q_value_twin'] = q_value[1].mean().data.item() 279 target_q_value_dict['target q_value'] = target_q_value.mean().data.item() 280 # critic network1 281 td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight']) 282 critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma) 283 loss_dict['critic_loss'] = critic_loss 284 # critic network2(twin network) 285 td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight']) 286 critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma) 287 loss_dict['critic_twin_loss'] = critic_twin_loss 288 td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2 289 else: 290 # DDPG: single critic network 291 q_value_dict['q_value'] = q_value.mean().data.item() 292 target_q_value_dict['target q_value'] = target_q_value.mean().data.item() 293 td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight']) 294 critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma) 295 loss_dict['critic_loss'] = critic_loss 296 # ================ 297 # critic update 298 # ================ 299 self._optimizer_critic.zero_grad() 300 for k in loss_dict: 301 if 'critic' in k: 302 loss_dict[k].backward() 303 self._optimizer_critic.step() 304 # =============================== 305 # actor learn forward and update 306 # =============================== 307 # actor updates every ``self._actor_update_freq`` iters 308 if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0: 309 actor_data = self._learn_model.forward(data['obs'], mode='compute_actor') 310 actor_data['obs'] = data['obs'] 311 if self._twin_critic: 312 actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean() 313 else: 314 actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean() 315 316 loss_dict['actor_loss'] = actor_loss 317 # actor update 318 self._optimizer_actor.zero_grad() 319 actor_loss.backward() 320 self._optimizer_actor.step() 321 # ============= 322 # after update 323 # ============= 324 loss_dict['total_loss'] = sum(loss_dict.values()) 325 self._forward_learn_cnt += 1 326 self._target_model.update(self._learn_model.state_dict()) 327 if self._cfg.action_space == 'hybrid': 328 action_log_value = -1. # TODO(nyz) better way to viz hybrid action 329 else: 330 action_log_value = data['action'].mean() 331 return { 332 'cur_lr_actor': self._optimizer_actor.defaults['lr'], 333 'cur_lr_critic': self._optimizer_critic.defaults['lr'], 334 # 'q_value': np.array(q_value).mean(), 335 'action': action_log_value, 336 'priority': td_error_per_sample.abs().tolist(), 337 'td_error': td_error_per_sample.abs().mean(), 338 **loss_dict, 339 **q_value_dict, 340 **target_q_value_dict, 341 } 342 343 def _state_dict_learn(self) -> Dict[str, Any]: 344 """ 345 Overview: 346 Return the state_dict of learn mode, usually including model, target_model and optimizers. 347 Returns: 348 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 349 """ 350 return { 351 'model': self._learn_model.state_dict(), 352 'target_model': self._target_model.state_dict(), 353 'optimizer_actor': self._optimizer_actor.state_dict(), 354 'optimizer_critic': self._optimizer_critic.state_dict(), 355 } 356 357 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 358 """ 359 Overview: 360 Load the state_dict variable into policy learn mode. 361 Arguments: 362 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 363 364 .. tip:: 365 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 366 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 367 complicated operation. 368 """ 369 self._learn_model.load_state_dict(state_dict['model']) 370 self._target_model.load_state_dict(state_dict['target_model']) 371 self._optimizer_actor.load_state_dict(state_dict['optimizer_actor']) 372 self._optimizer_critic.load_state_dict(state_dict['optimizer_critic']) 373 374 def _init_collect(self) -> None: 375 """ 376 Overview: 377 Initialize the collect mode of policy, including related attributes and modules. For DDPG, it contains the \ 378 collect_model to balance the exploration and exploitation with the perturbed noise mechanism, and other \ 379 algorithm-specific arguments such as unroll_len. \ 380 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 381 382 .. note:: 383 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 384 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 385 """ 386 self._unroll_len = self._cfg.collect.unroll_len 387 # collect model 388 self._collect_model = model_wrap( 389 self._model, 390 wrapper_name='action_noise', 391 noise_type='gauss', 392 noise_kwargs={ 393 'mu': 0.0, 394 'sigma': self._cfg.collect.noise_sigma 395 }, 396 noise_range=None 397 ) 398 if self._cfg.action_space == 'hybrid': 399 self._collect_model = model_wrap(self._collect_model, wrapper_name='hybrid_eps_greedy_multinomial_sample') 400 self._collect_model.reset() 401 402 def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: 403 """ 404 Overview: 405 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 406 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 407 data, such as the action to interact with the envs. 408 Arguments: 409 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 410 key of the dict is environment id and the value is the corresponding data of the env. 411 Returns: 412 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 413 other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ 414 dict is the same as the input data, i.e., environment id. 415 416 .. note:: 417 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 418 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 419 You can implement you own model rather than use the default model. For more information, please raise an \ 420 issue in GitHub repo and we will continue to follow up. 421 422 .. note:: 423 For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``. 424 """ 425 data_id = list(data.keys()) 426 data = default_collate(list(data.values())) 427 if self._cuda: 428 data = to_device(data, self._device) 429 self._collect_model.eval() 430 with torch.no_grad(): 431 output = self._collect_model.forward(data, mode='compute_actor', **kwargs) 432 if self._cuda: 433 output = to_device(output, 'cpu') 434 output = default_decollate(output) 435 return {i: d for i, d in zip(data_id, output)} 436 437 def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], 438 timestep: namedtuple) -> Dict[str, torch.Tensor]: 439 """ 440 Overview: 441 Process and pack one timestep transition data into a dict, which can be directly used for training and \ 442 saved in replay buffer. For DDPG, it contains obs, next_obs, action, reward, done. 443 Arguments: 444 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. 445 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ 446 as input. For DDPG, it contains the action and the logit of the action (in hybrid action space). 447 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 448 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 449 reward, done, info, etc. 450 Returns: 451 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 452 """ 453 transition = { 454 'obs': obs, 455 'next_obs': timestep.obs, 456 'action': policy_output['action'], 457 'reward': timestep.reward, 458 'done': timestep.done, 459 } 460 if self._cfg.action_space == 'hybrid': 461 transition['logit'] = policy_output['logit'] 462 return transition 463 464 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 465 """ 466 Overview: 467 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ 468 can be used for training directly. In DDPG, a train sample is a processed transition (unroll_len=1). 469 Arguments: 470 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ 471 the same format as the return value of ``self._process_transition`` method. 472 Returns: 473 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ 474 as input transitions, but may contain more data for training. 475 """ 476 return get_train_sample(transitions, self._unroll_len) 477 478 def _init_eval(self) -> None: 479 """ 480 Overview: 481 Initialize the eval mode of policy, including related attributes and modules. For DDPG, it contains the \ 482 eval model to greedily select action type with argmax q_value mechanism for hybrid action space. \ 483 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 484 485 .. note:: 486 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 487 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 488 """ 489 self._eval_model = model_wrap(self._model, wrapper_name='base') 490 if self._cfg.action_space == 'hybrid': 491 self._eval_model = model_wrap(self._eval_model, wrapper_name='hybrid_argmax_sample') 492 self._eval_model.reset() 493 494 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 495 """ 496 Overview: 497 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 498 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 499 action to interact with the envs. 500 Arguments: 501 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 502 key of the dict is environment id and the value is the corresponding data of the env. 503 Returns: 504 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 505 key of the dict is the same as the input data, i.e. environment id. 506 507 .. note:: 508 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 509 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 510 You can implement you own model rather than use the default model. For more information, please raise an \ 511 issue in GitHub repo and we will continue to follow up. 512 513 .. note:: 514 For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``. 515 """ 516 data_id = list(data.keys()) 517 data = default_collate(list(data.values())) 518 if self._cuda: 519 data = to_device(data, self._device) 520 self._eval_model.eval() 521 with torch.no_grad(): 522 output = self._eval_model.forward(data, mode='compute_actor') 523 if self._cuda: 524 output = to_device(output, 'cpu') 525 output = default_decollate(output) 526 return {i: d for i, d in zip(data_id, output)} 527 528 def _monitor_vars_learn(self) -> List[str]: 529 """ 530 Overview: 531 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 532 as text logger, tensorboard logger, will use these keys to save the corresponding data. 533 Returns: 534 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 535 """ 536 ret = [ 537 'cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'q_value_twin', 538 'action', 'td_error' 539 ] 540 if self._twin_critic: 541 ret += ['critic_twin_loss'] 542 return ret