ding.torch_utils.distribution¶
ding.torch_utils.distribution
¶
Pd
¶
Bases: object
Overview
Abstract class for parameterizable probability distributions and sampling functions.
Interfaces:
neglogp, entropy, noise_mode, mode, sample
.. tip::
In dereived classes, `logits` should be an attribute member stored in class.
neglogp(x)
¶
Overview
Calculate cross_entropy between input x and logits
Arguments:
- x (:obj:torch.Tensor): the input tensor
Return:
- cross_entropy (:obj:torch.Tensor): the returned cross_entropy loss
entropy()
¶
Overview
Calculate the softmax entropy of logits
Arguments:
- reduction (:obj:str): support [None, 'mean'], default set to 'mean'
Returns:
- entropy (:obj:torch.Tensor): the calculated entropy
noise_mode()
¶
Overview
Add noise to logits. This method is designed for randomness
mode()
¶
Overview
Return logits argmax result. This method is designed for deterministic.
sample()
¶
Overview
Sample from logits's distribution by using softmax. This method is designed for multinomial.
CategoricalPd
¶
Bases: Pd
Overview
Catagorical probility distribution sampler
Interfaces:
__init__, neglogp, entropy, noise_mode, mode, sample
__init__(logits=None)
¶
Overview
Init the Pd with logits
Arguments: - logits (:obj:torch.Tensor): logits to sample from
update_logits(logits)
¶
Overview
Updata logits
Arguments:
- logits (:obj:torch.Tensor): logits to update
neglogp(x, reduction='mean')
¶
Overview
Calculate cross_entropy between input x and logits
Arguments:
- x (:obj:torch.Tensor): the input tensor
- reduction (:obj:str): support [None, 'mean'], default set to mean
Return:
- cross_entropy (:obj:torch.Tensor): the returned cross_entropy loss
entropy(reduction='mean')
¶
Overview
Calculate the softmax entropy of logits
Arguments:
- reduction (:obj:str): support [None, 'mean'], default set to mean
Returns:
- entropy (:obj:torch.Tensor): the calculated entropy
noise_mode(viz=False)
¶
Overview
add noise to logits
Arguments:
- viz (:obj:bool): Whether to return numpy from of logits, noise and noise_logits; Short for visualize . (Because tensor type cannot visualize in tb or text log)
Returns:
- result (:obj:torch.Tensor): noised logits
- viz_feature (:obj:Dict[str, np.ndarray]): ndarray type data for visualization.
mode(viz=False)
¶
Overview
return logits argmax result
Arguments:
- viz (:obj:bool): Whether to return numpy from of logits, noise and noise_logits;
Short for visualize . (Because tensor type cannot visualize in tb or text log)
Returns:
- result (:obj:torch.Tensor): the logits argmax result
- viz_feature (:obj:Dict[str, np.ndarray]): ndarray type data for visualization.
sample(viz=False)
¶
Overview
Sample from logits's distribution by using softmax
Arguments:
- viz (:obj:bool): Whether to return numpy from of logits, noise and noise_logits; Short for visualize . (Because tensor type cannot visualize in tb or text log)
Returns:
- result (:obj:torch.Tensor): the logits sampled result
- viz_feature (:obj:Dict[str, np.ndarray]): ndarray type data for visualization.
CategoricalPdPytorch
¶
Bases: Categorical
Overview
Wrapped torch.distributions.Categorical
Interfaces
__init__, update_logits, update_probs, sample, neglogp, mode, entropy
__init__(probs=None)
¶
Overview
Initialize the CategoricalPdPytorch object.
Arguments:
- probs (:obj:torch.Tensor): The tensor of probabilities.
update_logits(logits)
¶
Overview
Updata logits
Arguments:
- logits (:obj:torch.Tensor): logits to update
update_probs(probs)
¶
Overview
Updata probs
Arguments:
- probs (:obj:torch.Tensor): probs to update
sample()
¶
Overview
Sample from logits's distribution by using softmax
Return:
- result (:obj:torch.Tensor): the logits sampled result
neglogp(actions, reduction='mean')
¶
Overview
Calculate cross_entropy between input x and logits
Arguments:
- actions (:obj:torch.Tensor): the input action tensor
- reduction (:obj:str): support [None, 'mean'], default set to mean
Return:
- cross_entropy (:obj:torch.Tensor): the returned cross_entropy loss
mode()
¶
Overview
Return logits argmax result
Return:
- result(:obj:torch.Tensor): the logits argmax result
entropy(reduction=None)
¶
Overview
Calculate the softmax entropy of logits
Arguments:
- reduction (:obj:str): support [None, 'mean'], default set to mean
Returns:
- entropy (:obj:torch.Tensor): the calculated entropy
Full Source Code
../ding/torch_utils/distribution.py