Skip to content

ding.torch_utils.network.rnn

ding.torch_utils.network.rnn

LSTMForwardWrapper

Bases: object

Overview

Class providing methods to use before and after the LSTM forward method. Wraps the LSTM forward method.

Interfaces: _before_forward, _after_forward

LSTM

Bases: Module, LSTMForwardWrapper

Overview

Implementation of an LSTM cell with Layer Normalization (LN).

Interfaces: __init__, forward

.. note::

For a primer on LSTM, refer to https://zhuanlan.zhihu.com/p/32085405.

__init__(input_size, hidden_size, num_layers, norm_type=None, dropout=0.0)

Overview

Initialize LSTM cell parameters.

Arguments: - input_size (:obj:int): Size of the input vector. - hidden_size (:obj:int): Size of the hidden state vector. - num_layers (:obj:int): Number of LSTM layers. - norm_type (:obj:Optional[str]): Normalization type, default is None. - dropout (:obj:float): Dropout rate, default is 0.

forward(inputs, prev_state, list_next_state=True)

Overview

Compute output and next state given previous state and input.

Arguments: - inputs (:obj:torch.Tensor): Input vector of cell, size [seq_len, batch_size, input_size]. - prev_state (:obj:torch.Tensor): Previous state, size [num_directions*num_layers, batch_size, hidden_size]. - list_next_state (:obj:bool): Whether to return next_state in list format, default is True. Returns: - x (:obj:torch.Tensor): Output from LSTM. - next_state (:obj:Union[torch.Tensor, list]): Hidden state from LSTM.

PytorchLSTM

Bases: LSTM, LSTMForwardWrapper

Overview

Wrapper class for PyTorch's nn.LSTM, formats the input and output. For more details on nn.LSTM, refer to https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

Interfaces: forward

forward(inputs, prev_state, list_next_state=True)

Overview

Executes nn.LSTM.forward with preprocessed input.

Arguments: - inputs (:obj:torch.Tensor): Input vector of cell, size [seq_len, batch_size, input_size]. - prev_state (:obj:torch.Tensor): Previous state, size [num_directions*num_layers, batch_size, hidden_size]. - list_next_state (:obj:bool): Whether to return next_state in list format, default is True. Returns: - output (:obj:torch.Tensor): Output from LSTM. - next_state (:obj:Union[torch.Tensor, list]): Hidden state from LSTM.

GRU

Bases: GRUCell, LSTMForwardWrapper

Overview

This class extends the torch.nn.GRUCell and LSTMForwardWrapper classes, and formats inputs and outputs accordingly.

Interfaces: __init__, forward Properties: hidden_size, num_layers

.. note:: For further details, refer to the official PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU

__init__(input_size, hidden_size, num_layers)

Overview

Initialize the GRU class with input size, hidden size, and number of layers.

Arguments: - input_size (:obj:int): The size of the input vector. - hidden_size (:obj:int): The size of the hidden state vector. - num_layers (:obj:int): The number of GRU layers.

forward(inputs, prev_state=None, list_next_state=True)

Overview

Wrap the nn.GRU.forward method.

Arguments: - inputs (:obj:torch.Tensor): Input vector of cell, tensor of size [seq_len, batch_size, input_size]. - prev_state (:obj:Optional[torch.Tensor]): None or tensor of size [num_directions*num_layers, batch_size, hidden_size]. - list_next_state (:obj:bool): Whether to return next_state in list format (default is True). Returns: - output (:obj:torch.Tensor): Output from GRU. - next_state (:obj:torch.Tensor or :obj:list): Hidden state from GRU.

is_sequence(data)

Overview

Determines if the input data is of type list or tuple.

Arguments: - data: The input data to be checked. Returns: - boolean: True if the input is a list or a tuple, False otherwise.

sequence_mask(lengths, max_len=None)

Overview

Generates a boolean mask for a batch of sequences with differing lengths.

Arguments: - lengths (:obj:torch.Tensor): A tensor with the lengths of each sequence. Shape could be (n, 1) or (n). - max_len (:obj:int, optional): The padding size. If max_len is None, the padding size is the max length of sequences. Returns: - masks (:obj:torch.BoolTensor): A boolean mask tensor. The mask has the same device as lengths.

get_lstm(lstm_type, input_size, hidden_size, num_layers=1, norm_type='LN', dropout=0.0, seq_len=None, batch_size=None)

Overview

Build and return the corresponding LSTM cell based on the provided parameters.

Arguments: - lstm_type (:obj:str): Version of RNN cell. Supported options are ['normal', 'pytorch', 'hpc', 'gru']. - input_size (:obj:int): Size of the input vector. - hidden_size (:obj:int): Size of the hidden state vector. - num_layers (:obj:int): Number of LSTM layers (default is 1). - norm_type (:obj:str): Type of normalization (default is 'LN'). - dropout (:obj:float): Dropout rate (default is 0.0). - seq_len (:obj:Optional[int]): Sequence length (default is None). - batch_size (:obj:Optional[int]): Batch size (default is None). Returns: - lstm (:obj:Union[LSTM, PytorchLSTM]): The corresponding LSTM cell.

Full Source Code

../ding/torch_utils/network/rnn.py

1from typing import Optional, Union, List, Tuple, Dict 2import math 3import torch 4import torch.nn as nn 5import treetensor.torch as ttorch 6 7import ding 8from ding.torch_utils.network.normalization import build_normalization 9if ding.enable_hpc_rl: 10 from hpc_rll.torch_utils.network.rnn import LSTM as HPCLSTM 11else: 12 HPCLSTM = None 13 14 15def is_sequence(data): 16 """ 17 Overview: 18 Determines if the input data is of type list or tuple. 19 Arguments: 20 - data: The input data to be checked. 21 Returns: 22 - boolean: True if the input is a list or a tuple, False otherwise. 23 """ 24 return isinstance(data, list) or isinstance(data, tuple) 25 26 27def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.BoolTensor: 28 """ 29 Overview: 30 Generates a boolean mask for a batch of sequences with differing lengths. 31 Arguments: 32 - lengths (:obj:`torch.Tensor`): A tensor with the lengths of each sequence. Shape could be (n, 1) or (n). 33 - max_len (:obj:`int`, optional): The padding size. If max_len is None, the padding size is the max length of \ 34 sequences. 35 Returns: 36 - masks (:obj:`torch.BoolTensor`): A boolean mask tensor. The mask has the same device as lengths. 37 """ 38 if len(lengths.shape) == 1: 39 lengths = lengths.unsqueeze(dim=1) 40 bz = lengths.numel() 41 if max_len is None: 42 max_len = lengths.max() 43 else: 44 max_len = min(max_len, lengths.max()) 45 return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device) 46 47 48class LSTMForwardWrapper(object): 49 """ 50 Overview: 51 Class providing methods to use before and after the LSTM `forward` method. 52 Wraps the LSTM `forward` method. 53 Interfaces: 54 ``_before_forward``, ``_after_forward`` 55 """ 56 57 def _before_forward(self, inputs: torch.Tensor, prev_state: Union[None, List[Dict]]) -> torch.Tensor: 58 """ 59 Overview: 60 Preprocesses the inputs and previous states before the LSTM `forward` method. 61 Arguments: 62 - inputs (:obj:`torch.Tensor`): Input vector of the LSTM cell. Shape: [seq_len, batch_size, input_size] 63 - prev_state (:obj:`Union[None, List[Dict]]`): Previous state tensor. Shape: [num_directions*num_layers, \ 64 batch_size, hidden_size]. If None, prv_state will be initialized to all zeros. 65 Returns: 66 - prev_state (:obj:`torch.Tensor`): Preprocessed previous state for the LSTM batch. 67 """ 68 assert hasattr(self, 'num_layers') 69 assert hasattr(self, 'hidden_size') 70 seq_len, batch_size = inputs.shape[:2] 71 if prev_state is None: 72 num_directions = 1 73 zeros = torch.zeros( 74 num_directions * self.num_layers, 75 batch_size, 76 self.hidden_size, 77 dtype=inputs.dtype, 78 device=inputs.device 79 ) 80 prev_state = (zeros, zeros) 81 elif is_sequence(prev_state): 82 if len(prev_state) != batch_size: 83 raise RuntimeError( 84 "prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size) 85 ) 86 num_directions = 1 87 zeros = torch.zeros( 88 num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device 89 ) 90 state = [] 91 for prev in prev_state: 92 if prev is None: 93 state.append([zeros, zeros]) 94 else: 95 if isinstance(prev, (Dict, ttorch.Tensor)): 96 state.append([v for v in prev.values()]) 97 else: 98 state.append(prev) 99 state = list(zip(*state)) 100 prev_state = [torch.cat(t, dim=1) for t in state] 101 elif isinstance(prev_state, dict): 102 prev_state = list(prev_state.values()) 103 else: 104 raise TypeError("not support prev_state type: {}".format(type(prev_state))) 105 return prev_state 106 107 def _after_forward(self, 108 next_state: Tuple[torch.Tensor], 109 list_next_state: bool = False) -> Union[List[Dict], Dict[str, torch.Tensor]]: 110 """ 111 Overview: 112 Post-processes the next_state after the LSTM `forward` method. 113 Arguments: 114 - next_state (:obj:`Tuple[torch.Tensor]`): Tuple containing the next state (h, c). 115 - list_next_state (:obj:`bool`, optional): Determines the format of the returned next_state. \ 116 If True, returns next_state in list format. Default is False. 117 Returns: 118 - next_state(:obj:`Union[List[Dict], Dict[str, torch.Tensor]]`): The post-processed next_state. 119 """ 120 if list_next_state: 121 h, c = next_state 122 batch_size = h.shape[1] 123 next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] 124 next_state = list(zip(*next_state)) 125 next_state = [{k: v for k, v in zip(['h', 'c'], item)} for item in next_state] 126 else: 127 next_state = {k: v for k, v in zip(['h', 'c'], next_state)} 128 return next_state 129 130 131class LSTM(nn.Module, LSTMForwardWrapper): 132 """ 133 Overview: 134 Implementation of an LSTM cell with Layer Normalization (LN). 135 Interfaces: 136 ``__init__``, ``forward`` 137 138 .. note:: 139 140 For a primer on LSTM, refer to https://zhuanlan.zhihu.com/p/32085405. 141 """ 142 143 def __init__( 144 self, 145 input_size: int, 146 hidden_size: int, 147 num_layers: int, 148 norm_type: Optional[str] = None, 149 dropout: float = 0. 150 ) -> None: 151 """ 152 Overview: 153 Initialize LSTM cell parameters. 154 Arguments: 155 - input_size (:obj:`int`): Size of the input vector. 156 - hidden_size (:obj:`int`): Size of the hidden state vector. 157 - num_layers (:obj:`int`): Number of LSTM layers. 158 - norm_type (:obj:`Optional[str]`): Normalization type, default is None. 159 - dropout (:obj:`float`): Dropout rate, default is 0. 160 """ 161 super(LSTM, self).__init__() 162 self.input_size = input_size 163 self.hidden_size = hidden_size 164 self.num_layers = num_layers 165 166 norm_func = build_normalization(norm_type) 167 self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) 168 self.wx = nn.ParameterList() 169 self.wh = nn.ParameterList() 170 dims = [input_size] + [hidden_size] * num_layers 171 for l in range(num_layers): 172 self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) 173 self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) 174 self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) 175 self.use_dropout = dropout > 0. 176 if self.use_dropout: 177 self.dropout = nn.Dropout(dropout) 178 self._init() 179 180 def _init(self): 181 """ 182 Overview: 183 Initialize the parameters of the LSTM cell. 184 """ 185 186 gain = math.sqrt(1. / self.hidden_size) 187 for l in range(self.num_layers): 188 torch.nn.init.uniform_(self.wx[l], -gain, gain) 189 torch.nn.init.uniform_(self.wh[l], -gain, gain) 190 if self.bias is not None: 191 torch.nn.init.uniform_(self.bias[l], -gain, gain) 192 193 def forward(self, 194 inputs: torch.Tensor, 195 prev_state: torch.Tensor, 196 list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: 197 """ 198 Overview: 199 Compute output and next state given previous state and input. 200 Arguments: 201 - inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size]. 202 - prev_state (:obj:`torch.Tensor`): Previous state, \ 203 size [num_directions*num_layers, batch_size, hidden_size]. 204 - list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True. 205 Returns: 206 - x (:obj:`torch.Tensor`): Output from LSTM. 207 - next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM. 208 """ 209 seq_len, batch_size = inputs.shape[:2] 210 prev_state = self._before_forward(inputs, prev_state) 211 212 H, C = prev_state 213 x = inputs 214 next_state = [] 215 for l in range(self.num_layers): 216 h, c = H[l], C[l] 217 new_x = [] 218 for s in range(seq_len): 219 gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) 220 ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) 221 if self.bias is not None: 222 gate += self.bias[l] 223 gate = list(torch.chunk(gate, 4, dim=1)) 224 i, f, o, u = gate 225 i = torch.sigmoid(i) 226 f = torch.sigmoid(f) 227 o = torch.sigmoid(o) 228 u = torch.tanh(u) 229 c = f * c + i * u 230 h = o * torch.tanh(c) 231 new_x.append(h) 232 next_state.append((h, c)) 233 x = torch.stack(new_x, dim=0) 234 if self.use_dropout and l != self.num_layers - 1: 235 x = self.dropout(x) 236 next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] 237 238 next_state = self._after_forward(next_state, list_next_state) 239 return x, next_state 240 241 242class PytorchLSTM(nn.LSTM, LSTMForwardWrapper): 243 """ 244 Overview: 245 Wrapper class for PyTorch's nn.LSTM, formats the input and output. For more details on nn.LSTM, 246 refer to https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM 247 Interfaces: 248 ``forward`` 249 """ 250 251 def forward(self, 252 inputs: torch.Tensor, 253 prev_state: torch.Tensor, 254 list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: 255 """ 256 Overview: 257 Executes nn.LSTM.forward with preprocessed input. 258 Arguments: 259 - inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size]. 260 - prev_state (:obj:`torch.Tensor`): Previous state, size [num_directions*num_layers, batch_size, \ 261 hidden_size]. 262 - list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True. 263 Returns: 264 - output (:obj:`torch.Tensor`): Output from LSTM. 265 - next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM. 266 """ 267 prev_state = self._before_forward(inputs, prev_state) 268 output, next_state = nn.LSTM.forward(self, inputs, prev_state) 269 next_state = self._after_forward(next_state, list_next_state) 270 return output, next_state 271 272 273class GRU(nn.GRUCell, LSTMForwardWrapper): 274 """ 275 Overview: 276 This class extends the `torch.nn.GRUCell` and `LSTMForwardWrapper` classes, and formats inputs and outputs 277 accordingly. 278 Interfaces: 279 ``__init__``, ``forward`` 280 Properties: 281 hidden_size, num_layers 282 283 .. note:: 284 For further details, refer to the official PyTorch documentation: 285 <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU> 286 """ 287 288 def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None: 289 """ 290 Overview: 291 Initialize the GRU class with input size, hidden size, and number of layers. 292 Arguments: 293 - input_size (:obj:`int`): The size of the input vector. 294 - hidden_size (:obj:`int`): The size of the hidden state vector. 295 - num_layers (:obj:`int`): The number of GRU layers. 296 """ 297 super(GRU, self).__init__(input_size, hidden_size) 298 self.hidden_size = hidden_size 299 self.num_layers = num_layers 300 301 def forward(self, 302 inputs: torch.Tensor, 303 prev_state: Optional[torch.Tensor] = None, 304 list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, List]]: 305 """ 306 Overview: 307 Wrap the `nn.GRU.forward` method. 308 Arguments: 309 - inputs (:obj:`torch.Tensor`): Input vector of cell, tensor of size [seq_len, batch_size, input_size]. 310 - prev_state (:obj:`Optional[torch.Tensor]`): None or tensor of \ 311 size [num_directions*num_layers, batch_size, hidden_size]. 312 - list_next_state (:obj:`bool`): Whether to return next_state in list format (default is True). 313 Returns: 314 - output (:obj:`torch.Tensor`): Output from GRU. 315 - next_state (:obj:`torch.Tensor` or :obj:`list`): Hidden state from GRU. 316 """ 317 # for compatibility 318 prev_state, _ = self._before_forward(inputs, prev_state) 319 inputs, prev_state = inputs.squeeze(0), prev_state.squeeze(0) 320 next_state = nn.GRUCell.forward(self, inputs, prev_state) 321 next_state = next_state.unsqueeze(0) 322 x = next_state 323 # for compatibility 324 next_state = self._after_forward([next_state, next_state.clone()], list_next_state) 325 return x, next_state 326 327 328def get_lstm( 329 lstm_type: str, 330 input_size: int, 331 hidden_size: int, 332 num_layers: int = 1, 333 norm_type: str = 'LN', 334 dropout: float = 0., 335 seq_len: Optional[int] = None, 336 batch_size: Optional[int] = None 337) -> Union[LSTM, PytorchLSTM]: 338 """ 339 Overview: 340 Build and return the corresponding LSTM cell based on the provided parameters. 341 Arguments: 342 - lstm_type (:obj:`str`): Version of RNN cell. Supported options are ['normal', 'pytorch', 'hpc', 'gru']. 343 - input_size (:obj:`int`): Size of the input vector. 344 - hidden_size (:obj:`int`): Size of the hidden state vector. 345 - num_layers (:obj:`int`): Number of LSTM layers (default is 1). 346 - norm_type (:obj:`str`): Type of normalization (default is 'LN'). 347 - dropout (:obj:`float`): Dropout rate (default is 0.0). 348 - seq_len (:obj:`Optional[int]`): Sequence length (default is None). 349 - batch_size (:obj:`Optional[int]`): Batch size (default is None). 350 Returns: 351 - lstm (:obj:`Union[LSTM, PytorchLSTM]`): The corresponding LSTM cell. 352 """ 353 assert lstm_type in ['normal', 'pytorch', 'hpc', 'gru'] 354 if lstm_type == 'normal': 355 return LSTM(input_size, hidden_size, num_layers, norm_type, dropout=dropout) 356 elif lstm_type == 'pytorch': 357 return PytorchLSTM(input_size, hidden_size, num_layers, dropout=dropout) 358 elif lstm_type == 'hpc': 359 return HPCLSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout).cuda() 360 elif lstm_type == 'gru': 361 assert num_layers == 1 362 return GRU(input_size, hidden_size, num_layers)