ding.policy.dt¶
ding.policy.dt
¶
DTPolicy
¶
Bases: Policy
Overview
Policy class of Decision Transformer algorithm in discrete environments. Paper link: https://arxiv.org/abs/2106.01345.
default_model()
¶
Overview
Return this algorithm default neural network model setting for demonstration. __init__ method will automatically call this method to get the default model setting and create model.
Returns:
- model_info (:obj:Tuple[str, List[str]]): The registered model name and model's import_names.
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated by import_names path. For example about DQN, its registered name is dqn and the import_names is ding.model.template.q_learning.
Full Source Code
../ding/policy/dt.py
1from typing import List, Dict, Any, Tuple, Optional 2from collections import namedtuple 3import torch.nn.functional as F 4import torch 5import numpy as np 6from ding.torch_utils import to_device 7from ding.utils import POLICY_REGISTRY 8from ding.utils.data import default_decollate 9from .base_policy import Policy 10 11 12@POLICY_REGISTRY.register('dt') 13class DTPolicy(Policy): 14 """ 15 Overview: 16 Policy class of Decision Transformer algorithm in discrete environments. 17 Paper link: https://arxiv.org/abs/2106.01345. 18 """ 19 config = dict( 20 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 21 type='dt', 22 # (bool) Whether to use cuda for network. 23 cuda=False, 24 # (bool) Whether the RL algorithm is on-policy or off-policy. 25 on_policy=False, 26 # (bool) Whether use priority(priority sample, IS weight, update priority) 27 priority=False, 28 # (int) N-step reward for target q_value estimation 29 obs_shape=4, 30 action_shape=2, 31 rtg_scale=1000, # normalize returns to go 32 max_eval_ep_len=1000, # max len of one episode 33 batch_size=64, # training batch size 34 wt_decay=1e-4, # decay weight in optimizer 35 warmup_steps=10000, # steps for learning rate warmup 36 context_len=20, # length of transformer input 37 learning_rate=1e-4, 38 ) 39 40 def default_model(self) -> Tuple[str, List[str]]: 41 """ 42 Overview: 43 Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ 44 automatically call this method to get the default model setting and create model. 45 Returns: 46 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. 47 48 .. note:: 49 The user can define and use customized network model but must obey the same inferface definition indicated \ 50 by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ 51 ``ding.model.template.q_learning``. 52 """ 53 return 'dt', ['ding.model.template.dt'] 54 55 def _init_learn(self) -> None: 56 """ 57 Overview: 58 Initialize the learn mode of policy, including related attributes and modules. For Decision Transformer, \ 59 it mainly contains the optimizer, algorithm-specific arguments such as rtg_scale and lr scheduler. 60 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 61 62 .. note:: 63 For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ 64 and ``_load_state_dict_learn`` methods. 65 66 .. note:: 67 For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. 68 69 .. note:: 70 If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ 71 with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. 72 """ 73 # rtg_scale: scale of `return to go` 74 # rtg_target: max target of `return to go` 75 # Our goal is normalize `return to go` to (0, 1), which will favour the covergence. 76 # As a result, we usually set rtg_scale == rtg_target. 77 self.rtg_scale = self._cfg.rtg_scale # normalize returns to go 78 self.rtg_target = self._cfg.rtg_target # max target reward_to_go 79 self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode 80 81 lr = self._cfg.learning_rate # learning rate 82 wt_decay = self._cfg.wt_decay # weight decay 83 warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler 84 85 self.clip_grad_norm_p = self._cfg.clip_grad_norm_p 86 self.context_len = self._cfg.model.context_len # K in decision transformer 87 88 self.state_dim = self._cfg.model.state_dim 89 self.act_dim = self._cfg.model.act_dim 90 91 self._learn_model = self._model 92 self._atari_env = 'state_mean' not in self._cfg 93 self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg 94 95 if self._atari_env: 96 self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr) 97 else: 98 self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) 99 100 self._scheduler = torch.optim.lr_scheduler.LambdaLR( 101 self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1) 102 ) 103 104 self.max_env_score = -1.0 105 106 def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: 107 """ 108 Overview: 109 Policy forward function of learn mode (training policy and updating parameters). Forward means \ 110 that the policy inputs some training batch data from the offline dataset and then returns the output \ 111 result, including various training information such as loss, current learning rate. 112 Arguments: 113 - data (:obj:`List[torch.Tensor]`): The input data used for policy forward, including a series of \ 114 processed torch.Tensor data, i.e., timesteps, states, actions, returns_to_go, traj_mask. 115 Returns: 116 - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ 117 recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ 118 detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. 119 120 .. note:: 121 The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ 122 For the data type that not supported, the main reason is that the corresponding model does not support it. \ 123 You can implement you own model rather than use the default model. For more information, please raise an \ 124 issue in GitHub repo and we will continue to follow up. 125 126 """ 127 self._learn_model.train() 128 129 timesteps, states, actions, returns_to_go, traj_mask = data 130 131 # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), 132 # and we need a 3-dim tensor 133 if len(returns_to_go.shape) == 2: 134 returns_to_go = returns_to_go.unsqueeze(-1) 135 136 if self._basic_discrete_env: 137 actions = actions.to(torch.long) 138 actions = actions.squeeze(-1) 139 action_target = torch.clone(actions).detach().to(self._device) 140 141 if self._atari_env: 142 state_preds, action_preds, return_preds = self._learn_model.forward( 143 timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1 144 ) 145 else: 146 state_preds, action_preds, return_preds = self._learn_model.forward( 147 timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go 148 ) 149 150 if self._atari_env: 151 action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) 152 else: 153 traj_mask = traj_mask.view(-1, ) 154 155 # only consider non padded elements 156 action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] 157 158 if self._cfg.model.continuous: 159 action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] 160 action_loss = F.mse_loss(action_preds, action_target) 161 else: 162 action_target = action_target.view(-1)[traj_mask > 0] 163 action_loss = F.cross_entropy(action_preds, action_target) 164 165 self._optimizer.zero_grad() 166 action_loss.backward() 167 if self._cfg.multi_gpu: 168 self.sync_gradients(self._learn_model) 169 torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p) 170 self._optimizer.step() 171 self._scheduler.step() 172 173 return { 174 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'], 175 'action_loss': action_loss.detach().cpu().item(), 176 'total_loss': action_loss.detach().cpu().item(), 177 } 178 179 def _init_eval(self) -> None: 180 """ 181 Overview: 182 Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \ 183 eval model, some algorithm-specific parameters such as context_len, max_eval_ep_len, etc. 184 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 185 186 .. tip:: 187 For the evaluation of complete episodes, we need to maintain some historical information for transformer \ 188 inference. These variables need to be initialized in ``_init_eval`` and reset in ``_reset_eval`` when \ 189 necessary. 190 191 .. note:: 192 If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ 193 with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. 194 """ 195 self._eval_model = self._model 196 # init data 197 self._device = torch.device(self._device) 198 self.rtg_scale = self._cfg.rtg_scale # normalize returns to go 199 self.rtg_target = self._cfg.rtg_target # max target reward_to_go 200 self.state_dim = self._cfg.model.state_dim 201 self.act_dim = self._cfg.model.act_dim 202 self.eval_batch_size = self._cfg.evaluator_env_num 203 self.max_eval_ep_len = self._cfg.max_eval_ep_len 204 self.context_len = self._cfg.model.context_len # K in decision transformer 205 206 self.t = [0 for _ in range(self.eval_batch_size)] 207 if self._cfg.model.continuous: 208 self.actions = torch.zeros( 209 (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device 210 ) 211 else: 212 self.actions = torch.zeros( 213 (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device 214 ) 215 self._atari_env = 'state_mean' not in self._cfg 216 self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg 217 if self._atari_env: 218 self.states = torch.zeros( 219 ( 220 self.eval_batch_size, 221 self.max_eval_ep_len, 222 ) + tuple(self.state_dim), 223 dtype=torch.float32, 224 device=self._device 225 ) 226 self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] 227 else: 228 self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] 229 self.states = torch.zeros( 230 (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device 231 ) 232 self.state_mean = torch.from_numpy(np.array(self._cfg.state_mean)).to(self._device) 233 self.state_std = torch.from_numpy(np.array(self._cfg.state_std)).to(self._device) 234 self.timesteps = torch.arange( 235 start=0, end=self.max_eval_ep_len, step=1 236 ).repeat(self.eval_batch_size, 1).to(self._device) 237 self.rewards_to_go = torch.zeros( 238 (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device 239 ) 240 241 def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: 242 """ 243 Overview: 244 Policy forward function of eval mode (evaluation policy performance, such as interacting with envs. \ 245 Forward means that the policy gets some input data (current obs/return-to-go and historical information) \ 246 from the envs and then returns the output data, such as the action to interact with the envs. \ 247 Arguments: 248 - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs and \ 249 reward to calculate running return-to-go. The key of the dict is environment id and the value is the \ 250 corresponding data of the env. 251 Returns: 252 - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ 253 key of the dict is the same as the input data, i.e. environment id. 254 255 .. note:: 256 Decision Transformer will do different operations for different types of envs in evaluation. 257 """ 258 # save and forward 259 data_id = list(data.keys()) 260 261 self._eval_model.eval() 262 with torch.no_grad(): 263 if self._atari_env: 264 states = torch.zeros( 265 ( 266 self.eval_batch_size, 267 self.context_len, 268 ) + tuple(self.state_dim), 269 dtype=torch.float32, 270 device=self._device 271 ) 272 timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device) 273 else: 274 states = torch.zeros( 275 (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device 276 ) 277 timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device) 278 if not self._cfg.model.continuous: 279 actions = torch.zeros( 280 (self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device 281 ) 282 else: 283 actions = torch.zeros( 284 (self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device 285 ) 286 rewards_to_go = torch.zeros( 287 (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device 288 ) 289 for i in data_id: 290 if self._atari_env: 291 self.states[i, self.t[i]] = data[i]['obs'].to(self._device) 292 else: 293 self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std 294 self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale).to(self._device) 295 self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] 296 297 if self.t[i] <= self.context_len: 298 if self._atari_env: 299 timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( 300 (1, 1), dtype=torch.int64 301 ).to(self._device) 302 else: 303 timesteps[i] = self.timesteps[i, :self.context_len] 304 states[i] = self.states[i, :self.context_len] 305 actions[i] = self.actions[i, :self.context_len] 306 rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] 307 else: 308 if self._atari_env: 309 timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( 310 (1, 1), dtype=torch.int64 311 ).to(self._device) 312 else: 313 timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] 314 states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] 315 actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] 316 rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] 317 if self._basic_discrete_env: 318 actions = actions.squeeze(-1) 319 _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go) 320 del timesteps, states, actions, rewards_to_go 321 322 logits = act_preds[:, -1, :] 323 if not self._cfg.model.continuous: 324 if self._atari_env: 325 probs = F.softmax(logits, dim=-1) 326 act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device) 327 for i in data_id: 328 act[i] = torch.multinomial(probs[i], num_samples=1) 329 else: 330 act = torch.argmax(logits, axis=1).unsqueeze(1) 331 else: 332 act = logits 333 for i in data_id: 334 self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t 335 self.t[i] += 1 336 337 if self._cuda: 338 act = to_device(act, 'cpu') 339 output = {'action': act} 340 output = default_decollate(output) 341 return {i: d for i, d in zip(data_id, output)} 342 343 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 344 """ 345 Overview: 346 Reset some stateful variables for eval mode when necessary, such as the historical info of transformer \ 347 for decision transformer. If ``data_id`` is None, it means to reset all the stateful \ 348 varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ 349 different environments/episodes in evaluation in ``data_id`` will have different history. 350 Arguments: 351 - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ 352 specified by ``data_id``. 353 """ 354 # clean data 355 if data_id is None: 356 self.t = [0 for _ in range(self.eval_batch_size)] 357 self.timesteps = torch.arange( 358 start=0, end=self.max_eval_ep_len, step=1 359 ).repeat(self.eval_batch_size, 1).to(self._device) 360 if not self._cfg.model.continuous: 361 self.actions = torch.zeros( 362 (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device 363 ) 364 else: 365 self.actions = torch.zeros( 366 (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), 367 dtype=torch.float32, 368 device=self._device 369 ) 370 if self._atari_env: 371 self.states = torch.zeros( 372 ( 373 self.eval_batch_size, 374 self.max_eval_ep_len, 375 ) + tuple(self.state_dim), 376 dtype=torch.float32, 377 device=self._device 378 ) 379 self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] 380 else: 381 self.states = torch.zeros( 382 (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), 383 dtype=torch.float32, 384 device=self._device 385 ) 386 self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] 387 388 self.rewards_to_go = torch.zeros( 389 (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device 390 ) 391 else: 392 for i in data_id: 393 self.t[i] = 0 394 if not self._cfg.model.continuous: 395 self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) 396 else: 397 self.actions[i] = torch.zeros( 398 (self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device 399 ) 400 if self._atari_env: 401 self.states[i] = torch.zeros( 402 (self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device 403 ) 404 self.running_rtg[i] = self.rtg_target 405 else: 406 self.states[i] = torch.zeros( 407 (self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device 408 ) 409 self.running_rtg[i] = self.rtg_target / self.rtg_scale 410 self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device) 411 self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) 412 413 def _monitor_vars_learn(self) -> List[str]: 414 """ 415 Overview: 416 Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ 417 as text logger, tensorboard logger, will use these keys to save the corresponding data. 418 Returns: 419 - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. 420 """ 421 return ['cur_lr', 'action_loss'] 422 423 def _init_collect(self) -> None: 424 pass 425 426 def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: 427 pass 428 429 def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 430 pass 431 432 def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: 433 pass