Skip to content

ding.torch_utils.loss.multi_logits_loss

ding.torch_utils.loss.multi_logits_loss

MultiLogitsLoss

Bases: Module

Overview

Base class for supervised learning on linklink, including basic processes.

Interfaces: __init__, forward.

__init__(criterion=None, smooth_ratio=0.1)

Overview

Initialization method, use cross_entropy as default criterion.

Arguments: - criterion (:obj:str): Criterion type, supports ['cross_entropy', 'label_smooth_ce']. - smooth_ratio (:obj:float): Smoothing ratio for label smoothing.

forward(logits, labels)

Overview

Calculate multiple logits loss.

Arguments: - logits (:obj:torch.Tensor): Predicted logits, whose shape must be 2-dim, like (B, N). - labels (:obj:torch.LongTensor): Ground truth. Returns: - loss (:obj:torch.Tensor): Calculated loss.

Full Source Code

../ding/torch_utils/loss/multi_logits_loss.py

1import numpy as np 2import torch 3import torch.nn as nn 4import torch.nn.functional as F 5 6from ding.torch_utils.network import one_hot 7 8 9class MultiLogitsLoss(nn.Module): 10 """ 11 Overview: 12 Base class for supervised learning on linklink, including basic processes. 13 Interfaces: 14 ``__init__``, ``forward``. 15 """ 16 17 def __init__(self, criterion: str = None, smooth_ratio: float = 0.1) -> None: 18 """ 19 Overview: 20 Initialization method, use cross_entropy as default criterion. 21 Arguments: 22 - criterion (:obj:`str`): Criterion type, supports ['cross_entropy', 'label_smooth_ce']. 23 - smooth_ratio (:obj:`float`): Smoothing ratio for label smoothing. 24 """ 25 super(MultiLogitsLoss, self).__init__() 26 if criterion is None: 27 criterion = 'cross_entropy' 28 assert (criterion in ['cross_entropy', 'label_smooth_ce']) 29 self.criterion = criterion 30 if self.criterion == 'label_smooth_ce': 31 self.ratio = smooth_ratio 32 33 def _label_process(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.LongTensor: 34 """ 35 Overview: 36 Process the label according to the criterion. 37 Arguments: 38 - logits (:obj:`torch.Tensor`): Predicted logits. 39 - labels (:obj:`torch.LongTensor`): Ground truth. 40 Returns: 41 - ret (:obj:`torch.LongTensor`): Processed label. 42 """ 43 N = logits.shape[1] 44 if self.criterion == 'cross_entropy': 45 return one_hot(labels, num=N) 46 elif self.criterion == 'label_smooth_ce': 47 val = float(self.ratio) / (N - 1) 48 ret = torch.full_like(logits, val) 49 ret.scatter_(1, labels.unsqueeze(1), 1 - val) 50 return ret 51 52 def _nll_loss(self, nlls: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: 53 """ 54 Overview: 55 Calculate the negative log likelihood loss. 56 Arguments: 57 - nlls (:obj:`torch.Tensor`): Negative log likelihood loss. 58 - labels (:obj:`torch.LongTensor`): Ground truth. 59 Returns: 60 - ret (:obj:`torch.Tensor`): Calculated loss. 61 """ 62 ret = (-nlls * (labels.detach())) 63 return ret.sum(dim=1) 64 65 def _get_metric_matrix(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: 66 """ 67 Overview: 68 Calculate the metric matrix. 69 Arguments: 70 - logits (:obj:`torch.Tensor`): Predicted logits. 71 - labels (:obj:`torch.LongTensor`): Ground truth. 72 Returns: 73 - metric (:obj:`torch.Tensor`): Calculated metric matrix. 74 """ 75 M, N = logits.shape 76 labels = self._label_process(logits, labels) 77 logits = F.log_softmax(logits, dim=1) 78 metric = [] 79 for i in range(M): 80 logit = logits[i] 81 logit = logit.repeat(M).reshape(M, N) 82 metric.append(self._nll_loss(logit, labels)) 83 return torch.stack(metric, dim=0) 84 85 def _match(self, matrix: torch.Tensor): 86 """ 87 Overview: 88 Match the metric matrix. 89 Arguments: 90 - matrix (:obj:`torch.Tensor`): Metric matrix. 91 Returns: 92 - index (:obj:`np.ndarray`): Matched index. 93 """ 94 mat = matrix.clone().detach().to('cpu').numpy() 95 mat = -mat # maximize 96 M = mat.shape[0] 97 index = np.full(M, -1, dtype=np.int32) # -1 note not find link 98 lx = mat.max(axis=1) 99 ly = np.zeros(M, dtype=np.float32) 100 visx = np.zeros(M, dtype=np.bool_) 101 visy = np.zeros(M, dtype=np.bool_) 102 103 def has_augmented_path(t, binary_distance_matrix): 104 # What is changed? visx, visy, distance_matrix, index 105 # What is changed within this function? visx, visy, index 106 visx[t] = True 107 for i in range(M): 108 if not visy[i] and binary_distance_matrix[t, i]: 109 visy[i] = True 110 if index[i] == -1 or has_augmented_path(index[i], binary_distance_matrix): 111 index[i] = t 112 return True 113 return False 114 115 for i in range(M): 116 while True: 117 visx.fill(False) 118 visy.fill(False) 119 distance_matrix = self._get_distance_matrix(lx, ly, mat, M) 120 binary_distance_matrix = np.abs(distance_matrix) < 1e-4 121 if has_augmented_path(i, binary_distance_matrix): 122 break 123 masked_distance_matrix = distance_matrix[:, ~visy][visx] 124 if 0 in masked_distance_matrix.shape: # empty matrix 125 raise RuntimeError("match error, matrix: {}".format(matrix)) 126 else: 127 d = masked_distance_matrix.min() 128 lx[visx] -= d 129 ly[visy] += d 130 return index 131 132 @staticmethod 133 def _get_distance_matrix(lx: np.ndarray, ly: np.ndarray, mat: np.ndarray, M: int) -> np.ndarray: 134 """ 135 Overview: 136 Get distance matrix. 137 Arguments: 138 - lx (:obj:`np.ndarray`): lx. 139 - ly (:obj:`np.ndarray`): ly. 140 - mat (:obj:`np.ndarray`): mat. 141 - M (:obj:`int`): M. 142 """ 143 nlx = np.broadcast_to(lx, [M, M]).T 144 nly = np.broadcast_to(ly, [M, M]) 145 nret = nlx + nly - mat 146 return nret 147 148 def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: 149 """ 150 Overview: 151 Calculate multiple logits loss. 152 Arguments: 153 - logits (:obj:`torch.Tensor`): Predicted logits, whose shape must be 2-dim, like (B, N). 154 - labels (:obj:`torch.LongTensor`): Ground truth. 155 Returns: 156 - loss (:obj:`torch.Tensor`): Calculated loss. 157 """ 158 assert (len(logits.shape) == 2) 159 metric_matrix = self._get_metric_matrix(logits, labels) 160 index = self._match(metric_matrix) 161 loss = [] 162 for i in range(metric_matrix.shape[0]): 163 loss.append(metric_matrix[index[i], i]) 164 return sum(loss) / len(loss)