Skip to content

ding.torch_utils.network.gtrxl

ding.torch_utils.network.gtrxl

Overview

This file implements the core modules of GTrXL Transformer as described in "Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764).

PositionalEmbedding

Bases: Module

Overview

The PositionalEmbedding module implements the positional embedding used in the vanilla Transformer model.

Interfaces: __init__, forward

.. note:: This implementation is adapted from https://github.com/kimiyoung/transformer-xl/blob/ master/pytorch/mem_transformer.py

__init__(embedding_dim)

Overview

Initialize the PositionalEmbedding module.

Arguments: - embedding_dim: (:obj:int): The dimensionality of the embeddings.

forward(pos_seq)

Overview

Compute positional embedding given a sequence of positions.

Arguments: - pos_seq (:obj:torch.Tensor): The positional sequence, typically a 1D tensor of integers in the form of [seq_len-1, seq_len-2, ..., 1, 0], Returns: - pos_embedding (:obj:torch.Tensor): The computed positional embeddings. The shape of the tensor is (seq_len, 1, embedding_dim).

GRUGatingUnit

Bases: Module

Overview

The GRUGatingUnit module implements the GRU gating mechanism used in the GTrXL model.

Interfaces: __init__, forward

__init__(input_dim, bg=2.0)

Overview

Initialize the GRUGatingUnit module.

Arguments: - input_dim (:obj:int): The dimensionality of the input. - bg (:obj:bg): The gate bias. By setting bg > 0 we can explicitly initialize the gating mechanism to be close to the identity map. This can greatly improve the learning speed and stability since it initializes the agent close to a Markovian policy (ignore attention at the beginning).

forward(x, y)

Overview

Compute the output value using the GRU gating mechanism.

Arguments: - x: (:obj:torch.Tensor): The first input tensor. - y: (:obj:torch.Tensor): The second input tensor. x and y should have the same shape and their last dimension should match the input_dim. Returns: - g: (:obj:torch.Tensor): The output of the GRU gating mechanism. The shape of g matches the shapes of x and y.

Memory

Overview

A class that stores the context used to add memory to Transformer.

Interfaces: __init__, init, update, get, to

.. note:: For details, refer to Transformer-XL: https://arxiv.org/abs/1901.02860

__init__(memory_len=20, batch_size=64, embedding_dim=256, layer_num=3, memory=None)

Overview

Initialize the Memory module.

Arguments: - memory_len (:obj:int): The dimension of memory, i.e., how many past observations to use as memory. - batch_size (:obj:int): The dimension of each batch. - embedding_dim (:obj:int): The dimension of embedding, which is the dimension of a single observation after embedding. - layer_num (:obj:int): The number of transformer layers. - memory (:obj:Optional[torch.Tensor]): The initial memory. Default is None.

init(memory=None)

Overview

Initialize memory with an input list of tensors or create it automatically given its dimensions.

Arguments: - memory (:obj:Optional[torch.Tensor]): Input memory tensor with shape (layer_num, memory_len, bs, embedding_dim). Its shape is (layer_num, memory_len, bs, embedding_dim), where memory_len is length of memory, bs is batch size and embedding_dim is the dimension of embedding.

update(hidden_state)

Overview

Update the memory given a sequence of hidden states. Example for single layer: (memory_len=3, hidden_size_len=2, bs=3)

    m00 m01 m02      h00 h01 h02              m20 m21 m22
m = m10 m11 m12  h = h10 h11 h12  => new_m =  h00 h01 h02
    m20 m21 m22                               h10 h11 h12

Arguments: - hidden_state: (:obj:List[torch.Tensor]): The hidden states to update the memory. Each tensor in the list has shape (cur_seq, bs, embedding_dim), where cur_seq is the length of the sequence. Returns: - memory: (:obj:Optional[torch.Tensor]): The updated memory, with shape (layer_num, memory_len, bs, embedding_dim).

get()

Overview

Get the current memory.

Returns: - memory: (:obj:Optional[torch.Tensor]): The current memory, with shape (layer_num, memory_len, bs, embedding_dim).

to(device='cpu')

Overview

Move the current memory to the specified device.

Arguments: device (:obj:str): The device to move the memory to. Default is 'cpu'.

AttentionXL

Bases: Module

Overview

An implementation of the Attention mechanism used in the TransformerXL model.

Interfaces: __init__, forward

__init__(input_dim, head_dim, head_num, dropout)

Overview

Initialize the AttentionXL module.

Arguments: - input_dim (:obj:int): The dimensionality of the input features. - head_dim (:obj:int): The dimensionality of each attention head. - head_num (:obj:int): The number of attention heads. - dropout (:obj:nn.Module): The dropout layer to use

forward(inputs, pos_embedding, full_input, u, v, mask=None)

Overview

Compute the forward pass for the AttentionXL module.

Arguments: - inputs (:obj:torch.Tensor): The attention input with shape (cur_seq, bs, input_dim). - pos_embedding (:obj:torch.Tensor): The positional embedding with shape (full_seq, 1, full_seq). - full_input (:obj:torch.Tensor): The concatenated memory and input tensor with shape (full_seq, bs, input_dim). - u (:obj:torch.nn.Parameter): The content parameter with shape (head_num, head_dim). - v (:obj:torch.nn.Parameter): The position parameter with shape (head_num, head_dim). - mask (:obj:Optional[torch.Tensor]): The attention mask with shape (cur_seq, full_seq, 1). If None, no masking is applied. Returns: - output (:obj:torch.Tensor): The output of the attention mechanism with shape (cur_seq, bs, input_dim).

GatedTransformerXLLayer

Bases: Module

Overview

This class implements the attention layer of GTrXL (Gated Transformer-XL).

Interfaces: __init__, forward

__init__(input_dim, head_dim, hidden_dim, head_num, mlp_num, dropout, activation, gru_gating=True, gru_bias=2.0)

Overview

Initialize GatedTransformerXLLayer.

Arguments: - input_dim (:obj:int): The dimension of the input tensor. - head_dim (:obj:int): The dimension of each head in the multi-head attention. - hidden_dim (:obj:int): The dimension of the hidden layer in the MLP. - head_num (:obj:int): The number of heads for the multi-head attention. - mlp_num (:obj:int): The number of MLP layers in the attention layer. - dropout (:obj:nn.Module): The dropout module used in the MLP and attention layers. - activation (:obj:nn.Module): The activation function to be used in the MLP layers. - gru_gating (:obj:bool, optional): Whether to use GRU gates. If False, replace GRU gates with residual connections. Default is True. - gru_bias (:obj:float, optional): The bias of the GRU gate. Default is 2.

forward(inputs, pos_embedding, u, v, memory, mask=None)

Overview

Compute forward pass of GTrXL layer.

Arguments: - inputs (:obj:torch.Tensor): The attention input tensor of shape (cur_seq, bs, input_dim). - pos_embedding (:obj:torch.Tensor): The positional embedding tensor of shape (full_seq, 1, full_seq). - u (:obj:torch.nn.Parameter): The content parameter tensor of shape (head_num, head_dim). - v (:obj:torch.nn.Parameter): The position parameter tensor of shape (head_num, head_dim). - memory (:obj:torch.Tensor): The memory tensor of shape (prev_seq, bs, input_dim). - mask (:obj:Optional[torch.Tensor]): The attention mask tensor of shape (cur_seq, full_seq, 1). Default is None. Returns: - output (:obj:torch.Tensor): layer output of shape (cur_seq, bs, input_dim)

GTrXL

Bases: Module

Overview

GTrXL Transformer implementation as described in "Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764).

Interfaces: __init__, forward, reset_memory, get_memory

__init__(input_dim, head_dim=128, embedding_dim=256, head_num=2, mlp_num=2, layer_num=3, memory_len=64, dropout_ratio=0.0, activation=nn.ReLU(), gru_gating=True, gru_bias=2.0, use_embedding_layer=True)

Overview

Init GTrXL Model.

Arguments: - input_dim (:obj:int): The dimension of the input observation. - head_dim (:obj:int, optional): The dimension of each head. Default is 128. - embedding_dim (:obj:int, optional): The dimension of the embedding. Default is 256. - head_num (:obj:int, optional): The number of heads for multi-head attention. Default is 2. - mlp_num (:obj:int, optional): The number of MLP layers in the attention layer. Default is 2. - layer_num (:obj:int, optional): The number of transformer layers. Default is 3. - memory_len (:obj:int, optional): The length of memory. Default is 64. - dropout_ratio (:obj:float, optional): The dropout ratio. Default is 0. - activation (:obj:nn.Module, optional): The activation function. Default is nn.ReLU(). - gru_gating (:obj:bool, optional): If False, replace GRU gates with residual connections. Default is True. - gru_bias (:obj:float, optional): The GRU gate bias. Default is 2.0. - use_embedding_layer (:obj:bool, optional): If False, don't use input embedding layer. Default is True. Raises: - AssertionError: If embedding_dim is not an even number.

reset_memory(batch_size=None, state=None)

Overview

Clear or set the memory of GTrXL.

Arguments: - batch_size (:obj:Optional[int]): The batch size. Default is None. - state (:obj:Optional[torch.Tensor]): The input memory with shape (layer_num, memory_len, bs, embedding_dim). Default is None.

get_memory()

Overview

Returns the memory of GTrXL.

Returns: - memory (:obj:Optional[torch.Tensor]): The output memory or None if memory has not been initialized. The shape is (layer_num, memory_len, bs, embedding_dim).

forward(x, batch_first=False, return_mem=True)

Overview

Performs a forward pass on the GTrXL.

Arguments: - x (:obj:torch.Tensor): The input tensor with shape (seq_len, bs, input_size). - batch_first (:obj:bool, optional): If the input data has shape (bs, seq_len, input_size), set this parameter to True to transpose along the first and second dimension and obtain shape (seq_len, bs, input_size). This does not affect the output memory. Default is False. - return_mem (:obj:bool, optional): If False, return only the output tensor without dict. Default is True. Returns: - x (:obj:Dict[str, torch.Tensor]): A dictionary containing the transformer output of shape (seq_len, bs, embedding_size) and memory of shape (layer_num, seq_len, bs, embedding_size).

Full Source Code

../ding/torch_utils/network/gtrxl.py

1""" 2Overview: 3 This file implements the core modules of GTrXL Transformer as described in 4 "Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764). 5""" 6from typing import Optional, Dict, List 7import warnings 8import numpy as np 9import torch 10import torch.nn as nn 11from ding.torch_utils.network.nn_module import fc_block, build_normalization, F 12 13 14class PositionalEmbedding(nn.Module): 15 """ 16 Overview: 17 The PositionalEmbedding module implements the positional embedding used in the vanilla Transformer model. 18 Interfaces: 19 ``__init__``, ``forward`` 20 21 .. note:: 22 This implementation is adapted from https://github.com/kimiyoung/transformer-xl/blob/ \ 23 master/pytorch/mem_transformer.py 24 """ 25 26 def __init__(self, embedding_dim: int): 27 """ 28 Overview: 29 Initialize the PositionalEmbedding module. 30 Arguments: 31 - embedding_dim: (:obj:`int`): The dimensionality of the embeddings. 32 """ 33 34 super(PositionalEmbedding, self).__init__() 35 self.embedding_dim = embedding_dim 36 inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_dim, 2.0) / embedding_dim)) # (embedding_dim / 2) 37 self.register_buffer('inv_freq', inv_freq) 38 39 def forward(self, pos_seq: torch.Tensor) -> torch.Tensor: 40 """ 41 Overview: 42 Compute positional embedding given a sequence of positions. 43 Arguments: 44 - pos_seq (:obj:`torch.Tensor`): The positional sequence, \ 45 typically a 1D tensor of integers in the form of [seq_len-1, seq_len-2, ..., 1, 0], 46 Returns: 47 - pos_embedding (:obj:`torch.Tensor`): The computed positional embeddings. \ 48 The shape of the tensor is (seq_len, 1, embedding_dim). 49 """ 50 51 sinusoid_inp = torch.outer(pos_seq, self.inv_freq) 52 # For position embedding, the order of sin/cos is negligible. 53 # This is because tokens are consumed by the matrix multiplication which is permutation-invariant. 54 pos_embedding = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 55 return pos_embedding.unsqueeze(1) 56 57 58class GRUGatingUnit(torch.nn.Module): 59 """ 60 Overview: 61 The GRUGatingUnit module implements the GRU gating mechanism used in the GTrXL model. 62 Interfaces: 63 ``__init__``, ``forward`` 64 """ 65 66 def __init__(self, input_dim: int, bg: float = 2.): 67 """ 68 Overview: 69 Initialize the GRUGatingUnit module. 70 Arguments: 71 - input_dim (:obj:`int`): The dimensionality of the input. 72 - bg (:obj:`bg`): The gate bias. By setting bg > 0 we can explicitly initialize the gating mechanism to \ 73 be close to the identity map. This can greatly improve the learning speed and stability since it \ 74 initializes the agent close to a Markovian policy (ignore attention at the beginning). 75 """ 76 77 super(GRUGatingUnit, self).__init__() 78 self.Wr = torch.nn.Linear(input_dim, input_dim, bias=False) 79 self.Ur = torch.nn.Linear(input_dim, input_dim, bias=False) 80 self.Wz = torch.nn.Linear(input_dim, input_dim, bias=False) 81 self.Uz = torch.nn.Linear(input_dim, input_dim, bias=False) 82 self.Wg = torch.nn.Linear(input_dim, input_dim, bias=False) 83 self.Ug = torch.nn.Linear(input_dim, input_dim, bias=False) 84 self.bg = nn.Parameter(torch.full([input_dim], bg)) # bias 85 self.sigmoid = torch.nn.Sigmoid() 86 self.tanh = torch.nn.Tanh() 87 88 def forward(self, x: torch.Tensor, y: torch.Tensor): 89 """ 90 Overview: 91 Compute the output value using the GRU gating mechanism. 92 Arguments: 93 - x: (:obj:`torch.Tensor`): The first input tensor. 94 - y: (:obj:`torch.Tensor`): The second input tensor. \ 95 x and y should have the same shape and their last dimension should match the input_dim. 96 Returns: 97 - g: (:obj:`torch.Tensor`): The output of the GRU gating mechanism. \ 98 The shape of g matches the shapes of x and y. 99 """ 100 101 r = self.sigmoid(self.Wr(y) + self.Ur(x)) 102 z = self.sigmoid(self.Wz(y) + self.Uz(x) - self.bg) 103 h = self.tanh(self.Wg(y) + self.Ug(torch.mul(r, x))) # element wise multiplication 104 g = torch.mul(1 - z, x) + torch.mul(z, h) 105 return g # x.shape == y.shape == g.shape 106 107 108class Memory: 109 """ 110 Overview: 111 A class that stores the context used to add memory to Transformer. 112 Interfaces: 113 ``__init__``, ``init``, ``update``, ``get``, ``to`` 114 115 .. note:: 116 For details, refer to Transformer-XL: https://arxiv.org/abs/1901.02860 117 """ 118 119 def __init__( 120 self, 121 memory_len: int = 20, 122 batch_size: int = 64, 123 embedding_dim: int = 256, 124 layer_num: int = 3, 125 memory: Optional[torch.Tensor] = None 126 ) -> None: 127 """ 128 Overview: 129 Initialize the Memory module. 130 Arguments: 131 - memory_len (:obj:`int`): The dimension of memory, i.e., how many past observations to use as memory. 132 - batch_size (:obj:`int`): The dimension of each batch. 133 - embedding_dim (:obj:`int`): The dimension of embedding, which is the dimension of a single observation \ 134 after embedding. 135 - layer_num (:obj:`int`): The number of transformer layers. 136 - memory (:obj:`Optional[torch.Tensor]`): The initial memory. Default is None. 137 """ 138 139 super(Memory, self).__init__() 140 self.embedding_dim = embedding_dim 141 self.bs = batch_size 142 self.layer_num = layer_num 143 self.memory_len = memory_len 144 self.memory = None 145 self.init(memory) 146 147 def init(self, memory: Optional[torch.Tensor] = None): 148 """ 149 Overview: 150 Initialize memory with an input list of tensors or create it automatically given its dimensions. 151 Arguments: 152 - memory (:obj:`Optional[torch.Tensor]`): Input memory tensor with shape \ 153 (layer_num, memory_len, bs, embedding_dim). Its shape is (layer_num, memory_len, bs, embedding_dim), \ 154 where memory_len is length of memory, bs is batch size and embedding_dim is the dimension of embedding. 155 """ 156 157 if memory is not None: 158 self.memory = memory 159 layer_num_plus1, self.memory_len, self.bs, self.embedding_dim = memory.shape 160 self.layer_num = layer_num_plus1 - 1 161 else: 162 self.memory = torch.zeros( 163 self.layer_num + 1, self.memory_len, self.bs, self.embedding_dim, dtype=torch.float 164 ) 165 166 def update(self, hidden_state: List[torch.Tensor]): 167 """ 168 Overview: 169 Update the memory given a sequence of hidden states. 170 Example for single layer: (memory_len=3, hidden_size_len=2, bs=3) 171 172 m00 m01 m02 h00 h01 h02 m20 m21 m22 173 m = m10 m11 m12 h = h10 h11 h12 => new_m = h00 h01 h02 174 m20 m21 m22 h10 h11 h12 175 Arguments: 176 - hidden_state: (:obj:`List[torch.Tensor]`): The hidden states to update the memory. \ 177 Each tensor in the list has shape (cur_seq, bs, embedding_dim), where cur_seq \ 178 is the length of the sequence. 179 Returns: 180 - memory: (:obj:`Optional[torch.Tensor]`): The updated memory, with shape \ 181 (layer_num, memory_len, bs, embedding_dim). 182 """ 183 184 if self.memory is None or hidden_state is None: 185 raise ValueError('Failed to update memory! Memory would be None') # TODO add support of no memory 186 sequence_len = hidden_state[0].shape[0] 187 with torch.no_grad(): 188 new_memory = [] 189 end = self.memory_len + sequence_len 190 beg = max(0, end - self.memory_len) 191 for i in range(self.layer_num + 1): 192 m = self.memory[i] 193 h = hidden_state[i] 194 cat = torch.cat([m, h], dim=0) 195 new_memory.append(cat[beg:end].detach()) 196 new_memory = torch.stack(new_memory, dim=0) 197 self.memory = new_memory 198 return new_memory 199 200 def get(self): 201 """ 202 Overview: 203 Get the current memory. 204 Returns: 205 - memory: (:obj:`Optional[torch.Tensor]`): The current memory, \ 206 with shape (layer_num, memory_len, bs, embedding_dim). 207 """ 208 209 return self.memory 210 211 def to(self, device: str = 'cpu'): 212 """ 213 Overview: 214 Move the current memory to the specified device. 215 Arguments: 216 device (:obj:`str`): The device to move the memory to. Default is 'cpu'. 217 """ 218 219 self.memory = self.memory.to(device) 220 221 222class AttentionXL(torch.nn.Module): 223 """ 224 Overview: 225 An implementation of the Attention mechanism used in the TransformerXL model. 226 Interfaces: 227 ``__init__``, ``forward`` 228 """ 229 230 def __init__(self, input_dim: int, head_dim: int, head_num: int, dropout: nn.Module) -> None: 231 """ 232 Overview: 233 Initialize the AttentionXL module. 234 Arguments: 235 - input_dim (:obj:`int`): The dimensionality of the input features. 236 - head_dim (:obj:`int`): The dimensionality of each attention head. 237 - head_num (:obj:`int`): The number of attention heads. 238 - dropout (:obj:`nn.Module`): The dropout layer to use 239 """ 240 241 super(AttentionXL, self).__init__() 242 self.head_num = head_num 243 self.head_dim = head_dim 244 self.dropout = dropout 245 self.attention_kv = fc_block(input_dim, head_dim * head_num * 2) # key, value 246 self.attention_q = fc_block(input_dim, head_dim * head_num) # query (not computed with past hidden states) 247 self.project = fc_block(head_dim * head_num, input_dim) # project attention output back to input_dim 248 self.project_pos = fc_block(input_dim, head_dim * head_num) # project the positional embedding 249 self.scale = 1 / (head_dim ** 0.5) # for scaled dot product attention 250 251 def _rel_shift(self, x: torch.Tensor, zero_upper: bool = False) -> torch.Tensor: 252 """ 253 Overview: 254 Perform a relative shift operation on the attention score matrix. 255 Example: 256 a00 a01 a02 0 a00 a01 a02 0 a00 a01 a02 0 a10 a02 0 0 257 a10 a11 a12 => 0 a10 a11 a12 => a02 0 a10 => a11 a12 0 => a11 a12 0 258 a20 a21 a22 0 a20 a21 a22 a11 a12 0 a20 a21 a22 a20 a21 a22 259 a20 a21 a22 260 1) Append one "column" of zeros to the left 261 2) Reshape the matrix from [3 x 4] into [4 x 3] 262 3) Remove the first "row" 263 4) Mask out the upper triangle (optional) 264 265 .. note:: 266 See the following material for better understanding: https://github.com/kimiyoung/transformer-xl/issues/8 \ 267 https://arxiv.org/pdf/1901.02860.pdf (Appendix B) 268 Arguments: 269 - x (:obj:`torch.Tensor`): The input tensor with shape (cur_seq, full_seq, bs, head_num). 270 - zero_upper (:obj:`bool`): If True, the upper-right triangle of the matrix is set to zero. 271 Returns: 272 - x (:obj:`torch.Tensor`): The input tensor after the relative shift operation, \ 273 with shape (cur_seq, full_seq, bs, head_num). 274 """ 275 276 x_padded = F.pad(x, [1, 0]) # step 1 277 x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) # step 2 278 x = x_padded[:, :, 1:].view_as(x) # step 3 279 if zero_upper: 280 ones = torch.ones((x.size(2), x.size(3))).unsqueeze(0).unsqueeze(0) 281 x = x * torch.tril(ones.to(x.device), x.size(3) - x.size(2)) # step 4 282 return x 283 284 def forward( 285 self, 286 inputs: torch.Tensor, 287 pos_embedding: torch.Tensor, 288 full_input: torch.Tensor, 289 u: torch.nn.Parameter, 290 v: torch.nn.Parameter, 291 mask: Optional[torch.Tensor] = None, 292 ) -> torch.Tensor: 293 """ 294 Overview: 295 Compute the forward pass for the AttentionXL module. 296 Arguments: 297 - inputs (:obj:`torch.Tensor`): The attention input with shape (cur_seq, bs, input_dim). 298 - pos_embedding (:obj:`torch.Tensor`): The positional embedding with shape (full_seq, 1, full_seq). 299 - full_input (:obj:`torch.Tensor`): The concatenated memory and input tensor with shape \ 300 (full_seq, bs, input_dim). 301 - u (:obj:`torch.nn.Parameter`): The content parameter with shape (head_num, head_dim). 302 - v (:obj:`torch.nn.Parameter`): The position parameter with shape (head_num, head_dim). 303 - mask (:obj:`Optional[torch.Tensor]`): The attention mask with shape (cur_seq, full_seq, 1). \ 304 If None, no masking is applied. 305 Returns: 306 - output (:obj:`torch.Tensor`): The output of the attention mechanism with shape (cur_seq, bs, input_dim). 307 """ 308 309 bs, cur_seq, full_seq = inputs.shape[1], inputs.shape[0], full_input.shape[0] 310 prev_seq = full_seq - cur_seq 311 312 kv = self.attention_kv(full_input) 313 key, value = torch.chunk(kv, 2, dim=-1) # full_seq x bs x num_head*dim_head 314 query = self.attention_q(inputs) # cur_seq x bs x num_head*dim_head 315 r = self.project_pos(pos_embedding) # full_seq x 1 x num_head*dim_head 316 317 key = key.view(full_seq, bs, self.head_num, self.head_dim) 318 query = query.view(cur_seq, bs, self.head_num, self.head_dim) 319 value = value.view(cur_seq + prev_seq, bs, self.head_num, self.head_dim) 320 r = r.view(full_seq, self.head_num, self.head_dim) 321 322 # (query + u) * key^T 323 q_u = query + u 324 content_attn = q_u.permute(1, 2, 0, 3) @ key.permute(1, 2, 3, 0) # bs x head_num x cur_seq x full_seq 325 326 # (query + v) * R^T 327 q_v = query + v 328 position_attn = q_v.permute(1, 2, 0, 3) @ r.permute(1, 2, 0) # bs x head_num x cur_seq x full_seq 329 position_attn = self._rel_shift(position_attn) 330 331 attn = content_attn + position_attn # bs x head_num x cur_seq x full_seq 332 attn.mul_(self.scale) 333 334 # fills float('-inf') where mask is True to let softmax ignore those positions. 335 if mask is not None and mask.any().item(): 336 mask = mask.permute(2, 0, 1).unsqueeze(1) # 1 x 1 x cur_seq x full_seq 337 assert mask.shape[2:] == attn.shape[2:] # check shape of mask 338 attn = attn.masked_fill(mask, -float("inf")).type_as(attn) 339 340 attn = F.softmax(attn, dim=-1) 341 attn = self.dropout(attn) 342 343 # multiply softmax output by value 344 attn_vec = attn @ value.permute(1, 2, 0, 3) 345 attn_vec = attn_vec.permute(2, 0, 1, 3) 346 347 attn_vec = attn_vec.contiguous().view(cur_seq, bs, self.head_num * self.head_dim) 348 # cur_seq x bs x head_num * head_dim 349 output = self.dropout(self.project(attn_vec)) # cur_seq x bs x input_dim 350 return output 351 352 353class GatedTransformerXLLayer(torch.nn.Module): 354 """ 355 Overview: 356 This class implements the attention layer of GTrXL (Gated Transformer-XL). 357 Interfaces: 358 ``__init__``, ``forward`` 359 """ 360 361 def __init__( 362 self, 363 input_dim: int, 364 head_dim: int, 365 hidden_dim: int, 366 head_num: int, 367 mlp_num: int, 368 dropout: nn.Module, 369 activation: nn.Module, 370 gru_gating: bool = True, 371 gru_bias: float = 2. 372 ) -> None: 373 """ 374 Overview: 375 Initialize GatedTransformerXLLayer. 376 Arguments: 377 - input_dim (:obj:`int`): The dimension of the input tensor. 378 - head_dim (:obj:`int`): The dimension of each head in the multi-head attention. 379 - hidden_dim (:obj:`int`): The dimension of the hidden layer in the MLP. 380 - head_num (:obj:`int`): The number of heads for the multi-head attention. 381 - mlp_num (:obj:`int`): The number of MLP layers in the attention layer. 382 - dropout (:obj:`nn.Module`): The dropout module used in the MLP and attention layers. 383 - activation (:obj:`nn.Module`): The activation function to be used in the MLP layers. 384 - gru_gating (:obj:`bool`, optional): Whether to use GRU gates. If False, replace GRU gates with \ 385 residual connections. Default is True. 386 - gru_bias (:obj:`float`, optional): The bias of the GRU gate. Default is 2. 387 """ 388 389 super(GatedTransformerXLLayer, self).__init__() 390 self.dropout = dropout 391 self.gating = gru_gating 392 if self.gating is True: 393 self.gate1 = GRUGatingUnit(input_dim, gru_bias) 394 self.gate2 = GRUGatingUnit(input_dim, gru_bias) 395 self.attention = AttentionXL( 396 input_dim, 397 head_dim, 398 head_num, 399 dropout, 400 ) 401 layers = [] 402 dims = [input_dim] + [hidden_dim] * (mlp_num - 1) + [input_dim] 403 for i in range(mlp_num): 404 layers.append(fc_block(dims[i], dims[i + 1], activation=activation)) 405 if i != mlp_num - 1: 406 layers.append(self.dropout) 407 layers.append(self.dropout) 408 self.mlp = nn.Sequential(*layers) 409 self.layernorm1 = build_normalization('LN')(input_dim) 410 self.layernorm2 = build_normalization('LN')(input_dim) 411 self.activation = activation 412 413 def forward( 414 self, 415 inputs: torch.Tensor, 416 pos_embedding: torch.Tensor, 417 u: torch.nn.Parameter, 418 v: torch.nn.Parameter, 419 memory: torch.Tensor, 420 mask: Optional[torch.Tensor] = None, 421 ) -> torch.Tensor: 422 """ 423 Overview: 424 Compute forward pass of GTrXL layer. 425 Arguments: 426 - inputs (:obj:`torch.Tensor`): The attention input tensor of shape (cur_seq, bs, input_dim). 427 - pos_embedding (:obj:`torch.Tensor`): The positional embedding tensor of shape (full_seq, 1, full_seq). 428 - u (:obj:`torch.nn.Parameter`): The content parameter tensor of shape (head_num, head_dim). 429 - v (:obj:`torch.nn.Parameter`): The position parameter tensor of shape (head_num, head_dim). 430 - memory (:obj:`torch.Tensor`): The memory tensor of shape (prev_seq, bs, input_dim). 431 - mask (:obj:`Optional[torch.Tensor]`): The attention mask tensor of shape (cur_seq, full_seq, 1). 432 Default is None. 433 Returns: 434 - output (:obj:`torch.Tensor`): layer output of shape (cur_seq, bs, input_dim) 435 """ 436 437 # concat memory with input across sequence dimension 438 full_input = torch.cat([memory, inputs], dim=0) # full_seq x bs x input_dim 439 x1 = self.layernorm1(full_input) 440 a1 = self.dropout(self.attention(inputs, pos_embedding, x1, u, v, mask=mask)) 441 a1 = self.activation(a1) # RELU after attention 442 o1 = self.gate1(inputs, a1) if self.gating else inputs + a1 443 x2 = self.layernorm2(o1) 444 m2 = self.dropout(self.mlp(x2)) 445 o2 = self.gate2(o1, m2) if self.gating else o1 + m2 446 return o2 447 448 449class GTrXL(nn.Module): 450 """ 451 Overview: 452 GTrXL Transformer implementation as described in "Stabilizing Transformer for Reinforcement Learning" 453 (https://arxiv.org/abs/1910.06764). 454 Interfaces: 455 ``__init__``, ``forward``, ``reset_memory``, ``get_memory`` 456 """ 457 458 def __init__( 459 self, 460 input_dim: int, 461 head_dim: int = 128, 462 embedding_dim: int = 256, 463 head_num: int = 2, 464 mlp_num: int = 2, 465 layer_num: int = 3, 466 memory_len: int = 64, 467 dropout_ratio: float = 0., 468 activation: nn.Module = nn.ReLU(), 469 gru_gating: bool = True, 470 gru_bias: float = 2., 471 use_embedding_layer: bool = True, 472 ) -> None: 473 """Overview: 474 Init GTrXL Model. 475 Arguments: 476 - input_dim (:obj:`int`): The dimension of the input observation. 477 - head_dim (:obj:`int`, optional): The dimension of each head. Default is 128. 478 - embedding_dim (:obj:`int`, optional): The dimension of the embedding. Default is 256. 479 - head_num (:obj:`int`, optional): The number of heads for multi-head attention. Default is 2. 480 - mlp_num (:obj:`int`, optional): The number of MLP layers in the attention layer. Default is 2. 481 - layer_num (:obj:`int`, optional): The number of transformer layers. Default is 3. 482 - memory_len (:obj:`int`, optional): The length of memory. Default is 64. 483 - dropout_ratio (:obj:`float`, optional): The dropout ratio. Default is 0. 484 - activation (:obj:`nn.Module`, optional): The activation function. Default is nn.ReLU(). 485 - gru_gating (:obj:`bool`, optional): If False, replace GRU gates with residual connections. \ 486 Default is True. 487 - gru_bias (:obj:`float`, optional): The GRU gate bias. Default is 2.0. 488 - use_embedding_layer (:obj:`bool`, optional): If False, don't use input embedding layer. Default is True. 489 Raises: 490 - AssertionError: If `embedding_dim` is not an even number. 491 """ 492 493 super(GTrXL, self).__init__() 494 assert embedding_dim % 2 == 0, 'embedding_dim={} should be even'.format(input_dim) 495 self.head_num = head_num 496 self.head_dim = head_dim 497 self.layer_num = layer_num 498 if isinstance(input_dim, list): 499 input_dim = np.prod(input_dim) 500 self.use_embedding_layer = use_embedding_layer 501 if use_embedding_layer: 502 self.embedding = fc_block(input_dim, embedding_dim, activation=activation) 503 self.activation = activation 504 self.pos_embedding = PositionalEmbedding(embedding_dim) 505 # memory to save hidden states of past segments 506 # it will be initialized in the forward method to get its size dynamically 507 self.memory = None 508 self.memory_len = memory_len 509 layers = [] 510 dims = [embedding_dim] + [embedding_dim] * layer_num 511 self.dropout = nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Identity() 512 for i in range(layer_num): 513 layers.append( 514 GatedTransformerXLLayer( 515 dims[i], head_dim, embedding_dim, head_num, mlp_num, self.dropout, self.activation, gru_gating, 516 gru_bias 517 ) 518 ) 519 self.layers = nn.Sequential(*layers) 520 self.embedding_dim = embedding_dim 521 # u and v are the parameters to compute global content bias and global positional bias 522 self.u, self.v = ( 523 torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)), 524 torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)), 525 ) 526 self.att_mask = {} # create an attention mask for each different seq_len, in this way we don't need to create a 527 # new one each time we call the forward method 528 self.pos_embedding_dict = {} # create a pos embedding for each different seq_len 529 530 def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.Tensor] = None): 531 """ 532 Overview: 533 Clear or set the memory of GTrXL. 534 Arguments: 535 - batch_size (:obj:`Optional[int]`): The batch size. Default is None. 536 - state (:obj:`Optional[torch.Tensor]`): The input memory with shape \ 537 (layer_num, memory_len, bs, embedding_dim). Default is None. 538 """ 539 540 self.memory = Memory(memory_len=self.memory_len, layer_num=self.layer_num, embedding_dim=self.embedding_dim) 541 if batch_size is not None: 542 self.memory = Memory(self.memory_len, batch_size, self.embedding_dim, self.layer_num) 543 elif state is not None: 544 self.memory.init(state) 545 546 def get_memory(self): 547 """ 548 Overview: 549 Returns the memory of GTrXL. 550 Returns: 551 - memory (:obj:`Optional[torch.Tensor]`): The output memory or None if memory has not been initialized. \ 552 The shape is (layer_num, memory_len, bs, embedding_dim). 553 """ 554 555 if self.memory is None: 556 return None 557 else: 558 return self.memory.get() 559 560 def forward(self, x: torch.Tensor, batch_first: bool = False, return_mem: bool = True) -> Dict[str, torch.Tensor]: 561 """ 562 Overview: 563 Performs a forward pass on the GTrXL. 564 Arguments: 565 - x (:obj:`torch.Tensor`): The input tensor with shape (seq_len, bs, input_size). 566 - batch_first (:obj:`bool`, optional): If the input data has shape (bs, seq_len, input_size), \ 567 set this parameter to True to transpose along the first and second dimension and obtain shape \ 568 (seq_len, bs, input_size). This does not affect the output memory. Default is False. \ 569 - return_mem (:obj:`bool`, optional): If False, return only the output tensor without dict. Default is True. 570 Returns: 571 - x (:obj:`Dict[str, torch.Tensor]`): A dictionary containing the transformer output of shape \ 572 (seq_len, bs, embedding_size) and memory of shape (layer_num, seq_len, bs, embedding_size). 573 """ 574 575 if batch_first: 576 x = torch.transpose(x, 1, 0) # bs x cur_seq x input_dim -> cur_seq x bs x input_dim 577 cur_seq, bs = x.shape[:2] 578 memory = None if self.memory is None else self.memory.get() 579 if memory is None: 580 self.reset_memory(bs) # (layer_num+1) x memory_len x batch_size x embedding_dim 581 elif memory.shape[-2] != bs or memory.shape[-1] != self.embedding_dim: 582 warnings.warn( 583 "Memory {} and Input {} dimensions don't match," 584 " this will cause the memory to be initialized to fit your input!".format( 585 list(memory.shape[-2:]), [x.shape[-2]] + [self.embedding_dim] 586 ) 587 ) 588 self.reset_memory(bs) 589 self.memory.to(x.device) 590 memory = self.memory.get() 591 592 if self.use_embedding_layer: 593 x = self.dropout(self.embedding(x)) 594 prev_seq = self.memory_len 595 full_seq = cur_seq + prev_seq 596 597 if cur_seq in self.att_mask.keys(): 598 attn_mask = self.att_mask[cur_seq] 599 else: 600 attn_mask = ( 601 torch.triu( 602 torch.ones((cur_seq, full_seq)), 603 diagonal=1 + prev_seq, # fixed in train, eval, collect 604 ).bool().unsqueeze(-1).to(x.device) 605 ) # cur_seq x full_seq x 1 606 self.att_mask[cur_seq] = attn_mask 607 608 if cur_seq in self.pos_embedding_dict.keys(): 609 pos_embedding = self.pos_embedding_dict[cur_seq] 610 else: 611 pos_ips = torch.arange(full_seq - 1, -1, -1.0, dtype=torch.float) # full_seq 612 pos_embedding = self.pos_embedding(pos_ips.to(x.device)) 613 self.pos_embedding_dict[cur_seq] = pos_embedding 614 pos_embedding = self.dropout(pos_embedding) # full_seq x 1 x embedding_dim 615 616 hidden_state = [x] 617 out = x 618 for i in range(self.layer_num): 619 layer = self.layers[i] 620 out = layer( 621 out, 622 pos_embedding, 623 self.u, 624 self.v, 625 mask=attn_mask, 626 memory=memory[i], # (layer_num+1) x memory_len x batch_size x embedding_dim 627 ) # cur_seq x bs x embedding_dim 628 hidden_state.append(out.clone()) 629 630 out = self.dropout(out) 631 self.memory.update(hidden_state) # (layer_num+1) x memory_len x batch_size x embedding_dim 632 633 if batch_first: 634 out = torch.transpose(out, 1, 0) # cur_seq x bs x embedding_dim -> bs x cur_seq x embedding_dim 635 if return_mem: 636 output = {"logit": out, "memory": memory} # return the content of the memory before the last update 637 else: 638 output = {"logit": out} 639 return output