Skip to content

ding.model.common.encoder

ding.model.common.encoder

ConvEncoder

Bases: Module

Overview

The Convolution Encoder is used to encode 2-dim image observations.

Interfaces: __init__, forward.

__init__(obs_shape, hidden_size_list=[32, 64, 64, 128], activation=nn.ReLU(), kernel_size=[8, 4, 3], stride=[4, 2, 1], padding=None, layer_norm=False, norm_type=None)

Overview

Initialize the Convolution Encoder according to the provided arguments.

Arguments: - obs_shape (:obj:SequenceType): Sequence of in_channel, plus one or more input size. - hidden_size_list (:obj:SequenceType): Sequence of hidden_size of subsequent conv layers and the final dense layer. - activation (:obj:nn.Module): Type of activation to use in the conv layers and ResBlock. Default is nn.ReLU(). - kernel_size (:obj:SequenceType): Sequence of kernel_size of subsequent conv layers. - stride (:obj:SequenceType): Sequence of stride of subsequent conv layers. - padding (:obj:SequenceType): Padding added to all four sides of the input for each conv layer. See nn.Conv2d for more details. Default is None. - layer_norm (:obj:bool): Whether to use DreamerLayerNorm, which is kind of special trick proposed in DreamerV3. - norm_type (:obj:str): Type of normalization to use. See ding.torch_utils.network.ResBlock for more details. Default is None.

forward(x)

Overview

Return output 1D embedding tensor of the env's 2D image observation.

Arguments: - x (:obj:torch.Tensor): Raw 2D observation of the environment. Returns: - outputs (:obj:torch.Tensor): Output embedding tensor. Shapes: - x : :math:(B, C, H, W), where B is batch size, C is channel, H is height, W is width. - outputs: :math:(B, N), where N = hidden_size_list[-1] . Examples: >>> conv = ConvEncoder( >>> obs_shape=(4, 84, 84), >>> hidden_size_list=[32, 64, 64, 128], >>> activation=nn.ReLU(), >>> kernel_size=[8, 4, 3], >>> stride=[4, 2, 1], >>> padding=None, >>> layer_norm=False, >>> norm_type=None >>> ) >>> x = torch.randn(1, 4, 84, 84) >>> output = conv(x)

FCEncoder

Bases: Module

Overview

The full connected encoder is used to encode 1-dim input variable.

Interfaces: __init__, forward.

__init__(obs_shape, hidden_size_list, res_block=False, activation=nn.ReLU(), norm_type=None, dropout=None)

Overview

Initialize the FC Encoder according to arguments.

Arguments: - obs_shape (:obj:int): Observation shape. - hidden_size_list (:obj:SequenceType): Sequence of hidden_size of subsequent FC layers. - res_block (:obj:bool): Whether use res_block. Default is False. - activation (:obj:nn.Module): Type of activation to use in ResFCBlock. Default is nn.ReLU(). - norm_type (:obj:str): Type of normalization to use. See ding.torch_utils.network.ResFCBlock for more details. Default is None. - dropout (:obj:float): Dropout rate of the dropout layer. If None then default no dropout layer.

forward(x)

Overview

Return output embedding tensor of the env observation.

Arguments: - x (:obj:torch.Tensor): Env raw observation. Returns: - outputs (:obj:torch.Tensor): Output embedding tensor. Shapes: - x : :math:(B, M), where M = obs_shape. - outputs: :math:(B, N), where N = hidden_size_list[-1]. Examples: >>> fc = FCEncoder( >>> obs_shape=4, >>> hidden_size_list=[32, 64, 64, 128], >>> activation=nn.ReLU(), >>> norm_type=None, >>> dropout=None >>> ) >>> x = torch.randn(1, 4) >>> output = fc(x)

IMPALACnnResidualBlock

Bases: Module

Overview

This CNN encoder residual block is residual basic block used in IMPALA algorithm, which preserves the channel number and shape. IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures https://arxiv.org/pdf/1802.01561.pdf

Interfaces: __init__, forward.

__init__(in_channnel, scale=1, batch_norm=False)

Overview

Initialize the IMPALA CNN residual block according to arguments.

Arguments: - in_channnel (:obj:int): Channel number of input features. - scale (:obj:float): Scale of module, defaults to 1. - batch_norm (:obj:bool): Whether use batch normalization, defaults to False.

residual(x)

Overview

Return output tensor of the residual block, keep the shape and channel number unchanged. The inplace of activation function should be False for the first relu, so that it does not change the origin input tensor of the residual block.

Arguments: - x (:obj:torch.Tensor): Input tensor. Returns: - output (:obj:torch.Tensor): Output tensor.

forward(x)

Overview

Return output tensor of the residual block, keep the shape and channel number unchanged.

Arguments: - x (:obj:torch.Tensor): Input tensor. Returns: - output (:obj:torch.Tensor): Output tensor. Examples: >>> block = IMPALACnnResidualBlock(16) >>> x = torch.randn(1, 16, 84, 84) >>> output = block(x)

IMPALACnnDownStack

Bases: Module

Overview

Downsampling stack of CNN encoder used in IMPALA algorithmn. Every IMPALACnnDownStack consists n IMPALACnnResidualBlock, which reduces the spatial size by 2 with maxpooling. IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures https://arxiv.org/pdf/1802.01561.pdf

Interfaces: __init__, forward.

__init__(in_channnel, nblock, out_channel, scale=1, pool=True, **kwargs)

Overview

Initialize every impala cnn block of the Impala Cnn Encoder.

Arguments: - in_channnel (:obj:int): Channel number of input features. - nblock (:obj:int): Residual Block number in each block. - out_channel (:obj:int): Channel number of output features. - scale (:obj:float): Scale of the module. - pool (:obj:bool): Whether to use maxing pooling after first conv layer.

forward(x)

Overview

Return output tensor of the downsampling stack. The output shape is different from input shape. And you can refer to the output_shape method to get the output shape.

Arguments: - x (:obj:torch.Tensor): Input tensor. Returns: - output (:obj:torch.Tensor): Output tensor. Examples: >>> stack = IMPALACnnDownStack(16, 2, 32) >>> x = torch.randn(1, 16, 84, 84) >>> output = stack(x)

output_shape(inshape)

Overview

Calculate the output shape of the downsampling stack according to input shape and related arguments.

Arguments: - inshape (:obj:tuple): Input shape. Returns: - output_shape (:obj:tuple): Output shape. Shapes: - inshape (:obj:tuple): :math:(C, H, W), where C is channel number, H is height and W is width. - output_shape (:obj:tuple): :math:(C, H, W), where C is channel number, H is height and W is width. Examples: >>> stack = IMPALACnnDownStack(16, 2, 32) >>> inshape = (16, 84, 84) >>> output_shape = stack.output_shape(inshape)

IMPALAConvEncoder

Bases: Module

Overview

IMPALA CNN encoder, which is used in IMPALA algorithm. IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, https://arxiv.org/pdf/1802.01561.pdf,

Interface: __init__, forward, output_shape.

__init__(obs_shape, channels=(16, 32, 32), outsize=256, scale_ob=255.0, nblock=2, final_relu=True, **kwargs)

Overview

Initialize the IMPALA CNN encoder according to arguments.

Arguments: - obs_shape (:obj:SequenceType): 2D image observation shape. - channels (:obj:SequenceType): The channel number of a series of impala cnn blocks. Each element of the sequence is the output channel number of a impala cnn block. - outsize (:obj:int): The output size the final linear layer, which means the dimension of the 1D embedding vector. - scale_ob (:obj:float): The scale of the input observation, which is used to normalize the input observation, such as dividing 255.0 for the raw image observation. - nblock (:obj:int): The number of Residual Block in each block. - final_relu (:obj:bool): Whether to use ReLU activation in the final output of encoder. - kwargs (:obj:Dict[str, Any]): Other arguments for IMPALACnnDownStack.

forward(x)

Overview

Return the 1D embedding vector of the input 2D observation.

Arguments: - x (:obj:torch.Tensor): Input 2D observation tensor. Returns: - output (:obj:torch.Tensor): Output 1D embedding vector. Shapes: - x (:obj:torch.Tensor): :math:(B, C, H, W), where B is batch size, C is channel number, H is height and W is width. - output (:obj:torch.Tensor): :math:(B, outsize), where B is batch size. Examples: >>> encoder = IMPALAConvEncoder( >>> obs_shape=(4, 84, 84), >>> channels=(16, 32, 32), >>> outsize=256, >>> scale_ob=255.0, >>> nblock=2, >>> final_relu=True, >>> ) >>> x = torch.randn(1, 4, 84, 84) >>> output = encoder(x)

GaussianFourierProjectionTimeEncoder

Bases: Module

Overview

Gaussian random features for encoding time steps. This module is used as the encoder of time in generative models such as diffusion model.

Interfaces: __init__, forward.

__init__(embed_dim, scale=30.0)

Overview

Initialize the Gaussian Fourier Projection Time Encoder according to arguments.

Arguments: - embed_dim (:obj:int): The dimension of the output embedding vector. - scale (:obj:float): The scale of the Gaussian random features.

forward(x)

Overview

Return the output embedding vector of the input time step.

Arguments: - x (:obj:torch.Tensor): Input time step tensor. Returns: - output (:obj:torch.Tensor): Output embedding vector. Shapes: - x (:obj:torch.Tensor): :math:(B,), where B is batch size. - output (:obj:torch.Tensor): :math:(B, embed_dim), where B is batch size, embed_dim is the dimension of the output embedding vector. Examples: >>> encoder = GaussianFourierProjectionTimeEncoder(128) >>> x = torch.randn(100) >>> output = encoder(x)

prod(iterable)

Overview

Product of all elements.(To be deprecated soon.) This function denifition is for supporting python version that under 3.8. In Python3.8 and larger, 'math.prod()' is recommended.

Full Source Code

../ding/model/common/encoder.py

1from typing import Optional, Dict, Union, List 2from functools import reduce 3import operator 4import math 5import numpy as np 6import torch 7import torch.nn as nn 8from torch.nn import functional as F 9 10from ding.torch_utils import ResFCBlock, ResBlock, Flatten, normed_linear, normed_conv2d 11from ding.torch_utils.network.dreamer import Conv2dSame, DreamerLayerNorm 12from ding.utils import SequenceType 13 14 15def prod(iterable): 16 """ 17 Overview: 18 Product of all elements.(To be deprecated soon.) This function denifition is for supporting python version \ 19 that under 3.8. In Python3.8 and larger, 'math.prod()' is recommended. 20 """ 21 return reduce(operator.mul, iterable, 1) 22 23 24class ConvEncoder(nn.Module): 25 """ 26 Overview: 27 The Convolution Encoder is used to encode 2-dim image observations. 28 Interfaces: 29 ``__init__``, ``forward``. 30 """ 31 32 def __init__( 33 self, 34 obs_shape: SequenceType, 35 hidden_size_list: SequenceType = [32, 64, 64, 128], 36 activation: Optional[nn.Module] = nn.ReLU(), 37 kernel_size: SequenceType = [8, 4, 3], 38 stride: SequenceType = [4, 2, 1], 39 padding: Optional[SequenceType] = None, 40 layer_norm: Optional[bool] = False, 41 norm_type: Optional[str] = None 42 ) -> None: 43 """ 44 Overview: 45 Initialize the ``Convolution Encoder`` according to the provided arguments. 46 Arguments: 47 - obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``. 48 - hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent conv layers \ 49 and the final dense layer. 50 - activation (:obj:`nn.Module`): Type of activation to use in the conv ``layers`` and ``ResBlock``. \ 51 Default is ``nn.ReLU()``. 52 - kernel_size (:obj:`SequenceType`): Sequence of ``kernel_size`` of subsequent conv layers. 53 - stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers. 54 - padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \ 55 See ``nn.Conv2d`` for more details. Default is ``None``. 56 - layer_norm (:obj:`bool`): Whether to use ``DreamerLayerNorm``, which is kind of special trick \ 57 proposed in DreamerV3. 58 - norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResBlock`` \ 59 for more details. Default is ``None``. 60 """ 61 super(ConvEncoder, self).__init__() 62 self.obs_shape = obs_shape 63 self.act = activation 64 self.hidden_size_list = hidden_size_list 65 if padding is None: 66 padding = [0 for _ in range(len(kernel_size))] 67 68 layers = [] 69 input_size = obs_shape[0] # in_channel 70 for i in range(len(kernel_size)): 71 if layer_norm: 72 layers.append( 73 Conv2dSame( 74 in_channels=input_size, 75 out_channels=hidden_size_list[i], 76 kernel_size=(kernel_size[i], kernel_size[i]), 77 stride=(2, 2), 78 bias=False, 79 ) 80 ) 81 layers.append(DreamerLayerNorm(hidden_size_list[i])) 82 layers.append(self.act) 83 else: 84 layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i])) 85 layers.append(self.act) 86 input_size = hidden_size_list[i] 87 if len(self.hidden_size_list) >= len(kernel_size) + 2: 88 assert self.hidden_size_list[len(kernel_size) - 1] == self.hidden_size_list[ 89 len(kernel_size)], "Please indicate the same hidden size between conv and res block" 90 assert len( 91 set(hidden_size_list[len(kernel_size):-1]) 92 ) <= 1, "Please indicate the same hidden size for res block parts" 93 for i in range(len(kernel_size), len(self.hidden_size_list) - 1): 94 layers.append(ResBlock(self.hidden_size_list[i - 1], activation=self.act, norm_type=norm_type)) 95 layers.append(Flatten()) 96 self.main = nn.Sequential(*layers) 97 98 flatten_size = self._get_flatten_size() 99 self.output_size = hidden_size_list[-1] # outside to use 100 self.mid = nn.Linear(flatten_size, hidden_size_list[-1]) 101 102 def _get_flatten_size(self) -> int: 103 """ 104 Overview: 105 Get the encoding size after ``self.main`` to get the number of ``in-features`` to feed to ``nn.Linear``. 106 Returns: 107 - outputs (:obj:`torch.Tensor`): Size ``int`` Tensor representing the number of ``in-features``. 108 Shapes: 109 - outputs: :math:`(1,)`. 110 Examples: 111 >>> conv = ConvEncoder( 112 >>> obs_shape=(4, 84, 84), 113 >>> hidden_size_list=[32, 64, 64, 128], 114 >>> activation=nn.ReLU(), 115 >>> kernel_size=[8, 4, 3], 116 >>> stride=[4, 2, 1], 117 >>> padding=None, 118 >>> layer_norm=False, 119 >>> norm_type=None 120 >>> ) 121 >>> flatten_size = conv._get_flatten_size() 122 """ 123 test_data = torch.randn(1, *self.obs_shape) 124 with torch.no_grad(): 125 output = self.main(test_data) 126 return output.shape[1] 127 128 def forward(self, x: torch.Tensor) -> torch.Tensor: 129 """ 130 Overview: 131 Return output 1D embedding tensor of the env's 2D image observation. 132 Arguments: 133 - x (:obj:`torch.Tensor`): Raw 2D observation of the environment. 134 Returns: 135 - outputs (:obj:`torch.Tensor`): Output embedding tensor. 136 Shapes: 137 - x : :math:`(B, C, H, W)`, where ``B`` is batch size, ``C`` is channel, ``H`` is height, ``W`` is width. 138 - outputs: :math:`(B, N)`, where ``N = hidden_size_list[-1]`` . 139 Examples: 140 >>> conv = ConvEncoder( 141 >>> obs_shape=(4, 84, 84), 142 >>> hidden_size_list=[32, 64, 64, 128], 143 >>> activation=nn.ReLU(), 144 >>> kernel_size=[8, 4, 3], 145 >>> stride=[4, 2, 1], 146 >>> padding=None, 147 >>> layer_norm=False, 148 >>> norm_type=None 149 >>> ) 150 >>> x = torch.randn(1, 4, 84, 84) 151 >>> output = conv(x) 152 """ 153 x = self.main(x) 154 x = self.mid(x) 155 return x 156 157 158class FCEncoder(nn.Module): 159 """ 160 Overview: 161 The full connected encoder is used to encode 1-dim input variable. 162 Interfaces: 163 ``__init__``, ``forward``. 164 """ 165 166 def __init__( 167 self, 168 obs_shape: int, 169 hidden_size_list: SequenceType, 170 res_block: bool = False, 171 activation: Optional[nn.Module] = nn.ReLU(), 172 norm_type: Optional[str] = None, 173 dropout: Optional[float] = None 174 ) -> None: 175 """ 176 Overview: 177 Initialize the FC Encoder according to arguments. 178 Arguments: 179 - obs_shape (:obj:`int`): Observation shape. 180 - hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent FC layers. 181 - res_block (:obj:`bool`): Whether use ``res_block``. Default is ``False``. 182 - activation (:obj:`nn.Module`): Type of activation to use in ``ResFCBlock``. Default is ``nn.ReLU()``. 183 - norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResFCBlock`` \ 184 for more details. Default is ``None``. 185 - dropout (:obj:`float`): Dropout rate of the dropout layer. If ``None`` then default no dropout layer. 186 """ 187 super(FCEncoder, self).__init__() 188 self.obs_shape = obs_shape 189 self.act = activation 190 self.init = nn.Linear(obs_shape, hidden_size_list[0]) 191 192 if res_block: 193 assert len(set(hidden_size_list)) == 1, "Please indicate the same hidden size for res block parts" 194 if len(hidden_size_list) == 1: 195 self.main = ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout) 196 else: 197 layers = [] 198 for i in range(len(hidden_size_list)): 199 layers.append( 200 ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout) 201 ) 202 self.main = nn.Sequential(*layers) 203 else: 204 layers = [] 205 for i in range(len(hidden_size_list) - 1): 206 layers.append(nn.Linear(hidden_size_list[i], hidden_size_list[i + 1])) 207 layers.append(self.act) 208 if dropout is not None: 209 layers.append(nn.Dropout(dropout)) 210 self.main = nn.Sequential(*layers) 211 212 def forward(self, x: torch.Tensor) -> torch.Tensor: 213 """ 214 Overview: 215 Return output embedding tensor of the env observation. 216 Arguments: 217 - x (:obj:`torch.Tensor`): Env raw observation. 218 Returns: 219 - outputs (:obj:`torch.Tensor`): Output embedding tensor. 220 Shapes: 221 - x : :math:`(B, M)`, where ``M = obs_shape``. 222 - outputs: :math:`(B, N)`, where ``N = hidden_size_list[-1]``. 223 Examples: 224 >>> fc = FCEncoder( 225 >>> obs_shape=4, 226 >>> hidden_size_list=[32, 64, 64, 128], 227 >>> activation=nn.ReLU(), 228 >>> norm_type=None, 229 >>> dropout=None 230 >>> ) 231 >>> x = torch.randn(1, 4) 232 >>> output = fc(x) 233 """ 234 x = self.act(self.init(x)) 235 x = self.main(x) 236 return x 237 238 239class StructEncoder(nn.Module): 240 241 def __init__(self, obs_shape: Dict[str, Union[int, List[int]]]) -> None: 242 super(StructEncoder, self).__init__() 243 # TODO concrete implementation 244 raise NotImplementedError 245 246 247class IMPALACnnResidualBlock(nn.Module): 248 """ 249 Overview: 250 This CNN encoder residual block is residual basic block used in IMPALA algorithm, 251 which preserves the channel number and shape. 252 IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures 253 https://arxiv.org/pdf/1802.01561.pdf 254 Interfaces: 255 ``__init__``, ``forward``. 256 """ 257 258 def __init__(self, in_channnel: int, scale: float = 1, batch_norm: bool = False): 259 """ 260 Overview: 261 Initialize the IMPALA CNN residual block according to arguments. 262 Arguments: 263 - in_channnel (:obj:`int`): Channel number of input features. 264 - scale (:obj:`float`): Scale of module, defaults to 1. 265 - batch_norm (:obj:`bool`): Whether use batch normalization, defaults to False. 266 """ 267 super().__init__() 268 self.in_channnel = in_channnel 269 self.batch_norm = batch_norm 270 s = math.sqrt(scale) 271 self.conv0 = normed_conv2d(self.in_channnel, self.in_channnel, 3, padding=1, scale=s) 272 self.conv1 = normed_conv2d(self.in_channnel, self.in_channnel, 3, padding=1, scale=s) 273 if self.batch_norm: 274 self.bn0 = nn.BatchNorm2d(self.in_channnel) 275 self.bn1 = nn.BatchNorm2d(self.in_channnel) 276 277 def residual(self, x: torch.Tensor) -> torch.Tensor: 278 """ 279 Overview: 280 Return output tensor of the residual block, keep the shape and channel number unchanged. 281 The inplace of activation function should be False for the first relu, 282 so that it does not change the origin input tensor of the residual block. 283 Arguments: 284 - x (:obj:`torch.Tensor`): Input tensor. 285 Returns: 286 - output (:obj:`torch.Tensor`): Output tensor. 287 """ 288 if self.batch_norm: 289 x = self.bn0(x) 290 x = F.relu(x, inplace=False) 291 x = self.conv0(x) 292 if self.batch_norm: 293 x = self.bn1(x) 294 x = F.relu(x, inplace=True) 295 x = self.conv1(x) 296 return x 297 298 def forward(self, x: torch.Tensor) -> torch.Tensor: 299 """ 300 Overview: 301 Return output tensor of the residual block, keep the shape and channel number unchanged. 302 Arguments: 303 - x (:obj:`torch.Tensor`): Input tensor. 304 Returns: 305 - output (:obj:`torch.Tensor`): Output tensor. 306 Examples: 307 >>> block = IMPALACnnResidualBlock(16) 308 >>> x = torch.randn(1, 16, 84, 84) 309 >>> output = block(x) 310 """ 311 return x + self.residual(x) 312 313 314class IMPALACnnDownStack(nn.Module): 315 """ 316 Overview: 317 Downsampling stack of CNN encoder used in IMPALA algorithmn. 318 Every IMPALACnnDownStack consists n IMPALACnnResidualBlock, 319 which reduces the spatial size by 2 with maxpooling. 320 IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures 321 https://arxiv.org/pdf/1802.01561.pdf 322 Interfaces: 323 ``__init__``, ``forward``. 324 """ 325 326 def __init__(self, in_channnel, nblock, out_channel, scale=1, pool=True, **kwargs): 327 """ 328 Overview: 329 Initialize every impala cnn block of the Impala Cnn Encoder. 330 Arguments: 331 - in_channnel (:obj:`int`): Channel number of input features. 332 - nblock (:obj:`int`): Residual Block number in each block. 333 - out_channel (:obj:`int`): Channel number of output features. 334 - scale (:obj:`float`): Scale of the module. 335 - pool (:obj:`bool`): Whether to use maxing pooling after first conv layer. 336 """ 337 super().__init__() 338 self.in_channnel = in_channnel 339 self.out_channel = out_channel 340 self.pool = pool 341 self.firstconv = normed_conv2d(in_channnel, out_channel, 3, padding=1) 342 s = scale / math.sqrt(nblock) 343 self.blocks = nn.ModuleList([IMPALACnnResidualBlock(out_channel, scale=s, **kwargs) for _ in range(nblock)]) 344 345 def forward(self, x: torch.Tensor) -> torch.Tensor: 346 """ 347 Overview: 348 Return output tensor of the downsampling stack. The output shape is different from input shape. And you \ 349 can refer to the ``output_shape`` method to get the output shape. 350 Arguments: 351 - x (:obj:`torch.Tensor`): Input tensor. 352 Returns: 353 - output (:obj:`torch.Tensor`): Output tensor. 354 Examples: 355 >>> stack = IMPALACnnDownStack(16, 2, 32) 356 >>> x = torch.randn(1, 16, 84, 84) 357 >>> output = stack(x) 358 """ 359 x = self.firstconv(x) 360 if self.pool: 361 x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 362 for block in self.blocks: 363 x = block(x) 364 return x 365 366 def output_shape(self, inshape: tuple) -> tuple: 367 """ 368 Overview: 369 Calculate the output shape of the downsampling stack according to input shape and related arguments. 370 Arguments: 371 - inshape (:obj:`tuple`): Input shape. 372 Returns: 373 - output_shape (:obj:`tuple`): Output shape. 374 Shapes: 375 - inshape (:obj:`tuple`): :math:`(C, H, W)`, where C is channel number, H is height and W is width. 376 - output_shape (:obj:`tuple`): :math:`(C, H, W)`, where C is channel number, H is height and W is width. 377 Examples: 378 >>> stack = IMPALACnnDownStack(16, 2, 32) 379 >>> inshape = (16, 84, 84) 380 >>> output_shape = stack.output_shape(inshape) 381 """ 382 c, h, w = inshape 383 assert c == self.in_channnel 384 if self.pool: 385 return (self.out_channel, (h + 1) // 2, (w + 1) // 2) 386 else: 387 return (self.out_channel, h, w) 388 389 390class IMPALAConvEncoder(nn.Module): 391 """ 392 Overview: 393 IMPALA CNN encoder, which is used in IMPALA algorithm. 394 IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, \ 395 https://arxiv.org/pdf/1802.01561.pdf, 396 Interface: 397 ``__init__``, ``forward``, ``output_shape``. 398 """ 399 name = "IMPALAConvEncoder" # put it here to preserve pickle compat 400 401 def __init__( 402 self, 403 obs_shape: SequenceType, 404 channels: SequenceType = (16, 32, 32), 405 outsize: int = 256, 406 scale_ob: float = 255.0, 407 nblock: int = 2, 408 final_relu: bool = True, 409 **kwargs 410 ) -> None: 411 """ 412 Overview: 413 Initialize the IMPALA CNN encoder according to arguments. 414 Arguments: 415 - obs_shape (:obj:`SequenceType`): 2D image observation shape. 416 - channels (:obj:`SequenceType`): The channel number of a series of impala cnn blocks. \ 417 Each element of the sequence is the output channel number of a impala cnn block. 418 - outsize (:obj:`int`): The output size the final linear layer, which means the dimension of the \ 419 1D embedding vector. 420 - scale_ob (:obj:`float`): The scale of the input observation, which is used to normalize the input \ 421 observation, such as dividing 255.0 for the raw image observation. 422 - nblock (:obj:`int`): The number of Residual Block in each block. 423 - final_relu (:obj:`bool`): Whether to use ReLU activation in the final output of encoder. 424 - kwargs (:obj:`Dict[str, Any]`): Other arguments for ``IMPALACnnDownStack``. 425 """ 426 super().__init__() 427 self.scale_ob = scale_ob 428 c, h, w = obs_shape 429 curshape = (c, h, w) 430 s = 1 / math.sqrt(len(channels)) # per stack scale 431 self.stacks = nn.ModuleList() 432 for out_channel in channels: 433 stack = IMPALACnnDownStack(curshape[0], nblock=nblock, out_channel=out_channel, scale=s, **kwargs) 434 self.stacks.append(stack) 435 curshape = stack.output_shape(curshape) 436 self.dense = normed_linear(prod(curshape), outsize, scale=1.4) 437 self.outsize = outsize 438 self.final_relu = final_relu 439 440 def forward(self, x: torch.Tensor) -> torch.Tensor: 441 """ 442 Overview: 443 Return the 1D embedding vector of the input 2D observation. 444 Arguments: 445 - x (:obj:`torch.Tensor`): Input 2D observation tensor. 446 Returns: 447 - output (:obj:`torch.Tensor`): Output 1D embedding vector. 448 Shapes: 449 - x (:obj:`torch.Tensor`): :math:`(B, C, H, W)`, where B is batch size, C is channel number, H is height \ 450 and W is width. 451 - output (:obj:`torch.Tensor`): :math:`(B, outsize)`, where B is batch size. 452 Examples: 453 >>> encoder = IMPALAConvEncoder( 454 >>> obs_shape=(4, 84, 84), 455 >>> channels=(16, 32, 32), 456 >>> outsize=256, 457 >>> scale_ob=255.0, 458 >>> nblock=2, 459 >>> final_relu=True, 460 >>> ) 461 >>> x = torch.randn(1, 4, 84, 84) 462 >>> output = encoder(x) 463 """ 464 x = x / self.scale_ob 465 for (i, layer) in enumerate(self.stacks): 466 x = layer(x) 467 *batch_shape, h, w, c = x.shape 468 x = x.reshape((*batch_shape, h * w * c)) 469 x = F.relu(x) 470 x = self.dense(x) 471 if self.final_relu: 472 x = torch.relu(x) 473 return x 474 475 476class GaussianFourierProjectionTimeEncoder(nn.Module): 477 """ 478 Overview: 479 Gaussian random features for encoding time steps. 480 This module is used as the encoder of time in generative models such as diffusion model. 481 Interfaces: 482 ``__init__``, ``forward``. 483 """ 484 485 def __init__(self, embed_dim, scale=30.): 486 """ 487 Overview: 488 Initialize the Gaussian Fourier Projection Time Encoder according to arguments. 489 Arguments: 490 - embed_dim (:obj:`int`): The dimension of the output embedding vector. 491 - scale (:obj:`float`): The scale of the Gaussian random features. 492 """ 493 super().__init__() 494 # Randomly sample weights during initialization. These weights are fixed 495 # during optimization and are not trainable. 496 self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale * 2 * np.pi, requires_grad=False) 497 498 def forward(self, x): 499 """ 500 Overview: 501 Return the output embedding vector of the input time step. 502 Arguments: 503 - x (:obj:`torch.Tensor`): Input time step tensor. 504 Returns: 505 - output (:obj:`torch.Tensor`): Output embedding vector. 506 Shapes: 507 - x (:obj:`torch.Tensor`): :math:`(B,)`, where B is batch size. 508 - output (:obj:`torch.Tensor`): :math:`(B, embed_dim)`, where B is batch size, embed_dim is the \ 509 dimension of the output embedding vector. 510 Examples: 511 >>> encoder = GaussianFourierProjectionTimeEncoder(128) 512 >>> x = torch.randn(100) 513 >>> output = encoder(x) 514 """ 515 x_proj = x[..., None] * self.W[None, :] 516 return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)