1from typing import Union 2 3import torch 4import torch.nn as nn 5import torch.nn.functional as F 6from ding.utils import SequenceType 7 8 9class ContrastiveLoss(nn.Module): 10 """ 11 Overview: 12 The class for contrastive learning losses. Only InfoNCE loss is supported currently. \ 13 Code Reference: https://github.com/rdevon/DIM. Paper Reference: https://arxiv.org/abs/1808.06670. 14 Interfaces: 15 ``__init__``, ``forward``. 16 """ 17 18 def __init__( 19 self, 20 x_size: Union[int, SequenceType], 21 y_size: Union[int, SequenceType], 22 heads: SequenceType = [1, 1], 23 encode_shape: int = 64, 24 loss_type: str = "infoNCE", # Only the InfoNCE loss is available now. 25 temperature: float = 1.0, 26 ) -> None: 27 """ 28 Overview: 29 Initialize the ContrastiveLoss object using the given arguments. 30 Arguments: 31 - x_size (:obj:`Union[int, SequenceType]`): input shape for x, both the obs shape and the encoding shape \ 32 are supported. 33 - y_size (:obj:`Union[int, SequenceType]`): Input shape for y, both the obs shape and the encoding shape \ 34 are supported. 35 - heads (:obj:`SequenceType`): A list of 2 int elems, ``heads[0]`` for x and ``head[1]`` for y. \ 36 Used in multi-head, global-local, local-local MI maximization process. 37 - encoder_shape (:obj:`Union[int, SequenceType]`): The dimension of encoder hidden state. 38 - loss_type: Only the InfoNCE loss is available now. 39 - temperature: The parameter to adjust the ``log_softmax``. 40 """ 41 super(ContrastiveLoss, self).__init__() 42 assert len(heads) == 2, "Expected length of 2, but got: {}".format(len(heads)) 43 assert loss_type.lower() in ["infonce"] 44 45 self._type = loss_type.lower() 46 self._encode_shape = encode_shape 47 self._heads = heads 48 self._x_encoder = self._create_encoder(x_size, heads[0]) 49 self._y_encoder = self._create_encoder(y_size, heads[1]) 50 self._temperature = temperature 51 52 def _create_encoder(self, obs_size: Union[int, SequenceType], heads: int) -> nn.Module: 53 """ 54 Overview: 55 Create the encoder for the input obs. 56 Arguments: 57 - obs_size (:obj:`Union[int, SequenceType]`): input shape for x, both the obs shape and the encoding shape \ 58 are supported. If the obs_size is an int, it means the obs is a 1D vector. If the obs_size is a list \ 59 such as [1, 16, 16], it means the obs is a 3D image with shape [1, 16, 16]. 60 - heads (:obj:`int`): The number of heads. 61 Returns: 62 - encoder (:obj:`nn.Module`): The encoder module. 63 Examples: 64 >>> obs_size = 16 65 or 66 >>> obs_size = [1, 16, 16] 67 >>> heads = 1 68 >>> encoder = self._create_encoder(obs_size, heads) 69 """ 70 from ding.model import ConvEncoder, FCEncoder 71 72 if isinstance(obs_size, int): 73 obs_size = [obs_size] 74 assert len(obs_size) in [1, 3] 75 76 if len(obs_size) == 1: 77 hidden_size_list = [128, 128, self._encode_shape * heads] 78 encoder = FCEncoder(obs_size[0], hidden_size_list) 79 else: 80 hidden_size_list = [32, 64, 64, self._encode_shape * heads] 81 if obs_size[-1] >= 36: 82 encoder = ConvEncoder(obs_size, hidden_size_list) 83 else: 84 encoder = ConvEncoder(obs_size, hidden_size_list, kernel_size=[4, 3, 2], stride=[2, 1, 1]) 85 return encoder 86 87 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 88 """ 89 Overview: 90 Computes the noise contrastive estimation-based loss, a.k.a. infoNCE. 91 Arguments: 92 - x (:obj:`torch.Tensor`): The input x, both raw obs and encoding are supported. 93 - y (:obj:`torch.Tensor`): The input y, both raw obs and encoding are supported. 94 Returns: 95 loss (:obj:`torch.Tensor`): The calculated loss value. 96 Examples: 97 >>> x_dim = [3, 16] 98 >>> encode_shape = 16 99 >>> x = np.random.normal(0, 1, size=x_dim) 100 >>> y = x ** 2 + 0.01 * np.random.normal(0, 1, size=x_dim) 101 >>> estimator = ContrastiveLoss(dims, dims, encode_shape=encode_shape) 102 >>> loss = estimator.forward(x, y) 103 Examples: 104 >>> x_dim = [3, 1, 16, 16] 105 >>> encode_shape = 16 106 >>> x = np.random.normal(0, 1, size=x_dim) 107 >>> y = x ** 2 + 0.01 * np.random.normal(0, 1, size=x_dim) 108 >>> estimator = ContrastiveLoss(dims, dims, encode_shape=encode_shape) 109 >>> loss = estimator.forward(x, y) 110 """ 111 112 N = x.size(0) 113 x_heads, y_heads = self._heads 114 x = self._x_encoder.forward(x).view(N, x_heads, self._encode_shape) 115 y = self._y_encoder.forward(y).view(N, y_heads, self._encode_shape) 116 117 x_n = x.view(-1, self._encode_shape) 118 y_n = y.view(-1, self._encode_shape) 119 120 # Use inner product to obtain positive samples. 121 # [N, x_heads, encode_dim] * [N, encode_dim, y_heads] -> [N, x_heads, y_heads] 122 u_pos = torch.matmul(x, y.permute(0, 2, 1)).unsqueeze(2) 123 # Use outer product to obtain all sample permutations. 124 # [N * x_heads, encode_dim] X [encode_dim, N * y_heads] -> [N * x_heads, N * y_heads] 125 u_all = torch.mm(y_n, x_n.t()).view(N, y_heads, N, x_heads).permute(0, 2, 3, 1) 126 127 # Mask the diagonal part to obtain the negative samples, with all diagonals setting to -10. 128 mask = torch.eye(N)[:, :, None, None].to(x.device) 129 n_mask = 1 - mask 130 u_neg = (n_mask * u_all) - (10. * (1 - n_mask)) 131 u_neg = u_neg.view(N, N * x_heads, y_heads).unsqueeze(dim=1).expand(-1, x_heads, -1, -1) 132 133 # Concatenate positive and negative samples and apply log softmax. 134 pred_lgt = torch.cat([u_pos, u_neg], dim=2) 135 pred_log = F.log_softmax(pred_lgt * self._temperature, dim=2) 136 137 # The positive score is the first element of the log softmax. 138 loss = -pred_log[:, :, 0, :].mean() 139 return loss