Skip to content

ding.torch_utils.network.activation

ding.torch_utils.network.activation

Lambda

Bases: Module

Overview

A custom lambda module for constructing custom layers.

Interfaces: __init__, forward.

__init__(f)

Overview

Initialize the lambda module with a given function.

Arguments: - f (:obj:Callable): a python function

forward(x)

Overview

Compute the function of the input tensor.

Arguments: - x (:obj:torch.Tensor): The input tensor.

GLU

Bases: Module

Overview

Gating Linear Unit (GLU), a specific type of activation function, which is first proposed in Language Modeling with Gated Convolutional Networks.

Interfaces: __init__, forward.

__init__(input_dim, output_dim, context_dim, input_type='fc')

Overview

Initialize the GLU module.

Arguments: - input_dim (:obj:int): The dimension of the input tensor. - output_dim (:obj:int): The dimension of the output tensor. - context_dim (:obj:int): The dimension of the context tensor. - input_type (:obj:str): The type of input, now supports ['fc', 'conv2d']

forward(x, context)

Overview

Compute the GLU transformation of the input tensor.

Arguments: - x (:obj:torch.Tensor): The input tensor. - context (:obj:torch.Tensor): The context tensor. Returns: - x (:obj:torch.Tensor): The output tensor after GLU transformation.

Swish

Bases: Module

Overview

Swish activation function, which is a smooth, non-monotonic activation function. For more details, please refer to Searching for Activation Functions.

Interfaces: __init__, forward.

__init__()

Overview

Initialize the Swish module.

forward(x)

Overview

Compute the Swish transformation of the input tensor.

Arguments: - x (:obj:torch.Tensor): The input tensor. Returns: - x (:obj:torch.Tensor): The output tensor after Swish transformation.

GELU

Bases: Module

Overview

Gaussian Error Linear Units (GELU) activation function, which is widely used in NLP models like GPT, BERT. For more details, please refer to the original paper: https://arxiv.org/pdf/1606.08415.pdf.

Interfaces: __init__, forward.

__init__()

Overview

Initialize the GELU module.

forward(x)

Overview

Compute the GELU transformation of the input tensor.

Arguments: - x (:obj:torch.Tensor): The input tensor. Returns: - x (:obj:torch.Tensor): The output tensor after GELU transformation.

build_activation(activation, inplace=None)

Overview

Build and return the activation module according to the given type.

Arguments: - activation (:obj:str): The type of activation module, now supports ['relu', 'glu', 'prelu', 'swish', 'gelu', 'tanh', 'sigmoid', 'softplus', 'elu', 'square', 'identity']. - inplace (Optional[:obj:bool): Execute the operation in-place in activation, defaults to None. Returns: - act_func (:obj:nn.module): The corresponding activation module.

Full Source Code

../ding/torch_utils/network/activation.py

1import math 2from collections.abc import Callable 3 4import torch 5import torch.nn as nn 6 7 8class Lambda(nn.Module): 9 """ 10 Overview: 11 A custom lambda module for constructing custom layers. 12 Interfaces: 13 ``__init__``, ``forward``. 14 """ 15 16 def __init__(self, f: Callable): 17 """ 18 Overview: 19 Initialize the lambda module with a given function. 20 Arguments: 21 - f (:obj:`Callable`): a python function 22 """ 23 super(Lambda, self).__init__() 24 self.f = f 25 26 def forward(self, x: torch.Tensor) -> torch.Tensor: 27 """ 28 Overview: 29 Compute the function of the input tensor. 30 Arguments: 31 - x (:obj:`torch.Tensor`): The input tensor. 32 """ 33 return self.f(x) 34 35 36class GLU(nn.Module): 37 """ 38 Overview: 39 Gating Linear Unit (GLU), a specific type of activation function, which is first proposed in 40 [Language Modeling with Gated Convolutional Networks](https://arxiv.org/pdf/1612.08083.pdf). 41 Interfaces: 42 ``__init__``, ``forward``. 43 """ 44 45 def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None: 46 """ 47 Overview: 48 Initialize the GLU module. 49 Arguments: 50 - input_dim (:obj:`int`): The dimension of the input tensor. 51 - output_dim (:obj:`int`): The dimension of the output tensor. 52 - context_dim (:obj:`int`): The dimension of the context tensor. 53 - input_type (:obj:`str`): The type of input, now supports ['fc', 'conv2d'] 54 """ 55 super(GLU, self).__init__() 56 assert (input_type in ['fc', 'conv2d']) 57 if input_type == 'fc': 58 self.layer1 = nn.Linear(context_dim, input_dim) 59 self.layer2 = nn.Linear(input_dim, output_dim) 60 elif input_type == 'conv2d': 61 self.layer1 = nn.Conv2d(context_dim, input_dim, 1, 1, 0) 62 self.layer2 = nn.Conv2d(input_dim, output_dim, 1, 1, 0) 63 64 def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: 65 """ 66 Overview: 67 Compute the GLU transformation of the input tensor. 68 Arguments: 69 - x (:obj:`torch.Tensor`): The input tensor. 70 - context (:obj:`torch.Tensor`): The context tensor. 71 Returns: 72 - x (:obj:`torch.Tensor`): The output tensor after GLU transformation. 73 """ 74 gate = self.layer1(context) 75 gate = torch.sigmoid(gate) 76 x = gate * x 77 x = self.layer2(x) 78 return x 79 80 81class Swish(nn.Module): 82 """ 83 Overview: 84 Swish activation function, which is a smooth, non-monotonic activation function. For more details, please refer 85 to [Searching for Activation Functions](https://arxiv.org/pdf/1710.05941.pdf). 86 Interfaces: 87 ``__init__``, ``forward``. 88 """ 89 90 def __init__(self): 91 """ 92 Overview: 93 Initialize the Swish module. 94 """ 95 super(Swish, self).__init__() 96 97 def forward(self, x: torch.Tensor) -> torch.Tensor: 98 """ 99 Overview: 100 Compute the Swish transformation of the input tensor. 101 Arguments: 102 - x (:obj:`torch.Tensor`): The input tensor. 103 Returns: 104 - x (:obj:`torch.Tensor`): The output tensor after Swish transformation. 105 """ 106 return x * torch.sigmoid(x) 107 108 109class GELU(nn.Module): 110 """ 111 Overview: 112 Gaussian Error Linear Units (GELU) activation function, which is widely used in NLP models like GPT, BERT. 113 For more details, please refer to the original paper: https://arxiv.org/pdf/1606.08415.pdf. 114 Interfaces: 115 ``__init__``, ``forward``. 116 """ 117 118 def __init__(self): 119 """ 120 Overview: 121 Initialize the GELU module. 122 """ 123 super(GELU, self).__init__() 124 125 def forward(self, x: torch.Tensor) -> torch.Tensor: 126 """ 127 Overview: 128 Compute the GELU transformation of the input tensor. 129 Arguments: 130 - x (:obj:`torch.Tensor`): The input tensor. 131 Returns: 132 - x (:obj:`torch.Tensor`): The output tensor after GELU transformation. 133 """ 134 return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 135 136 137def build_activation(activation: str, inplace: bool = None) -> nn.Module: 138 """ 139 Overview: 140 Build and return the activation module according to the given type. 141 Arguments: 142 - activation (:obj:`str`): The type of activation module, now supports \ 143 ['relu', 'glu', 'prelu', 'swish', 'gelu', 'tanh', 'sigmoid', 'softplus', 'elu', 'square', 'identity']. 144 - inplace (Optional[:obj:`bool`): Execute the operation in-place in activation, defaults to None. 145 Returns: 146 - act_func (:obj:`nn.module`): The corresponding activation module. 147 """ 148 if inplace is not None: 149 assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation) 150 else: 151 inplace = False 152 act_func = { 153 'relu': nn.ReLU(inplace=inplace), 154 'glu': GLU, 155 'prelu': nn.PReLU(), 156 'swish': Swish(), 157 'gelu': GELU(), 158 "tanh": nn.Tanh(), 159 "sigmoid": nn.Sigmoid(), 160 "softplus": nn.Softplus(), 161 "elu": nn.ELU(), 162 "silu": torch.nn.SiLU(inplace=inplace), 163 "square": Lambda(lambda x: x ** 2), 164 "identity": Lambda(lambda x: x), 165 } 166 if activation.lower() in act_func.keys(): 167 return act_func[activation] 168 else: 169 raise KeyError("invalid key for activation: {}".format(activation))