ding.torch_utils.loss.cross_entropy_loss¶
ding.torch_utils.loss.cross_entropy_loss
¶
LabelSmoothCELoss
¶
Bases: Module
Overview
Label smooth cross entropy loss.
Interfaces:
__init__, forward.
__init__(ratio)
¶
Overview
Initialize the LabelSmoothCELoss object using the given arguments.
Arguments:
- ratio (:obj:float): The ratio of label-smoothing (the value is in 0-1). If the ratio is larger, the extent of label smoothing is larger.
forward(logits, labels)
¶
Overview
Calculate label smooth cross entropy loss.
Arguments:
- logits (:obj:torch.Tensor): Predicted logits.
- labels (:obj:torch.LongTensor): Ground truth.
Returns:
- loss (:obj:torch.Tensor): Calculated loss.
SoftFocalLoss
¶
Bases: Module
Overview
Soft focal loss.
Interfaces:
__init__, forward.
__init__(gamma=2, weight=None, size_average=True, reduce=None)
¶
Overview
Initialize the SoftFocalLoss object using the given arguments.
Arguments:
- gamma (:obj:int): The extent of focus on hard samples. A smaller gamma will lead to more focus on easy samples, while a larger gamma will lead to more focus on hard samples.
- weight (:obj:Any): The weight for loss of each class.
- size_average (:obj:bool): By default, the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False.
- reduce (:obj:Optional[bool]): By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss for each batch element instead and ignores size_average.
forward(inputs, targets)
¶
Overview
Calculate soft focal loss.
Arguments:
- logits (:obj:torch.Tensor): Predicted logits.
- labels (:obj:torch.LongTensor): Ground truth.
Returns:
- loss (:obj:torch.Tensor): Calculated loss.
build_ce_criterion(cfg)
¶
Overview
Get a cross entropy loss instance according to given config.
Arguments:
- cfg (:obj:dict) : Config dict. It contains:
- type (:obj:str): Type of loss function, now supports ['cross_entropy', 'label_smooth_ce', 'soft_focal_loss'].
- kwargs (:obj:dict): Arguments for the corresponding loss function.
Returns:
- loss (:obj:nn.Module): loss function instance
Full Source Code
../ding/torch_utils/loss/cross_entropy_loss.py