Skip to content

ding.torch_utils.network.res_block

ding.torch_utils.network.res_block

ResBlock

Bases: Module

Overview

Residual Block with 2D convolution layers, including 3 types: basic block: input channel: C x -> 33C -> norm -> act -> 33C -> norm -> act -> out ________/+ bottleneck block: x -> 11(1/4C) -> norm -> act -> 33(1/4C) -> norm -> act -> 11C -> norm -> act -> out ____________/+ downsample block: used in EfficientZero input channel: C x -> 33C -> norm -> act -> 33C -> norm -> act -> out ___ 33C ________/+

.. note:: You can refer to Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>_ for more details.

Interfaces

__init__, forward

__init__(in_channels, activation=nn.ReLU(), norm_type='BN', res_type='basic', bias=True, out_channels=None)

Overview

Init the 2D convolution residual block.

Arguments: - in_channels (:obj:int): Number of channels in the input tensor. - activation (:obj:nn.Module): The optional activation function. - norm_type (:obj:str): Type of the normalization, default set to 'BN'(Batch Normalization), supports ['BN', 'LN', 'IN', 'GN', 'SyncBN', None]. - res_type (:obj:str): Type of residual block, supports ['basic', 'bottleneck', 'downsample'] - bias (:obj:bool): Whether to add a learnable bias to the conv2d_block. default set to True. - out_channels (:obj:int): Number of channels in the output tensor, default set to None, which means out_channels = in_channels.

forward(x)

Overview

Return the redisual block output.

Arguments: - x (:obj:torch.Tensor): The input tensor. Returns: - x (:obj:torch.Tensor): The resblock output tensor.

ResFCBlock

Bases: Module

Overview

Residual Block with 2 fully connected layers. x -> fc1 -> norm -> act -> fc2 -> norm -> act -> out _______/+

Interfaces

__init__, forward

__init__(in_channels, activation=nn.ReLU(), norm_type='BN', dropout=None)

Overview

Init the fully connected layer residual block.

Arguments: - in_channels (:obj:int): The number of channels in the input tensor. - activation (:obj:nn.Module): The optional activation function. - norm_type (:obj:str): The type of the normalization, default set to 'BN'. - dropout (:obj:float): The dropout rate, default set to None.

forward(x)

Overview

Return the output of the redisual block.

Arguments: - x (:obj:torch.Tensor): The input tensor. Returns: - x (:obj:torch.Tensor): The resblock output tensor.

TemporalSpatialResBlock

Bases: Module

Overview

Residual Block using MLP layers for both temporal and spatial input. t → time_mlp → h1 → dense2 → h2 → out ↗+ ↗+ x → dense1 → ↗ ↘ ↗ → modify_x → → → →

__init__(input_dim, output_dim, t_dim=128, activation=torch.nn.SiLU())

Overview

Init the temporal spatial residual block.

Arguments: - input_dim (:obj:int): The number of channels in the input tensor. - output_dim (:obj:int): The number of channels in the output tensor. - t_dim (:obj:int): The dimension of the temporal input. - activation (:obj:nn.Module): The optional activation function.

forward(x, t)

Overview

Return the redisual block output.

Arguments: - x (:obj:torch.Tensor): The input tensor. - t (:obj:torch.Tensor): The temporal input tensor.

Full Source Code

../ding/torch_utils/network/res_block.py

1from typing import Union 2 3import torch 4import torch.nn as nn 5 6from .nn_module import conv2d_block, fc_block 7 8 9class ResBlock(nn.Module): 10 """ 11 Overview: 12 Residual Block with 2D convolution layers, including 3 types: 13 basic block: 14 input channel: C 15 x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out 16 \__________________________________________/+ 17 bottleneck block: 18 x -> 1*1*(1/4*C) -> norm -> act -> 3*3*(1/4*C) -> norm -> act -> 1*1*C -> norm -> act -> out 19 \_____________________________________________________________________________/+ 20 downsample block: used in EfficientZero 21 input channel: C 22 x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out 23 \__________________ 3*3*C ____________________/+ 24 25 .. note:: 26 You can refer to `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_ for more \ 27 details. 28 29 Interfaces: 30 ``__init__``, ``forward`` 31 """ 32 33 def __init__( 34 self, 35 in_channels: int, 36 activation: nn.Module = nn.ReLU(), 37 norm_type: str = 'BN', 38 res_type: str = 'basic', 39 bias: bool = True, 40 out_channels: Union[int, None] = None, 41 ) -> None: 42 """ 43 Overview: 44 Init the 2D convolution residual block. 45 Arguments: 46 - in_channels (:obj:`int`): Number of channels in the input tensor. 47 - activation (:obj:`nn.Module`): The optional activation function. 48 - norm_type (:obj:`str`): Type of the normalization, default set to 'BN'(Batch Normalization), \ 49 supports ['BN', 'LN', 'IN', 'GN', 'SyncBN', None]. 50 - res_type (:obj:`str`): Type of residual block, supports ['basic', 'bottleneck', 'downsample'] 51 - bias (:obj:`bool`): Whether to add a learnable bias to the conv2d_block. default set to True. 52 - out_channels (:obj:`int`): Number of channels in the output tensor, default set to None, \ 53 which means out_channels = in_channels. 54 """ 55 super(ResBlock, self).__init__() 56 self.act = activation 57 assert res_type in ['basic', 'bottleneck', 58 'downsample'], 'residual type only support basic and bottleneck, not:{}'.format(res_type) 59 self.res_type = res_type 60 if out_channels is None: 61 out_channels = in_channels 62 if self.res_type == 'basic': 63 self.conv1 = conv2d_block( 64 in_channels, out_channels, 3, 1, 1, activation=self.act, norm_type=norm_type, bias=bias 65 ) 66 self.conv2 = conv2d_block( 67 out_channels, out_channels, 3, 1, 1, activation=None, norm_type=norm_type, bias=bias 68 ) 69 elif self.res_type == 'bottleneck': 70 self.conv1 = conv2d_block( 71 in_channels, out_channels, 1, 1, 0, activation=self.act, norm_type=norm_type, bias=bias 72 ) 73 self.conv2 = conv2d_block( 74 out_channels, out_channels, 3, 1, 1, activation=self.act, norm_type=norm_type, bias=bias 75 ) 76 self.conv3 = conv2d_block( 77 out_channels, out_channels, 1, 1, 0, activation=None, norm_type=norm_type, bias=bias 78 ) 79 elif self.res_type == 'downsample': 80 self.conv1 = conv2d_block( 81 in_channels, out_channels, 3, 2, 1, activation=self.act, norm_type=norm_type, bias=bias 82 ) 83 self.conv2 = conv2d_block( 84 out_channels, out_channels, 3, 1, 1, activation=None, norm_type=norm_type, bias=bias 85 ) 86 self.conv3 = conv2d_block(in_channels, out_channels, 3, 2, 1, activation=None, norm_type=None, bias=bias) 87 88 def forward(self, x: torch.Tensor) -> torch.Tensor: 89 """ 90 Overview: 91 Return the redisual block output. 92 Arguments: 93 - x (:obj:`torch.Tensor`): The input tensor. 94 Returns: 95 - x (:obj:`torch.Tensor`): The resblock output tensor. 96 """ 97 identity = x 98 x = self.conv1(x) 99 x = self.conv2(x) 100 if self.res_type == 'bottleneck': 101 x = self.conv3(x) 102 elif self.res_type == 'downsample': 103 identity = self.conv3(identity) 104 x = self.act(x + identity) 105 return x 106 107 108class ResFCBlock(nn.Module): 109 """ 110 Overview: 111 Residual Block with 2 fully connected layers. 112 x -> fc1 -> norm -> act -> fc2 -> norm -> act -> out 113 \_____________________________________/+ 114 115 Interfaces: 116 ``__init__``, ``forward`` 117 """ 118 119 def __init__( 120 self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN', dropout: float = None 121 ): 122 """ 123 Overview: 124 Init the fully connected layer residual block. 125 Arguments: 126 - in_channels (:obj:`int`): The number of channels in the input tensor. 127 - activation (:obj:`nn.Module`): The optional activation function. 128 - norm_type (:obj:`str`): The type of the normalization, default set to 'BN'. 129 - dropout (:obj:`float`): The dropout rate, default set to None. 130 """ 131 super(ResFCBlock, self).__init__() 132 self.act = activation 133 if dropout is not None: 134 self.dropout = nn.Dropout(dropout) 135 else: 136 self.dropout = None 137 self.fc1 = fc_block(in_channels, in_channels, activation=self.act, norm_type=norm_type) 138 self.fc2 = fc_block(in_channels, in_channels, activation=None, norm_type=norm_type) 139 140 def forward(self, x: torch.Tensor) -> torch.Tensor: 141 """ 142 Overview: 143 Return the output of the redisual block. 144 Arguments: 145 - x (:obj:`torch.Tensor`): The input tensor. 146 Returns: 147 - x (:obj:`torch.Tensor`): The resblock output tensor. 148 """ 149 identity = x 150 x = self.fc1(x) 151 x = self.fc2(x) 152 x = self.act(x + identity) 153 if self.dropout is not None: 154 x = self.dropout(x) 155 return x 156 157 158class TemporalSpatialResBlock(nn.Module): 159 """ 160 Overview: 161 Residual Block using MLP layers for both temporal and spatial input. 162 t → time_mlp → h1 → dense2 → h2 → out 163 ↗+ ↗+ 164 x → dense1 → ↗ 165 ↘ ↗ 166 → modify_x → → → → 167 """ 168 169 def __init__(self, input_dim, output_dim, t_dim=128, activation=torch.nn.SiLU()): 170 """ 171 Overview: 172 Init the temporal spatial residual block. 173 Arguments: 174 - input_dim (:obj:`int`): The number of channels in the input tensor. 175 - output_dim (:obj:`int`): The number of channels in the output tensor. 176 - t_dim (:obj:`int`): The dimension of the temporal input. 177 - activation (:obj:`nn.Module`): The optional activation function. 178 """ 179 super().__init__() 180 # temporal input is the embedding of time, which is a Gaussian Fourier Feature tensor 181 self.time_mlp = nn.Sequential( 182 activation, 183 nn.Linear(t_dim, output_dim), 184 ) 185 self.dense1 = nn.Sequential(nn.Linear(input_dim, output_dim), activation) 186 self.dense2 = nn.Sequential(nn.Linear(output_dim, output_dim), activation) 187 self.modify_x = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity() 188 189 def forward(self, x, t) -> torch.Tensor: 190 """ 191 Overview: 192 Return the redisual block output. 193 Arguments: 194 - x (:obj:`torch.Tensor`): The input tensor. 195 - t (:obj:`torch.Tensor`): The temporal input tensor. 196 """ 197 h1 = self.dense1(x) + self.time_mlp(t) 198 h2 = self.dense2(h1) 199 return h2 + self.modify_x(x)