ding.torch_utils.network.gumbel_softmax¶
ding.torch_utils.network.gumbel_softmax
¶
GumbelSoftmax
¶
Bases: Module
Overview
An nn.Module that computes GumbelSoftmax.
Interfaces:
__init__, forward, gumbel_softmax_sample
.. note:: For more information on GumbelSoftmax, refer to the paper Categorical Reparameterization with Gumbel-Softmax.
__init__()
¶
Overview
Initialize the GumbelSoftmax module.
gumbel_softmax_sample(x, temperature, eps=1e-08)
¶
Overview
Draw a sample from the Gumbel-Softmax distribution.
Arguments:
- x (:obj:torch.Tensor): Input tensor.
- temperature (:obj:float): Non-negative scalar controlling the sharpness of the distribution.
- eps (:obj:float): Small number to prevent division by zero, default is 1e-8.
Returns:
- output (:obj:torch.Tensor): Sample from Gumbel-Softmax distribution.
forward(x, temperature=1.0, hard=False)
¶
Overview
Forward pass for the GumbelSoftmax module.
Arguments:
- x (:obj:torch.Tensor): Unnormalized log-probabilities.
- temperature (:obj:float): Non-negative scalar controlling the sharpness of the distribution.
- hard (:obj:bool): If True, returns one-hot encoded labels. Default is False.
Returns:
- output (:obj:torch.Tensor): Sample from Gumbel-Softmax distribution.
Shapes:
- x: its shape is :math:(B, N), where B is the batch size and N is the number of classes.
- y: its shape is :math:(B, N), where B is the batch size and N is the number of classes.
Full Source Code
../ding/torch_utils/network/gumbel_softmax.py