Skip to content

ding.policy.r2d2_collect_traj

ding.policy.r2d2_collect_traj

R2D2CollectTrajPolicy

Bases: Policy

Overview

Policy class of R2D2 for collecting expert traj for R2D3.

Config

== ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= 1 type str dqn | RL policy register name, refer to | This arg is optional, | registry POLICY_REGISTRY | a placeholder 2 cuda bool False | Whether to use cuda for network | This arg can be diff- | erent from modes 3 on_policy bool False | Whether the RL algorithm is on-policy | or off-policy 4 priority bool False | Whether use priority(PER) | Priority sample, | update priority 5 | priority_IS bool False | Whether use Importance Sampling Weight | _weight | to correct biased update. If True, | priority must be True. 6 | discount_ float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse | factor [0.95, 0.999] | gamma | reward env 7 nstep int 3, | N-step reward discount sum for target [3, 5] | q_value estimation 8 burnin_step int 2 | The timestep of burnin operation, | which is designed to RNN hidden state | difference caused by off-policy 9 | learn.update int 1 | How many updates(iterations) to train | This args can be vary | per_collect | after collector's one collection. Only | from envs. Bigger val | valid in serial training | means more off-policy 10 | learn.batch_ int 64 | The number of samples of an iteration | size 11 | learn.learning float 0.001 | Gradient step length of an iteration. | _rate 12 | learn.value_ bool True | Whether use value_rescale function for | rescale | predicted value 13 | learn.target_ int 100 | Frequence of target network update. | Hard(assign) update | update_freq 14 | learn.ignore_ bool False | Whether ignore done for target value | Enable it for some | done | calculation. | fake termination env 15 collect.n_sample int [8, 128] | The number of training samples of a | It varies from | call of collector. | different envs 16 | collect.unroll int 1 | unroll length of an iteration | In RNN, unroll_len>1 | _len == ==================== ======== ============== ======================================== =======================

Full Source Code

../ding/policy/r2d2_collect_traj.py

1import copy 2from collections import namedtuple 3from typing import List, Dict, Any, Tuple, Union, Optional 4 5import torch 6 7from ding.model import model_wrap 8from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_nstep_td_error_with_rescale, get_nstep_return_data, \ 9 get_train_sample 10from ding.torch_utils import Adam, to_device 11from ding.utils import POLICY_REGISTRY 12from ding.utils.data import timestep_collate, default_collate, default_decollate 13from .base_policy import Policy 14 15 16@POLICY_REGISTRY.register('r2d2_collect_traj') 17class R2D2CollectTrajPolicy(Policy): 18 r""" 19 Overview: 20 Policy class of R2D2 for collecting expert traj for R2D3. 21 22 Config: 23 == ==================== ======== ============== ======================================== ======================= 24 ID Symbol Type Default Value Description Other(Shape) 25 == ==================== ======== ============== ======================================== ======================= 26 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional, 27 | registry ``POLICY_REGISTRY`` | a placeholder 28 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- 29 | erent from modes 30 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy 31 | or off-policy 32 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, 33 | update priority 34 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight 35 | ``_weight`` | to correct biased update. If True, 36 | priority must be True. 37 6 | ``discount_`` float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse 38 | ``factor`` [0.95, 0.999] | gamma | reward env 39 7 ``nstep`` int 3, | N-step reward discount sum for target 40 [3, 5] | q_value estimation 41 8 ``burnin_step`` int 2 | The timestep of burnin operation, 42 | which is designed to RNN hidden state 43 | difference caused by off-policy 44 9 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary 45 | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val 46 | valid in serial training | means more off-policy 47 10 | ``learn.batch_`` int 64 | The number of samples of an iteration 48 | ``size`` 49 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. 50 | ``_rate`` 51 12 | ``learn.value_`` bool True | Whether use value_rescale function for 52 | ``rescale`` | predicted value 53 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update 54 | ``update_freq`` 55 14 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some 56 | ``done`` | calculation. | fake termination env 57 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from 58 | call of collector. | different envs 59 16 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 60 | ``_len`` 61 == ==================== ======== ============== ======================================== ======================= 62 """ 63 config = dict( 64 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 65 type='r2d2', 66 # (bool) Whether to use cuda for network. 67 cuda=False, 68 # (bool) Whether the RL algorithm is on-policy or off-policy. 69 on_policy=False, 70 # (bool) Whether use priority(priority sample, IS weight, update priority) 71 priority=True, 72 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 73 priority_IS_weight=True, 74 # ============================================================== 75 # The following configs are algorithm-specific 76 # ============================================================== 77 # (float) Reward's future discount factor, aka. gamma. 78 discount_factor=0.997, 79 # (int) N-step reward for target q_value estimation 80 nstep=5, 81 # (int) the timestep of burnin operation, which is designed to RNN hidden state difference 82 # caused by off-policy 83 burnin_step=2, 84 # (int) the trajectory length to unroll the RNN network minus 85 # the timestep of burnin operation 86 unroll_len=80, 87 learn=dict( 88 update_per_collect=1, 89 batch_size=64, 90 learning_rate=0.0001, 91 # ============================================================== 92 # The following configs are algorithm-specific 93 # ============================================================== 94 # (int) Frequence of target network update. 95 # target_update_freq=100, 96 target_update_theta=0.001, 97 # (bool) whether use value_rescale function for predicted value 98 value_rescale=True, 99 ignore_done=False, 100 ), 101 collect=dict( 102 # NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF 103 # each_iter_n_sample=32, 104 # `env_num` is used in hidden state, should equal to that one in env config. 105 # User should specify this value in user config. 106 env_num=None, 107 ), 108 eval=dict( 109 # `env_num` is used in hidden state, should equal to that one in env config. 110 # User should specify this value in user config. 111 env_num=None, 112 ), 113 other=dict( 114 eps=dict( 115 type='exp', 116 start=0.95, 117 end=0.05, 118 decay=10000, 119 ), 120 replay_buffer=dict(replay_buffer_size=10000, ), 121 ), 122 ) 123 124 def default_model(self) -> Tuple[str, List[str]]: 125 return 'drqn', ['ding.model.template.q_learning'] 126 127 def _init_learn(self) -> None: 128 r""" 129 Overview: 130 Init the learner model of R2D2Policy 131 132 Arguments: 133 .. note:: 134 135 The _init_learn method takes the argument from the self._cfg.learn in the config file 136 137 - learning_rate (:obj:`float`): The learning rate fo the optimizer 138 - gamma (:obj:`float`): The discount factor 139 - nstep (:obj:`int`): The num of n step return 140 - value_rescale (:obj:`bool`): Whether to use value rescaled loss in algorithm 141 - burnin_step (:obj:`int`): The num of step of burnin 142 """ 143 self._priority = self._cfg.priority 144 self._priority_IS_weight = self._cfg.priority_IS_weight 145 self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) 146 self._gamma = self._cfg.discount_factor 147 self._nstep = self._cfg.nstep 148 self._burnin_step = self._cfg.burnin_step 149 self._value_rescale = self._cfg.learn.value_rescale 150 151 self._target_model = copy.deepcopy(self._model) 152 # self._target_model = model_wrap( TODO(pu) 153 # self._target_model, 154 # wrapper_name='target', 155 # update_type='assign', 156 # update_kwargs={'freq': self._cfg.learn.target_update_freq} 157 # ) 158 self._target_model = model_wrap( 159 self._target_model, 160 wrapper_name='target', 161 update_type='momentum', 162 update_kwargs={'theta': self._cfg.learn.target_update_theta} 163 ) 164 165 self._target_model = model_wrap( 166 self._target_model, 167 wrapper_name='hidden_state', 168 state_num=self._cfg.learn.batch_size, 169 ) 170 self._learn_model = model_wrap( 171 self._model, 172 wrapper_name='hidden_state', 173 state_num=self._cfg.learn.batch_size, 174 ) 175 self._learn_model = model_wrap(self._learn_model, wrapper_name='argmax_sample') 176 self._learn_model.reset() 177 self._target_model.reset() 178 179 def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict: 180 r""" 181 Overview: 182 Preprocess the data to fit the required data format for learning 183 184 Arguments: 185 - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function 186 187 Returns: 188 - data (:obj:`Dict[str, Any]`): the processed data, including at least \ 189 ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight'] 190 - data_info (:obj:`dict`): the data info, such as replay_buffer_idx, replay_unique_id 191 """ 192 # data preprocess 193 data = timestep_collate(data) 194 if self._cuda: 195 data = to_device(data, self._device) 196 197 if self._priority_IS_weight: 198 assert self._priority, "Use IS Weight correction, but Priority is not used." 199 if self._priority and self._priority_IS_weight: 200 data['weight'] = data['IS'] 201 else: 202 data['weight'] = data.get('weight', None) 203 204 bs = self._burnin_step 205 206 # data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate 207 # the q_nstep_td_error, should be length of [self._unroll_len_add_burnin_step-self._burnin_step] 208 ignore_done = self._cfg.learn.ignore_done 209 if ignore_done: 210 data['done'] = [None for _ in range(self._unroll_len_add_burnin_step - bs)] 211 else: 212 data['done'] = data['done'][bs:].float() # for computation of online model self._learn_model 213 # NOTE that after the proprocessing of get_nstep_return_data() in _get_train_sample 214 # the data['done'] [t] is already the n-step done 215 216 # if the data don't include 'weight' or 'value_gamma' then fill in None in a list 217 # with length of [self._unroll_len_add_burnin_step-self._burnin_step], 218 # below is two different implementation ways 219 if 'value_gamma' not in data: 220 data['value_gamma'] = [None for _ in range(self._unroll_len_add_burnin_step - bs)] 221 else: 222 data['value_gamma'] = data['value_gamma'][bs:] 223 224 if 'weight' not in data: 225 data['weight'] = [None for _ in range(self._unroll_len_add_burnin_step - bs)] 226 else: 227 data['weight'] = data['weight'] * torch.ones_like(data['done']) 228 # every timestep in sequence has same weight, which is the _priority_IS_weight in PER 229 230 data['action'] = data['action'][bs:-self._nstep] 231 data['reward'] = data['reward'][bs:-self._nstep] 232 233 # the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value, 234 # target_q_value, and target_q_action 235 data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep] 236 # the main_obs is used to calculate the q_value, the [bs:-self._nstep] means using the data from 237 # [bs] timestep to [self._unroll_len_add_burnin_step-self._nstep] timestep 238 data['main_obs'] = data['obs'][bs:-self._nstep] 239 # the target_obs is used to calculate the target_q_value 240 data['target_obs'] = data['obs'][bs + self._nstep:] 241 242 return data 243 244 def _forward_learn(self, data: dict) -> Dict[str, Any]: 245 r""" 246 Overview: 247 Forward and backward function of learn mode. 248 Acquire the data, calculate the loss and optimize learner model. 249 250 Arguments: 251 - data (:obj:`dict`): Dict type data, including at least \ 252 ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight'] 253 254 Returns: 255 - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss 256 - cur_lr (:obj:`float`): Current learning rate 257 - total_loss (:obj:`float`): The calculated loss 258 """ 259 # forward 260 data = self._data_preprocess_learn(data) 261 self._learn_model.train() 262 self._target_model.train() 263 # take out timestep=0 264 self._learn_model.reset(data_id=None, state=data['prev_state'][0]) 265 self._target_model.reset(data_id=None, state=data['prev_state'][0]) 266 267 if len(data['burnin_nstep_obs']) != 0: 268 with torch.no_grad(): 269 inputs = {'obs': data['burnin_nstep_obs']} 270 burnin_output = self._learn_model.forward( 271 inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] 272 ) 273 burnin_output_target = self._target_model.forward( 274 inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep] 275 ) 276 277 self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][0]) 278 inputs = {'obs': data['main_obs']} 279 q_value = self._learn_model.forward(inputs)['logit'] 280 self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][1]) 281 self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][1]) 282 283 next_inputs = {'obs': data['target_obs']} 284 with torch.no_grad(): 285 target_q_value = self._target_model.forward(next_inputs)['logit'] 286 # argmax_action double_dqn 287 target_q_action = self._learn_model.forward(next_inputs)['action'] 288 289 action, reward, done, weight = data['action'], data['reward'], data['done'], data['weight'] 290 value_gamma = data['value_gamma'] 291 # T, B, nstep -> T, nstep, B 292 reward = reward.permute(0, 2, 1).contiguous() 293 loss = [] 294 td_error = [] 295 for t in range(self._unroll_len_add_burnin_step - self._burnin_step - self._nstep): 296 # here t=0 means timestep <self._burnin_step> in the original sample sequence, we minus self._nstep 297 # because for the last <self._nstep> timestep in the sequence, we don't have their target obs 298 td_data = q_nstep_td_data( 299 q_value[t], target_q_value[t], action[t], target_q_action[t], reward[t], done[t], weight[t] 300 ) 301 if self._value_rescale: 302 l, e = q_nstep_td_error_with_rescale(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t]) 303 loss.append(l) 304 td_error.append(e.abs()) 305 else: 306 l, e = q_nstep_td_error(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t]) 307 loss.append(l) 308 td_error.append(e.abs()) 309 loss = sum(loss) / (len(loss) + 1e-8) 310 311 # using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence 312 td_error_per_sample = 0.9 * torch.max( 313 torch.stack(td_error), dim=0 314 )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8)) 315 # td_error shape list(<self._unroll_len_add_burnin_step-self._burnin_step-self._nstep>, B), for example, (75,64) 316 # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error) 317 318 # update 319 self._optimizer.zero_grad() 320 loss.backward() 321 self._optimizer.step() 322 # after update 323 self._target_model.update(self._learn_model.state_dict()) 324 325 # the information for debug 326 batch_range = torch.arange(action[0].shape[0]) 327 q_s_a_t0 = q_value[0][batch_range, action[0]] 328 target_q_s_a_t0 = target_q_value[0][batch_range, target_q_action[0]] 329 330 return { 331 'cur_lr': self._optimizer.defaults['lr'], 332 'total_loss': loss.item(), 333 'priority': td_error_per_sample.abs().tolist(), 334 # the first timestep in the sequence, may not be the start of episode TODO(pu) 335 'q_s_taken-a_t0': q_s_a_t0.mean().item(), 336 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(), 337 'q_s_a-mean_t0': q_value[0].mean().item(), 338 } 339 340 def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: 341 self._learn_model.reset(data_id=data_id) 342 343 def _state_dict_learn(self) -> Dict[str, Any]: 344 return { 345 'model': self._learn_model.state_dict(), 346 'optimizer': self._optimizer.state_dict(), 347 } 348 349 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 350 self._learn_model.load_state_dict(state_dict['model']) 351 self._optimizer.load_state_dict(state_dict['optimizer']) 352 353 def _init_collect(self) -> None: 354 r""" 355 Overview: 356 Collect mode init method. Called by ``self.__init__``. 357 Init traj and unroll length, collect model. 358 """ 359 # assert 'unroll_len' not in self._cfg.collect, "r2d2 use default unroll_len" 360 self._nstep = self._cfg.nstep 361 self._burnin_step = self._cfg.burnin_step 362 self._gamma = self._cfg.discount_factor 363 self._unroll_len_add_burnin_step = self._cfg.unroll_len + self._cfg.burnin_step 364 self._unroll_len = self._unroll_len_add_burnin_step # for compatibility 365 # self._unroll_len = self._cfg.collect.unroll_len 366 367 self._collect_model = model_wrap( 368 self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True 369 ) 370 # self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample') 371 self._collect_model = model_wrap(self._collect_model, wrapper_name='argmax_sample') 372 373 self._collect_model.reset() 374 375 # def _forward_collect(self, data: dict, eps: float) -> dict: 376 def _forward_collect(self, data: dict) -> dict: 377 r""" 378 Overview: 379 Collect output according to eps_greedy plugin 380 381 Arguments: 382 - data (:obj:`dict`): Dict type data, including at least ['obs']. 383 384 Returns: 385 - data (:obj:`dict`): The collected data 386 """ 387 data_id = list(data.keys()) 388 data = default_collate(list(data.values())) 389 if self._cuda: 390 data = to_device(data, self._device) 391 data = {'obs': data} 392 self._collect_model.eval() 393 with torch.no_grad(): 394 # in collect phase, inference=True means that each time we only pass one timestep data, 395 # so the we can get the hidden state of rnn: <prev_state> at each timestep. 396 # output = self._collect_model.forward(data, data_id=data_id, eps=eps, inference=True) 397 output = self._collect_model.forward(data, data_id=data_id, inference=True) 398 # output = self._collect_model.forward(data, inference=True) 399 400 if self._cuda: 401 output = to_device(output, 'cpu') 402 output = default_decollate(output) 403 return {i: d for i, d in zip(data_id, output)} 404 405 def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: 406 self._collect_model.reset(data_id=data_id) 407 408 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 409 r""" 410 Overview: 411 Generate dict type transition data from inputs. 412 Arguments: 413 - obs (:obj:`Any`): Env observation 414 - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state'] 415 - timestep (:obj:`namedtuple`): Output after env step, including at least ['reward', 'done'] \ 416 (here 'obs' indicates obs after env step). 417 Returns: 418 - transition (:obj:`dict`): Dict type transition data. 419 """ 420 transition = { 421 'obs': obs, 422 'action': model_output['action'], 423 # 'prev_state': model_output['prev_state'], 424 'prev_state': None, 425 'reward': timestep.reward, 426 'done': timestep.done, 427 } 428 return transition 429 430 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 431 r""" 432 Overview: 433 Get the trajectory and the n step return data, then sample from the n_step return data 434 435 Arguments: 436 - data (:obj:`list`): The trajectory's cache 437 438 Returns: 439 - samples (:obj:`dict`): The training samples generated 440 """ 441 from copy import deepcopy 442 data_one_step = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma)) 443 # data_one_step = deepcopy(data) 444 data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) 445 for i in range(len(data)): 446 # here we record the one-step done, we don't need record one-step reward, 447 # because the n-step reward in data already include one-step reward 448 data[i]['done_one_step'] = data_one_step[i]['done'] 449 return get_train_sample(data, self._unroll_len) # self._unroll_len_add_burnin_step 450 451 def _init_eval(self) -> None: 452 r""" 453 Overview: 454 Evaluate mode init method. Called by ``self.__init__``. 455 Init eval model with argmax strategy. 456 """ 457 self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num) 458 self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample') 459 self._eval_model.reset() 460 461 def _forward_eval(self, data: dict) -> dict: 462 r""" 463 Overview: 464 Forward function of collect mode, similar to ``self._forward_collect``. 465 466 Arguments: 467 - data (:obj:`dict`): Dict type data, including at least ['obs']. 468 469 Returns: 470 - output (:obj:`dict`): Dict type data, including at least inferred action according to input obs. 471 """ 472 data_id = list(data.keys()) 473 data = default_collate(list(data.values())) 474 if self._cuda: 475 data = to_device(data, self._device) 476 data = {'obs': data} 477 self._eval_model.eval() 478 with torch.no_grad(): 479 output = self._eval_model.forward(data, data_id=data_id, inference=True) 480 if self._cuda: 481 output = to_device(output, 'cpu') 482 output = default_decollate(output) 483 return {i: d for i, d in zip(data_id, output)} 484 485 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 486 self._eval_model.reset(data_id=data_id) 487 488 def _monitor_vars_learn(self) -> List[str]: 489 return super()._monitor_vars_learn() + [ 490 'total_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 'q_s_a-mean_t0' 491 ]