Skip to content

ding.model.template.decision_transformer

ding.model.template.decision_transformer

this extremely minimal Decision Transformer model is based on the following causal transformer (GPT) implementation:

Misha Laskin's tweet: https://twitter.com/MishaLaskin/status/1481767788775628801?cxt=HHwWgoCzmYD9pZApAAAA

and its corresponding notebook: https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing

** the above colab notebook has a bug while applying masked_fill which is fixed in the following code

MaskedCausalAttention

Bases: Module

Overview

The implementation of masked causal attention in decision transformer. The input of this module is a sequence of several tokens. For the calculated hidden embedding for the i-th token, it is only related the 0 to i-1 input tokens by applying a mask to the attention map. Thus, this module is called masked-causal attention.

Interfaces: __init__, forward

__init__(h_dim, max_T, n_heads, drop_p)

Overview

Initialize the MaskedCausalAttention Model according to input arguments.

Arguments: - h_dim (:obj:int): The dimension of the hidden layers, such as 128. - max_T (:obj:int): The max context length of the attention, such as 6. - n_heads (:obj:int): The number of heads in calculating attention, such as 8. - drop_p (:obj:float): The drop rate of the drop-out layer, such as 0.1.

forward(x)

Overview

MaskedCausalAttention forward computation graph, input a sequence tensor and return a tensor with the same shape.

Arguments: - x (:obj:torch.Tensor): The input tensor. Returns: - out (:obj:torch.Tensor): Output tensor, the shape is the same as the input. Examples: >>> inputs = torch.randn(2, 4, 64) >>> model = MaskedCausalAttention(64, 5, 4, 0.1) >>> outputs = model(inputs) >>> assert outputs.shape == torch.Size([2, 4, 64])

Block

Bases: Module

Overview

The implementation of a transformer block in decision transformer.

Interfaces: __init__, forward

__init__(h_dim, max_T, n_heads, drop_p)

Overview

Initialize the Block Model according to input arguments.

Arguments: - h_dim (:obj:int): The dimension of the hidden layers, such as 128. - max_T (:obj:int): The max context length of the attention, such as 6. - n_heads (:obj:int): The number of heads in calculating attention, such as 8. - drop_p (:obj:float): The drop rate of the drop-out layer, such as 0.1.

forward(x)

Overview

Forward computation graph of the decision transformer block, input a sequence tensor and return a tensor with the same shape.

Arguments: - x (:obj:torch.Tensor): The input tensor. Returns: - output (:obj:torch.Tensor): Output tensor, the shape is the same as the input. Examples: >>> inputs = torch.randn(2, 4, 64) >>> model = Block(64, 5, 4, 0.1) >>> outputs = model(inputs) >>> outputs.shape == torch.Size([2, 4, 64])

DecisionTransformer

Bases: Module

Overview

The implementation of decision transformer.

Interfaces: __init__, forward, configure_optimizers

__init__(state_dim, act_dim, n_blocks, h_dim, context_len, n_heads, drop_p, max_timestep=4096, state_encoder=None, continuous=False)

Overview

Initialize the DecisionTransformer Model according to input arguments.

Arguments: - obs_shape (:obj:Union[int, SequenceType]): Dimension of state, such as 128 or (4, 84, 84). - act_dim (:obj:int): The dimension of actions, such as 6. - n_blocks (:obj:int): The number of transformer blocks in the decision transformer, such as 3. - h_dim (:obj:int): The dimension of the hidden layers, such as 128. - context_len (:obj:int): The max context length of the attention, such as 6. - n_heads (:obj:int): The number of heads in calculating attention, such as 8. - drop_p (:obj:float): The drop rate of the drop-out layer, such as 0.1. - max_timestep (:obj:int): The max length of the total sequence, defaults to be 4096. - state_encoder (:obj:Optional[nn.Module]): The encoder to pre-process the given input. If it is set to None, the raw state will be pushed into the transformer. - continuous (:obj:bool): Whether the action space is continuous, defaults to be False.

forward(timesteps, states, actions, returns_to_go, tar=None)

Overview

Forward computation graph of the decision transformer, input a sequence tensor and return a tensor with the same shape.

Arguments: - timesteps (:obj:torch.Tensor): The timestep for input sequence. - states (:obj:torch.Tensor): The sequence of states. - actions (:obj:torch.Tensor): The sequence of actions. - returns_to_go (:obj:torch.Tensor): The sequence of return-to-go. - tar (:obj:Optional[int]): Whether to predict action, regardless of index. Returns: - output (:obj:Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): Output contains three tensors, they are correspondingly the predicted states, predicted actions and predicted return-to-go. Examples: >>> B, T = 4, 6 >>> state_dim = 3 >>> act_dim = 2 >>> DT_model = DecisionTransformer( state_dim=state_dim, act_dim=act_dim, n_blocks=3, h_dim=8, context_len=T, n_heads=2, drop_p=0.1, ) >>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T >>> states = torch.randn([B, T, state_dim]) # B x T x state_dim >>> actions = torch.randint(0, act_dim, [B, T, 1]) >>> action_target = torch.randint(0, act_dim, [B, T, 1]) >>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float() >>> traj_mask = torch.ones([B, T], dtype=torch.long) # B x T >>> actions = actions.squeeze(-1) >>> state_preds, action_preds, return_preds = DT_model.forward( timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go ) >>> assert state_preds.shape == torch.Size([B, T, state_dim]) >>> assert return_preds.shape == torch.Size([B, T, 1]) >>> assert action_preds.shape == torch.Size([B, T, act_dim])

configure_optimizers(weight_decay, learning_rate, betas=(0.9, 0.95))

Overview

This function returns an optimizer given the input arguments. We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layernorm/embedding weights).

Arguments: - weight_decay (:obj:float): The weigh decay of the optimizer. - learning_rate (:obj:float): The learning rate of the optimizer. - betas (:obj:Tuple[float, float]): The betas for Adam optimizer. Outputs: - optimizer (:obj:torch.optim.Optimizer): The desired optimizer.

Full Source Code

../ding/model/template/decision_transformer.py

1""" 2this extremely minimal Decision Transformer model is based on 3the following causal transformer (GPT) implementation: 4 5Misha Laskin's tweet: 6https://twitter.com/MishaLaskin/status/1481767788775628801?cxt=HHwWgoCzmYD9pZApAAAA 7 8and its corresponding notebook: 9https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing 10 11** the above colab notebook has a bug while applying masked_fill 12which is fixed in the following code 13""" 14 15import math 16from typing import Union, Optional, Tuple 17 18import torch 19import torch.nn as nn 20import torch.nn.functional as F 21from ding.utils import SequenceType 22 23 24class MaskedCausalAttention(nn.Module): 25 """ 26 Overview: 27 The implementation of masked causal attention in decision transformer. The input of this module is a sequence \ 28 of several tokens. For the calculated hidden embedding for the i-th token, it is only related the 0 to i-1 \ 29 input tokens by applying a mask to the attention map. Thus, this module is called masked-causal attention. 30 Interfaces: 31 ``__init__``, ``forward`` 32 """ 33 34 def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: 35 """ 36 Overview: 37 Initialize the MaskedCausalAttention Model according to input arguments. 38 Arguments: 39 - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. 40 - max_T (:obj:`int`): The max context length of the attention, such as 6. 41 - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. 42 - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. 43 """ 44 super().__init__() 45 46 self.n_heads = n_heads 47 self.max_T = max_T 48 49 self.q_net = nn.Linear(h_dim, h_dim) 50 self.k_net = nn.Linear(h_dim, h_dim) 51 self.v_net = nn.Linear(h_dim, h_dim) 52 53 self.proj_net = nn.Linear(h_dim, h_dim) 54 55 self.att_drop = nn.Dropout(drop_p) 56 self.proj_drop = nn.Dropout(drop_p) 57 58 ones = torch.ones((max_T, max_T)) 59 mask = torch.tril(ones).view(1, 1, max_T, max_T) 60 61 # register buffer makes sure mask does not get updated 62 # during backpropagation 63 self.register_buffer('mask', mask) 64 65 def forward(self, x: torch.Tensor) -> torch.Tensor: 66 """ 67 Overview: 68 MaskedCausalAttention forward computation graph, input a sequence tensor \ 69 and return a tensor with the same shape. 70 Arguments: 71 - x (:obj:`torch.Tensor`): The input tensor. 72 Returns: 73 - out (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. 74 Examples: 75 >>> inputs = torch.randn(2, 4, 64) 76 >>> model = MaskedCausalAttention(64, 5, 4, 0.1) 77 >>> outputs = model(inputs) 78 >>> assert outputs.shape == torch.Size([2, 4, 64]) 79 """ 80 B, T, C = x.shape # batch size, seq length, h_dim * n_heads 81 82 N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim 83 84 # rearrange q, k, v as (B, N, T, D) 85 q = self.q_net(x).view(B, T, N, D).transpose(1, 2) 86 k = self.k_net(x).view(B, T, N, D).transpose(1, 2) 87 v = self.v_net(x).view(B, T, N, D).transpose(1, 2) 88 89 # weights (B, N, T, T) 90 weights = q @ k.transpose(2, 3) / math.sqrt(D) 91 # causal mask applied to weights 92 weights = weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf')) 93 # normalize weights, all -inf -> 0 after softmax 94 normalized_weights = F.softmax(weights, dim=-1) 95 96 # attention (B, N, T, D) 97 attention = self.att_drop(normalized_weights @ v) 98 99 # gather heads and project (B, N, T, D) -> (B, T, N*D) 100 attention = attention.transpose(1, 2).contiguous().view(B, T, N * D) 101 102 out = self.proj_drop(self.proj_net(attention)) 103 return out 104 105 106class Block(nn.Module): 107 """ 108 Overview: 109 The implementation of a transformer block in decision transformer. 110 Interfaces: 111 ``__init__``, ``forward`` 112 """ 113 114 def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: 115 """ 116 Overview: 117 Initialize the Block Model according to input arguments. 118 Arguments: 119 - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. 120 - max_T (:obj:`int`): The max context length of the attention, such as 6. 121 - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. 122 - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. 123 """ 124 super().__init__() 125 self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) 126 self.mlp = nn.Sequential( 127 nn.Linear(h_dim, 4 * h_dim), 128 nn.GELU(), 129 nn.Linear(4 * h_dim, h_dim), 130 nn.Dropout(drop_p), 131 ) 132 self.ln1 = nn.LayerNorm(h_dim) 133 self.ln2 = nn.LayerNorm(h_dim) 134 135 def forward(self, x: torch.Tensor) -> torch.Tensor: 136 """ 137 Overview: 138 Forward computation graph of the decision transformer block, input a sequence tensor \ 139 and return a tensor with the same shape. 140 Arguments: 141 - x (:obj:`torch.Tensor`): The input tensor. 142 Returns: 143 - output (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. 144 Examples: 145 >>> inputs = torch.randn(2, 4, 64) 146 >>> model = Block(64, 5, 4, 0.1) 147 >>> outputs = model(inputs) 148 >>> outputs.shape == torch.Size([2, 4, 64]) 149 """ 150 # Attention -> LayerNorm -> MLP -> LayerNorm 151 x = x + self.attention(x) # residual 152 x = self.ln1(x) 153 x = x + self.mlp(x) # residual 154 x = self.ln2(x) 155 # x = x + self.attention(self.ln1(x)) 156 # x = x + self.mlp(self.ln2(x)) 157 return x 158 159 160class DecisionTransformer(nn.Module): 161 """ 162 Overview: 163 The implementation of decision transformer. 164 Interfaces: 165 ``__init__``, ``forward``, ``configure_optimizers`` 166 """ 167 168 def __init__( 169 self, 170 state_dim: Union[int, SequenceType], 171 act_dim: int, 172 n_blocks: int, 173 h_dim: int, 174 context_len: int, 175 n_heads: int, 176 drop_p: float, 177 max_timestep: int = 4096, 178 state_encoder: Optional[nn.Module] = None, 179 continuous: bool = False 180 ): 181 """ 182 Overview: 183 Initialize the DecisionTransformer Model according to input arguments. 184 Arguments: 185 - obs_shape (:obj:`Union[int, SequenceType]`): Dimension of state, such as 128 or (4, 84, 84). 186 - act_dim (:obj:`int`): The dimension of actions, such as 6. 187 - n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3. 188 - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. 189 - context_len (:obj:`int`): The max context length of the attention, such as 6. 190 - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. 191 - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. 192 - max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096. 193 - state_encoder (:obj:`Optional[nn.Module]`): The encoder to pre-process the given input. If it is set to \ 194 None, the raw state will be pushed into the transformer. 195 - continuous (:obj:`bool`): Whether the action space is continuous, defaults to be ``False``. 196 """ 197 super().__init__() 198 199 self.state_dim = state_dim 200 self.act_dim = act_dim 201 self.h_dim = h_dim 202 203 # transformer blocks 204 input_seq_len = 3 * context_len 205 206 # projection heads (project to embedding) 207 self.embed_ln = nn.LayerNorm(h_dim) 208 self.embed_timestep = nn.Embedding(max_timestep, h_dim) 209 self.drop = nn.Dropout(drop_p) 210 211 self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) 212 self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim)) 213 214 if state_encoder is None: 215 self.state_encoder = None 216 blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] 217 self.embed_rtg = torch.nn.Linear(1, h_dim) 218 self.embed_state = torch.nn.Linear(state_dim, h_dim) 219 self.predict_rtg = torch.nn.Linear(h_dim, 1) 220 self.predict_state = torch.nn.Linear(h_dim, state_dim) 221 if continuous: 222 # continuous actions 223 self.embed_action = torch.nn.Linear(act_dim, h_dim) 224 use_action_tanh = True # True for continuous actions 225 else: 226 # discrete actions 227 self.embed_action = torch.nn.Embedding(act_dim, h_dim) 228 use_action_tanh = False # False for discrete actions 229 self.predict_action = nn.Sequential( 230 *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) 231 ) 232 else: 233 blocks = [Block(h_dim, input_seq_len + 1, n_heads, drop_p) for _ in range(n_blocks)] 234 self.state_encoder = state_encoder 235 self.embed_rtg = nn.Sequential(nn.Linear(1, h_dim), nn.Tanh()) 236 self.head = nn.Linear(h_dim, act_dim, bias=False) 237 self.embed_action = nn.Sequential(nn.Embedding(act_dim, h_dim), nn.Tanh()) 238 self.transformer = nn.Sequential(*blocks) 239 240 def forward( 241 self, 242 timesteps: torch.Tensor, 243 states: torch.Tensor, 244 actions: torch.Tensor, 245 returns_to_go: torch.Tensor, 246 tar: Optional[int] = None 247 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 248 """ 249 Overview: 250 Forward computation graph of the decision transformer, input a sequence tensor \ 251 and return a tensor with the same shape. 252 Arguments: 253 - timesteps (:obj:`torch.Tensor`): The timestep for input sequence. 254 - states (:obj:`torch.Tensor`): The sequence of states. 255 - actions (:obj:`torch.Tensor`): The sequence of actions. 256 - returns_to_go (:obj:`torch.Tensor`): The sequence of return-to-go. 257 - tar (:obj:`Optional[int]`): Whether to predict action, regardless of index. 258 Returns: 259 - output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains three tensors, \ 260 they are correspondingly the predicted states, predicted actions and predicted return-to-go. 261 Examples: 262 >>> B, T = 4, 6 263 >>> state_dim = 3 264 >>> act_dim = 2 265 >>> DT_model = DecisionTransformer(\ 266 state_dim=state_dim,\ 267 act_dim=act_dim,\ 268 n_blocks=3,\ 269 h_dim=8,\ 270 context_len=T,\ 271 n_heads=2,\ 272 drop_p=0.1,\ 273 ) 274 >>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T 275 >>> states = torch.randn([B, T, state_dim]) # B x T x state_dim 276 >>> actions = torch.randint(0, act_dim, [B, T, 1]) 277 >>> action_target = torch.randint(0, act_dim, [B, T, 1]) 278 >>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float() 279 >>> traj_mask = torch.ones([B, T], dtype=torch.long) # B x T 280 >>> actions = actions.squeeze(-1) 281 >>> state_preds, action_preds, return_preds = DT_model.forward(\ 282 timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go\ 283 ) 284 >>> assert state_preds.shape == torch.Size([B, T, state_dim]) 285 >>> assert return_preds.shape == torch.Size([B, T, 1]) 286 >>> assert action_preds.shape == torch.Size([B, T, act_dim]) 287 """ 288 B, T = states.shape[0], states.shape[1] 289 if self.state_encoder is None: 290 time_embeddings = self.embed_timestep(timesteps) 291 292 # time embeddings are treated similar to positional embeddings 293 state_embeddings = self.embed_state(states) + time_embeddings 294 action_embeddings = self.embed_action(actions) + time_embeddings 295 returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings 296 297 # stack rtg, states and actions and reshape sequence as 298 # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) 299 t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings), 300 dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) 301 h = self.embed_ln(t_p) 302 # transformer and prediction 303 h = self.transformer(h) 304 # get h reshaped such that its size = (B x 3 x T x h_dim) and 305 # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t 306 # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t 307 # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t 308 # that is, for each timestep (t) we have 3 output embeddings from the transformer, 309 # each conditioned on all previous timesteps plus 310 # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. 311 h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) 312 313 return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a 314 state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a 315 action_preds = self.predict_action(h[:, 1]) # predict action given r, s 316 else: 317 state_embeddings = self.state_encoder( 318 states.reshape(-1, *self.state_dim).type(torch.float32).contiguous() 319 ) # (batch * block_size, h_dim) 320 state_embeddings = state_embeddings.reshape(B, T, self.h_dim) # (batch, block_size, h_dim) 321 returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32)) 322 action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim) 323 324 token_embeddings = torch.zeros( 325 (B, T * 3 - int(tar is None), self.h_dim), dtype=torch.float32, device=state_embeddings.device 326 ) 327 token_embeddings[:, ::3, :] = returns_embeddings 328 token_embeddings[:, 1::3, :] = state_embeddings 329 token_embeddings[:, 2::3, :] = action_embeddings[:, -T + int(tar is None):, :] 330 331 all_global_pos_emb = torch.repeat_interleave( 332 self.global_pos_emb, B, dim=0 333 ) # batch_size, traj_length, h_dim 334 335 position_embeddings = torch.gather( 336 all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1) 337 ) + self.pos_emb[:, :token_embeddings.shape[1], :] 338 339 t_p = token_embeddings + position_embeddings 340 341 h = self.drop(t_p) 342 h = self.transformer(h) 343 h = self.embed_ln(h) 344 logits = self.head(h) 345 346 return_preds = None 347 state_preds = None 348 action_preds = logits[:, 1::3, :] # only keep predictions from state_embeddings 349 350 return state_preds, action_preds, return_preds 351 352 def configure_optimizers( 353 self, weight_decay: float, learning_rate: float, betas: Tuple[float, float] = (0.9, 0.95) 354 ) -> torch.optim.Optimizer: 355 """ 356 Overview: 357 This function returns an optimizer given the input arguments. \ 358 We are separating out all parameters of the model into two buckets: those that will experience \ 359 weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 360 Arguments: 361 - weight_decay (:obj:`float`): The weigh decay of the optimizer. 362 - learning_rate (:obj:`float`): The learning rate of the optimizer. 363 - betas (:obj:`Tuple[float, float]`): The betas for Adam optimizer. 364 Outputs: 365 - optimizer (:obj:`torch.optim.Optimizer`): The desired optimizer. 366 """ 367 368 # separate out all parameters to those that will and won't experience regularizing weight decay 369 decay = set() 370 no_decay = set() 371 # whitelist_weight_modules = (torch.nn.Linear, ) 372 whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) 373 blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 374 for mn, m in self.named_modules(): 375 for pn, p in m.named_parameters(): 376 fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 377 378 if pn.endswith('bias'): 379 # all biases will not be decayed 380 no_decay.add(fpn) 381 elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 382 # weights of whitelist modules will be weight decayed 383 decay.add(fpn) 384 elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 385 # weights of blacklist modules will NOT be weight decayed 386 no_decay.add(fpn) 387 388 # special case the position embedding parameter in the root GPT module as not decayed 389 no_decay.add('pos_emb') 390 no_decay.add('global_pos_emb') 391 392 # validate that we considered every parameter 393 param_dict = {pn: p for pn, p in self.named_parameters()} 394 inter_params = decay & no_decay 395 union_params = decay | no_decay 396 assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 397 assert len(param_dict.keys() - union_params) == 0,\ 398 "parameters %s were not separated into either decay/no_decay set!" \ 399 % (str(param_dict.keys() - union_params), ) 400 401 # create the pytorch optimizer object 402 optim_groups = [ 403 { 404 "params": [param_dict[pn] for pn in sorted(list(decay))], 405 "weight_decay": weight_decay 406 }, 407 { 408 "params": [param_dict[pn] for pn in sorted(list(no_decay))], 409 "weight_decay": 0.0 410 }, 411 ] 412 optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) 413 return optimizer