Skip to content

ding.policy.impala

ding.policy.impala

IMPALAPolicy

Bases: Policy

Overview

Policy class of IMPALA algorithm. Paper link: https://arxiv.org/abs/1802.01561.

Config

== ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str impala | RL policy register name, refer to | this arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | this arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4. priority bool False | Whether use priority(PER) | priority sample, | update priority

5 | priority_ bool False | Whether use Importance Sampling Weight | If True, priority | IS_weight | | must be True 6 unroll_len int 32 | trajectory length to calculate v-trace | target 7 | learn.update int 4 | How many updates(iterations) to train | this args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy == ==================== ======== ============== ======================================== =======================

default_model()

Overview

Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

.. note:: The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For example about IMPALA , its registered name is vac and the import_names is ding.model.template.vac.

Full Source Code

../ding/policy/impala.py

1from collections import namedtuple 2from typing import List, Dict, Any, Tuple 3 4import torch 5import treetensor.torch as ttorch 6 7from ding.model import model_wrap 8from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action, get_train_sample 9from ding.torch_utils import Adam, RMSprop, to_device 10from ding.utils import POLICY_REGISTRY 11from ding.utils.data import default_collate, default_decollate, ttorch_collate 12from ding.policy.base_policy import Policy 13 14 15@POLICY_REGISTRY.register('impala') 16class IMPALAPolicy(Policy): 17 """ 18 Overview: 19 Policy class of IMPALA algorithm. Paper link: https://arxiv.org/abs/1802.01561. 20 21 Config: 22 == ==================== ======== ============== ======================================== ======================= 23 ID Symbol Type Default Value Description Other(Shape) 24 == ==================== ======== ============== ======================================== ======================= 25 1 ``type`` str impala | RL policy register name, refer to | this arg is optional, 26 | registry ``POLICY_REGISTRY`` | a placeholder 27 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff- 28 | erent from modes 29 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 30 | or off-policy 31 4. ``priority`` bool False | Whether use priority(PER) | priority sample, 32 | update priority 33 34 5 | ``priority_`` bool False | Whether use Importance Sampling Weight | If True, priority 35 | ``IS_weight`` | | must be True 36 6 ``unroll_len`` int 32 | trajectory length to calculate v-trace 37 | target 38 7 | ``learn.update`` int 4 | How many updates(iterations) to train | this args can be vary 39 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 40 | valid in serial training | means more off-policy 41 == ==================== ======== ============== ======================================== ======================= 42 """ 43 config = dict( 44 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 45 type='impala', 46 # (bool) Whether to use cuda in policy. 47 cuda=False, 48 # (bool) Whether learning policy is the same as collecting data policy(on-policy). 49 on_policy=False, 50 # (bool) Whether to enable priority experience sample. 51 priority=False, 52 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 53 priority_IS_weight=False, 54 # (str) Which kind of action space used in IMPALAPolicy, ['discrete', 'continuous']. 55 action_space='discrete', 56 # (int) the trajectory length to calculate v-trace target. 57 unroll_len=32, 58 # (bool) Whether to need policy data in process transition. 59 transition_with_policy_data=True, 60 # learn_mode config 61 learn=dict( 62 # (int) collect n_sample data, train model update_per_collect times. 63 update_per_collect=4, 64 # (int) the number of data for a train iteration. 65 batch_size=16, 66 # (float) The step size of gradient descent. 67 learning_rate=0.0005, 68 # (float) loss weight of the value network, the weight of policy network is set to 1. 69 value_weight=0.5, 70 # (float) loss weight of the entropy regularization, the weight of policy network is set to 1. 71 entropy_weight=0.0001, 72 # (float) discount factor for future reward, defaults int [0, 1]. 73 discount_factor=0.99, 74 # (float) additional discounting parameter. 75 lambda_=0.95, 76 # (float) clip ratio of importance weights. 77 rho_clip_ratio=1.0, 78 # (float) clip ratio of importance weights. 79 c_clip_ratio=1.0, 80 # (float) clip ratio of importance sampling. 81 rho_pg_clip_ratio=1.0, 82 # (str) The gradient clip operation type used in IMPALA, ['clip_norm', clip_value', 'clip_momentum_norm']. 83 grad_clip_type=None, 84 # (float) The gradient clip target value used in IMPALA. 85 # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value. 86 clip_value=0.5, 87 # (str) Optimizer used to train the network, ['adam', 'rmsprop']. 88 optim='adam', 89 ), 90 # collect_mode config 91 collect=dict( 92 # (int) How many training samples collected in one collection procedure. 93 # Only one of [n_sample, n_episode] shoule be set. 94 # n_sample=16, 95 ), 96 eval=dict(), # for compatibility 97 other=dict( 98 replay_buffer=dict( 99 # (int) Maximum size of replay buffer. Usually, larger buffer size is better. 100 replay_buffer_size=1000, 101 # (int) Maximum use times for a sample in buffer. If reaches this value, the sample will be removed. 102 max_use=16, 103 ), 104 ), 105 ) 106 107 def default_model(self) -> Tuple[str, List[str]]: 108 """ 109 Overview: 110 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 111 automatically call this method to get the default model setting and create model. 112 Returns: 113 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 114 115 .. note:: 116 The user can define and use customized network model but must obey the same inferface definition indicated \ 117 by import_names path. For example about IMPALA , its registered name is ``vac`` and the import_names is \ 118 ``ding.model.template.vac``. 119 """ 120 return 'vac', ['ding.model.template.vac'] 121 122 def _init_learn(self) -> None: 123 """ 124 Overview: 125 Initialize the learn mode of policy, including related attributes and modules. For IMPALA, it mainly \ 126 contains optimizer, algorithm-specific arguments such as loss weight and gamma, main (learn) model. 127 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 128 129 .. note:: 130 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 131 and ``_load_state_dict_learn`` methods. 132 133 .. note:: 134 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 135 136 .. note:: 137 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 138 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 139 """ 140 assert self._cfg.action_space in ["continuous", "discrete"], self._cfg.action_space 141 self._action_space = self._cfg.action_space 142 # Optimizer 143 optim_type = self._cfg.learn.optim 144 if optim_type == 'rmsprop': 145 self._optimizer = RMSprop(self._model.parameters(), lr=self._cfg.learn.learning_rate) 146 elif optim_type == 'adam': 147 self._optimizer = Adam( 148 self._model.parameters(), 149 grad_clip_type=self._cfg.learn.grad_clip_type, 150 clip_value=self._cfg.learn.clip_value, 151 lr=self._cfg.learn.learning_rate 152 ) 153 else: 154 raise NotImplementedError("Now only support rmsprop and adam, but input is {}".format(optim_type)) 155 self._learn_model = model_wrap(self._model, wrapper_name='base') 156 157 self._action_shape = self._cfg.model.action_shape 158 self._unroll_len = self._cfg.unroll_len 159 160 # Algorithm config 161 self._priority = self._cfg.priority 162 self._priority_IS_weight = self._cfg.priority_IS_weight 163 self._value_weight = self._cfg.learn.value_weight 164 self._entropy_weight = self._cfg.learn.entropy_weight 165 self._gamma = self._cfg.learn.discount_factor 166 self._lambda = self._cfg.learn.lambda_ 167 self._rho_clip_ratio = self._cfg.learn.rho_clip_ratio 168 self._c_clip_ratio = self._cfg.learn.c_clip_ratio 169 self._rho_pg_clip_ratio = self._cfg.learn.rho_pg_clip_ratio 170 171 # Main model 172 self._learn_model.reset() 173 174 def _data_preprocess_learn(self, data: List[Dict[str, Any]]): 175 """ 176 Overview: 177 Data preprocess function of learn mode. 178 Convert list trajectory data to to trajectory data, which is a dict of tensors. 179 Arguments: 180 - data (:obj:`List[Dict[str, Any]]`): List type data, a list of data for training. Each list element is a \ 181 dict, whose values are torch.Tensor or np.ndarray or dict/list combinations, keys include at least \ 182 'obs', 'next_obs', 'logit', 'action', 'reward', 'done' 183 Returns: 184 - data (:obj:`dict`): Dict type data. Values are torch.Tensor or np.ndarray or dict/list combinations. \ 185 ReturnsKeys: 186 - necessary: 'logit', 'action', 'reward', 'done', 'weight', 'obs_plus_1'. 187 - optional and not used in later computation: 'obs', 'next_obs'.'IS', 'collect_iter', 'replay_unique_id', \ 188 'replay_buffer_idx', 'priority', 'staleness', 'use'. 189 ReturnsShapes: 190 - obs_plus_1 (:obj:`torch.FloatTensor`): :math:`(T * B, obs_shape)`, where T is timestep, B is batch size \ 191 and obs_shape is the shape of single env observation 192 - logit (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where N is action dim 193 - action (:obj:`torch.LongTensor`): :math:`(T, B)` 194 - reward (:obj:`torch.FloatTensor`): :math:`(T+1, B)` 195 - done (:obj:`torch.FloatTensor`): :math:`(T, B)` 196 - weight (:obj:`torch.FloatTensor`): :math:`(T, B)` 197 """ 198 elem = data[0] 199 if isinstance(elem, dict): # old pipeline 200 data = default_collate(data) 201 elif isinstance(elem, list): # new task pipeline 202 data = default_collate(default_collate(data)) 203 else: 204 raise TypeError("not support element type ({}) in IMPALA".format(type(elem))) 205 if self._cuda: 206 data = to_device(data, self._device) 207 if self._priority_IS_weight: 208 assert self._priority, "Use IS Weight correction, but Priority is not used." 209 if self._priority and self._priority_IS_weight: 210 data['weight'] = data['IS'] 211 else: 212 data['weight'] = data.get('weight', None) 213 if isinstance(elem, dict): # old pipeline 214 for k in data: 215 if isinstance(data[k], list): 216 data[k] = default_collate(data[k]) 217 data['obs_plus_1'] = torch.cat([data['obs'], data['next_obs'][-1:]], dim=0) # shape (T+1)*B,env_obs_shape 218 return data 219 220 def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 221 """ 222 Overview: 223 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 224 that the policy inputs some training batch data from the replay buffer and then returns the output \ 225 result, including various training information such as loss and current learning rate. 226 Arguments: 227 - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ 228 training samples. For each element in list, the key of the dict is the name of data items and the \ 229 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 230 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ 231 dimension by some utility functions such as ``default_preprocess_learn``. \ 232 For IMPALA, each element in list is a dict containing at least the following keys: ``obs``, \ 233 ``action``, ``logit``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such \ 234 as ``weight``. 235 Returns: 236 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 237 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 238 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 239 240 .. note:: 241 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 242 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 243 You can implement you own model rather than use the default model. For more information, please raise an \ 244 issue in GitHub repo and we will continue to follow up. 245 246 .. note:: 247 For more detailed examples, please refer to unittest for IMPALAPolicy: ``ding.policy.tests.test_impala``. 248 """ 249 data = self._data_preprocess_learn(data) 250 # ==================== 251 # IMPALA forward 252 # ==================== 253 self._learn_model.train() 254 output = self._learn_model.forward( 255 data['obs_plus_1'].view((-1, ) + data['obs_plus_1'].shape[2:]), mode='compute_actor_critic' 256 ) 257 target_logit, behaviour_logit, actions, values, rewards, weights = self._reshape_data(output, data) 258 # Calculate vtrace error 259 data = vtrace_data(target_logit, behaviour_logit, actions, values, rewards, weights) 260 g, l, r, c, rg = self._gamma, self._lambda, self._rho_clip_ratio, self._c_clip_ratio, self._rho_pg_clip_ratio 261 if self._action_space == 'continuous': 262 vtrace_loss = vtrace_error_continuous_action(data, g, l, r, c, rg) 263 elif self._action_space == 'discrete': 264 vtrace_loss = vtrace_error_discrete_action(data, g, l, r, c, rg) 265 266 wv, we = self._value_weight, self._entropy_weight 267 total_loss = vtrace_loss.policy_loss + wv * vtrace_loss.value_loss - we * vtrace_loss.entropy_loss 268 # ==================== 269 # IMPALA update 270 # ==================== 271 self._optimizer.zero_grad() 272 total_loss.backward() 273 self._optimizer.step() 274 return { 275 'cur_lr': self._optimizer.defaults['lr'], 276 'total_loss': total_loss.item(), 277 'policy_loss': vtrace_loss.policy_loss.item(), 278 'value_loss': vtrace_loss.value_loss.item(), 279 'entropy_loss': vtrace_loss.entropy_loss.item(), 280 } 281 282 def _reshape_data(self, output: Dict[str, Any], data: Dict[str, Any]) -> Tuple: 283 """ 284 Overview: 285 Obtain weights for loss calculating, where should be 0 for done positions. Update values and rewards with \ 286 the weight. 287 Arguments: 288 - output (:obj:`Dict[int, Any]`): Dict type data, output of learn_model forward. \ 289 Values are torch.Tensor or np.ndarray or dict/list combinations,keys are value, logit. 290 - data (:obj:`Dict[int, Any]`): Dict type data, input of policy._forward_learn Values are torch.Tensor or \ 291 np.ndarray or dict/list combinations. Keys includes at least ['logit', 'action', 'reward', 'done']. 292 Returns: 293 - data (:obj:`Tuple[Any]`): Tuple of target_logit, behaviour_logit, actions, values, rewards, weights. 294 ReturnsShapes: 295 - target_logit (:obj:`torch.FloatTensor`): :math:`((T+1), B, Obs_Shape)`, where T is timestep,\ 296 B is batch size and Obs_Shape is the shape of single env observation. 297 - behaviour_logit (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where N is action dim. 298 - actions (:obj:`torch.LongTensor`): :math:`(T, B)` 299 - values (:obj:`torch.FloatTensor`): :math:`(T+1, B)` 300 - rewards (:obj:`torch.FloatTensor`): :math:`(T, B)` 301 - weights (:obj:`torch.FloatTensor`): :math:`(T, B)` 302 """ 303 if self._action_space == 'continuous': 304 target_logit = {} 305 target_logit['mu'] = output['logit']['mu'].reshape(self._unroll_len + 1, -1, 306 self._action_shape)[:-1 307 ] # shape (T+1),B,env_action_shape 308 target_logit['sigma'] = output['logit']['sigma'].reshape(self._unroll_len + 1, -1, self._action_shape 309 )[:-1] # shape (T+1),B,env_action_shape 310 elif self._action_space == 'discrete': 311 target_logit = output['logit'].reshape(self._unroll_len + 1, -1, 312 self._action_shape)[:-1] # shape (T+1),B,env_action_shape 313 behaviour_logit = data['logit'] # shape T,B 314 actions = data['action'] # shape T,B for discrete # shape T,B,env_action_shape for continuous 315 values = output['value'].reshape(self._unroll_len + 1, -1) # shape T+1,B,env_action_shape 316 rewards = data['reward'] # shape T,B 317 weights_ = 1 - data['done'].float() # shape T,B 318 weights = torch.ones_like(rewards) # shape T,B 319 values[1:] = values[1:] * weights_ 320 weights[1:] = weights_[:-1] 321 rewards = rewards * weights # shape T,B 322 return target_logit, behaviour_logit, actions, values, rewards, weights 323 324 def _init_collect(self) -> None: 325 """ 326 Overview: 327 Initialize the collect mode of policy, including related attributes and modules. For IMPALA, it contains \ 328 the collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \ 329 discrete action space), and other algorithm-specific arguments such as unroll_len. 330 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 331 332 .. note:: 333 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 334 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 335 """ 336 assert self._cfg.action_space in ["continuous", "discrete"] 337 self._action_space = self._cfg.action_space 338 if self._action_space == 'continuous': 339 self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample') 340 elif self._action_space == 'discrete': 341 self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample') 342 343 self._collect_model.reset() 344 345 def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: 346 """ 347 Overview: 348 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 349 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 350 data, such as the action to interact with the envs. 351 Arguments: 352 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 353 key of the dict is environment id and the value is the corresponding data of the env. 354 Returns: 355 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 356 other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \ 357 method. The key of the dict is the same as the input data, i.e. environment id. 358 359 .. tip:: 360 If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \ 361 related data as extra keyword arguments of this method. 362 363 .. note:: 364 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 365 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 366 You can implement you own model rather than use the default model. For more information, please raise an \ 367 issue in GitHub repo and we will continue to follow up. 368 369 .. note:: 370 For more detailed examples, please refer to unittest for IMPALAPolicy: ``ding.policy.tests.test_impala``. 371 """ 372 data_id = list(data.keys()) 373 data = default_collate(list(data.values())) 374 if self._cuda: 375 data = to_device(data, self._device) 376 self._collect_model.eval() 377 with torch.no_grad(): 378 output = self._collect_model.forward(data, mode='compute_actor') 379 if self._cuda: 380 output = to_device(output, 'cpu') 381 output = default_decollate(output) 382 output = {i: d for i, d in zip(data_id, output)} 383 return output 384 385 def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 386 """ 387 Overview: 388 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ 389 can be used for training. In IMPALA, a train sample is processed transitions with unroll_len length. 390 Arguments: 391 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ 392 the same format as the return value of ``self._process_transition`` method. 393 Returns: 394 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ 395 as input transitions, but may contain more data for training. 396 """ 397 return get_train_sample(data, self._unroll_len) 398 399 def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], 400 timestep: namedtuple) -> Dict[str, torch.Tensor]: 401 """ 402 Overview: 403 Process and pack one timestep transition data into a dict, which can be directly used for training and \ 404 saved in replay buffer. For IMPALA, it contains obs, next_obs, action, reward, done, logit. 405 Arguments: 406 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. 407 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ 408 as input. For IMPALA, it contains the action and the logit of the action. 409 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 410 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 411 reward, done, info, etc. 412 Returns: 413 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 414 """ 415 transition = { 416 'obs': obs, 417 'next_obs': timestep.obs, 418 'logit': policy_output['logit'], 419 'action': policy_output['action'], 420 'reward': timestep.reward, 421 'done': timestep.done, 422 } 423 return transition 424 425 def _init_eval(self) -> None: 426 """ 427 Overview: 428 Initialize the eval mode of policy, including related attributes and modules. For IMPALA, it contains the \ 429 eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action). 430 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 431 432 .. note:: 433 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 434 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 435 """ 436 assert self._cfg.action_space in ["continuous", "discrete"], self._cfg.action_space 437 self._action_space = self._cfg.action_space 438 if self._action_space == 'continuous': 439 self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample') 440 elif self._action_space == 'discrete': 441 self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') 442 443 self._eval_model.reset() 444 445 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 446 """ 447 Overview: 448 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 449 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 450 action to interact with the envs. ``_forward_eval`` in IMPALA often uses deterministic sample to get \ 451 actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \ 452 exploitation. 453 Arguments: 454 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 455 key of the dict is environment id and the value is the corresponding data of the env. 456 Returns: 457 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 458 key of the dict is the same as the input data, i.e. environment id. 459 460 .. note:: 461 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 462 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 463 You can implement you own model rather than use the default model. For more information, please raise an \ 464 issue in GitHub repo and we will continue to follow up. 465 466 .. note:: 467 For more detailed examples, please refer to unittest for IMPALAPolicy: ``ding.policy.tests.test_impala``. 468 """ 469 data_id = list(data.keys()) 470 data = default_collate(list(data.values())) 471 if self._cuda: 472 data = to_device(data, self._device) 473 self._eval_model.eval() 474 with torch.no_grad(): 475 output = self._eval_model.forward(data, mode='compute_actor') 476 if self._cuda: 477 output = to_device(output, 'cpu') 478 output = default_decollate(output) 479 output = {i: d for i, d in zip(data_id, output)} 480 return output 481 482 def _monitor_vars_learn(self) -> List[str]: 483 """ 484 Overview: 485 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 486 as text logger, tensorboard logger, will use these keys to save the corresponding data. 487 Returns: 488 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 489 """ 490 return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss']