ding.torch_utils.network.normalization¶
ding.torch_utils.network.normalization
¶
build_normalization(norm_type, dim=None)
¶
Overview
Construct the corresponding normalization module. For beginners, refer to this article to learn more about batch normalization.
Arguments:
- norm_type (:obj:str): Type of the normalization. Currently supports ['BN', 'LN', 'IN', 'SyncBN'].
- dim (:obj:Optional[int]): Dimension of the normalization, applicable when norm_type is in ['BN', 'IN'].
Returns:
- norm_func (:obj:nn.Module): The corresponding batch normalization function.
Full Source Code
../ding/torch_utils/network/normalization.py
1from typing import Optional 2import torch.nn as nn 3 4 5def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module: 6 """ 7 Overview: 8 Construct the corresponding normalization module. For beginners, 9 refer to [this article](https://zhuanlan.zhihu.com/p/34879333) to learn more about batch normalization. 10 Arguments: 11 - norm_type (:obj:`str`): Type of the normalization. Currently supports ['BN', 'LN', 'IN', 'SyncBN']. 12 - dim (:obj:`Optional[int]`): Dimension of the normalization, applicable when norm_type is in ['BN', 'IN']. 13 Returns: 14 - norm_func (:obj:`nn.Module`): The corresponding batch normalization function. 15 """ 16 if dim is None: 17 key = norm_type 18 else: 19 if norm_type in ['BN', 'IN']: 20 key = norm_type + str(dim) 21 elif norm_type in ['LN', 'SyncBN']: 22 key = norm_type 23 else: 24 raise NotImplementedError("not support indicated dim when creates {}".format(norm_type)) 25 norm_func = { 26 'BN1': nn.BatchNorm1d, 27 'BN2': nn.BatchNorm2d, 28 'LN': nn.LayerNorm, 29 'IN1': nn.InstanceNorm1d, 30 'IN2': nn.InstanceNorm2d, 31 'SyncBN': nn.SyncBatchNorm, 32 } 33 if key in norm_func.keys(): 34 return norm_func[key] 35 else: 36 raise KeyError("invalid norm type: {}".format(key))