Skip to content

ding.torch_utils.loss.cross_entropy_loss

ding.torch_utils.loss.cross_entropy_loss

LabelSmoothCELoss

Bases: Module

Overview

Label smooth cross entropy loss.

Interfaces: __init__, forward.

__init__(ratio)

Overview

Initialize the LabelSmoothCELoss object using the given arguments.

Arguments: - ratio (:obj:float): The ratio of label-smoothing (the value is in 0-1). If the ratio is larger, the extent of label smoothing is larger.

forward(logits, labels)

Overview

Calculate label smooth cross entropy loss.

Arguments: - logits (:obj:torch.Tensor): Predicted logits. - labels (:obj:torch.LongTensor): Ground truth. Returns: - loss (:obj:torch.Tensor): Calculated loss.

SoftFocalLoss

Bases: Module

Overview

Soft focal loss.

Interfaces: __init__, forward.

__init__(gamma=2, weight=None, size_average=True, reduce=None)

Overview

Initialize the SoftFocalLoss object using the given arguments.

Arguments: - gamma (:obj:int): The extent of focus on hard samples. A smaller gamma will lead to more focus on easy samples, while a larger gamma will lead to more focus on hard samples. - weight (:obj:Any): The weight for loss of each class. - size_average (:obj:bool): By default, the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False. - reduce (:obj:Optional[bool]): By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss for each batch element instead and ignores size_average.

forward(inputs, targets)

Overview

Calculate soft focal loss.

Arguments: - logits (:obj:torch.Tensor): Predicted logits. - labels (:obj:torch.LongTensor): Ground truth. Returns: - loss (:obj:torch.Tensor): Calculated loss.

build_ce_criterion(cfg)

Overview

Get a cross entropy loss instance according to given config.

Arguments: - cfg (:obj:dict) : Config dict. It contains: - type (:obj:str): Type of loss function, now supports ['cross_entropy', 'label_smooth_ce', 'soft_focal_loss']. - kwargs (:obj:dict): Arguments for the corresponding loss function. Returns: - loss (:obj:nn.Module): loss function instance

Full Source Code

../ding/torch_utils/loss/cross_entropy_loss.py

1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4from typing import Any, Optional 5 6 7class LabelSmoothCELoss(nn.Module): 8 """ 9 Overview: 10 Label smooth cross entropy loss. 11 Interfaces: 12 ``__init__``, ``forward``. 13 """ 14 15 def __init__(self, ratio: float) -> None: 16 """ 17 Overview: 18 Initialize the LabelSmoothCELoss object using the given arguments. 19 Arguments: 20 - ratio (:obj:`float`): The ratio of label-smoothing (the value is in 0-1). If the ratio is larger, the \ 21 extent of label smoothing is larger. 22 """ 23 super().__init__() 24 self.ratio = ratio 25 26 def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: 27 """ 28 Overview: 29 Calculate label smooth cross entropy loss. 30 Arguments: 31 - logits (:obj:`torch.Tensor`): Predicted logits. 32 - labels (:obj:`torch.LongTensor`): Ground truth. 33 Returns: 34 - loss (:obj:`torch.Tensor`): Calculated loss. 35 """ 36 B, N = logits.shape 37 val = float(self.ratio) / (N - 1) 38 one_hot = torch.full_like(logits, val) 39 one_hot.scatter_(1, labels.unsqueeze(1), 1 - val) 40 logits = F.log_softmax(logits, dim=1) 41 return -torch.sum(logits * (one_hot.detach())) / B 42 43 44class SoftFocalLoss(nn.Module): 45 """ 46 Overview: 47 Soft focal loss. 48 Interfaces: 49 ``__init__``, ``forward``. 50 """ 51 52 def __init__( 53 self, gamma: int = 2, weight: Any = None, size_average: bool = True, reduce: Optional[bool] = None 54 ) -> None: 55 """ 56 Overview: 57 Initialize the SoftFocalLoss object using the given arguments. 58 Arguments: 59 - gamma (:obj:`int`): The extent of focus on hard samples. A smaller ``gamma`` will lead to more focus on \ 60 easy samples, while a larger ``gamma`` will lead to more focus on hard samples. 61 - weight (:obj:`Any`): The weight for loss of each class. 62 - size_average (:obj:`bool`): By default, the losses are averaged over each loss element in the batch. \ 63 Note that for some losses, there are multiple elements per sample. If the field ``size_average`` is \ 64 set to ``False``, the losses are instead summed for each minibatch. Ignored when ``reduce`` is \ 65 ``False``. 66 - reduce (:obj:`Optional[bool]`): By default, the losses are averaged or summed over observations for \ 67 each minibatch depending on size_average. When ``reduce`` is ``False``, returns a loss for each batch \ 68 element instead and ignores ``size_average``. 69 """ 70 super().__init__() 71 self.gamma = gamma 72 self.nll_loss = torch.nn.NLLLoss2d(weight, size_average, reduce=reduce) 73 74 def forward(self, inputs: torch.Tensor, targets: torch.LongTensor) -> torch.Tensor: 75 """ 76 Overview: 77 Calculate soft focal loss. 78 Arguments: 79 - logits (:obj:`torch.Tensor`): Predicted logits. 80 - labels (:obj:`torch.LongTensor`): Ground truth. 81 Returns: 82 - loss (:obj:`torch.Tensor`): Calculated loss. 83 """ 84 return self.nll_loss((1 - F.softmax(inputs, 1)) ** self.gamma * F.log_softmax(inputs, 1), targets) 85 86 87def build_ce_criterion(cfg: dict) -> nn.Module: 88 """ 89 Overview: 90 Get a cross entropy loss instance according to given config. 91 Arguments: 92 - cfg (:obj:`dict`) : Config dict. It contains: 93 - type (:obj:`str`): Type of loss function, now supports ['cross_entropy', 'label_smooth_ce', \ 94 'soft_focal_loss']. 95 - kwargs (:obj:`dict`): Arguments for the corresponding loss function. 96 Returns: 97 - loss (:obj:`nn.Module`): loss function instance 98 """ 99 if cfg.type == 'cross_entropy': 100 return nn.CrossEntropyLoss() 101 elif cfg.type == 'label_smooth_ce': 102 return LabelSmoothCELoss(cfg.kwargs.smooth_ratio) 103 elif cfg.type == 'soft_focal_loss': 104 return SoftFocalLoss() 105 else: 106 raise ValueError("invalid criterion type:{}".format(cfg.type))