Skip to content

ding.torch_utils.network.dreamer

ding.torch_utils.network.dreamer

Conv2dSame

Bases: Conv2d

Overview

Conv2dSame Network for dreamerv3.

Interfaces: __init__, forward

calc_same_pad(i, k, s, d)

Overview

Calculate the same padding size.

Arguments: - i (:obj:int): Input size. - k (:obj:int): Kernel size. - s (:obj:int): Stride size. - d (:obj:int): Dilation size.

forward(x)

Overview

compute the forward of Conv2dSame.

Arguments: - x (:obj:torch.Tensor): Input tensor.

DreamerLayerNorm

Bases: Module

Overview

DreamerLayerNorm Network for dreamerv3.

Interfaces: __init__, forward

__init__(ch, eps=0.001)

Overview

Init the DreamerLayerNorm class.

Arguments: - ch (:obj:int): Input channel. - eps (:obj:float): Epsilon.

forward(x)

Overview

compute the forward of DreamerLayerNorm.

Arguments: - x (:obj:torch.Tensor): Input tensor.

DenseHead

Bases: Module

Overview

DenseHead Network for value head, reward head, and discount head of dreamerv3.

Interfaces: __init__, forward

__init__(inp_dim, shape, layer_num, units, act='SiLU', norm='LN', dist='normal', std=1.0, outscale=1.0, device='cpu')

Overview

Init the DenseHead class.

Arguments: - inp_dim (:obj:int): Input dimension. - shape (:obj:tuple): Output shape. - layer_num (:obj:int): Number of layers. - units (:obj:int): Number of units. - act (:obj:str): Activation function. - norm (:obj:str): Normalization function. - dist (:obj:str): Distribution function. - std (:obj:float): Standard deviation. - outscale (:obj:float): Output scale. - device (:obj:str): Device.

forward(features)

Overview

compute the forward of DenseHead.

Arguments: - features (:obj:torch.Tensor): Input tensor.

ActionHead

Bases: Module

Overview

ActionHead Network for action head of dreamerv3.

Interfaces: __init__, forward

__init__(inp_dim, size, layers, units, act=nn.ELU, norm=nn.LayerNorm, dist='trunc_normal', init_std=0.0, min_std=0.1, max_std=1.0, temp=0.1, outscale=1.0, unimix_ratio=0.01)

Overview

Initialize the ActionHead class.

Arguments: - inp_dim (:obj:int): Input dimension. - size (:obj:int): Output size. - layers (:obj:int): Number of layers. - units (:obj:int): Number of units. - act (:obj:str): Activation function. - norm (:obj:str): Normalization function. - dist (:obj:str): Distribution function. - init_std (:obj:float): Initial standard deviation. - min_std (:obj:float): Minimum standard deviation. - max_std (:obj:float): Maximum standard deviation. - temp (:obj:float): Temperature. - outscale (:obj:float): Output scale. - unimix_ratio (:obj:float): Unimix ratio.

forward(features)

Overview

compute the forward of ActionHead.

Arguments: - features (:obj:torch.Tensor): Input tensor.

SampleDist

Overview

A kind of sample Dist for ActionHead of dreamerv3.

Interfaces: __init__, mean, mode, entropy

__init__(dist, samples=100)

Overview

Initialize the SampleDist class.

Arguments: - dist (:obj:torch.Tensor): Distribution. - samples (:obj:int): Number of samples.

mean()

Overview

Calculate the mean of the distribution.

mode()

Overview

Calculate the mode of the distribution.

entropy()

Overview

Calculate the entropy of the distribution.

OneHotDist

Bases: OneHotCategorical

Overview

A kind of onehot Dist for dreamerv3.

Interfaces: __init__, mode, sample

__init__(logits=None, probs=None, unimix_ratio=0.0)

Overview

Initialize the OneHotDist class.

Arguments: - logits (:obj:torch.Tensor): Logits. - probs (:obj:torch.Tensor): Probabilities. - unimix_ratio (:obj:float): Unimix ratio.

mode()

Overview

Calculate the mode of the distribution.

sample(sample_shape=(), seed=None)

Overview

Sample from the distribution.

Arguments: - sample_shape (:obj:tuple): Sample shape. - seed (:obj:int): Seed.

TwoHotDistSymlog

Overview

A kind of twohotsymlog Dist for dreamerv3.

Interfaces: __init__, mode, mean, log_prob, log_prob_target

__init__(logits=None, low=-20.0, high=20.0, device='cpu')

Overview

Initialize the TwoHotDistSymlog class.

Arguments: - logits (:obj:torch.Tensor): Logits. - low (:obj:float): Low. - high (:obj:float): High. - device (:obj:str): Device.

mean()

Overview

Calculate the mean of the distribution.

mode()

Overview

Calculate the mode of the distribution.

log_prob(x)

Overview

Calculate the log probability of the distribution.

Arguments: - x (:obj:torch.Tensor): Input tensor.

log_prob_target(target)

Overview

Calculate the log probability of the target.

Arguments: - target (:obj:torch.Tensor): Target tensor.

SymlogDist

Overview

A kind of Symlog Dist for dreamerv3.

Interfaces: __init__, entropy, mode, mean, log_prob

__init__(mode, dist='mse', aggregation='sum', tol=1e-08, dim_to_reduce=[-1, -2, -3])

Overview

Initialize the SymlogDist class.

Arguments: - mode (:obj:torch.Tensor): Mode. - dist (:obj:str): Distribution function. - aggregation (:obj:str): Aggregation function. - tol (:obj:float): Tolerance. - dim_to_reduce (:obj:list): Dimension to reduce.

mode()

Overview

Calculate the mode of the distribution.

mean()

Overview

Calculate the mean of the distribution.

log_prob(value)

Overview

Calculate the log probability of the distribution.

Arguments: - value (:obj:torch.Tensor): Input tensor.

ContDist

Overview

A kind of ordinary Dist for dreamerv3.

Interfaces: __init__, entropy, mode, sample, log_prob

__init__(dist=None)

Overview

Initialize the ContDist class.

Arguments: - dist (:obj:torch.Tensor): Distribution.

__getattr__(name)

Overview

Get attribute.

Arguments: - name (:obj:str): Attribute name.

entropy()

Overview

Calculate the entropy of the distribution.

mode()

Overview

Calculate the mode of the distribution.

sample(sample_shape=())

Overview

Sample from the distribution.

Arguments: - sample_shape (:obj:tuple): Sample shape.

Bernoulli

Overview

A kind of Bernoulli Dist for dreamerv3.

Interfaces: __init__, entropy, mode, sample, log_prob

__init__(dist=None)

Overview

Initialize the Bernoulli distribution.

Arguments: - dist (:obj:torch.Tensor): Distribution.

__getattr__(name)

Overview

Get attribute.

Arguments: - name (:obj:str): Attribute name.

entropy()

Overview

Calculate the entropy of the distribution.

mode()

Overview

Calculate the mode of the distribution.

sample(sample_shape=())

Overview

Sample from the distribution.

Arguments: - sample_shape (:obj:tuple): Sample shape.

log_prob(x)

Overview

Calculate the log probability of the distribution.

Arguments: - x (:obj:torch.Tensor): Input tensor.

UnnormalizedHuber

Bases: Normal

Overview

A kind of UnnormalizedHuber Dist for dreamerv3.

Interfaces: __init__, mode, log_prob

__init__(loc, scale, threshold=1, **kwargs)

Overview

Initialize the UnnormalizedHuber class.

Arguments: - loc (:obj:torch.Tensor): Location. - scale (:obj:torch.Tensor): Scale. - threshold (:obj:float): Threshold.

log_prob(event)

Overview

Calculate the log probability of the distribution.

Arguments: - event (:obj:torch.Tensor): Event.

mode()

Overview

Calculate the mode of the distribution.

SafeTruncatedNormal

Bases: Normal

Overview

A kind of SafeTruncatedNormal Dist for dreamerv3.

Interfaces: __init__, sample

__init__(loc, scale, low, high, clip=1e-06, mult=1)

Overview

Initialize the SafeTruncatedNormal class.

Arguments: - loc (:obj:torch.Tensor): Location. - scale (:obj:torch.Tensor): Scale. - low (:obj:float): Low. - high (:obj:float): High. - clip (:obj:float): Clip. - mult (:obj:float): Mult.

sample(sample_shape)

Overview

Sample from the distribution.

Arguments: - sample_shape (:obj:tuple): Sample shape.

TanhBijector

Bases: Transform

Overview

A kind of TanhBijector Dist for dreamerv3.

Interfaces: __init__, _forward, _inverse, _forward_log_det_jacobian

__init__(validate_args=False, name='tanh')

Overview

Initialize the TanhBijector class.

Arguments: - validate_args (:obj:bool): Validate arguments. - name (:obj:str): Name.

static_scan(fn, inputs, start)

Overview

Static scan function.

Arguments: - fn (:obj:function): Function. - inputs (:obj:tuple): Inputs. - start (:obj:torch.Tensor): Start tensor.

weight_init(m)

Overview

weight_init for Linear, Conv2d, ConvTranspose2d, and LayerNorm.

Arguments: - m (:obj:torch.nn): Module.

uniform_weight_init(given_scale)

Overview

weight_init for Linear and LayerNorm.

Arguments: - given_scale (:obj:float): Given scale.

Full Source Code

../ding/torch_utils/network/dreamer.py

1import math 2import numpy as np 3 4import torch 5from torch import nn 6import torch.nn.functional as F 7from torch import distributions as torchd 8from ding.torch_utils import MLP 9from ding.rl_utils import symlog, inv_symlog 10 11 12class Conv2dSame(torch.nn.Conv2d): 13 """ 14 Overview: 15 Conv2dSame Network for dreamerv3. 16 Interfaces: 17 ``__init__``, ``forward`` 18 """ 19 20 def calc_same_pad(self, i, k, s, d): 21 """ 22 Overview: 23 Calculate the same padding size. 24 Arguments: 25 - i (:obj:`int`): Input size. 26 - k (:obj:`int`): Kernel size. 27 - s (:obj:`int`): Stride size. 28 - d (:obj:`int`): Dilation size. 29 """ 30 return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) 31 32 def forward(self, x): 33 """ 34 Overview: 35 compute the forward of Conv2dSame. 36 Arguments: 37 - x (:obj:`torch.Tensor`): Input tensor. 38 """ 39 ih, iw = x.size()[-2:] 40 pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) 41 pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) 42 43 if pad_h > 0 or pad_w > 0: 44 x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 45 46 ret = F.conv2d( 47 x, 48 self.weight, 49 self.bias, 50 self.stride, 51 self.padding, 52 self.dilation, 53 self.groups, 54 ) 55 return ret 56 57 58class DreamerLayerNorm(nn.Module): 59 """ 60 Overview: 61 DreamerLayerNorm Network for dreamerv3. 62 Interfaces: 63 ``__init__``, ``forward`` 64 """ 65 66 def __init__(self, ch, eps=1e-03): 67 """ 68 Overview: 69 Init the DreamerLayerNorm class. 70 Arguments: 71 - ch (:obj:`int`): Input channel. 72 - eps (:obj:`float`): Epsilon. 73 """ 74 75 super(DreamerLayerNorm, self).__init__() 76 self.norm = torch.nn.LayerNorm(ch, eps=eps) 77 78 def forward(self, x): 79 """ 80 Overview: 81 compute the forward of DreamerLayerNorm. 82 Arguments: 83 - x (:obj:`torch.Tensor`): Input tensor. 84 """ 85 86 x = x.permute(0, 2, 3, 1) 87 x = self.norm(x) 88 x = x.permute(0, 3, 1, 2) 89 return x 90 91 92class DenseHead(nn.Module): 93 """ 94 Overview: 95 DenseHead Network for value head, reward head, and discount head of dreamerv3. 96 Interfaces: 97 ``__init__``, ``forward`` 98 """ 99 100 def __init__( 101 self, 102 inp_dim, 103 shape, # (255,) 104 layer_num, 105 units, # 512 106 act='SiLU', 107 norm='LN', 108 dist='normal', 109 std=1.0, 110 outscale=1.0, 111 device='cpu', 112 ): 113 """ 114 Overview: 115 Init the DenseHead class. 116 Arguments: 117 - inp_dim (:obj:`int`): Input dimension. 118 - shape (:obj:`tuple`): Output shape. 119 - layer_num (:obj:`int`): Number of layers. 120 - units (:obj:`int`): Number of units. 121 - act (:obj:`str`): Activation function. 122 - norm (:obj:`str`): Normalization function. 123 - dist (:obj:`str`): Distribution function. 124 - std (:obj:`float`): Standard deviation. 125 - outscale (:obj:`float`): Output scale. 126 - device (:obj:`str`): Device. 127 """ 128 129 super(DenseHead, self).__init__() 130 self._shape = (shape, ) if isinstance(shape, int) else shape 131 if len(self._shape) == 0: 132 self._shape = (1, ) 133 self._layer_num = layer_num 134 self._units = units 135 self._act = getattr(torch.nn, act)() 136 self._norm = norm 137 self._dist = dist 138 self._std = std 139 self._device = device 140 141 self.mlp = MLP( 142 inp_dim, 143 self._units, 144 self._units, 145 self._layer_num, 146 layer_fn=nn.Linear, 147 activation=self._act, 148 norm_type=self._norm 149 ) 150 self.mlp.apply(weight_init) 151 152 self.mean_layer = nn.Linear(self._units, np.prod(self._shape)) 153 self.mean_layer.apply(uniform_weight_init(outscale)) 154 155 if self._std == "learned": 156 self.std_layer = nn.Linear(self._units, np.prod(self._shape)) 157 self.std_layer.apply(uniform_weight_init(outscale)) 158 159 def forward(self, features): 160 """ 161 Overview: 162 compute the forward of DenseHead. 163 Arguments: 164 - features (:obj:`torch.Tensor`): Input tensor. 165 """ 166 167 x = features 168 out = self.mlp(x) # (batch, time, _units=512) 169 mean = self.mean_layer(out) # (batch, time, 255) 170 if self._std == "learned": 171 std = self.std_layer(out) 172 else: 173 std = self._std 174 if self._dist == "normal": 175 return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape))) 176 elif self._dist == "huber": 177 return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape))) 178 elif self._dist == "binary": 179 return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) 180 elif self._dist == "twohot_symlog": 181 return TwoHotDistSymlog(logits=mean, low=-1., high=1., device=self._device) 182 raise NotImplementedError(self._dist) 183 184 185class ActionHead(nn.Module): 186 """ 187 Overview: 188 ActionHead Network for action head of dreamerv3. 189 Interfaces: 190 ``__init__``, ``forward`` 191 """ 192 193 def __init__( 194 self, 195 inp_dim, 196 size, 197 layers, 198 units, 199 act=nn.ELU, 200 norm=nn.LayerNorm, 201 dist="trunc_normal", 202 init_std=0.0, 203 min_std=0.1, 204 max_std=1.0, 205 temp=0.1, 206 outscale=1.0, 207 unimix_ratio=0.01, 208 ): 209 """ 210 Overview: 211 Initialize the ActionHead class. 212 Arguments: 213 - inp_dim (:obj:`int`): Input dimension. 214 - size (:obj:`int`): Output size. 215 - layers (:obj:`int`): Number of layers. 216 - units (:obj:`int`): Number of units. 217 - act (:obj:`str`): Activation function. 218 - norm (:obj:`str`): Normalization function. 219 - dist (:obj:`str`): Distribution function. 220 - init_std (:obj:`float`): Initial standard deviation. 221 - min_std (:obj:`float`): Minimum standard deviation. 222 - max_std (:obj:`float`): Maximum standard deviation. 223 - temp (:obj:`float`): Temperature. 224 - outscale (:obj:`float`): Output scale. 225 - unimix_ratio (:obj:`float`): Unimix ratio. 226 """ 227 super(ActionHead, self).__init__() 228 self._size = size 229 self._layers = layers 230 self._units = units 231 self._dist = dist 232 self._act = getattr(torch.nn, act) 233 self._norm = getattr(torch.nn, norm) 234 self._min_std = min_std 235 self._max_std = max_std 236 self._init_std = init_std 237 self._unimix_ratio = unimix_ratio 238 self._temp = temp() if callable(temp) else temp 239 240 pre_layers = [] 241 for index in range(self._layers): 242 pre_layers.append(nn.Linear(inp_dim, self._units, bias=False)) 243 pre_layers.append(self._norm(self._units, eps=1e-03)) 244 pre_layers.append(self._act()) 245 if index == 0: 246 inp_dim = self._units 247 self._pre_layers = nn.Sequential(*pre_layers) 248 self._pre_layers.apply(weight_init) 249 250 if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: 251 self._dist_layer = nn.Linear(self._units, 2 * self._size) 252 self._dist_layer.apply(uniform_weight_init(outscale)) 253 254 elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: 255 self._dist_layer = nn.Linear(self._units, self._size) 256 self._dist_layer.apply(uniform_weight_init(outscale)) 257 258 def forward(self, features): 259 """ 260 Overview: 261 compute the forward of ActionHead. 262 Arguments: 263 - features (:obj:`torch.Tensor`): Input tensor. 264 """ 265 266 x = features 267 x = self._pre_layers(x) 268 if self._dist == "tanh_normal": 269 x = self._dist_layer(x) 270 mean, std = torch.split(x, 2, -1) 271 mean = torch.tanh(mean) 272 std = F.softplus(std + self._init_std) + self._min_std 273 dist = torchd.normal.Normal(mean, std) 274 dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) 275 dist = torchd.independent.Independent(dist, 1) 276 dist = SampleDist(dist) 277 elif self._dist == "tanh_normal_5": 278 x = self._dist_layer(x) 279 mean, std = torch.split(x, 2, -1) 280 mean = 5 * torch.tanh(mean / 5) 281 std = F.softplus(std + 5) + 5 282 dist = torchd.normal.Normal(mean, std) 283 dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) 284 dist = torchd.independent.Independent(dist, 1) 285 dist = SampleDist(dist) 286 elif self._dist == "normal": 287 x = self._dist_layer(x) 288 mean, std = torch.split(x, [self._size] * 2, -1) 289 std = (self._max_std - self._min_std) * torch.sigmoid(std + 2.0) + self._min_std 290 dist = torchd.normal.Normal(torch.tanh(mean), std) 291 dist = ContDist(torchd.independent.Independent(dist, 1)) 292 elif self._dist == "normal_1": 293 x = self._dist_layer(x) 294 dist = torchd.normal.Normal(mean, 1) 295 dist = ContDist(torchd.independent.Independent(dist, 1)) 296 elif self._dist == "trunc_normal": 297 x = self._dist_layer(x) 298 mean, std = torch.split(x, [self._size] * 2, -1) 299 mean = torch.tanh(mean) 300 std = 2 * torch.sigmoid(std / 2) + self._min_std 301 dist = SafeTruncatedNormal(mean, std, -1, 1) 302 dist = ContDist(torchd.independent.Independent(dist, 1)) 303 elif self._dist == "onehot": 304 x = self._dist_layer(x) 305 dist = OneHotDist(x, unimix_ratio=self._unimix_ratio) 306 elif self._dist == "onehot_gumble": 307 x = self._dist_layer(x) 308 temp = self._temp 309 dist = ContDist(torchd.gumbel.Gumbel(x, 1 / temp)) 310 else: 311 raise NotImplementedError(self._dist) 312 return dist 313 314 315class SampleDist: 316 """ 317 Overview: 318 A kind of sample Dist for ActionHead of dreamerv3. 319 Interfaces: 320 ``__init__``, ``mean``, ``mode``, ``entropy`` 321 """ 322 323 def __init__(self, dist, samples=100): 324 """ 325 Overview: 326 Initialize the SampleDist class. 327 Arguments: 328 - dist (:obj:`torch.Tensor`): Distribution. 329 - samples (:obj:`int`): Number of samples. 330 """ 331 332 self._dist = dist 333 self._samples = samples 334 335 def mean(self): 336 """ 337 Overview: 338 Calculate the mean of the distribution. 339 """ 340 341 samples = self._dist.sample(self._samples) 342 return torch.mean(samples, 0) 343 344 def mode(self): 345 """ 346 Overview: 347 Calculate the mode of the distribution. 348 """ 349 350 sample = self._dist.sample(self._samples) 351 logprob = self._dist.log_prob(sample) 352 return sample[torch.argmax(logprob)][0] 353 354 def entropy(self): 355 """ 356 Overview: 357 Calculate the entropy of the distribution. 358 """ 359 360 sample = self._dist.sample(self._samples) 361 logprob = self.log_prob(sample) 362 return -torch.mean(logprob, 0) 363 364 365class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): 366 """ 367 Overview: 368 A kind of onehot Dist for dreamerv3. 369 Interfaces: 370 ``__init__``, ``mode``, ``sample`` 371 """ 372 373 def __init__(self, logits=None, probs=None, unimix_ratio=0.0): 374 """ 375 Overview: 376 Initialize the OneHotDist class. 377 Arguments: 378 - logits (:obj:`torch.Tensor`): Logits. 379 - probs (:obj:`torch.Tensor`): Probabilities. 380 - unimix_ratio (:obj:`float`): Unimix ratio. 381 """ 382 383 if logits is not None and unimix_ratio > 0.0: 384 probs = F.softmax(logits, dim=-1) 385 probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1] 386 logits = torch.log(probs) 387 super().__init__(logits=logits, probs=None) 388 else: 389 super().__init__(logits=logits, probs=probs) 390 391 def mode(self): 392 """ 393 Overview: 394 Calculate the mode of the distribution. 395 """ 396 397 _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) 398 return _mode.detach() + super().logits - super().logits.detach() 399 400 def sample(self, sample_shape=(), seed=None): 401 """ 402 Overview: 403 Sample from the distribution. 404 Arguments: 405 - sample_shape (:obj:`tuple`): Sample shape. 406 - seed (:obj:`int`): Seed. 407 """ 408 409 if seed is not None: 410 raise ValueError('need to check') 411 sample = super().sample(sample_shape) 412 probs = super().probs 413 while len(probs.shape) < len(sample.shape): 414 probs = probs[None] 415 sample += probs - probs.detach() 416 return sample 417 418 419class TwoHotDistSymlog: 420 """ 421 Overview: 422 A kind of twohotsymlog Dist for dreamerv3. 423 Interfaces: 424 ``__init__``, ``mode``, ``mean``, ``log_prob``, ``log_prob_target`` 425 """ 426 427 def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): 428 """ 429 Overview: 430 Initialize the TwoHotDistSymlog class. 431 Arguments: 432 - logits (:obj:`torch.Tensor`): Logits. 433 - low (:obj:`float`): Low. 434 - high (:obj:`float`): High. 435 - device (:obj:`str`): Device. 436 """ 437 438 self.logits = logits 439 self.probs = torch.softmax(logits, -1) 440 self.buckets = torch.linspace(low, high, steps=255).to(device) 441 self.width = (self.buckets[-1] - self.buckets[0]) / 255 442 443 def mean(self): 444 """ 445 Overview: 446 Calculate the mean of the distribution. 447 """ 448 449 _mean = self.probs * self.buckets 450 return inv_symlog(torch.sum(_mean, dim=-1, keepdim=True)) 451 452 def mode(self): 453 """ 454 Overview: 455 Calculate the mode of the distribution. 456 """ 457 458 _mode = self.probs * self.buckets 459 return inv_symlog(torch.sum(_mode, dim=-1, keepdim=True)) 460 461 # Inside OneHotCategorical, log_prob is calculated using only max element in targets 462 def log_prob(self, x): 463 """ 464 Overview: 465 Calculate the log probability of the distribution. 466 Arguments: 467 - x (:obj:`torch.Tensor`): Input tensor. 468 """ 469 470 x = symlog(x) 471 # x(time, batch, 1) 472 below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 473 above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1) 474 below = torch.clip(below, 0, len(self.buckets) - 1) 475 above = torch.clip(above, 0, len(self.buckets) - 1) 476 equal = (below == above) 477 478 dist_to_below = torch.where(equal, torch.tensor(1).to(x), torch.abs(self.buckets[below] - x)) 479 dist_to_above = torch.where(equal, torch.tensor(1).to(x), torch.abs(self.buckets[above] - x)) 480 total = dist_to_below + dist_to_above 481 weight_below = dist_to_above / total 482 weight_above = dist_to_below / total 483 target = ( 484 F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + 485 F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None] 486 ) 487 log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) 488 target = target.squeeze(-2) 489 490 return (target * log_pred).sum(-1) 491 492 def log_prob_target(self, target): 493 """ 494 Overview: 495 Calculate the log probability of the target. 496 Arguments: 497 - target (:obj:`torch.Tensor`): Target tensor. 498 """ 499 500 log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) 501 return (target * log_pred).sum(-1) 502 503 504class SymlogDist: 505 """ 506 Overview: 507 A kind of Symlog Dist for dreamerv3. 508 Interfaces: 509 ``__init__``, ``entropy``, ``mode``, ``mean``, ``log_prob`` 510 """ 511 512 def __init__(self, mode, dist='mse', aggregation='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): 513 """ 514 Overview: 515 Initialize the SymlogDist class. 516 Arguments: 517 - mode (:obj:`torch.Tensor`): Mode. 518 - dist (:obj:`str`): Distribution function. 519 - aggregation (:obj:`str`): Aggregation function. 520 - tol (:obj:`float`): Tolerance. 521 - dim_to_reduce (:obj:`list`): Dimension to reduce. 522 """ 523 self._mode = mode 524 self._dist = dist 525 self._aggregation = aggregation 526 self._tol = tol 527 self._dim_to_reduce = dim_to_reduce 528 529 def mode(self): 530 """ 531 Overview: 532 Calculate the mode of the distribution. 533 """ 534 535 return inv_symlog(self._mode) 536 537 def mean(self): 538 """ 539 Overview: 540 Calculate the mean of the distribution. 541 """ 542 543 return inv_symlog(self._mode) 544 545 def log_prob(self, value): 546 """ 547 Overview: 548 Calculate the log probability of the distribution. 549 Arguments: 550 - value (:obj:`torch.Tensor`): Input tensor. 551 """ 552 553 assert self._mode.shape == value.shape 554 if self._dist == 'mse': 555 distance = (self._mode - symlog(value)) ** 2.0 556 distance = torch.where(distance < self._tol, 0, distance) 557 elif self._dist == 'abs': 558 distance = torch.abs(self._mode - symlog(value)) 559 distance = torch.where(distance < self._tol, 0, distance) 560 else: 561 raise NotImplementedError(self._dist) 562 if self._aggregation == 'mean': 563 loss = distance.mean(self._dim_to_reduce) 564 elif self._aggregation == 'sum': 565 loss = distance.sum(self._dim_to_reduce) 566 else: 567 raise NotImplementedError(self._aggregation) 568 return -loss 569 570 571class ContDist: 572 """ 573 Overview: 574 A kind of ordinary Dist for dreamerv3. 575 Interfaces: 576 ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` 577 """ 578 579 def __init__(self, dist=None): 580 """ 581 Overview: 582 Initialize the ContDist class. 583 Arguments: 584 - dist (:obj:`torch.Tensor`): Distribution. 585 """ 586 587 super().__init__() 588 self._dist = dist 589 self.mean = dist.mean 590 591 def __getattr__(self, name): 592 """ 593 Overview: 594 Get attribute. 595 Arguments: 596 - name (:obj:`str`): Attribute name. 597 """ 598 599 return getattr(self._dist, name) 600 601 def entropy(self): 602 """ 603 Overview: 604 Calculate the entropy of the distribution. 605 """ 606 607 return self._dist.entropy() 608 609 def mode(self): 610 """ 611 Overview: 612 Calculate the mode of the distribution. 613 """ 614 615 return self._dist.mean 616 617 def sample(self, sample_shape=()): 618 """ 619 Overview: 620 Sample from the distribution. 621 Arguments: 622 - sample_shape (:obj:`tuple`): Sample shape. 623 """ 624 625 return self._dist.rsample(sample_shape) 626 627 def log_prob(self, x): 628 return self._dist.log_prob(x) 629 630 631class Bernoulli: 632 """ 633 Overview: 634 A kind of Bernoulli Dist for dreamerv3. 635 Interfaces: 636 ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` 637 """ 638 639 def __init__(self, dist=None): 640 """ 641 Overview: 642 Initialize the Bernoulli distribution. 643 Arguments: 644 - dist (:obj:`torch.Tensor`): Distribution. 645 """ 646 647 super().__init__() 648 self._dist = dist 649 self.mean = dist.mean 650 651 def __getattr__(self, name): 652 """ 653 Overview: 654 Get attribute. 655 Arguments: 656 - name (:obj:`str`): Attribute name. 657 """ 658 659 return getattr(self._dist, name) 660 661 def entropy(self): 662 """ 663 Overview: 664 Calculate the entropy of the distribution. 665 """ 666 return self._dist.entropy() 667 668 def mode(self): 669 """ 670 Overview: 671 Calculate the mode of the distribution. 672 """ 673 674 _mode = torch.round(self._dist.mean) 675 return _mode.detach() + self._dist.mean - self._dist.mean.detach() 676 677 def sample(self, sample_shape=()): 678 """ 679 Overview: 680 Sample from the distribution. 681 Arguments: 682 - sample_shape (:obj:`tuple`): Sample shape. 683 """ 684 685 return self._dist.rsample(sample_shape) 686 687 def log_prob(self, x): 688 """ 689 Overview: 690 Calculate the log probability of the distribution. 691 Arguments: 692 - x (:obj:`torch.Tensor`): Input tensor. 693 """ 694 695 _logits = self._dist.base_dist.logits 696 log_probs0 = -F.softplus(_logits) 697 log_probs1 = -F.softplus(-_logits) 698 699 return log_probs0 * (1 - x) + log_probs1 * x 700 701 702class UnnormalizedHuber(torchd.normal.Normal): 703 """ 704 Overview: 705 A kind of UnnormalizedHuber Dist for dreamerv3. 706 Interfaces: 707 ``__init__``, ``mode``, ``log_prob`` 708 """ 709 710 def __init__(self, loc, scale, threshold=1, **kwargs): 711 """ 712 Overview: 713 Initialize the UnnormalizedHuber class. 714 Arguments: 715 - loc (:obj:`torch.Tensor`): Location. 716 - scale (:obj:`torch.Tensor`): Scale. 717 - threshold (:obj:`float`): Threshold. 718 """ 719 super().__init__(loc, scale, **kwargs) 720 self._threshold = threshold 721 722 def log_prob(self, event): 723 """ 724 Overview: 725 Calculate the log probability of the distribution. 726 Arguments: 727 - event (:obj:`torch.Tensor`): Event. 728 """ 729 730 return -(torch.sqrt((event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) 731 732 def mode(self): 733 """ 734 Overview: 735 Calculate the mode of the distribution. 736 """ 737 738 return self.mean 739 740 741class SafeTruncatedNormal(torchd.normal.Normal): 742 """ 743 Overview: 744 A kind of SafeTruncatedNormal Dist for dreamerv3. 745 Interfaces: 746 ``__init__``, ``sample`` 747 """ 748 749 def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): 750 """ 751 Overview: 752 Initialize the SafeTruncatedNormal class. 753 Arguments: 754 - loc (:obj:`torch.Tensor`): Location. 755 - scale (:obj:`torch.Tensor`): Scale. 756 - low (:obj:`float`): Low. 757 - high (:obj:`float`): High. 758 - clip (:obj:`float`): Clip. 759 - mult (:obj:`float`): Mult. 760 """ 761 762 super().__init__(loc, scale) 763 self._low = low 764 self._high = high 765 self._clip = clip 766 self._mult = mult 767 768 def sample(self, sample_shape): 769 """ 770 Overview: 771 Sample from the distribution. 772 Arguments: 773 - sample_shape (:obj:`tuple`): Sample shape. 774 """ 775 776 event = super().sample(sample_shape) 777 if self._clip: 778 clipped = torch.clip(event, self._low + self._clip, self._high - self._clip) 779 event = event - event.detach() + clipped.detach() 780 if self._mult: 781 event *= self._mult 782 return event 783 784 785class TanhBijector(torchd.Transform): 786 """ 787 Overview: 788 A kind of TanhBijector Dist for dreamerv3. 789 Interfaces: 790 ``__init__``, ``_forward``, ``_inverse``, ``_forward_log_det_jacobian`` 791 """ 792 793 def __init__(self, validate_args=False, name='tanh'): 794 """ 795 Overview: 796 Initialize the TanhBijector class. 797 Arguments: 798 - validate_args (:obj:`bool`): Validate arguments. 799 - name (:obj:`str`): Name. 800 """ 801 802 super().__init__() 803 804 def _forward(self, x): 805 """ 806 Overview: 807 Calculate the forward of the distribution. 808 Arguments: 809 - x (:obj:`torch.Tensor`): Input tensor. 810 """ 811 812 return torch.tanh(x) 813 814 def _inverse(self, y): 815 """ 816 Overview: 817 Calculate the inverse of the distribution. 818 Arguments: 819 - y (:obj:`torch.Tensor`): Input tensor. 820 """ 821 822 y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y) 823 y = torch.atanh(y) 824 return y 825 826 def _forward_log_det_jacobian(self, x): 827 """ 828 Overview: 829 Calculate the forward log det jacobian of the distribution. 830 Arguments: 831 - x (:obj:`torch.Tensor`): Input tensor. 832 """ 833 834 log2 = torch.math.log(2.0) 835 return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) 836 837 838def static_scan(fn, inputs, start): 839 """ 840 Overview: 841 Static scan function. 842 Arguments: 843 - fn (:obj:`function`): Function. 844 - inputs (:obj:`tuple`): Inputs. 845 - start (:obj:`torch.Tensor`): Start tensor. 846 """ 847 848 last = start # {logit, stoch, deter:[batch_size, self._deter]} 849 indices = range(inputs[0].shape[0]) 850 flag = True 851 for index in indices: 852 inp = lambda x: (_input[x] for _input in inputs) # inputs:(action:(time, batch, 6), embed:(time, batch, 4096)) 853 last = fn(last, *inp(index)) # post, prior 854 if flag: 855 if isinstance(last, dict): 856 outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} 857 else: 858 outputs = [] 859 for _last in last: 860 if isinstance(_last, dict): 861 outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) 862 else: 863 outputs.append(_last.clone().unsqueeze(0)) 864 flag = False 865 else: 866 if isinstance(last, dict): 867 for key in last.keys(): 868 outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) 869 else: 870 for j in range(len(outputs)): 871 if isinstance(last[j], dict): 872 for key in last[j].keys(): 873 outputs[j][key] = torch.cat([outputs[j][key], last[j][key].unsqueeze(0)], dim=0) 874 else: 875 outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) 876 if isinstance(last, dict): 877 outputs = [outputs] 878 return outputs 879 880 881def weight_init(m): 882 """ 883 Overview: 884 weight_init for Linear, Conv2d, ConvTranspose2d, and LayerNorm. 885 Arguments: 886 - m (:obj:`torch.nn`): Module. 887 """ 888 889 if isinstance(m, nn.Linear): 890 in_num = m.in_features 891 out_num = m.out_features 892 denoms = (in_num + out_num) / 2.0 893 scale = 1.0 / denoms 894 std = np.sqrt(scale) / 0.87962566103423978 895 nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) 896 if hasattr(m.bias, 'data'): 897 m.bias.data.fill_(0.0) 898 elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 899 space = m.kernel_size[0] * m.kernel_size[1] 900 in_num = space * m.in_channels 901 out_num = space * m.out_channels 902 denoms = (in_num + out_num) / 2.0 903 scale = 1.0 / denoms 904 std = np.sqrt(scale) / 0.87962566103423978 905 nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) 906 if hasattr(m.bias, 'data'): 907 m.bias.data.fill_(0.0) 908 elif isinstance(m, nn.LayerNorm): 909 m.weight.data.fill_(1.0) 910 if hasattr(m.bias, 'data'): 911 m.bias.data.fill_(0.0) 912 913 914def uniform_weight_init(given_scale): 915 """ 916 Overview: 917 weight_init for Linear and LayerNorm. 918 Arguments: 919 - given_scale (:obj:`float`): Given scale. 920 """ 921 922 def f(m): 923 if isinstance(m, nn.Linear): 924 in_num = m.in_features 925 out_num = m.out_features 926 denoms = (in_num + out_num) / 2.0 927 scale = given_scale / denoms 928 limit = np.sqrt(3 * scale) 929 nn.init.uniform_(m.weight.data, a=-limit, b=limit) 930 if hasattr(m.bias, 'data'): 931 m.bias.data.fill_(0.0) 932 elif isinstance(m, nn.LayerNorm): 933 m.weight.data.fill_(1.0) 934 if hasattr(m.bias, 'data'): 935 m.bias.data.fill_(0.0) 936 937 return f