Skip to content

ding.torch_utils.network.merge

ding.torch_utils.network.merge

This file provides an implementation of several different neural network modules that are used for merging and transforming input data in various ways. The following components can be used when we are dealing with data from multiple modes, or when we need to merge multiple intermediate embedded representations in the forward process of a model.

The main classes defined in this code are:

- BilinearGeneral: This class implements a bilinear transformation layer that applies a bilinear transformation to
    incoming data, as described in the "Multiplicative Interactions and Where to Find Them", published at ICLR 2020,
    https://openreview.net/forum?id=rylnK6VtDH. The transformation involves two input features and an output
    feature, and also includes an optional bias term.

- TorchBilinearCustomized: This class implements a bilinear layer similar to the one provided by PyTorch
    (torch.nn.Bilinear), but with additional customizations. This class can be used as an alternative to the
    BilinearGeneral class.

- TorchBilinear: This class is a simple wrapper around the PyTorch's built-in nn.Bilinear module. It provides the
    same functionality as PyTorch's nn.Bilinear but within the structure of the current module.

- FiLM: This class implements a Feature-wise Linear Modulation (FiLM) layer. FiLM layers apply an affine
    transformation to the input data, conditioned on some additional context information.

- GatingType: This is an enumeration class that defines different types of gating mechanisms that can be used in
    the modules.

- SumMerge: This class provides a simple summing mechanism to merge input streams.

- VectorMerge: This class implements a more complex merging mechanism for vector streams.
    The streams are first transformed using layer normalization, a ReLU activation, and a linear layer.
    Then they are merged either by simple summing or by using a gating mechanism.

The implementation of these classes involves PyTorch and Numpy libraries, and the classes use PyTorch's nn.Module as the base class, making them compatible with PyTorch's neural network modules and functionalities. These modules can be useful building blocks in more complex deep learning architectures.

BilinearGeneral

Bases: Module

Overview

Bilinear implementation as in: Multiplicative Interactions and Where to Find Them, ICLR 2020, https://openreview.net/forum?id=rylnK6VtDH.

Interfaces: __init__, forward

__init__(in1_features, in2_features, out_features)

Overview

Initialize the Bilinear layer.

Arguments: - in1_features (:obj:int): The size of each first input sample. - in2_features (:obj:int): The size of each second input sample. - out_features (:obj:int): The size of each output sample.

reset_parameters()

Overview

Initialize the parameters of the Bilinear layer.

forward(x, z)

Overview

compute the bilinear function.

Arguments: - x (:obj:torch.Tensor): The first input tensor. - z (:obj:torch.Tensor): The second input tensor.

TorchBilinearCustomized

Bases: Module

Overview

Customized Torch Bilinear implementation.

Interfaces: __init__, forward

__init__(in1_features, in2_features, out_features)

Overview

Initialize the Bilinear layer.

Arguments: - in1_features (:obj:int): The size of each first input sample. - in2_features (:obj:int): The size of each second input sample. - out_features (:obj:int): The size of each output sample.

reset_parameters()

Overview

Initialize the parameters of the Bilinear layer.

forward(x, z)

Overview

Compute the bilinear function.

Arguments: - x (:obj:torch.Tensor): The first input tensor. - z (:obj:torch.Tensor): The second input tensor.

FiLM

Bases: Module

Overview

Feature-wise Linear Modulation (FiLM) Layer. This layer applies feature-wise affine transformation based on context.

Interfaces: __init__, forward

__init__(feature_dim, context_dim)

Overview

Initialize the FiLM layer.

Arguments: - feature_dim (:obj:int). The dimension of the input feature vector. - context_dim (:obj:int). The dimension of the input context vector.

forward(feature, context)

Overview

Forward propagation.

Arguments: - feature (:obj:torch.Tensor). The input feature, shape (batch_size, feature_dim). - context (:obj:torch.Tensor). The input context, shape (batch_size, context_dim). Returns: - conditioned_feature : torch.Tensor. The output feature after FiLM, shape (batch_size, feature_dim).

GatingType

Bases: Enum

Overview

Enum class defining different types of tensor gating and aggregation in modules.

SumMerge

Bases: Module

Overview

A PyTorch module that merges a list of tensors by computing their sum. All input tensors must have the same size. This module can work with any type of tensor (vector, units or visual).

Interfaces: __init__, forward

forward(tensors)

Overview

Forward pass of the SumMerge module, which sums the input tensors.

Arguments: - tensors (:obj:List[Tensor]): List of input tensors to be summed. All tensors must have the same size. Returns: - summed (:obj:Tensor): Tensor resulting from the sum of all input tensors.

VectorMerge

Bases: Module

Overview

Merges multiple vector streams. Streams are first transformed through layer normalization, relu, and linear layers, then summed. They don't need to have the same size. Gating can also be used before the sum.

Interfaces: __init__, encode, _compute_gate, forward

.. note:: For more details about the gating types, please refer to the GatingType enum class.

__init__(input_sizes, output_size, gating_type=GatingType.NONE, use_layer_norm=True)

Overview

Initialize the VectorMerge module.

Arguments: - input_sizes (:obj:Dict[str, int]): A dictionary mapping input names to their sizes. The size is a single integer for 1D inputs, or None for 0D inputs. If an input size is None, we assume it's (). - output_size (:obj:int): The size of the output vector. - gating_type (:obj:GatingType): The type of gating mechanism to use. Default is GatingType.NONE. - use_layer_norm (:obj:bool): Whether to use layer normalization. Default is True.

encode(inputs)

Overview

Encode the input tensors using layer normalization, relu, and linear transformations.

Arguments: - inputs (:obj:Dict[str, Tensor]): The input tensors. Returns: - gates (:obj:List[Tensor]): The gate tensors after transformations. - outputs (:obj:List[Tensor]): The output tensors after transformations.

forward(inputs)

Overview

Forward pass through the VectorMerge module.

Arguments: - inputs (:obj:Dict[str, Tensor]): The input tensors. Returns: - output (:obj:Tensor): The output tensor after passing through the module.

Full Source Code

../ding/torch_utils/network/merge.py

1""" 2This file provides an implementation of several different neural network modules that are used for merging and 3transforming input data in various ways. The following components can be used when we are dealing with 4data from multiple modes, or when we need to merge multiple intermediate embedded representations in 5the forward process of a model. 6 7The main classes defined in this code are: 8 9 - BilinearGeneral: This class implements a bilinear transformation layer that applies a bilinear transformation to 10 incoming data, as described in the "Multiplicative Interactions and Where to Find Them", published at ICLR 2020, 11 https://openreview.net/forum?id=rylnK6VtDH. The transformation involves two input features and an output 12 feature, and also includes an optional bias term. 13 14 - TorchBilinearCustomized: This class implements a bilinear layer similar to the one provided by PyTorch 15 (torch.nn.Bilinear), but with additional customizations. This class can be used as an alternative to the 16 BilinearGeneral class. 17 18 - TorchBilinear: This class is a simple wrapper around the PyTorch's built-in nn.Bilinear module. It provides the 19 same functionality as PyTorch's nn.Bilinear but within the structure of the current module. 20 21 - FiLM: This class implements a Feature-wise Linear Modulation (FiLM) layer. FiLM layers apply an affine 22 transformation to the input data, conditioned on some additional context information. 23 24 - GatingType: This is an enumeration class that defines different types of gating mechanisms that can be used in 25 the modules. 26 27 - SumMerge: This class provides a simple summing mechanism to merge input streams. 28 29 - VectorMerge: This class implements a more complex merging mechanism for vector streams. 30 The streams are first transformed using layer normalization, a ReLU activation, and a linear layer. 31 Then they are merged either by simple summing or by using a gating mechanism. 32 33The implementation of these classes involves PyTorch and Numpy libraries, and the classes use PyTorch's nn.Module as 34the base class, making them compatible with PyTorch's neural network modules and functionalities. 35These modules can be useful building blocks in more complex deep learning architectures. 36""" 37 38import enum 39import math 40from collections import OrderedDict 41from typing import List, Dict, Tuple 42 43import numpy as np 44import torch 45import torch.nn as nn 46import torch.nn.functional as F 47from torch import Tensor 48 49 50class BilinearGeneral(nn.Module): 51 """ 52 Overview: 53 Bilinear implementation as in: Multiplicative Interactions and Where to Find Them, 54 ICLR 2020, https://openreview.net/forum?id=rylnK6VtDH. 55 Interfaces: 56 ``__init__``, ``forward`` 57 """ 58 59 def __init__(self, in1_features: int, in2_features: int, out_features: int): 60 """ 61 Overview: 62 Initialize the Bilinear layer. 63 Arguments: 64 - in1_features (:obj:`int`): The size of each first input sample. 65 - in2_features (:obj:`int`): The size of each second input sample. 66 - out_features (:obj:`int`): The size of each output sample. 67 """ 68 69 super(BilinearGeneral, self).__init__() 70 # Initialize the weight matrices W and U, and the bias vectors V and b 71 self.W = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features)) 72 self.U = nn.Parameter(torch.Tensor(out_features, in2_features)) 73 self.V = nn.Parameter(torch.Tensor(out_features, in1_features)) 74 self.b = nn.Parameter(torch.Tensor(out_features)) 75 self.in1_features = in1_features 76 self.in2_features = in2_features 77 self.out_features = out_features 78 self.reset_parameters() 79 80 def reset_parameters(self): 81 """ 82 Overview: 83 Initialize the parameters of the Bilinear layer. 84 """ 85 86 stdv = 1. / np.sqrt(self.in1_features) 87 self.W.data.uniform_(-stdv, stdv) 88 self.U.data.uniform_(-stdv, stdv) 89 self.V.data.uniform_(-stdv, stdv) 90 self.b.data.uniform_(-stdv, stdv) 91 92 def forward(self, x: torch.Tensor, z: torch.Tensor): 93 """ 94 Overview: 95 compute the bilinear function. 96 Arguments: 97 - x (:obj:`torch.Tensor`): The first input tensor. 98 - z (:obj:`torch.Tensor`): The second input tensor. 99 """ 100 101 # Compute the bilinear function 102 # x^TWz 103 out_W = torch.einsum('bi,kij,bj->bk', x, self.W, z) 104 # x^TU 105 out_U = z.matmul(self.U.t()) 106 # Vz 107 out_V = x.matmul(self.V.t()) 108 # x^TWz + x^TU + Vz + b 109 out = out_W + out_U + out_V + self.b 110 return out 111 112 113class TorchBilinearCustomized(nn.Module): 114 """ 115 Overview: 116 Customized Torch Bilinear implementation. 117 Interfaces: 118 ``__init__``, ``forward`` 119 """ 120 121 def __init__(self, in1_features: int, in2_features: int, out_features: int): 122 """ 123 Overview: 124 Initialize the Bilinear layer. 125 Arguments: 126 - in1_features (:obj:`int`): The size of each first input sample. 127 - in2_features (:obj:`int`): The size of each second input sample. 128 - out_features (:obj:`int`): The size of each output sample. 129 """ 130 131 super(TorchBilinearCustomized, self).__init__() 132 self.in1_features = in1_features 133 self.in2_features = in2_features 134 self.out_features = out_features 135 self.weight = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features)) 136 self.bias = nn.Parameter(torch.Tensor(out_features)) 137 self.reset_parameters() 138 139 def reset_parameters(self): 140 """ 141 Overview: 142 Initialize the parameters of the Bilinear layer. 143 """ 144 145 bound = 1 / math.sqrt(self.in1_features) 146 nn.init.uniform_(self.weight, -bound, bound) 147 nn.init.uniform_(self.bias, -bound, bound) 148 149 def forward(self, x, z): 150 """ 151 Overview: 152 Compute the bilinear function. 153 Arguments: 154 - x (:obj:`torch.Tensor`): The first input tensor. 155 - z (:obj:`torch.Tensor`): The second input tensor. 156 """ 157 158 # Using torch.einsum for the bilinear operation 159 out = torch.einsum('bi,oij,bj->bo', x, self.weight, z) + self.bias 160 return out.squeeze(-1) 161 162 163""" 164Overview: 165 Implementation of the Bilinear layer as in PyTorch: 166 https://pytorch.org/docs/stable/generated/torch.nn.Bilinear.html#torch.nn.Bilinear 167Arguments: 168 - in1_features (:obj:`int`): The size of each first input sample. 169 - in2_features (:obj:`int`): The size of each second input sample. 170 - out_features (:obj:`int`): The size of each output sample. 171 - bias (:obj:`bool`): If set to False, the layer will not learn an additive bias. Default: ``True``. 172""" 173TorchBilinear = nn.Bilinear 174 175 176class FiLM(nn.Module): 177 """ 178 Overview: 179 Feature-wise Linear Modulation (FiLM) Layer. 180 This layer applies feature-wise affine transformation based on context. 181 Interfaces: 182 ``__init__``, ``forward`` 183 """ 184 185 def __init__(self, feature_dim: int, context_dim: int): 186 """ 187 Overview: 188 Initialize the FiLM layer. 189 Arguments: 190 - feature_dim (:obj:`int`). The dimension of the input feature vector. 191 - context_dim (:obj:`int`). The dimension of the input context vector. 192 """ 193 194 super(FiLM, self).__init__() 195 # Define the fully connected layer for context 196 # The output dimension is twice the feature dimension for gamma and beta 197 self.context_layer = nn.Linear(context_dim, 2 * feature_dim) 198 199 def forward(self, feature: torch.Tensor, context: torch.Tensor): 200 """ 201 Overview: 202 Forward propagation. 203 Arguments: 204 - feature (:obj:`torch.Tensor`). The input feature, shape (batch_size, feature_dim). 205 - context (:obj:`torch.Tensor`). The input context, shape (batch_size, context_dim). 206 Returns: 207 - conditioned_feature : torch.Tensor. The output feature after FiLM, shape (batch_size, feature_dim). 208 """ 209 210 # Pass context through the fully connected layer 211 out = self.context_layer(context) 212 # Split the output into two parts: gamma and beta 213 # The dimension for splitting is 1 (feature dimension) 214 gamma, beta = torch.split(out, out.shape[1] // 2, dim=1) 215 # Apply feature-wise affine transformation 216 conditioned_feature = gamma * feature + beta 217 return conditioned_feature 218 219 220class GatingType(enum.Enum): 221 """ 222 Overview: 223 Enum class defining different types of tensor gating and aggregation in modules. 224 """ 225 NONE = 'none' 226 GLOBAL = 'global' 227 POINTWISE = 'pointwise' 228 229 230class SumMerge(nn.Module): 231 """ 232 Overview: 233 A PyTorch module that merges a list of tensors by computing their sum. All input tensors must have the same 234 size. This module can work with any type of tensor (vector, units or visual). 235 Interfaces: 236 ``__init__``, ``forward`` 237 """ 238 239 def forward(self, tensors: List[Tensor]) -> Tensor: 240 """ 241 Overview: 242 Forward pass of the SumMerge module, which sums the input tensors. 243 Arguments: 244 - tensors (:obj:`List[Tensor]`): List of input tensors to be summed. All tensors must have the same size. 245 Returns: 246 - summed (:obj:`Tensor`): Tensor resulting from the sum of all input tensors. 247 """ 248 # stack the tensors along the first dimension 249 stacked = torch.stack(tensors, dim=0) 250 251 # compute the sum along the first dimension 252 summed = torch.sum(stacked, dim=0) 253 # summed = sum(tensors) 254 return summed 255 256 257class VectorMerge(nn.Module): 258 """ 259 Overview: 260 Merges multiple vector streams. Streams are first transformed through layer normalization, relu, and linear 261 layers, then summed. They don't need to have the same size. Gating can also be used before the sum. 262 Interfaces: 263 ``__init__``, ``encode``, ``_compute_gate``, ``forward`` 264 265 .. note:: 266 For more details about the gating types, please refer to the GatingType enum class. 267 """ 268 269 def __init__( 270 self, 271 input_sizes: Dict[str, int], 272 output_size: int, 273 gating_type: GatingType = GatingType.NONE, 274 use_layer_norm: bool = True, 275 ): 276 """ 277 Overview: 278 Initialize the `VectorMerge` module. 279 Arguments: 280 - input_sizes (:obj:`Dict[str, int]`): A dictionary mapping input names to their sizes. \ 281 The size is a single integer for 1D inputs, or `None` for 0D inputs. \ 282 If an input size is `None`, we assume it's `()`. 283 - output_size (:obj:`int`): The size of the output vector. 284 - gating_type (:obj:`GatingType`): The type of gating mechanism to use. Default is `GatingType.NONE`. 285 - use_layer_norm (:obj:`bool`): Whether to use layer normalization. Default is `True`. 286 """ 287 super().__init__() 288 self._input_sizes = OrderedDict(input_sizes) 289 self._output_size = output_size 290 self._gating_type = gating_type 291 self._use_layer_norm = use_layer_norm 292 293 if self._use_layer_norm: 294 self._layer_norms = nn.ModuleDict() 295 else: 296 self._layer_norms = None 297 298 self._linears = nn.ModuleDict() 299 for name, size in self._input_sizes.items(): 300 linear_input_size = size if size > 0 else 1 301 if self._use_layer_norm: 302 self._layer_norms[name] = nn.LayerNorm(linear_input_size) 303 self._linears[name] = nn.Linear(linear_input_size, self._output_size) 304 305 self._gating_linears = nn.ModuleDict() 306 if self._gating_type is GatingType.GLOBAL: 307 self.gate_size = 1 308 elif self._gating_type is GatingType.POINTWISE: 309 self.gate_size = self._output_size 310 elif self._gating_type is GatingType.NONE: 311 self._gating_linears = None 312 else: 313 raise ValueError(f'Gating type {self._gating_type} is not supported') 314 315 if self._gating_linears is not None: 316 if len(self._input_sizes) == 2: 317 # more efficient than the general version below 318 for name, size in self._input_sizes.items(): 319 gate_input_size = size if size > 0 else 1 320 gating_layer = nn.Linear(gate_input_size, self.gate_size) 321 torch.nn.init.normal_(gating_layer.weight, std=0.005) 322 torch.nn.init.constant_(gating_layer.bias, 0.0) 323 self._gating_linears[name] = gating_layer 324 else: 325 for name, size in self._input_sizes.items(): 326 gate_input_size = size if size > 0 else 1 327 gating_layer = nn.Linear(gate_input_size, len(self._input_sizes) * self.gate_size) 328 torch.nn.init.normal_(gating_layer.weight, std=0.005) 329 torch.nn.init.constant_(gating_layer.bias, 0.0) 330 self._gating_linears[name] = gating_layer 331 332 def encode(self, inputs: Dict[str, Tensor]) -> Tuple[List[Tensor], List[Tensor]]: 333 """ 334 Overview: 335 Encode the input tensors using layer normalization, relu, and linear transformations. 336 Arguments: 337 - inputs (:obj:`Dict[str, Tensor]`): The input tensors. 338 Returns: 339 - gates (:obj:`List[Tensor]`): The gate tensors after transformations. 340 - outputs (:obj:`List[Tensor]`): The output tensors after transformations. 341 """ 342 gates, outputs = [], [] 343 for name, size in self._input_sizes.items(): 344 feature = inputs[name] 345 if size <= 0 and feature.dim() == 1: 346 feature = feature.unsqueeze(-1) 347 feature = feature.to(torch.float32) 348 if self._use_layer_norm and name in self._layer_norms: 349 feature = self._layer_norms[name](feature) 350 feature = F.relu(feature) 351 gates.append(feature) 352 outputs.append(self._linears[name](feature)) 353 return gates, outputs 354 355 def _compute_gate( 356 self, 357 init_gate: List[Tensor], 358 ) -> List[Tensor]: 359 """ 360 Overview: 361 Compute the gate values based on the initial gate values. 362 Arguments: 363 - init_gate (:obj:`List[Tensor]`): The initial gate values. 364 Returns: 365 - gate (:obj:`List[Tensor]`): The computed gate values. 366 """ 367 if len(self._input_sizes) == 2: 368 gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)] 369 gate = sum(gate) 370 sigmoid = torch.sigmoid(gate) 371 gate = [sigmoid, 1.0 - sigmoid] 372 else: 373 gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)] 374 gate = sum(gate) 375 gate = gate.reshape([-1, len(self._input_sizes), self.gate_size]) 376 gate = F.softmax(gate, dim=1) 377 assert gate.shape[1] == len(self._input_sizes) 378 gate = [gate[:, i] for i in range(len(self._input_sizes))] 379 return gate 380 381 def forward(self, inputs: Dict[str, Tensor]) -> Tensor: 382 """ 383 Overview: 384 Forward pass through the VectorMerge module. 385 Arguments: 386 - inputs (:obj:`Dict[str, Tensor]`): The input tensors. 387 Returns: 388 - output (:obj:`Tensor`): The output tensor after passing through the module. 389 """ 390 gates, outputs = self.encode(inputs) 391 if len(outputs) == 1: 392 # Special case of 1-D inputs that do not need any gating. 393 output = outputs[0] 394 elif self._gating_type is GatingType.NONE: 395 output = sum(outputs) 396 else: 397 gate = self._compute_gate(gates) 398 data = [g * d for g, d in zip(gate, outputs)] 399 output = sum(data) 400 return output