Skip to content

ding.torch_utils.network.gumbel_softmax

ding.torch_utils.network.gumbel_softmax

GumbelSoftmax

Bases: Module

Overview

An nn.Module that computes GumbelSoftmax.

Interfaces: __init__, forward, gumbel_softmax_sample

.. note:: For more information on GumbelSoftmax, refer to the paper Categorical Reparameterization with Gumbel-Softmax.

__init__()

Overview

Initialize the GumbelSoftmax module.

gumbel_softmax_sample(x, temperature, eps=1e-08)

Overview

Draw a sample from the Gumbel-Softmax distribution.

Arguments: - x (:obj:torch.Tensor): Input tensor. - temperature (:obj:float): Non-negative scalar controlling the sharpness of the distribution. - eps (:obj:float): Small number to prevent division by zero, default is 1e-8. Returns: - output (:obj:torch.Tensor): Sample from Gumbel-Softmax distribution.

forward(x, temperature=1.0, hard=False)

Overview

Forward pass for the GumbelSoftmax module.

Arguments: - x (:obj:torch.Tensor): Unnormalized log-probabilities. - temperature (:obj:float): Non-negative scalar controlling the sharpness of the distribution. - hard (:obj:bool): If True, returns one-hot encoded labels. Default is False. Returns: - output (:obj:torch.Tensor): Sample from Gumbel-Softmax distribution. Shapes: - x: its shape is :math:(B, N), where B is the batch size and N is the number of classes. - y: its shape is :math:(B, N), where B is the batch size and N is the number of classes.

Full Source Code

../ding/torch_utils/network/gumbel_softmax.py

1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4 5 6class GumbelSoftmax(nn.Module): 7 """ 8 Overview: 9 An `nn.Module` that computes GumbelSoftmax. 10 Interfaces: 11 ``__init__``, ``forward``, ``gumbel_softmax_sample`` 12 13 .. note:: 14 For more information on GumbelSoftmax, refer to the paper [Categorical Reparameterization \ 15 with Gumbel-Softmax](https://arxiv.org/abs/1611.01144). 16 """ 17 18 def __init__(self) -> None: 19 """ 20 Overview: 21 Initialize the `GumbelSoftmax` module. 22 """ 23 super(GumbelSoftmax, self).__init__() 24 25 def gumbel_softmax_sample(self, x: torch.Tensor, temperature: float, eps: float = 1e-8) -> torch.Tensor: 26 """ 27 Overview: 28 Draw a sample from the Gumbel-Softmax distribution. 29 Arguments: 30 - x (:obj:`torch.Tensor`): Input tensor. 31 - temperature (:obj:`float`): Non-negative scalar controlling the sharpness of the distribution. 32 - eps (:obj:`float`): Small number to prevent division by zero, default is `1e-8`. 33 Returns: 34 - output (:obj:`torch.Tensor`): Sample from Gumbel-Softmax distribution. 35 """ 36 U = torch.rand(x.shape) 37 U = U.to(x.device) 38 y = x - torch.log(-torch.log(U + eps) + eps) 39 return F.softmax(y / temperature, dim=1) 40 41 def forward(self, x: torch.Tensor, temperature: float = 1.0, hard: bool = False) -> torch.Tensor: 42 """ 43 Overview: 44 Forward pass for the `GumbelSoftmax` module. 45 Arguments: 46 - x (:obj:`torch.Tensor`): Unnormalized log-probabilities. 47 - temperature (:obj:`float`): Non-negative scalar controlling the sharpness of the distribution. 48 - hard (:obj:`bool`): If `True`, returns one-hot encoded labels. Default is `False`. 49 Returns: 50 - output (:obj:`torch.Tensor`): Sample from Gumbel-Softmax distribution. 51 Shapes: 52 - x: its shape is :math:`(B, N)`, where `B` is the batch size and `N` is the number of classes. 53 - y: its shape is :math:`(B, N)`, where `B` is the batch size and `N` is the number of classes. 54 """ 55 y = self.gumbel_softmax_sample(x, temperature) 56 if hard: 57 y_hard = torch.zeros_like(x) 58 y_hard[torch.arange(0, x.shape[0]), y.max(1)[1]] = 1 59 # The detach function treat (y_hard - y) as constant, 60 # to make sure makes the gradient equal to y_soft gradient 61 y = (y_hard - y).detach() + y 62 return y