Skip to content

ding.torch_utils.network.soft_argmax

ding.torch_utils.network.soft_argmax

SoftArgmax

Bases: Module

Overview

A neural network module that computes the SoftArgmax operation (essentially a 2-dimensional spatial softmax), which is often used for location regression tasks. It converts a feature map (such as a heatmap) into precise coordinate locations.

Interfaces: __init__, forward

.. note:: For more information on SoftArgmax, you can refer to https://en.wikipedia.org/wiki/Softmax_function and the paper https://arxiv.org/pdf/1504.00702.pdf.

__init__()

Overview

Initialize the SoftArgmax module.

forward(x)

Overview

Perform the forward pass of the SoftArgmax operation.

Arguments: - x (:obj:torch.Tensor): The input tensor, typically a heatmap representing predicted locations. Returns: - location (:obj:torch.Tensor): The predicted coordinates as a result of the SoftArgmax operation. Shapes: - x: :math:(B, C, H, W), where B is the batch size, C is the number of channels, and H and W represent height and width respectively. - location: :math:(B, 2), where B is the batch size and 2 represents the coordinates (height, width).

Full Source Code

../ding/torch_utils/network/soft_argmax.py

1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4 5 6class SoftArgmax(nn.Module): 7 """ 8 Overview: 9 A neural network module that computes the SoftArgmax operation (essentially a 2-dimensional spatial softmax), 10 which is often used for location regression tasks. It converts a feature map (such as a heatmap) into precise 11 coordinate locations. 12 Interfaces: 13 ``__init__``, ``forward`` 14 15 .. note:: 16 For more information on SoftArgmax, you can refer to <https://en.wikipedia.org/wiki/Softmax_function> 17 and the paper <https://arxiv.org/pdf/1504.00702.pdf>. 18 """ 19 20 def __init__(self): 21 """ 22 Overview: 23 Initialize the SoftArgmax module. 24 """ 25 super(SoftArgmax, self).__init__() 26 27 def forward(self, x: torch.Tensor) -> torch.Tensor: 28 """ 29 Overview: 30 Perform the forward pass of the SoftArgmax operation. 31 Arguments: 32 - x (:obj:`torch.Tensor`): The input tensor, typically a heatmap representing predicted locations. 33 Returns: 34 - location (:obj:`torch.Tensor`): The predicted coordinates as a result of the SoftArgmax operation. 35 Shapes: 36 - x: :math:`(B, C, H, W)`, where `B` is the batch size, `C` is the number of channels, \ 37 and `H` and `W` represent height and width respectively. 38 - location: :math:`(B, 2)`, where `B` is the batch size and 2 represents the coordinates (height, width). 39 """ 40 # Unpack the dimensions of the input tensor 41 B, C, H, W = x.shape 42 device, dtype = x.device, x.dtype 43 # Ensure the input tensor has a single channel 44 assert C == 1, "Input tensor should have only one channel" 45 # Create a meshgrid for the height (h_kernel) and width (w_kernel) 46 h_kernel = torch.arange(0, H, device=device).to(dtype) 47 h_kernel = h_kernel.view(1, 1, H, 1).repeat(1, 1, 1, W) 48 49 w_kernel = torch.arange(0, W, device=device).to(dtype) 50 w_kernel = w_kernel.view(1, 1, 1, W).repeat(1, 1, H, 1) 51 52 # Apply the softmax function across the spatial dimensions (height and width) 53 x = F.softmax(x.view(B, C, -1), dim=-1).view(B, C, H, W) 54 # Compute the expected values for height and width by multiplying the probability map by the meshgrids 55 h = (x * h_kernel).sum(dim=[1, 2, 3]) # Sum over the channel, height, and width dimensions 56 w = (x * w_kernel).sum(dim=[1, 2, 3]) # Sum over the channel, height, and width dimensions 57 58 # Stack the height and width coordinates along a new dimension to form the final output tensor 59 return torch.stack([h, w], dim=1)