Skip to content

ding.policy.r2d3

ding.policy.r2d3

R2D3Policy

Bases: Policy

Overview

Policy class of r2d3, from paper Making Efficient Use of Demonstrations to Solve Hard Exploration Problems .

Config

== ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str dqn | RL policy register name, refer to | This arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | This arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool False | Whether use priority(PER) | Priority sample, | update priority 5 | priority_IS bool False | Whether use Importance Sampling Weight | _weight | to correct biased update. If True, | priority must be True. 6 | discount_ float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 7 nstep int 3, | N-step reward discount sum for target [3, 5] | q_value estimation 8 burnin_step int 2 | The timestep of burnin operation, | which is designed to RNN hidden state | difference caused by off-policy 9 | learn.update int 1 | How many updates(iterations) to train | This args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 10 | learn.batch_ int 64 | The number of samples of an iteration | size 11 | learn.learning float 0.001 | Gradient step length of an iteration. | _rate 12 | learn.value_ bool True | Whether use value_rescale function for | rescale | predicted value 13 | learn.target_ int 100 | Frequence of target network update. | Hard(assign) update | update_freq 14 | learn.ignore_ bool False | Whether ignore done for target value | Enable it for some | done | calculation. | fake termination env 15 collect.n_sample int [8, 128] | The number of training samples of a | It varies from | call of collector. | different envs 16 | collect.unroll int 1 | unroll length of an iteration | In RNN, unroll_len>1 | _len == ==================== ======== ============== ======================================== =======================

Full Source Code

../ding/policy/r2d3.py

1import copy 2from collections import namedtuple 3from typing import List, Dict, Any, Tuple, Union, Optional 4 5import torch 6 7from ding.model import model_wrap 8from ding.rl_utils import q_nstep_td_error_with_rescale, get_nstep_return_data, \ 9 get_train_sample, dqfd_nstep_td_error, dqfd_nstep_td_error_with_rescale, dqfd_nstep_td_data 10from ding.torch_utils import Adam, to_device 11from ding.utils import POLICY_REGISTRY 12from ding.utils.data import timestep_collate, default_collate, default_decollate 13from .base_policy import Policy 14 15 16@POLICY_REGISTRY.register('r2d3') 17class R2D3Policy(Policy): 18 r""" 19 Overview: 20 Policy class of r2d3, from paper `Making Efficient Use of Demonstrations to Solve Hard Exploration Problems` . 21 22 Config: 23 == ==================== ======== ============== ======================================== ======================= 24 ID Symbol Type Default Value Description Other(Shape) 25 == ==================== ======== ============== ======================================== ======================= 26 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional, 27 | registry ``POLICY_REGISTRY`` | a placeholder 28 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 29 | erent from modes 30 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 31 | or off-policy 32 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, 33 | update priority 34 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight 35 | ``_weight`` | to correct biased update. If True, 36 | priority must be True. 37 6 | ``discount_`` float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse 38 | ``factor`` [0.95, 0.999] | gamma | reward env 39 7 ``nstep`` int 3, | N-step reward discount sum for target 40 [3, 5] | q_value estimation 41 8 ``burnin_step`` int 2 | The timestep of burnin operation, 42 | which is designed to RNN hidden state 43 | difference caused by off-policy 44 9 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary 45 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 46 | valid in serial training | means more off-policy 47 10 | ``learn.batch_`` int 64 | The number of samples of an iteration 48 | ``size`` 49 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 50 | ``_rate`` 51 12 | ``learn.value_`` bool True | Whether use value_rescale function for 52 | ``rescale`` | predicted value 53 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update 54 | ``update_freq`` 55 14 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 56 | ``done`` | calculation. | fake termination env 57 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from 58 | call of collector. | different envs 59 16 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 60 | ``_len`` 61 == ==================== ======== ============== ======================================== ======================= 62 """ 63 config = dict( 64 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 65 type='r2d3', 66 # (bool) Whether to use cuda for network. 67 cuda=False, 68 # (bool) Whether the RL algorithm is on-policy or off-policy. 69 on_policy=False, 70 # (bool) Whether use priority(priority sample, IS weight, update priority) 71 priority=True, 72 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 73 priority_IS_weight=True, 74 # ============================================================== 75 # The following configs are algorithm-specific 76 # ============================================================== 77 # (float) Reward's future discount factor, aka. gamma. 78 discount_factor=0.997, 79 # (int) N-step reward for target q_value estimation 80 nstep=5, 81 # (int) the timestep of burnin operation, which is designed to RNN hidden state difference 82 # caused by off-policy 83 burnin_step=2, 84 # (int) the trajectory length to unroll the RNN network minus 85 # the timestep of burnin operation 86 learn_unroll_len=80, 87 learn=dict( 88 update_per_collect=1, 89 batch_size=64, 90 learning_rate=0.0001, 91 # ============================================================== 92 # The following configs are algorithm-specific 93 # ============================================================== 94 # (int) Frequence of target network update. 95 # target_update_freq=100, 96 target_update_theta=0.001, 97 # (bool) whether use value_rescale function for predicted value 98 value_rescale=True, 99 ignore_done=False, 100 ), 101 collect=dict( 102 # NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF 103 # each_iter_n_sample=32, 104 # `env_num` is used in hidden state, should equal to that one in env config. 105 # User should specify this value in user config. 106 env_num=None, 107 ), 108 eval=dict( 109 # `env_num` is used in hidden state, should equal to that one in env config. 110 # User should specify this value in user config. 111 env_num=None, 112 ), 113 other=dict( 114 eps=dict( 115 type='exp', 116 start=0.95, 117 end=0.05, 118 decay=10000, 119 ), 120 replay_buffer=dict(replay_buffer_size=10000, ), 121 ), 122 ) 123 124 def default_model(self) -> Tuple[str, List[str]]: 125 return 'drqn', ['ding.model.template.q_learning'] 126 127 def _init_learn(self) -> None: 128 r""" 129 Overview: 130 Init the learner model of r2d3Policy 131 132 Arguments: 133 .. note:: 134 135 The _init_learn method takes the argument from the self._cfg.learn in the config file 136 137 - learning_rate (:obj:`float`): The learning rate fo the optimizer 138 - gamma (:obj:`float`): The discount factor 139 - nstep (:obj:`int`): The num of n step return 140 - value_rescale (:obj:`bool`): Whether to use value rescaled loss in algorithm 141 - burnin_step (:obj:`int`): The num of step of burnin 142 """ 143 self.lambda1 = self._cfg.learn.lambda1 # n-step return 144 self.lambda2 = self._cfg.learn.lambda2 # supervised loss 145 self.lambda3 = self._cfg.learn.lambda3 # L2 146 self.lambda_one_step_td = self._cfg.learn.lambda_one_step_td # 1-step return 147 # margin function in JE, here we implement this as a constant 148 self.margin_function = self._cfg.learn.margin_function 149 150 self._priority = self._cfg.priority 151 self._priority_IS_weight = self._cfg.priority_IS_weight 152 self._optimizer = Adam( 153 self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3, optim_type='adamw' 154 ) 155 self._gamma = self._cfg.discount_factor 156 self._nstep = self._cfg.nstep 157 self._burnin_step = self._cfg.burnin_step 158 self._value_rescale = self._cfg.learn.value_rescale 159 160 self._target_model = copy.deepcopy(self._model) 161 self._target_model = model_wrap( 162 self._target_model, 163 wrapper_name='target', 164 update_type='momentum', 165 update_kwargs={'theta': self._cfg.learn.target_update_theta} 166 ) 167 self._target_model = model_wrap( 168 self._target_model, 169 wrapper_name='hidden_state', 170 state_num=self._cfg.learn.batch_size, 171 ) 172 self._learn_model = model_wrap( 173 self._model, 174 wrapper_name='hidden_state', 175 state_num=self._cfg.learn.batch_size, 176 ) 177 self._learn_model = model_wrap(self._learn_model, wrapper_name='argmax_sample') 178 self._learn_model.reset() 179 self._target_model.reset() 180 181 def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict: 182 r""" 183 Overview: 184 Preprocess the data to fit the required data format for learning 185 186 Arguments: 187 - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function 188 189 Returns: 190 - data (:obj:`Dict[str, Any]`): the processed data, including at least \ 191 ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight'] 192 - data_info (:obj:`dict`): the data info, such as replay_buffer_idx, replay_unique_id 193 """ 194 # data preprocess 195 data = timestep_collate(data) 196 if self._cuda: 197 data = to_device(data, self._device) 198 199 if self._priority_IS_weight: 200 assert self._priority, "Use IS Weight correction, but Priority is not used." 201 if self._priority and self._priority_IS_weight: 202 data['weight'] = data['IS'] 203 else: 204 data['weight'] = data.get('weight', None) 205 206 bs = self._burnin_step 207 208 # data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate 209 # the q_nstep_td_error, should be length of [self._sequence_len-self._burnin_step-self._nstep] 210 ignore_done = self._cfg.learn.ignore_done 211 if ignore_done: 212 data['done'] = [None for _ in range(self._sequence_len - bs)] 213 else: 214 data['done'] = data['done'][bs:].float() 215 # NOTE that after the proprocessing of get_nstep_return_data() in _get_train_sample 216 # the data['done'] [t] is already the n-step done 217 218 # if the data don't include 'weight' or 'value_gamma' then fill in None in a list 219 # with length of [self._sequence_len-self._burnin_step-self._nstep], 220 # below is two different implementation ways 221 if 'value_gamma' not in data: 222 data['value_gamma'] = [None for _ in range(self._sequence_len - bs)] 223 else: 224 data['value_gamma'] = data['value_gamma'][bs:] 225 226 if 'weight' not in data: 227 data['weight'] = [None for _ in range(self._sequence_len - bs)] 228 else: 229 data['weight'] = data['weight'] * torch.ones_like(data['done']) 230 # every timestep in sequence has same weight, which is the _priority_IS_weight in PER 231 232 data['action'] = data['action'][bs:-self._nstep] 233 data['reward'] = data['reward'][bs:-self._nstep] 234 235 # the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value, 236 # target_q_value, and target_q_action 237 data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep] 238 # the main_obs is used to calculate the q_value, the [bs:-self._nstep] means using the data from 239 # [bs] timestep to [self._sequence_len-self._nstep] timestep 240 data['main_obs'] = data['obs'][bs:-self._nstep] 241 # the target_obs is used to calculate the target_q_value 242 data['target_obs'] = data['obs'][bs + self._nstep:] 243 244 # TODO(pu) 245 data['target_obs_one_step'] = data['obs'][bs + 1:] 246 if ignore_done: 247 data['done_one_step'] = [None for _ in range(self._sequence_len - bs)] 248 else: 249 data['done_one_step'] = data['done_one_step'][bs:].float() 250 251 return data 252 253 def _forward_learn(self, data: dict) -> Dict[str, Any]: 254 r""" 255 Overview: 256 Forward and backward function of learn mode. 257 Acquire the data, calculate the loss and optimize learner model. 258 259 Arguments: 260 - data (:obj:`dict`): Dict type data, including at least \ 261 ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight'] 262 263 Returns: 264 - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss 265 - cur_lr (:obj:`float`): Current learning rate 266 - total_loss (:obj:`float`): The calculated loss 267 """ 268 # forward 269 data = self._data_preprocess_learn(data) 270 self._learn_model.train() 271 self._target_model.train() 272 # take out the hidden state in timestep=0 273 self._learn_model.reset(data_id=None, state=data['prev_state'][0]) 274 self._target_model.reset(data_id=None, state=data['prev_state'][0]) 275 276 if len(data['burnin_nstep_obs']) != 0: 277 with torch.no_grad(): 278 inputs = {'obs': data['burnin_nstep_obs']} 279 burnin_output = self._learn_model.forward( 280 inputs, 281 saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep, self._burnin_step + 1] 282 ) 283 burnin_output_target = self._target_model.forward( 284 inputs, 285 saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep, self._burnin_step + 1] 286 ) 287 288 self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][0]) 289 inputs = {'obs': data['main_obs']} 290 q_value = self._learn_model.forward(inputs)['logit'] 291 292 # n-step 293 self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][1]) 294 self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][1]) 295 296 next_inputs = {'obs': data['target_obs']} 297 with torch.no_grad(): 298 target_q_value = self._target_model.forward(next_inputs)['logit'] 299 # argmax_action double_dqn 300 target_q_action = self._learn_model.forward(next_inputs)['action'] 301 302 # one-step 303 self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][2]) 304 self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][2]) 305 306 next_inputs_one_step = {'obs': data['target_obs_one_step']} 307 with torch.no_grad(): 308 target_q_value_one_step = self._target_model.forward(next_inputs_one_step)['logit'] 309 # argmax_action double_dqn 310 target_q_action_one_step = self._learn_model.forward(next_inputs_one_step)['action'] 311 312 action, reward, done, weight = data['action'], data['reward'], data['done'], data['weight'] 313 value_gamma = data['value_gamma'] 314 done_one_step = data['done_one_step'] 315 # T, B, nstep -> T, nstep, B 316 reward = reward.permute(0, 2, 1).contiguous() 317 loss = [] 318 loss_nstep = [] 319 loss_1step = [] 320 loss_sl = [] 321 td_error = [] 322 for t in range(self._sequence_len - self._burnin_step - self._nstep): 323 # here t=0 means timestep <self._burnin_step> in the original sample sequence, we minus self._nstep 324 # because for the last <self._nstep> timestep in the sequence, we don't have their target obs 325 td_data = dqfd_nstep_td_data( 326 q_value[t], 327 target_q_value[t], 328 action[t], 329 target_q_action[t], 330 reward[t], 331 done[t], 332 done_one_step[t], 333 weight[t], 334 target_q_value_one_step[t], 335 target_q_action_one_step[t], 336 data['is_expert'][t], # is_expert flag(expert 1, agent 0) 337 ) 338 339 if self._value_rescale: 340 l, e, loss_statistics = dqfd_nstep_td_error_with_rescale( 341 td_data, 342 self._gamma, 343 self.lambda1, 344 self.lambda2, 345 self.margin_function, 346 self.lambda_one_step_td, 347 self._nstep, 348 False, 349 value_gamma=value_gamma[t], 350 ) 351 loss.append(l) 352 # td_error.append(e.abs()) # first sum then abs 353 td_error.append(e) # first abs then sum 354 # loss statistics for debugging 355 loss_nstep.append(loss_statistics[0]) 356 loss_1step.append(loss_statistics[1]) 357 loss_sl.append(loss_statistics[2]) 358 359 else: 360 l, e, loss_statistics = dqfd_nstep_td_error( 361 td_data, 362 self._gamma, 363 self.lambda1, 364 self.lambda2, 365 self.margin_function, 366 self.lambda_one_step_td, 367 self._nstep, 368 False, 369 value_gamma=value_gamma[t], 370 ) 371 loss.append(l) 372 # td_error.append(e.abs()) # first sum then abs 373 td_error.append(e) # first abs then sum 374 # loss statistics for debugging 375 loss_nstep.append(loss_statistics[0]) 376 loss_1step.append(loss_statistics[1]) 377 loss_sl.append(loss_statistics[2]) 378 379 loss = sum(loss) / (len(loss) + 1e-8) 380 # loss statistics for debugging 381 loss_nstep = sum(loss_nstep) / (len(loss_nstep) + 1e-8) 382 loss_1step = sum(loss_1step) / (len(loss_1step) + 1e-8) 383 loss_sl = sum(loss_sl) / (len(loss_sl) + 1e-8) 384 385 # using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence 386 td_error_per_sample = 0.9 * torch.max( 387 torch.stack(td_error), dim=0 388 )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8)) 389 # td_error shape list(<self._sequence_len-self._burnin_step-self._nstep>, B), for example, (75,64) 390 # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error) 391 392 # update 393 self._optimizer.zero_grad() 394 loss.backward() 395 self._optimizer.step() 396 # after update 397 self._target_model.update(self._learn_model.state_dict()) 398 399 # the information for debug 400 batch_range = torch.arange(action[0].shape[0]) 401 q_s_a_t0 = q_value[0][batch_range, action[0]] 402 target_q_s_a_t0 = target_q_value[0][batch_range, target_q_action[0]] 403 404 return { 405 'cur_lr': self._optimizer.defaults['lr'], 406 'total_loss': loss.item(), 407 # loss statistics for debugging 408 'nstep_loss': loss_nstep.item(), 409 '1step_loss': loss_1step.item(), 410 'sl_loss': loss_sl.item(), 411 'priority': td_error_per_sample.abs().tolist(), 412 # the first timestep in the sequence, may not be the start of episode 413 'q_s_taken-a_t0': q_s_a_t0.mean().item(), 414 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(), 415 'q_s_a-mean_t0': q_value[0].mean().item(), 416 } 417 418 def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: 419 self._learn_model.reset(data_id=data_id) 420 421 def _state_dict_learn(self) -> Dict[str, Any]: 422 return { 423 'model': self._learn_model.state_dict(), 424 'target_model': self._target_model.state_dict(), 425 'optimizer': self._optimizer.state_dict(), 426 } 427 428 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 429 self._learn_model.load_state_dict(state_dict['model']) 430 self._target_model.load_state_dict(state_dict['target_model']) 431 self._optimizer.load_state_dict(state_dict['optimizer']) 432 433 def _init_collect(self) -> None: 434 r""" 435 Overview: 436 Collect mode init method. Called by ``self.__init__``. 437 Init traj and unroll length, collect model. 438 """ 439 assert 'unroll_len' not in self._cfg.collect, "r2d3 use default unroll_len" 440 self._nstep = self._cfg.nstep 441 self._burnin_step = self._cfg.burnin_step 442 self._gamma = self._cfg.discount_factor 443 self._sequence_len = self._cfg.learn_unroll_len + self._cfg.burnin_step 444 self._unroll_len = self._sequence_len # for compatibility 445 446 self._collect_model = model_wrap( 447 self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True 448 ) 449 self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample') 450 self._collect_model.reset() 451 452 def _forward_collect(self, data: dict, eps: float) -> dict: 453 r""" 454 Overview: 455 Collect output according to eps_greedy plugin 456 457 Arguments: 458 - data (:obj:`dict`): Dict type data, including at least ['obs']. 459 460 Returns: 461 - data (:obj:`dict`): The collected data 462 """ 463 data_id = list(data.keys()) 464 data = default_collate(list(data.values())) 465 if self._cuda: 466 data = to_device(data, self._device) 467 data = {'obs': data} 468 self._collect_model.eval() 469 with torch.no_grad(): 470 # in collect phase, inference=True means that each time we only pass one timestep data, 471 # so the we can get the hidden state of rnn: <prev_state> at each timestep. 472 output = self._collect_model.forward(data, data_id=data_id, eps=eps, inference=True) 473 if self._cuda: 474 output = to_device(output, 'cpu') 475 output = default_decollate(output) 476 return {i: d for i, d in zip(data_id, output)} 477 478 def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: 479 self._collect_model.reset(data_id=data_id) 480 481 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 482 r""" 483 Overview: 484 Generate dict type transition data from inputs. 485 Arguments: 486 - obs (:obj:`Any`): Env observation 487 - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state'] 488 - timestep (:obj:`namedtuple`): Output after env step, including at least ['reward', 'done'] \ 489 (here 'obs' indicates obs after env step). 490 Returns: 491 - transition (:obj:`dict`): Dict type transition data. 492 """ 493 transition = { 494 'obs': obs, 495 'action': model_output['action'], 496 'prev_state': model_output['prev_state'], 497 'reward': timestep.reward, 498 'done': timestep.done, 499 } 500 return transition 501 502 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 503 r""" 504 Overview: 505 Get the trajectory and the n step return data, then sample from the n_step return data 506 507 Arguments: 508 - data (:obj:`list`): The trajectory's cache 509 510 Returns: 511 - samples (:obj:`dict`): The training samples generated 512 """ 513 from copy import deepcopy 514 data_one_step = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma)) 515 data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) 516 for i in range(len(data)): 517 # here we record the one-step done, we don't need record one-step reward, 518 # because the n-step reward in data already include one-step reward 519 data[i]['done_one_step'] = data_one_step[i]['done'] 520 return get_train_sample(data, self._sequence_len) 521 522 def _init_eval(self) -> None: 523 r""" 524 Overview: 525 Evaluate mode init method. Called by ``self.__init__``. 526 Init eval model with argmax strategy. 527 """ 528 self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num) 529 self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample') 530 self._eval_model.reset() 531 532 def _forward_eval(self, data: dict) -> dict: 533 r""" 534 Overview: 535 Forward function of collect mode, similar to ``self._forward_collect``. 536 537 Arguments: 538 - data (:obj:`dict`): Dict type data, including at least ['obs']. 539 540 Returns: 541 - output (:obj:`dict`): Dict type data, including at least inferred action according to input obs. 542 """ 543 data_id = list(data.keys()) 544 data = default_collate(list(data.values())) 545 if self._cuda: 546 data = to_device(data, self._device) 547 data = {'obs': data} 548 self._eval_model.eval() 549 with torch.no_grad(): 550 output = self._eval_model.forward(data, data_id=data_id, inference=True) 551 if self._cuda: 552 output = to_device(output, 'cpu') 553 output = default_decollate(output) 554 return {i: d for i, d in zip(data_id, output)} 555 556 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 557 self._eval_model.reset(data_id=data_id) 558 559 def _monitor_vars_learn(self) -> List[str]: 560 return super()._monitor_vars_learn() + [ 561 'total_loss', 'nstep_loss', '1step_loss', 'sl_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 562 'q_s_a-mean_t0' 563 ]