Skip to content

ding.torch_utils.distribution

ding.torch_utils.distribution

Pd

Bases: object

Overview

Abstract class for parameterizable probability distributions and sampling functions.

Interfaces: neglogp, entropy, noise_mode, mode, sample

.. tip::

In dereived classes, `logits` should be an attribute member stored in class.

neglogp(x)

Overview

Calculate cross_entropy between input x and logits

Arguments: - x (:obj:torch.Tensor): the input tensor Return: - cross_entropy (:obj:torch.Tensor): the returned cross_entropy loss

entropy()

Overview

Calculate the softmax entropy of logits

Arguments: - reduction (:obj:str): support [None, 'mean'], default set to 'mean' Returns: - entropy (:obj:torch.Tensor): the calculated entropy

noise_mode()

Overview

Add noise to logits. This method is designed for randomness

mode()

Overview

Return logits argmax result. This method is designed for deterministic.

sample()

Overview

Sample from logits's distribution by using softmax. This method is designed for multinomial.

CategoricalPd

Bases: Pd

Overview

Catagorical probility distribution sampler

Interfaces: __init__, neglogp, entropy, noise_mode, mode, sample

__init__(logits=None)

Overview

Init the Pd with logits

Arguments: - logits (:obj:torch.Tensor): logits to sample from

update_logits(logits)

Overview

Updata logits

Arguments: - logits (:obj:torch.Tensor): logits to update

neglogp(x, reduction='mean')

Overview

Calculate cross_entropy between input x and logits

Arguments: - x (:obj:torch.Tensor): the input tensor - reduction (:obj:str): support [None, 'mean'], default set to mean Return: - cross_entropy (:obj:torch.Tensor): the returned cross_entropy loss

entropy(reduction='mean')

Overview

Calculate the softmax entropy of logits

Arguments: - reduction (:obj:str): support [None, 'mean'], default set to mean Returns: - entropy (:obj:torch.Tensor): the calculated entropy

noise_mode(viz=False)

Overview

add noise to logits

Arguments: - viz (:obj:bool): Whether to return numpy from of logits, noise and noise_logits; Short for visualize . (Because tensor type cannot visualize in tb or text log) Returns: - result (:obj:torch.Tensor): noised logits - viz_feature (:obj:Dict[str, np.ndarray]): ndarray type data for visualization.

mode(viz=False)

Overview

return logits argmax result

Arguments: - viz (:obj:bool): Whether to return numpy from of logits, noise and noise_logits; Short for visualize . (Because tensor type cannot visualize in tb or text log) Returns: - result (:obj:torch.Tensor): the logits argmax result - viz_feature (:obj:Dict[str, np.ndarray]): ndarray type data for visualization.

sample(viz=False)

Overview

Sample from logits's distribution by using softmax

Arguments: - viz (:obj:bool): Whether to return numpy from of logits, noise and noise_logits; Short for visualize . (Because tensor type cannot visualize in tb or text log) Returns: - result (:obj:torch.Tensor): the logits sampled result - viz_feature (:obj:Dict[str, np.ndarray]): ndarray type data for visualization.

CategoricalPdPytorch

Bases: Categorical

Overview

Wrapped torch.distributions.Categorical

Interfaces

__init__, update_logits, update_probs, sample, neglogp, mode, entropy

__init__(probs=None)

Overview

Initialize the CategoricalPdPytorch object.

Arguments: - probs (:obj:torch.Tensor): The tensor of probabilities.

update_logits(logits)

Overview

Updata logits

Arguments: - logits (:obj:torch.Tensor): logits to update

update_probs(probs)

Overview

Updata probs

Arguments: - probs (:obj:torch.Tensor): probs to update

sample()

Overview

Sample from logits's distribution by using softmax

Return: - result (:obj:torch.Tensor): the logits sampled result

neglogp(actions, reduction='mean')

Overview

Calculate cross_entropy between input x and logits

Arguments: - actions (:obj:torch.Tensor): the input action tensor - reduction (:obj:str): support [None, 'mean'], default set to mean Return: - cross_entropy (:obj:torch.Tensor): the returned cross_entropy loss

mode()

Overview

Return logits argmax result

Return: - result(:obj:torch.Tensor): the logits argmax result

entropy(reduction=None)

Overview

Calculate the softmax entropy of logits

Arguments: - reduction (:obj:str): support [None, 'mean'], default set to mean Returns: - entropy (:obj:torch.Tensor): the calculated entropy

Full Source Code

../ding/torch_utils/distribution.py

1from __future__ import absolute_import 2from __future__ import division 3from __future__ import print_function 4from typing import Tuple, Dict 5 6import torch 7import numpy as np 8import torch.nn.functional as F 9 10 11class Pd(object): 12 """ 13 Overview: 14 Abstract class for parameterizable probability distributions and sampling functions. 15 Interfaces: 16 ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` 17 18 .. tip:: 19 20 In dereived classes, `logits` should be an attribute member stored in class. 21 """ 22 23 def neglogp(self, x: torch.Tensor) -> torch.Tensor: 24 """ 25 Overview: 26 Calculate cross_entropy between input x and logits 27 Arguments: 28 - x (:obj:`torch.Tensor`): the input tensor 29 Return: 30 - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss 31 """ 32 raise NotImplementedError 33 34 def entropy(self) -> torch.Tensor: 35 """ 36 Overview: 37 Calculate the softmax entropy of logits 38 Arguments: 39 - reduction (:obj:`str`): support [None, 'mean'], default set to 'mean' 40 Returns: 41 - entropy (:obj:`torch.Tensor`): the calculated entropy 42 """ 43 raise NotImplementedError 44 45 def noise_mode(self): 46 """ 47 Overview: 48 Add noise to logits. This method is designed for randomness 49 """ 50 raise NotImplementedError 51 52 def mode(self): 53 """ 54 Overview: 55 Return logits argmax result. This method is designed for deterministic. 56 """ 57 raise NotImplementedError 58 59 def sample(self): 60 """ 61 Overview: 62 Sample from logits's distribution by using softmax. This method is designed for multinomial. 63 """ 64 raise NotImplementedError 65 66 67class CategoricalPd(Pd): 68 """ 69 Overview: 70 Catagorical probility distribution sampler 71 Interfaces: 72 ``__init__``, ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` 73 """ 74 75 def __init__(self, logits: torch.Tensor = None) -> None: 76 """ 77 Overview: 78 Init the Pd with logits 79 Arguments: 80 - logits (:obj:torch.Tensor): logits to sample from 81 """ 82 self.update_logits(logits) 83 84 def update_logits(self, logits: torch.Tensor) -> None: 85 """ 86 Overview: 87 Updata logits 88 Arguments: 89 - logits (:obj:`torch.Tensor`): logits to update 90 """ 91 self.logits = logits 92 93 def neglogp(self, x, reduction: str = 'mean') -> torch.Tensor: 94 """ 95 Overview: 96 Calculate cross_entropy between input x and logits 97 Arguments: 98 - x (:obj:`torch.Tensor`): the input tensor 99 - reduction (:obj:`str`): support [None, 'mean'], default set to mean 100 Return: 101 - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss 102 """ 103 return F.cross_entropy(self.logits, x, reduction=reduction) 104 105 def entropy(self, reduction: str = 'mean') -> torch.Tensor: 106 """ 107 Overview: 108 Calculate the softmax entropy of logits 109 Arguments: 110 - reduction (:obj:`str`): support [None, 'mean'], default set to mean 111 Returns: 112 - entropy (:obj:`torch.Tensor`): the calculated entropy 113 """ 114 a = self.logits - self.logits.max(dim=-1, keepdim=True)[0] 115 ea = torch.exp(a) 116 z = ea.sum(dim=-1, keepdim=True) 117 p = ea / z 118 entropy = (p * (torch.log(z) - a)).sum(dim=-1) 119 assert (reduction in [None, 'mean']) 120 if reduction is None: 121 return entropy 122 elif reduction == 'mean': 123 return entropy.mean() 124 125 def noise_mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: 126 """ 127 Overview: 128 add noise to logits 129 Arguments: 130 - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \ 131 Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) 132 Returns: 133 - result (:obj:`torch.Tensor`): noised logits 134 - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. 135 """ 136 u = torch.rand_like(self.logits) 137 u = -torch.log(-torch.log(u)) 138 noise_logits = self.logits + u 139 result = noise_logits.argmax(dim=-1) 140 if viz: 141 viz_feature = {} 142 viz_feature['logits'] = self.logits.data.cpu().numpy() 143 viz_feature['noise'] = u.data.cpu().numpy() 144 viz_feature['noise_logits'] = noise_logits.data.cpu().numpy() 145 return result, viz_feature 146 else: 147 return result 148 149 def mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: 150 """ 151 Overview: 152 return logits argmax result 153 Arguments: 154 - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; 155 Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) 156 Returns: 157 - result (:obj:`torch.Tensor`): the logits argmax result 158 - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. 159 """ 160 result = self.logits.argmax(dim=-1) 161 if viz: 162 viz_feature = {} 163 viz_feature['logits'] = self.logits.data.cpu().numpy() 164 return result, viz_feature 165 else: 166 return result 167 168 def sample(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: 169 """ 170 Overview: 171 Sample from logits's distribution by using softmax 172 Arguments: 173 - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \ 174 Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) 175 Returns: 176 - result (:obj:`torch.Tensor`): the logits sampled result 177 - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. 178 """ 179 p = torch.softmax(self.logits, dim=1) 180 result = torch.multinomial(p, 1).squeeze(1) 181 if viz: 182 viz_feature = {} 183 viz_feature['logits'] = self.logits.data.cpu().numpy() 184 return result, viz_feature 185 else: 186 return result 187 188 189class CategoricalPdPytorch(torch.distributions.Categorical): 190 """ 191 Overview: 192 Wrapped ``torch.distributions.Categorical`` 193 194 Interfaces: 195 ``__init__``, ``update_logits``, ``update_probs``, ``sample``, ``neglogp``, ``mode``, ``entropy`` 196 """ 197 198 def __init__(self, probs: torch.Tensor = None) -> None: 199 """ 200 Overview: 201 Initialize the CategoricalPdPytorch object. 202 Arguments: 203 - probs (:obj:`torch.Tensor`): The tensor of probabilities. 204 """ 205 if probs is not None: 206 self.update_probs(probs) 207 208 def update_logits(self, logits: torch.Tensor) -> None: 209 """ 210 Overview: 211 Updata logits 212 Arguments: 213 - logits (:obj:`torch.Tensor`): logits to update 214 """ 215 super().__init__(logits=logits) 216 217 def update_probs(self, probs: torch.Tensor) -> None: 218 """ 219 Overview: 220 Updata probs 221 Arguments: 222 - probs (:obj:`torch.Tensor`): probs to update 223 """ 224 super().__init__(probs=probs) 225 226 def sample(self) -> torch.Tensor: 227 """ 228 Overview: 229 Sample from logits's distribution by using softmax 230 Return: 231 - result (:obj:`torch.Tensor`): the logits sampled result 232 """ 233 return super().sample() 234 235 def neglogp(self, actions: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: 236 """ 237 Overview: 238 Calculate cross_entropy between input x and logits 239 Arguments: 240 - actions (:obj:`torch.Tensor`): the input action tensor 241 - reduction (:obj:`str`): support [None, 'mean'], default set to mean 242 Return: 243 - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss 244 """ 245 neglogp = super().log_prob(actions) 246 assert (reduction in ['none', 'mean']) 247 if reduction == 'none': 248 return neglogp 249 elif reduction == 'mean': 250 return neglogp.mean(dim=0) 251 252 def mode(self) -> torch.Tensor: 253 """ 254 Overview: 255 Return logits argmax result 256 Return: 257 - result(:obj:`torch.Tensor`): the logits argmax result 258 """ 259 return self.probs.argmax(dim=-1) 260 261 def entropy(self, reduction: str = None) -> torch.Tensor: 262 """ 263 Overview: 264 Calculate the softmax entropy of logits 265 Arguments: 266 - reduction (:obj:`str`): support [None, 'mean'], default set to mean 267 Returns: 268 - entropy (:obj:`torch.Tensor`): the calculated entropy 269 """ 270 entropy = super().entropy() 271 assert (reduction in [None, 'mean']) 272 if reduction is None: 273 return entropy 274 elif reduction == 'mean': 275 return entropy.mean()