Skip to content

ding.policy.r2d2

ding.policy.r2d2

R2D2Policy

Bases: Policy

Overview

Policy class of R2D2, from paper Recurrent Experience Replay in Distributed Reinforcement Learning . R2D2 proposes that several tricks should be used to improve upon DRQN, namely some recurrent experience replay tricks and the burn-in mechanism for off-policy training.

Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str r2d2 | RL policy register name, refer to | This arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | This arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool False | Whether use priority(PER) | Priority sample, | update priority 5 | priority_IS bool False | Whether use Importance Sampling Weight | _weight | to correct biased update. If True, | priority must be True. 6 | discount_ float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 7 nstep int 3, | N-step reward discount sum for target [3, 5] | q_value estimation 8 burnin_step int 2 | The timestep of burnin operation, | which is designed to RNN hidden state | difference caused by off-policy 9 | learn.update int 1 | How many updates(iterations) to train | This args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 10 | learn.batch_ int 64 | The number of samples of an iteration | size 11 | learn.learning float 0.001 | Gradient step length of an iteration. | _rate 12 | learn.value_ bool True | Whether use value_rescale function for | rescale | predicted value 13 | learn.target_ int 100 | Frequence of target network update. | Hard(assign) update | update_freq 14 | learn.ignore_ bool False | Whether ignore done for target value | Enable it for some | done | calculation. | fake termination env 15 collect.n_sample int [8, 128] | The number of training samples of a | It varies from | call of collector. | different envs 16 | collect.unroll int 1 | unroll length of an iteration | In RNN, unroll_len>1 | _len == ==================== ======== ============== ======================================== =======================

default_model()

Overview

Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.

.. note:: The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For example about R2D2, its registered name is drqn and the import_names is ding.model.template.q_learning.

Full Source Code

../ding/policy/r2d2.py

1import copy 2from collections import namedtuple 3from typing import List, Dict, Any, Tuple, Union, Optional 4 5import torch 6 7from ding.model import model_wrap 8from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_nstep_td_error_with_rescale, get_nstep_return_data, \ 9 get_train_sample 10from ding.torch_utils import Adam, to_device 11from ding.utils import POLICY_REGISTRY 12from ding.utils.data import timestep_collate, default_collate, default_decollate 13from .base_policy import Policy 14 15 16@POLICY_REGISTRY.register('r2d2') 17class R2D2Policy(Policy): 18 """ 19 Overview: 20 Policy class of R2D2, from paper `Recurrent Experience Replay in Distributed Reinforcement Learning` . 21 R2D2 proposes that several tricks should be used to improve upon DRQN, namely some recurrent experience replay \ 22 tricks and the burn-in mechanism for off-policy training. 23 Config: 24 == ==================== ======== ============== ======================================== ======================= 25 ID Symbol Type Default Value Description Other(Shape) 26 == ==================== ======== ============== ======================================== ======================= 27 1 ``type`` str r2d2 | RL policy register name, refer to | This arg is optional, 28 | registry ``POLICY_REGISTRY`` | a placeholder 29 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 30 | erent from modes 31 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 32 | or off-policy 33 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, 34 | update priority 35 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight 36 | ``_weight`` | to correct biased update. If True, 37 | priority must be True. 38 6 | ``discount_`` float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse 39 | ``factor`` [0.95, 0.999] | gamma | reward env 40 7 ``nstep`` int 3, | N-step reward discount sum for target 41 [3, 5] | q_value estimation 42 8 ``burnin_step`` int 2 | The timestep of burnin operation, 43 | which is designed to RNN hidden state 44 | difference caused by off-policy 45 9 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary 46 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 47 | valid in serial training | means more off-policy 48 10 | ``learn.batch_`` int 64 | The number of samples of an iteration 49 | ``size`` 50 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 51 | ``_rate`` 52 12 | ``learn.value_`` bool True | Whether use value_rescale function for 53 | ``rescale`` | predicted value 54 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update 55 | ``update_freq`` 56 14 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 57 | ``done`` | calculation. | fake termination env 58 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from 59 | call of collector. | different envs 60 16 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 61 | ``_len`` 62 == ==================== ======== ============== ======================================== ======================= 63 """ 64 config = dict( 65 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 66 type='r2d2', 67 # (bool) Whether to use cuda for network. 68 cuda=False, 69 # (bool) Whether the RL algorithm is on-policy or off-policy. 70 on_policy=False, 71 # (bool) Whether to use priority(priority sample, IS weight, update priority) 72 priority=True, 73 # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. 74 priority_IS_weight=True, 75 # (float) Reward's future discount factor, aka. gamma. 76 discount_factor=0.997, 77 # (int) N-step reward for target q_value estimation 78 nstep=5, 79 # (int) the timestep of burnin operation, which is designed to RNN hidden state difference 80 # caused by off-policy 81 burnin_step=20, 82 # (int) the trajectory length to unroll the RNN network minus 83 # the timestep of burnin operation 84 learn_unroll_len=80, 85 # learn_mode config 86 learn=dict( 87 # (int) The number of training updates (iterations) to perform after each data collection by the collector. 88 # A larger "update_per_collect" value implies a more off-policy approach. 89 # The whole pipeline process follows this cycle: collect data -> update policy -> collect data -> ... 90 update_per_collect=1, 91 # (int) The number of samples in a training batch. 92 batch_size=64, 93 # (float) The step size of gradient descent, determining the rate of learning. 94 learning_rate=0.0001, 95 # (int) Frequence of target network update. 96 # target_update_freq=100, 97 target_update_theta=0.001, 98 # (bool) whether use value_rescale function for predicted value 99 value_rescale=True, 100 # (bool) Whether ignore done(usually for max step termination env). 101 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 102 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 103 # However, interaction with HalfCheetah always gets done with done is False, 104 # Since we inplace done==True with done==False to keep 105 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 106 # when the episode step is greater than max episode step. 107 ignore_done=False, 108 ), 109 # collect_mode config 110 collect=dict( 111 # (int) How many training samples collected in one collection procedure. 112 # In each collect phase, we collect a total of <n_sample> sequence samples. 113 n_sample=32, 114 # (bool) It is important that set key traj_len_inf=True here, 115 # to make sure self._traj_len=INF in serial_sample_collector.py. 116 # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF 117 # unless the episode enters the 'done' state. 118 traj_len_inf=True, 119 # (int) `env_num` is used in hidden state, should equal to that one in env config (e.g. collector_env_num). 120 # User should specify this value in user config. `None` is a placeholder. 121 env_num=None, 122 ), 123 # eval_mode config 124 eval=dict( 125 # (int) `env_num` is used in hidden state, should equal to that one in env config (e.g. evaluator_env_num). 126 # User should specify this value in user config. 127 env_num=None, 128 ), 129 other=dict( 130 # Epsilon greedy with decay. 131 eps=dict( 132 # (str) Type of decay. Supports either 'exp' (exponential) or 'linear'. 133 type='exp', 134 # (float) Initial value of epsilon at the start. 135 start=0.95, 136 # (float) Final value of epsilon after decay. 137 end=0.05, 138 # (int) The number of environment steps over which epsilon should decay. 139 decay=10000, 140 ), 141 replay_buffer=dict( 142 # (int) Maximum size of replay buffer. Usually, larger buffer size is better. 143 replay_buffer_size=10000, 144 ), 145 ), 146 ) 147 148 def default_model(self) -> Tuple[str, List[str]]: 149 """ 150 Overview: 151 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 152 automatically call this method to get the default model setting and create model. 153 Returns: 154 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 155 156 .. note:: 157 The user can define and use customized network model but must obey the same inferface definition indicated \ 158 by import_names path. For example about R2D2, its registered name is ``drqn`` and the import_names is \ 159 ``ding.model.template.q_learning``. 160 """ 161 return 'drqn', ['ding.model.template.q_learning'] 162 163 def _init_learn(self) -> None: 164 """ 165 Overview: 166 Initialize the learn mode of policy, including some attributes and modules. For R2D2, it mainly contains \ 167 optimizer, algorithm-specific arguments such as burnin_step, value_rescale and gamma, main and target \ 168 model. Because of the use of RNN, all the models should be wrappered with ``hidden_state`` which needs to \ 169 be initialized with proper size. 170 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 171 172 .. note:: 173 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 174 and ``_load_state_dict_learn`` methods. 175 176 .. note:: 177 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 178 179 .. note:: 180 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 181 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 182 """ 183 self._priority = self._cfg.priority 184 self._priority_IS_weight = self._cfg.priority_IS_weight 185 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 186 self._gamma = self._cfg.discount_factor 187 self._nstep = self._cfg.nstep 188 self._burnin_step = self._cfg.burnin_step 189 self._value_rescale = self._cfg.learn.value_rescale 190 191 self._target_model = copy.deepcopy(self._model) 192 self._target_model = model_wrap( 193 self._target_model, 194 wrapper_name='target', 195 update_type='momentum', 196 update_kwargs={'theta': self._cfg.learn.target_update_theta} 197 ) 198 199 self._target_model = model_wrap( 200 self._target_model, 201 wrapper_name='hidden_state', 202 state_num=self._cfg.learn.batch_size, 203 ) 204 self._learn_model = model_wrap( 205 self._model, 206 wrapper_name='hidden_state', 207 state_num=self._cfg.learn.batch_size, 208 ) 209 self._learn_model = model_wrap(self._learn_model, wrapper_name='argmax_sample') 210 self._learn_model.reset() 211 self._target_model.reset() 212 213 def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 214 """ 215 Overview: 216 Preprocess the data to fit the required data format for learning 217 Arguments: 218 - data (:obj:`List[Dict[str, Any]]`): The data collected from collect function 219 Returns: 220 - data (:obj:`Dict[str, torch.Tensor]`): The processed data, including at least \ 221 ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight'] 222 """ 223 # data preprocess 224 data = timestep_collate(data) 225 if self._cuda: 226 data = to_device(data, self._device) 227 228 if self._priority_IS_weight: 229 assert self._priority, "Use IS Weight correction, but Priority is not used." 230 if self._priority and self._priority_IS_weight: 231 data['weight'] = data['IS'] 232 else: 233 data['weight'] = data.get('weight', None) 234 235 burnin_step = self._burnin_step 236 237 # data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate 238 # the q_nstep_td_error, should be length of [self._sequence_len-self._burnin_step] 239 ignore_done = self._cfg.learn.ignore_done 240 if ignore_done: 241 data['done'] = [None for _ in range(self._sequence_len - burnin_step)] 242 else: 243 data['done'] = data['done'][burnin_step:].float() # for computation of online model self._learn_model 244 # NOTE that after the proprocessing of get_nstep_return_data() in _get_train_sample 245 # the data['done'] [t] is already the n-step done 246 247 # if the data don't include 'weight' or 'value_gamma' then fill in None in a list 248 # with length of [self._sequence_len-self._burnin_step], 249 # below is two different implementation ways 250 if 'value_gamma' not in data: 251 data['value_gamma'] = [None for _ in range(self._sequence_len - burnin_step)] 252 else: 253 data['value_gamma'] = data['value_gamma'][burnin_step:] 254 255 if 'weight' not in data or data['weight'] is None: 256 data['weight'] = [None for _ in range(self._sequence_len - burnin_step)] 257 else: 258 data['weight'] = data['weight'] * torch.ones_like(data['done']) 259 # every timestep in sequence has same weight, which is the _priority_IS_weight in PER 260 261 # cut the seq_len from burn_in step to (seq_len - nstep) step 262 data['action'] = data['action'][burnin_step:-self._nstep] 263 # cut the seq_len from burn_in step to (seq_len - nstep) step 264 data['reward'] = data['reward'][burnin_step:-self._nstep] 265 266 # the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value, 267 # target_q_value, and target_q_action 268 269 # these slicing are all done in the outermost layer, which is the seq_len dim 270 data['burnin_nstep_obs'] = data['obs'][:burnin_step + self._nstep] 271 # the main_obs is used to calculate the q_value, the [bs:-self._nstep] means using the data from 272 # [bs] timestep to [self._sequence_len-self._nstep] timestep 273 data['main_obs'] = data['obs'][burnin_step:-self._nstep] 274 # the target_obs is used to calculate the target_q_value 275 data['target_obs'] = data['obs'][burnin_step + self._nstep:] 276 277 return data 278 279 def _forward_learn(self, data: List[List[Dict[str, Any]]]) -> Dict[str, Any]: 280 """ 281 Overview: 282 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 283 that the policy inputs some training batch data (trajectory for R2D2) from the replay buffer and then \ 284 returns the output result, including various training information such as loss, q value, priority. 285 Arguments: 286 - data (:obj:`List[List[Dict[int, Any]]]`): The input data used for policy forward, including a batch of \ 287 training samples. For each dict element, the key of the dict is the name of data items and the \ 288 value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ 289 combinations. In the ``_forward_learn`` method, data often need to first be stacked in the time and \ 290 batch dimension by the utility functions ``self._data_preprocess_learn``. \ 291 For R2D2, each element in list is a trajectory with the length of ``unroll_len``, and the element in \ 292 trajectory list is a dict containing at least the following keys: ``obs``, ``action``, ``prev_state``, \ 293 ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ 294 and ``value_gamma``. 295 Returns: 296 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 297 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 298 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 299 300 .. note:: 301 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 302 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 303 You can implement you own model rather than use the default model. For more information, please raise an \ 304 issue in GitHub repo and we will continue to follow up. 305 306 .. note:: 307 For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``. 308 """ 309 # forward 310 data = self._data_preprocess_learn(data) # output datatype: Dict 311 self._learn_model.train() 312 self._target_model.train() 313 # use the hidden state in timestep=0 314 # note the reset method is performed at the hidden state wrapper, to reset self._state. 315 self._learn_model.reset(data_id=None, state=data['prev_state'][0]) 316 self._target_model.reset(data_id=None, state=data['prev_state'][0]) 317 318 if len(data['burnin_nstep_obs']) != 0: 319 with torch.no_grad(): 320 inputs = {'obs': data['burnin_nstep_obs']} 321 burnin_output = self._learn_model.forward( 322 inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] 323 ) # keys include 'logit', 'hidden_state' 'saved_state', \ 324 # 'action', for their specific dim, please refer to DRQN model 325 burnin_output_target = self._target_model.forward( 326 inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] 327 ) 328 329 self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][0]) 330 inputs = {'obs': data['main_obs']} 331 q_value = self._learn_model.forward(inputs)['logit'] 332 self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][1]) 333 self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][1]) 334 335 next_inputs = {'obs': data['target_obs']} 336 with torch.no_grad(): 337 target_q_value = self._target_model.forward(next_inputs)['logit'] 338 # argmax_action double_dqn 339 target_q_action = self._learn_model.forward(next_inputs)['action'] 340 341 action, reward, done, weight = data['action'], data['reward'], data['done'], data['weight'] 342 value_gamma = data['value_gamma'] 343 # T, B, nstep -> T, nstep, B 344 reward = reward.permute(0, 2, 1).contiguous() 345 loss = [] 346 td_error = [] 347 for t in range(self._sequence_len - self._burnin_step - self._nstep): 348 # here t=0 means timestep <self._burnin_step> in the original sample sequence, we minus self._nstep 349 # because for the last <self._nstep> timestep in the sequence, we don't have their target obs 350 td_data = q_nstep_td_data( 351 q_value[t], target_q_value[t], action[t], target_q_action[t], reward[t], done[t], weight[t] 352 ) 353 if self._value_rescale: 354 l, e = q_nstep_td_error_with_rescale(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t]) 355 loss.append(l) 356 td_error.append(e.abs()) 357 else: 358 l, e = q_nstep_td_error(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t]) 359 loss.append(l) 360 # td will be a list of the length 361 # <self._sequence_len - self._burnin_step - self._nstep> 362 # and each value is a tensor of the size batch_size 363 td_error.append(e.abs()) 364 loss = sum(loss) / (len(loss) + 1e-8) 365 366 # using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence 367 td_error_per_sample = 0.9 * torch.max( 368 torch.stack(td_error), dim=0 369 )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8)) 370 # torch.max(torch.stack(td_error), dim=0) will return tuple like thing, please refer to torch.max 371 # td_error shape list(<self._sequence_len-self._burnin_step-self._nstep>, B), 372 # for example, (75,64) 373 # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error) 374 375 # update 376 self._optimizer.zero_grad() 377 loss.backward() 378 self._optimizer.step() 379 # after update 380 self._target_model.update(self._learn_model.state_dict()) 381 382 # the information for debug 383 batch_range = torch.arange(action[0].shape[0]) 384 q_s_a_t0 = q_value[0][batch_range, action[0]] 385 target_q_s_a_t0 = target_q_value[0][batch_range, target_q_action[0]] 386 387 return { 388 'cur_lr': self._optimizer.defaults['lr'], 389 'total_loss': loss.item(), 390 'priority': td_error_per_sample.tolist(), # note abs operation has been performed above 391 # the first timestep in the sequence, may not be the start of episode 392 'q_s_taken-a_t0': q_s_a_t0.mean().item(), 393 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(), 394 'q_s_a-mean_t0': q_value[0].mean().item(), 395 } 396 397 def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: 398 """ 399 Overview: 400 Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the \ 401 memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ 402 varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ 403 different trajectories in ``data_id`` will have different hidden state in RNN. 404 Arguments: 405 - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ 406 (i.e. RNN hidden_state in R2D2) specified by ``data_id``. 407 """ 408 409 self._learn_model.reset(data_id=data_id) 410 411 def _state_dict_learn(self) -> Dict[str, Any]: 412 """ 413 Overview: 414 Return the state_dict of learn mode, usually including model, target_model and optimizer. 415 Returns: 416 - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. 417 """ 418 return { 419 'model': self._learn_model.state_dict(), 420 'target_model': self._target_model.state_dict(), 421 'optimizer': self._optimizer.state_dict(), 422 } 423 424 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 425 """ 426 Overview: 427 Load the state_dict variable into policy learn mode. 428 Arguments: 429 - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. 430 431 .. tip:: 432 If you want to only load some parts of model, you can simply set the ``strict`` argument in \ 433 load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ 434 complicated operation. 435 """ 436 self._learn_model.load_state_dict(state_dict['model']) 437 self._target_model.load_state_dict(state_dict['target_model']) 438 self._optimizer.load_state_dict(state_dict['optimizer']) 439 440 def _init_collect(self) -> None: 441 """ 442 Overview: 443 Initialize the collect mode of policy, including related attributes and modules. For R2D2, it contains the \ 444 collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism and \ 445 maintain the hidden state of rnn. Besides, there are some initialization operations about other \ 446 algorithm-specific arguments such as burnin_step, unroll_len and nstep. 447 This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. 448 449 .. note:: 450 If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ 451 with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. 452 453 .. tip:: 454 Some variables need to initialize independently in different modes, such as gamma and nstep in R2D2. This \ 455 design is for the convenience of parallel execution of different policy modes. 456 """ 457 self._nstep = self._cfg.nstep 458 self._burnin_step = self._cfg.burnin_step 459 self._gamma = self._cfg.discount_factor 460 self._sequence_len = self._cfg.learn_unroll_len + self._cfg.burnin_step 461 self._unroll_len = self._sequence_len 462 463 # for r2d2, this hidden_state wrapper is to add the 'prev hidden state' for each transition. 464 # Note that collect env forms a batch and the key is added for the batch simultaneously. 465 self._collect_model = model_wrap( 466 self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True 467 ) 468 self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample') 469 self._collect_model.reset() 470 471 def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: 472 """ 473 Overview: 474 Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ 475 that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ 476 data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ 477 exploration, i.e., classic epsilon-greedy exploration strategy. 478 Arguments: 479 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 480 key of the dict is environment id and the value is the corresponding data of the env. 481 - eps (:obj:`float`): The epsilon value for exploration. 482 Returns: 483 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ 484 other necessary data (prev_state) for learn mode defined in ``self._process_transition`` method. The \ 485 key of the dict is the same as the input data, i.e. environment id. 486 487 .. note:: 488 RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \ 489 hidden states with ``_reset_collect`` method when episode ends. Besides, the previous hidden states are \ 490 necessary for training, so we need to return them in ``_process_transition`` method. 491 .. note:: 492 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 493 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 494 You can implement you own model rather than use the default model. For more information, please raise an \ 495 issue in GitHub repo and we will continue to follow up. 496 497 .. note:: 498 For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``. 499 """ 500 data_id = list(data.keys()) 501 data = default_collate(list(data.values())) 502 if self._cuda: 503 data = to_device(data, self._device) 504 data = {'obs': data} 505 self._collect_model.eval() 506 with torch.no_grad(): 507 # in collect phase, inference=True means that each time we only pass one timestep data, 508 # so the we can get the hidden state of rnn: <prev_state> at each timestep. 509 output = self._collect_model.forward(data, data_id=data_id, eps=eps, inference=True) 510 if self._cuda: 511 output = to_device(output, 'cpu') 512 output = default_decollate(output) 513 return {i: d for i, d in zip(data_id, output)} 514 515 def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: 516 """ 517 Overview: 518 Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \ 519 memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ 520 varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ 521 different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN. 522 Arguments: 523 - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ 524 (i.e., RNN hidden_state in R2D2) specified by ``data_id``. 525 """ 526 self._collect_model.reset(data_id=data_id) 527 528 def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], 529 timestep: namedtuple) -> Dict[str, torch.Tensor]: 530 """ 531 Overview: 532 Process and pack one timestep transition data into a dict, which can be directly used for training and \ 533 saved in replay buffer. For R2D2, it contains obs, action, prev_state, reward, and done. 534 Arguments: 535 - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. 536 - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network given the observation \ 537 as input. For R2D2, it contains the action and the prev_state of RNN. 538 - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ 539 except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ 540 reward, done, info, etc. 541 Returns: 542 - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. 543 """ 544 transition = { 545 'obs': obs, 546 'action': policy_output['action'], 547 'prev_state': policy_output['prev_state'], 548 'reward': timestep.reward, 549 'done': timestep.done, 550 } 551 return transition 552 553 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 554 """ 555 Overview: 556 For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ 557 can be used for training directly. In R2D2, a train sample is processed transitions with unroll_len \ 558 length. This method is usually used in collectors to execute necessary \ 559 RL data preprocessing before training, which can help learner amortize revelant time consumption. \ 560 In addition, you can also implement this method as an identity function and do the data processing \ 561 in ``self._forward_learn`` method. 562 Arguments: 563 - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ 564 the same format as the return value of ``self._process_transition`` method. 565 Returns: 566 - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each sample is a fixed-length \ 567 trajectory, and each element in a sample is the similar format as input transitions, but may contain \ 568 more data for training, such as nstep reward and value_gamma factor. 569 """ 570 transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma) 571 return get_train_sample(transitions, self._unroll_len) 572 573 def _init_eval(self) -> None: 574 """ 575 Overview: 576 Initialize the eval mode of policy, including related attributes and modules. For R2D2, it contains the \ 577 eval model to greedily select action with argmax q_value mechanism and main the hidden state. 578 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 579 580 .. note:: 581 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 582 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 583 """ 584 self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num) 585 self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample') 586 self._eval_model.reset() 587 588 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 589 """ 590 Overview: 591 Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ 592 means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ 593 action to interact with the envs. ``_forward_eval`` often use argmax sample method to get actions that \ 594 q_value is the highest. 595 Arguments: 596 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ 597 key of the dict is environment id and the value is the corresponding data of the env. 598 Returns: 599 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 600 key of the dict is the same as the input data, i.e. environment id. 601 602 .. note:: 603 RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \ 604 hidden states with ``_reset_eval`` method when the episode ends. 605 606 .. note:: 607 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 608 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 609 You can implement you own model rather than use the default model. For more information, please raise an \ 610 issue in GitHub repo and we will continue to follow up. 611 612 .. note:: 613 For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``. 614 """ 615 data_id = list(data.keys()) 616 data = default_collate(list(data.values())) 617 if self._cuda: 618 data = to_device(data, self._device) 619 data = {'obs': data} 620 self._eval_model.eval() 621 with torch.no_grad(): 622 output = self._eval_model.forward(data, data_id=data_id, inference=True) 623 if self._cuda: 624 output = to_device(output, 'cpu') 625 output = default_decollate(output) 626 return {i: d for i, d in zip(data_id, output)} 627 628 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 629 """ 630 Overview: 631 Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \ 632 memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ 633 varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ 634 different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN. 635 Arguments: 636 - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ 637 (i.e., RNN hidden_state in R2D2) specified by ``data_id``. 638 """ 639 self._eval_model.reset(data_id=data_id) 640 641 def _monitor_vars_learn(self) -> List[str]: 642 """ 643 Overview: 644 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 645 as text logger, tensorboard logger, will use these keys to save the corresponding data. 646 Returns: 647 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 648 """ 649 return super()._monitor_vars_learn() + [ 650 'total_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 'q_s_a-mean_t0' 651 ]