1from typing import List, Dict, Any, Tuple, Union, Optional 2from collections import namedtuple 3import torch 4import copy 5 6from ding.torch_utils import Adam, to_device 7from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_nstep_td_error_with_rescale, get_nstep_return_data, \ 8 get_train_sample 9from ding.model import model_wrap 10from ding.utils import POLICY_REGISTRY 11from ding.utils.data import timestep_collate, default_collate, default_decollate 12from .base_policy import Policy 13 14 15@POLICY_REGISTRY.register('ngu') 16class NGUPolicy(Policy): 17 r""" 18 Overview: 19 Policy class of NGU. The corresponding paper is `never give up: learning directed exploration strategies`. 20 21 Config: 22 == ==================== ======== ============== ======================================== ======================= 23 ID Symbol Type Default Value Description Other(Shape) 24 == ==================== ======== ============== ======================================== ======================= 25 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional, 26 | registry ``POLICY_REGISTRY`` | a placeholder 27 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 28 | erent from modes 29 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 30 | or off-policy 31 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, 32 | update priority 33 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight 34 | ``_weight`` | to correct biased update. If True, 35 | priority must be True. 36 6 | ``discount_`` float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse 37 | ``factor`` [0.95, 0.999] | gamma | reward env 38 7 ``nstep`` int 3, | N-step reward discount sum for target 39 [3, 5] | q_value estimation 40 8 ``burnin_step`` int 2 | The timestep of burnin operation, 41 | which is designed to RNN hidden state 42 | difference caused by off-policy 43 9 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary 44 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 45 | valid in serial training | means more off-policy 46 10 | ``learn.batch_`` int 64 | The number of samples of an iteration 47 | ``size`` 48 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 49 | ``_rate`` 50 12 | ``learn.value_`` bool True | Whether use value_rescale function for 51 | ``rescale`` | predicted value 52 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update 53 | ``update_freq`` 54 14 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 55 | ``done`` | calculation. | fake termination env 56 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from 57 | call of collector. | different envs 58 16 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 59 | ``_len`` 60 == ==================== ======== ============== ======================================== ======================= 61 """ 62 config = dict( 63 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 64 type='ngu', 65 # (bool) Whether to use cuda for network. 66 cuda=False, 67 # (bool) Whether the RL algorithm is on-policy or off-policy. 68 on_policy=False, 69 # (bool) Whether use priority(priority sample, IS weight, update priority) 70 priority=True, 71 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 72 priority_IS_weight=True, 73 # ============================================================== 74 # The following configs are algorithm-specific 75 # ============================================================== 76 # (float) Reward's future discount factor, aka. gamma. 77 discount_factor=0.997, 78 # (int) N-step reward for target q_value estimation 79 nstep=5, 80 # (int) the timestep of burnin operation, which is designed to RNN hidden state difference 81 # caused by off-policy 82 burnin_step=20, 83 # (int) <learn_unroll_len> is the total length of [sequence sample] minus 84 # the length of burnin part in [sequence sample], 85 # i.e., <sequence sample length> = <unroll_len> = <burnin_step> + <learn_unroll_len> 86 learn_unroll_len=80, # set this key according to the episode length 87 learn=dict( 88 update_per_collect=1, 89 batch_size=64, 90 learning_rate=0.0001, 91 # ============================================================== 92 # The following configs are algorithm-specific 93 # ============================================================== 94 # (float type) target_update_theta: Used for soft update of the target network, 95 # aka. Interpolation factor in polyak averaging for target networks. 96 target_update_theta=0.001, 97 # (bool) whether use value_rescale function for predicted value 98 value_rescale=True, 99 ignore_done=False, 100 ), 101 collect=dict( 102 # NOTE: It is important that set key traj_len_inf=True here, 103 # to make sure self._traj_len=INF in serial_sample_collector.py. 104 # In sequence-based policy, for each collect_env, 105 # we want to collect data of length self._traj_len=INF 106 # unless the episode enters the 'done' state. 107 # In each collect phase, we collect a total of <n_sample> sequence samples. 108 n_sample=32, 109 traj_len_inf=True, 110 # `env_num` is used in hidden state, should equal to that one in env config. 111 # User should specify this value in user config. 112 env_num=None, 113 ), 114 eval=dict( 115 # `env_num` is used in hidden state, should equal to that one in env config. 116 # User should specify this value in user config. 117 env_num=None, 118 ), 119 other=dict( 120 eps=dict( 121 type='exp', 122 start=0.95, 123 end=0.05, 124 decay=10000, 125 ), 126 replay_buffer=dict(replay_buffer_size=10000, ), 127 ), 128 ) 129 130 def default_model(self) -> Tuple[str, List[str]]: 131 return 'ngu', ['ding.model.template.ngu'] 132 133 def _init_learn(self) -> None: 134 r""" 135 Overview: 136 Init the learner model of R2D2Policy 137 138 Arguments: 139 .. note:: 140 141 The _init_learn method takes the argument from the self._cfg.learn in the config file 142 143 - learning_rate (:obj:`float`): The learning rate fo the optimizer 144 - gamma (:obj:`float`): The discount factor 145 - nstep (:obj:`int`): The num of n step return 146 - value_rescale (:obj:`bool`): Whether to use value rescaled loss in algorithm 147 - burnin_step (:obj:`int`): The num of step of burnin 148 """ 149 self._priority = self._cfg.priority 150 self._priority_IS_weight = self._cfg.priority_IS_weight 151 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 152 self._gamma = self._cfg.discount_factor 153 self._nstep = self._cfg.nstep 154 self._burnin_step = self._cfg.burnin_step 155 self._value_rescale = self._cfg.learn.value_rescale 156 157 self._target_model = copy.deepcopy(self._model) 158 # here we should not adopt the 'assign' mode of target network here because the reset bug 159 # self._target_model = model_wrap( 160 # self._target_model, 161 # wrapper_name='target', 162 # update_type='assign', 163 # update_kwargs={'freq': self._cfg.learn.target_update_freq} 164 # ) 165 self._target_model = model_wrap( 166 self._target_model, 167 wrapper_name='target', 168 update_type='momentum', 169 update_kwargs={'theta': self._cfg.learn.target_update_theta} 170 ) 171 self._target_model = model_wrap( 172 self._target_model, wrapper_name='hidden_state', state_num=self._cfg.learn.batch_size, save_prev_state=True 173 ) 174 self._learn_model = model_wrap( 175 self._model, wrapper_name='hidden_state', state_num=self._cfg.learn.batch_size, save_prev_state=True 176 ) 177 self._learn_model = model_wrap(self._learn_model, wrapper_name='argmax_sample') 178 self._learn_model.reset() 179 self._target_model.reset() 180 181 def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict: 182 r""" 183 Overview: 184 Preprocess the data to fit the required data format for learning 185 186 Arguments: 187 - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function 188 189 Returns: 190 - data (:obj:`Dict[str, Any]`): the processed data, including at least \ 191 ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight'] 192 - data_info (:obj:`dict`): the data info, such as replay_buffer_idx, replay_unique_id 193 """ 194 195 # data preprocess 196 data = timestep_collate(data) 197 if self._cuda: 198 data = to_device(data, self._device) 199 200 if self._priority_IS_weight: 201 assert self._priority, "Use IS Weight correction, but Priority is not used." 202 if self._priority and self._priority_IS_weight: 203 data['weight'] = data['IS'] 204 else: 205 data['weight'] = data.get('weight', None) 206 207 bs = self._burnin_step 208 209 # data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate 210 # the q_nstep_td_error, should be length of [self._sequence_len-self._burnin_step] 211 ignore_done = self._cfg.learn.ignore_done 212 if ignore_done: 213 data['done'] = [None for _ in range(self._sequence_len - bs - self._nstep)] 214 else: 215 data['done'] = data['done'][bs:].float() # for computation of online model self._learn_model 216 # NOTE that after the proprocessing of get_nstep_return_data() in _get_train_sample 217 # the data['done'] [t] is already the n-step done 218 219 # if the data don't include 'weight' or 'value_gamma' then fill in None in a list 220 # with length of [self._sequence_len-self._burnin_step], 221 # below is two different implementation ways 222 if 'value_gamma' not in data: 223 data['value_gamma'] = [None for _ in range(self._sequence_len - bs)] 224 else: 225 data['value_gamma'] = data['value_gamma'][bs:] 226 227 if 'weight' not in data: 228 data['weight'] = [None for _ in range(self._sequence_len - bs)] 229 else: 230 data['weight'] = data['weight'] * torch.ones_like(data['done']) 231 # every timestep in sequence has same weight, which is the _priority_IS_weight in PER 232 233 # the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value, 234 # target_q_value, and target_q_action 235 data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep] 236 data['burnin_nstep_action'] = data['action'][:bs + self._nstep] 237 data['burnin_nstep_reward'] = data['reward'][:bs + self._nstep] 238 data['burnin_nstep_beta'] = data['beta'][:bs + self._nstep] 239 240 # split obs into three parts 'burnin_obs' [0:bs], 'main_obs' [bs:bs+nstep], 'target_obs' [bs+nstep:] 241 # data['burnin_obs'] = data['obs'][:bs] 242 data['main_obs'] = data['obs'][bs:-self._nstep] 243 data['target_obs'] = data['obs'][bs + self._nstep:] 244 245 # data['burnin_action'] = data['action'][:bs] 246 data['main_action'] = data['action'][bs:-self._nstep] 247 data['target_action'] = data['action'][bs + self._nstep:] 248 249 # data['burnin_reward'] = data['reward'][:bs] 250 data['main_reward'] = data['reward'][bs:-self._nstep] 251 data['target_reward'] = data['reward'][bs + self._nstep:] 252 253 # data['burnin_beta'] = data['beta'][:bs] 254 data['main_beta'] = data['beta'][bs:-self._nstep] 255 data['target_beta'] = data['beta'][bs + self._nstep:] 256 257 # Note that Must be here after the previous slicing operation 258 data['action'] = data['action'][bs:-self._nstep] 259 data['reward'] = data['reward'][bs:-self._nstep] 260 261 return data 262 263 def _forward_learn(self, data: dict) -> Dict[str, Any]: 264 r""" 265 Overview: 266 Forward and backward function of learn mode. 267 Acquire the data, calculate the loss and optimize learner model. 268 269 Arguments: 270 - data (:obj:`dict`): Dict type data, including at least \ 271 ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight'] 272 273 Returns: 274 - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss 275 - cur_lr (:obj:`float`): Current learning rate 276 - total_loss (:obj:`float`): The calculated loss 277 """ 278 # forward 279 data = self._data_preprocess_learn(data) 280 self._learn_model.train() 281 self._target_model.train() 282 # use the hidden state in timestep=0 283 self._learn_model.reset(data_id=None, state=data['prev_state'][0]) 284 self._target_model.reset(data_id=None, state=data['prev_state'][0]) 285 286 if len(data['burnin_nstep_obs']) != 0: 287 with torch.no_grad(): 288 inputs = { 289 'obs': data['burnin_nstep_obs'], 290 'action': data['burnin_nstep_action'], 291 'reward': data['burnin_nstep_reward'], 292 'beta': data['burnin_nstep_beta'], 293 } 294 tmp = self._learn_model.forward( 295 inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] 296 ) 297 tmp_target = self._target_model.forward( 298 inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] 299 ) 300 301 inputs = { 302 'obs': data['main_obs'], 303 'action': data['main_action'], 304 'reward': data['main_reward'], 305 'beta': data['main_beta'], 306 } 307 self._learn_model.reset(data_id=None, state=tmp['saved_state'][0]) 308 q_value = self._learn_model.forward(inputs)['logit'] 309 310 self._learn_model.reset(data_id=None, state=tmp['saved_state'][1]) 311 self._target_model.reset(data_id=None, state=tmp_target['saved_state'][1]) 312 313 next_inputs = { 314 'obs': data['target_obs'], 315 'action': data['target_action'], 316 'reward': data['target_reward'], 317 'beta': data['target_beta'], 318 } 319 with torch.no_grad(): 320 target_q_value = self._target_model.forward(next_inputs)['logit'] 321 # argmax_action double_dqn 322 target_q_action = self._learn_model.forward(next_inputs)['action'] 323 324 action, reward, done, weight = data['action'], data['reward'], data['done'], data['weight'] 325 value_gamma = [ 326 None for _ in range(self._sequence_len - self._burnin_step) 327 ] # NOTE this is important, because we use diffrent gamma according to their beta in NGU alg. 328 329 # T, B, nstep -> T, nstep, B 330 reward = reward.permute(0, 2, 1).contiguous() 331 loss = [] 332 td_error = [] 333 self._gamma = [self.index_to_gamma[int(i)] for i in data['main_beta'][0]] # T, B -> B, e.g. 75,64 -> 64 334 335 # reward torch.Size([4, 5, 64]) 336 for t in range(self._sequence_len - self._burnin_step - self._nstep): 337 # here t=0 means timestep <self._burnin_step> in the original sample sequence, we minus self._nstep 338 # because for the last <self._nstep> timestep in the sequence, we don't have their target obs 339 td_data = q_nstep_td_data( 340 q_value[t], target_q_value[t], action[t], target_q_action[t], reward[t], done[t], weight[t] 341 ) 342 if self._value_rescale: 343 l, e = q_nstep_td_error_with_rescale(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t]) 344 loss.append(l) 345 td_error.append(e.abs()) 346 else: 347 l, e = q_nstep_td_error(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t]) 348 loss.append(l) 349 td_error.append(e.abs()) 350 loss = sum(loss) / (len(loss) + 1e-8) 351 352 # using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence 353 td_error_per_sample = 0.9 * torch.max( 354 torch.stack(td_error), dim=0 355 )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8)) 356 # td_error shape list(<self._sequence_len-self._burnin_step-self._nstep>, B), 357 # for example, (75,64) 358 # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error) 359 360 # update 361 self._optimizer.zero_grad() 362 loss.backward() 363 self._optimizer.step() 364 # after update 365 self._target_model.update(self._learn_model.state_dict()) 366 367 # the information for debug 368 batch_range = torch.arange(action[0].shape[0]) 369 q_s_a_t0 = q_value[0][batch_range, action[0]] 370 target_q_s_a_t0 = target_q_value[0][batch_range, target_q_action[0]] 371 372 return { 373 'cur_lr': self._optimizer.defaults['lr'], 374 'total_loss': loss.item(), 375 'priority': td_error_per_sample.abs().tolist(), 376 # the first timestep in the sequence, may not be the start of episode 377 'q_s_taken-a_t0': q_s_a_t0.mean().item(), 378 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(), 379 'q_s_a-mean_t0': q_value[0].mean().item(), 380 } 381 382 def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: 383 self._learn_model.reset(data_id=data_id) 384 385 def _state_dict_learn(self) -> Dict[str, Any]: 386 return { 387 'model': self._learn_model.state_dict(), 388 'target_model': self._target_model.state_dict(), 389 'optimizer': self._optimizer.state_dict(), 390 } 391 392 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 393 self._learn_model.load_state_dict(state_dict['model']) 394 self._target_model.load_state_dict(state_dict['target_model']) 395 self._optimizer.load_state_dict(state_dict['optimizer']) 396 397 def _init_collect(self) -> None: 398 r""" 399 Overview: 400 Collect mode init method. Called by ``self.__init__``. 401 Init traj and unroll length, collect model. 402 """ 403 assert 'unroll_len' not in self._cfg.collect, "ngu use default <unroll_len = learn_unroll_len + burnin_step>" 404 self._nstep = self._cfg.nstep 405 self._burnin_step = self._cfg.burnin_step 406 self._gamma = self._cfg.discount_factor 407 self._sequence_len = self._cfg.learn_unroll_len + self._cfg.burnin_step 408 self._unroll_len = self._sequence_len 409 self._collect_model = model_wrap( 410 self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True 411 ) 412 self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample') 413 self._collect_model.reset() 414 self.index_to_gamma = { # NOTE 415 i: 1 - torch.exp( 416 ( 417 (self._cfg.collect.env_num - 1 - i) * torch.log(torch.tensor(1 - 0.997)) + 418 i * torch.log(torch.tensor(1 - 0.99)) 419 ) / (self._cfg.collect.env_num - 1) 420 ) 421 for i in range(self._cfg.collect.env_num) 422 } 423 # NOTE: for NGU policy collect phase 424 self.beta_index = { 425 i: torch.randint(0, self._cfg.collect.env_num, [1]) 426 for i in range(self._cfg.collect.env_num) 427 } 428 # epsilon=0.4, alpha=9 429 self.eps = {i: 0.4 ** (1 + 8 * i / (self._cfg.collect.env_num - 1)) for i in range(self._cfg.collect.env_num)} 430 431 def _forward_collect(self, data: dict) -> dict: 432 r""" 433 Overview: 434 Collect output according to eps_greedy plugin 435 436 Arguments: 437 - data (:obj:`dict`): Dict type data, including at least ['obs']. 438 439 Returns: 440 - data (:obj:`dict`): The collected data 441 """ 442 data_id = list(data.keys()) 443 data = default_collate(list(data.values())) 444 445 obs = data['obs'] 446 prev_action = data['prev_action'].long() 447 prev_reward_extrinsic = data['prev_reward_extrinsic'] 448 449 beta_index = default_collate(list(self.beta_index.values())) 450 if len(data_id) != self._cfg.collect.env_num: 451 # in case, some env is in reset state and only return part data 452 beta_index = beta_index[data_id] 453 454 if self._cuda: 455 obs = to_device(obs, self._device) 456 beta_index = to_device(beta_index, self._device) 457 prev_action = to_device(prev_action, self._device) 458 prev_reward_extrinsic = to_device(prev_reward_extrinsic, self._device) 459 # TODO(pu): add prev_reward_intrinsic to network input, 460 # reward uses some kind of embedding instead of 1D value 461 data = { 462 'obs': obs, 463 'prev_action': prev_action, 464 'prev_reward_extrinsic': prev_reward_extrinsic, 465 'beta': beta_index 466 } 467 self._collect_model.eval() 468 with torch.no_grad(): 469 output = self._collect_model.forward(data, data_id=data_id, eps=self.eps, inference=True) 470 if self._cuda: 471 output = to_device(output, 'cpu') 472 output = default_decollate(output) 473 return {i: d for i, d in zip(data_id, output)} 474 475 def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: 476 self._collect_model.reset(data_id=data_id) 477 # NOTE: for NGU policy, in collect phase, each episode, we sample a new beta for each env 478 if data_id is not None: 479 self.beta_index[data_id[0]] = torch.randint(0, self._cfg.collect.env_num, [1]) 480 481 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple, env_id) -> dict: 482 r""" 483 Overview: 484 Generate dict type transition data from inputs. 485 Arguments: 486 - obs (:obj:`Any`): Env observation 487 - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state'] 488 - timestep (:obj:`namedtuple`): Output after env step, including at least ['reward', 'done'] \ 489 (here 'obs' indicates obs after env step). 490 Returns: 491 - transition (:obj:`dict`): Dict type transition data. 492 """ 493 if hasattr(timestep, 'null'): 494 transition = { 495 'beta': self.beta_index[env_id], 496 'obs': obs['obs'], # NOTE: input obs including obs, prev_action, prev_reward_extrinsic 497 'action': model_output['action'], 498 'prev_state': model_output['prev_state'], 499 'reward': timestep.reward, 500 'done': timestep.done, 501 'null': timestep.null, 502 } 503 else: 504 transition = { 505 'beta': self.beta_index[env_id], 506 'obs': obs['obs'], # NOTE: input obs including obs, prev_action, prev_reward_extrinsic 507 'action': model_output['action'], 508 'prev_state': model_output['prev_state'], 509 'reward': timestep.reward, 510 'done': timestep.done, 511 'null': False, 512 } 513 return transition 514 515 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 516 r""" 517 Overview: 518 Get the trajectory and the n step return data, then sample from the n_step return data 519 520 Arguments: 521 - data (:obj:`list`): The trajectory's cache 522 523 Returns: 524 - samples (:obj:`dict`): The training samples generated 525 """ 526 data = get_nstep_return_data(data, self._nstep, gamma=self.index_to_gamma[int(data[0]['beta'])].item()) 527 return get_train_sample(data, self._sequence_len) 528 529 def _init_eval(self) -> None: 530 r""" 531 Overview: 532 Evaluate mode init method. Called by ``self.__init__``. 533 Init eval model with argmax strategy. 534 """ 535 self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num) 536 self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample') 537 self._eval_model.reset() 538 # NOTE: for NGU policy eval phase 539 # beta_index = 0 -> beta is approximately 0 540 self.beta_index = {i: torch.tensor([0]) for i in range(self._cfg.eval.env_num)} 541 542 def _forward_eval(self, data: dict) -> dict: 543 r""" 544 Overview: 545 Forward function of collect mode, similar to ``self._forward_collect``. 546 547 Arguments: 548 - data (:obj:`dict`): Dict type data, including at least ['obs']. 549 550 Returns: 551 - output (:obj:`dict`): Dict type data, including at least inferred action according to input obs. 552 """ 553 554 data_id = list(data.keys()) 555 data = default_collate(list(data.values())) 556 557 obs = data['obs'] 558 prev_action = data['prev_action'].long() 559 prev_reward_extrinsic = data['prev_reward_extrinsic'] 560 561 beta_index = default_collate(list(self.beta_index.values())) 562 if len(data_id) != self._cfg.collect.env_num: 563 # in case, some env is in reset state and only return part data 564 beta_index = beta_index[data_id] 565 566 if self._cuda: 567 obs = to_device(obs, self._device) 568 beta_index = to_device(beta_index, self._device) 569 prev_action = to_device(prev_action, self._device) 570 prev_reward_extrinsic = to_device(prev_reward_extrinsic, self._device) 571 # TODO(pu): add prev_reward_intrinsic to network input, 572 # reward uses some kind of embedding instead of 1D value 573 data = { 574 'obs': obs, 575 'prev_action': prev_action, 576 'prev_reward_extrinsic': prev_reward_extrinsic, 577 'beta': beta_index 578 } 579 580 self._eval_model.eval() 581 with torch.no_grad(): 582 output = self._eval_model.forward(data, data_id=data_id, inference=True) 583 if self._cuda: 584 output = to_device(output, 'cpu') 585 output = default_decollate(output) 586 return {i: d for i, d in zip(data_id, output)} 587 588 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 589 self._eval_model.reset(data_id=data_id) 590 591 def _monitor_vars_learn(self) -> List[str]: 592 return super()._monitor_vars_learn() + [ 593 'total_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 'q_s_a-mean_t0' 594 ]