ding.torch_utils.network.gtrxl¶
ding.torch_utils.network.gtrxl
¶
Overview
This file implements the core modules of GTrXL Transformer as described in "Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764).
PositionalEmbedding
¶
Bases: Module
Overview
The PositionalEmbedding module implements the positional embedding used in the vanilla Transformer model.
Interfaces:
__init__, forward
.. note:: This implementation is adapted from https://github.com/kimiyoung/transformer-xl/blob/ master/pytorch/mem_transformer.py
__init__(embedding_dim)
¶
Overview
Initialize the PositionalEmbedding module.
Arguments:
- embedding_dim: (:obj:int): The dimensionality of the embeddings.
forward(pos_seq)
¶
Overview
Compute positional embedding given a sequence of positions.
Arguments:
- pos_seq (:obj:torch.Tensor): The positional sequence, typically a 1D tensor of integers in the form of [seq_len-1, seq_len-2, ..., 1, 0],
Returns:
- pos_embedding (:obj:torch.Tensor): The computed positional embeddings. The shape of the tensor is (seq_len, 1, embedding_dim).
GRUGatingUnit
¶
Bases: Module
Overview
The GRUGatingUnit module implements the GRU gating mechanism used in the GTrXL model.
Interfaces:
__init__, forward
__init__(input_dim, bg=2.0)
¶
Overview
Initialize the GRUGatingUnit module.
Arguments:
- input_dim (:obj:int): The dimensionality of the input.
- bg (:obj:bg): The gate bias. By setting bg > 0 we can explicitly initialize the gating mechanism to be close to the identity map. This can greatly improve the learning speed and stability since it initializes the agent close to a Markovian policy (ignore attention at the beginning).
forward(x, y)
¶
Overview
Compute the output value using the GRU gating mechanism.
Arguments:
- x: (:obj:torch.Tensor): The first input tensor.
- y: (:obj:torch.Tensor): The second input tensor. x and y should have the same shape and their last dimension should match the input_dim.
Returns:
- g: (:obj:torch.Tensor): The output of the GRU gating mechanism. The shape of g matches the shapes of x and y.
Memory
¶
Overview
A class that stores the context used to add memory to Transformer.
Interfaces:
__init__, init, update, get, to
.. note:: For details, refer to Transformer-XL: https://arxiv.org/abs/1901.02860
__init__(memory_len=20, batch_size=64, embedding_dim=256, layer_num=3, memory=None)
¶
Overview
Initialize the Memory module.
Arguments:
- memory_len (:obj:int): The dimension of memory, i.e., how many past observations to use as memory.
- batch_size (:obj:int): The dimension of each batch.
- embedding_dim (:obj:int): The dimension of embedding, which is the dimension of a single observation after embedding.
- layer_num (:obj:int): The number of transformer layers.
- memory (:obj:Optional[torch.Tensor]): The initial memory. Default is None.
init(memory=None)
¶
Overview
Initialize memory with an input list of tensors or create it automatically given its dimensions.
Arguments:
- memory (:obj:Optional[torch.Tensor]): Input memory tensor with shape (layer_num, memory_len, bs, embedding_dim). Its shape is (layer_num, memory_len, bs, embedding_dim), where memory_len is length of memory, bs is batch size and embedding_dim is the dimension of embedding.
update(hidden_state)
¶
Overview
Update the memory given a sequence of hidden states. Example for single layer: (memory_len=3, hidden_size_len=2, bs=3)
m00 m01 m02 h00 h01 h02 m20 m21 m22
m = m10 m11 m12 h = h10 h11 h12 => new_m = h00 h01 h02
m20 m21 m22 h10 h11 h12
Arguments:
- hidden_state: (:obj:List[torch.Tensor]): The hidden states to update the memory. Each tensor in the list has shape (cur_seq, bs, embedding_dim), where cur_seq is the length of the sequence.
Returns:
- memory: (:obj:Optional[torch.Tensor]): The updated memory, with shape (layer_num, memory_len, bs, embedding_dim).
get()
¶
Overview
Get the current memory.
Returns:
- memory: (:obj:Optional[torch.Tensor]): The current memory, with shape (layer_num, memory_len, bs, embedding_dim).
to(device='cpu')
¶
Overview
Move the current memory to the specified device.
Arguments:
device (:obj:str): The device to move the memory to. Default is 'cpu'.
AttentionXL
¶
Bases: Module
Overview
An implementation of the Attention mechanism used in the TransformerXL model.
Interfaces:
__init__, forward
__init__(input_dim, head_dim, head_num, dropout)
¶
Overview
Initialize the AttentionXL module.
Arguments:
- input_dim (:obj:int): The dimensionality of the input features.
- head_dim (:obj:int): The dimensionality of each attention head.
- head_num (:obj:int): The number of attention heads.
- dropout (:obj:nn.Module): The dropout layer to use
forward(inputs, pos_embedding, full_input, u, v, mask=None)
¶
Overview
Compute the forward pass for the AttentionXL module.
Arguments:
- inputs (:obj:torch.Tensor): The attention input with shape (cur_seq, bs, input_dim).
- pos_embedding (:obj:torch.Tensor): The positional embedding with shape (full_seq, 1, full_seq).
- full_input (:obj:torch.Tensor): The concatenated memory and input tensor with shape (full_seq, bs, input_dim).
- u (:obj:torch.nn.Parameter): The content parameter with shape (head_num, head_dim).
- v (:obj:torch.nn.Parameter): The position parameter with shape (head_num, head_dim).
- mask (:obj:Optional[torch.Tensor]): The attention mask with shape (cur_seq, full_seq, 1). If None, no masking is applied.
Returns:
- output (:obj:torch.Tensor): The output of the attention mechanism with shape (cur_seq, bs, input_dim).
GatedTransformerXLLayer
¶
Bases: Module
Overview
This class implements the attention layer of GTrXL (Gated Transformer-XL).
Interfaces:
__init__, forward
__init__(input_dim, head_dim, hidden_dim, head_num, mlp_num, dropout, activation, gru_gating=True, gru_bias=2.0)
¶
Overview
Initialize GatedTransformerXLLayer.
Arguments:
- input_dim (:obj:int): The dimension of the input tensor.
- head_dim (:obj:int): The dimension of each head in the multi-head attention.
- hidden_dim (:obj:int): The dimension of the hidden layer in the MLP.
- head_num (:obj:int): The number of heads for the multi-head attention.
- mlp_num (:obj:int): The number of MLP layers in the attention layer.
- dropout (:obj:nn.Module): The dropout module used in the MLP and attention layers.
- activation (:obj:nn.Module): The activation function to be used in the MLP layers.
- gru_gating (:obj:bool, optional): Whether to use GRU gates. If False, replace GRU gates with residual connections. Default is True.
- gru_bias (:obj:float, optional): The bias of the GRU gate. Default is 2.
forward(inputs, pos_embedding, u, v, memory, mask=None)
¶
Overview
Compute forward pass of GTrXL layer.
Arguments:
- inputs (:obj:torch.Tensor): The attention input tensor of shape (cur_seq, bs, input_dim).
- pos_embedding (:obj:torch.Tensor): The positional embedding tensor of shape (full_seq, 1, full_seq).
- u (:obj:torch.nn.Parameter): The content parameter tensor of shape (head_num, head_dim).
- v (:obj:torch.nn.Parameter): The position parameter tensor of shape (head_num, head_dim).
- memory (:obj:torch.Tensor): The memory tensor of shape (prev_seq, bs, input_dim).
- mask (:obj:Optional[torch.Tensor]): The attention mask tensor of shape (cur_seq, full_seq, 1).
Default is None.
Returns:
- output (:obj:torch.Tensor): layer output of shape (cur_seq, bs, input_dim)
GTrXL
¶
Bases: Module
Overview
GTrXL Transformer implementation as described in "Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764).
Interfaces:
__init__, forward, reset_memory, get_memory
__init__(input_dim, head_dim=128, embedding_dim=256, head_num=2, mlp_num=2, layer_num=3, memory_len=64, dropout_ratio=0.0, activation=nn.ReLU(), gru_gating=True, gru_bias=2.0, use_embedding_layer=True)
¶
Overview
Init GTrXL Model.
Arguments:
- input_dim (:obj:int): The dimension of the input observation.
- head_dim (:obj:int, optional): The dimension of each head. Default is 128.
- embedding_dim (:obj:int, optional): The dimension of the embedding. Default is 256.
- head_num (:obj:int, optional): The number of heads for multi-head attention. Default is 2.
- mlp_num (:obj:int, optional): The number of MLP layers in the attention layer. Default is 2.
- layer_num (:obj:int, optional): The number of transformer layers. Default is 3.
- memory_len (:obj:int, optional): The length of memory. Default is 64.
- dropout_ratio (:obj:float, optional): The dropout ratio. Default is 0.
- activation (:obj:nn.Module, optional): The activation function. Default is nn.ReLU().
- gru_gating (:obj:bool, optional): If False, replace GRU gates with residual connections. Default is True.
- gru_bias (:obj:float, optional): The GRU gate bias. Default is 2.0.
- use_embedding_layer (:obj:bool, optional): If False, don't use input embedding layer. Default is True.
Raises:
- AssertionError: If embedding_dim is not an even number.
reset_memory(batch_size=None, state=None)
¶
Overview
Clear or set the memory of GTrXL.
Arguments:
- batch_size (:obj:Optional[int]): The batch size. Default is None.
- state (:obj:Optional[torch.Tensor]): The input memory with shape (layer_num, memory_len, bs, embedding_dim). Default is None.
get_memory()
¶
Overview
Returns the memory of GTrXL.
Returns:
- memory (:obj:Optional[torch.Tensor]): The output memory or None if memory has not been initialized. The shape is (layer_num, memory_len, bs, embedding_dim).
forward(x, batch_first=False, return_mem=True)
¶
Overview
Performs a forward pass on the GTrXL.
Arguments:
- x (:obj:torch.Tensor): The input tensor with shape (seq_len, bs, input_size).
- batch_first (:obj:bool, optional): If the input data has shape (bs, seq_len, input_size), set this parameter to True to transpose along the first and second dimension and obtain shape (seq_len, bs, input_size). This does not affect the output memory. Default is False. - return_mem (:obj:bool, optional): If False, return only the output tensor without dict. Default is True.
Returns:
- x (:obj:Dict[str, torch.Tensor]): A dictionary containing the transformer output of shape (seq_len, bs, embedding_size) and memory of shape (layer_num, seq_len, bs, embedding_size).
Full Source Code
../ding/torch_utils/network/gtrxl.py