Skip to content

ding.torch_utils.network.diffusion

ding.torch_utils.network.diffusion

DiffusionConv1d

Bases: Module

Overview

Conv1d with activation and normalization for diffusion models.

Interfaces: __init__, forward

__init__(in_channels, out_channels, kernel_size, padding, activation=None, n_groups=8)

Overview

Create a 1-dim convlution layer with activation and normalization. This Conv1d have GropuNorm. And need add 1-dim when compute norm

Arguments: - in_channels (:obj:int): Number of channels in the input tensor - out_channels (:obj:int): Number of channels in the output tensor - kernel_size (:obj:int): Size of the convolving kernel - padding (:obj:int): Zero-padding added to both sides of the input - activation (:obj:nn.Module): the optional activation function

forward(inputs)

Overview

compute conv1d for inputs.

Arguments: - inputs (:obj:torch.Tensor): input tensor Return: - out (:obj:torch.Tensor): output tensor

SinusoidalPosEmb

Bases: Module

Overview

class for computing sin position embeding

Interfaces: __init__, forward

__init__(dim)

Overview

Initialization of SinusoidalPosEmb class

Arguments: - dim (:obj:int): dimension of embeding

forward(x)

Overview

compute sin position embeding

Arguments: - x (:obj:torch.Tensor): input tensor Return: - emb (:obj:torch.Tensor): output tensor

Residual

Bases: Module

Overview

Basic Residual block

Interfaces: __init__, forward

__init__(fn)

Overview

Initialization of Residual class

Arguments: - fn (:obj:nn.Module): function of residual block

forward(x, *arg, **kwargs)

Overview

compute residual block

Arguments: - x (:obj:torch.Tensor): input tensor

LayerNorm

Bases: Module

Overview

LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon]

Interfaces: __init__, forward

__init__(dim, eps=1e-05)

Overview

Initialization of LayerNorm class

Arguments: - dim (:obj:int): dimension of input - eps (:obj:float): eps of LayerNorm

forward(x)

Overview

compute LayerNorm

Arguments: - x (:obj:torch.Tensor): input tensor

PreNorm

Bases: Module

Overview

PreNorm, compute dim = 1, because Temporal input x [batch, dim, horizon]

Interfaces: __init__, forward

__init__(dim, fn)

Overview

Initialization of PreNorm class

Arguments: - dim (:obj:int): dimension of input - fn (:obj:nn.Module): function of residual block

forward(x)

Overview

compute PreNorm

Arguments: - x (:obj:torch.Tensor): input tensor

LinearAttention

Bases: Module

Overview

Linear Attention head

Interfaces: __init__, forward

__init__(dim, heads=4, dim_head=32)

Overview

Initialization of LinearAttention class

Arguments: - dim (:obj:int): dimension of input - heads (:obj:int): heads of attention - dim_head (:obj:int): dim of head

forward(x)

Overview

compute LinearAttention

Arguments: - x (:obj:torch.Tensor): input tensor

ResidualTemporalBlock

Bases: Module

Overview

Residual block of temporal

Interfaces: __init__, forward

__init__(in_channels, out_channels, embed_dim, kernel_size=5, mish=True)

Overview

Initialization of ResidualTemporalBlock class

Arguments: - in_channels (:obj:'int'): dim of in_channels - out_channels (:obj:'int'): dim of out_channels - embed_dim (:obj:'int'): dim of embeding layer - kernel_size (:obj:'int'): kernel_size of conv1d - mish (:obj:'bool'): whether use mish as a activate function

forward(x, t)

Overview

compute residual block

Arguments: - x (:obj:'tensor'): input tensor - t (:obj:'tensor'): time tensor

DiffusionUNet1d

Bases: Module

Overview

Diffusion unet for 1d vector data

Interfaces: __init__, forward, get_pred

__init__(transition_dim, dim=32, dim_mults=[1, 2, 4, 8], returns_condition=False, condition_dropout=0.1, calc_energy=False, kernel_size=5, attention=False)

Overview

Initialization of DiffusionUNet1d class

Arguments: - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim - dim (:obj:'int'): dim of layer - dim_mults (:obj:'SequenceType'): mults of dim - returns_condition (:obj:'bool'): whether use return as a condition - condition_dropout (:obj:'float'): dropout of returns condition - calc_energy (:obj:'bool'): whether use calc_energy - kernel_size (:obj:'int'): kernel_size of conv1d - attention (:obj:'bool'): whether use attention

forward(x, cond, time, returns=None, use_dropout=True, force_dropout=False)

Overview

compute diffusion unet forward

Arguments: - x (:obj:'tensor'): noise trajectory - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 - time (:obj:'int'): timestep of diffusion step - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return - use_dropout (:obj:'bool'): Whether use returns condition mask - force_dropout (:obj:'bool'): Whether use returns condition

get_pred(x, cond, time, returns=None, use_dropout=True, force_dropout=False)

Overview

compute diffusion unet forward

Arguments: - x (:obj:'tensor'): noise trajectory - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 - time (:obj:'int'): timestep of diffusion step - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return - use_dropout (:obj:'bool'): Whether use returns condition mask - force_dropout (:obj:'bool'): Whether use returns condition

TemporalValue

Bases: Module

Overview

temporal net for value function

Interfaces: __init__, forward

__init__(horizon, transition_dim, dim=32, time_dim=None, out_dim=1, kernel_size=5, dim_mults=[1, 2, 4, 8])

Overview

Initialization of TemporalValue class

Arguments: - horizon (:obj:'int'): horizon of trajectory - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim - dim (:obj:'int'): dim of layer - time_dim (:obj:'int'): dim of time - out_dim (:obj:'int'): dim of output - kernel_size (:obj:'int'): kernel_size of conv1d - dim_mults (:obj:'SequenceType'): mults of dim

forward(x, cond, time, *args)

Overview

compute temporal value forward

Arguments: - x (:obj:'tensor'): noise trajectory - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 - time (:obj:'int'): timestep of diffusion step

extract(a, t, x_shape)

Overview

extract output from a through index t.

Arguments: - a (:obj:torch.Tensor): input tensor - t (:obj:torch.Tensor): index tensor - x_shape (:obj:torch.Tensor): shape of x

cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32)

Overview

cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ

Arguments: - timesteps (:obj:int): timesteps of diffusion step - s (:obj:float): s - dtype (:obj:torch.dtype): dtype of beta Return: Tensor of beta [timesteps,], computing by cosine.

apply_conditioning(x, conditions, action_dim)

Overview

add condition into x

Arguments: - x (:obj:torch.Tensor): input tensor - conditions (:obj:dict): condition dict, key is timestep, value is condition - action_dim (:obj:int): action dim

Full Source Code

../ding/torch_utils/network/diffusion.py

1from typing import Union, List, Dict 2from collections import namedtuple 3import numpy as np 4import math 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType 9 10 11def extract(a, t, x_shape): 12 """ 13 Overview: 14 extract output from a through index t. 15 Arguments: 16 - a (:obj:`torch.Tensor`): input tensor 17 - t (:obj:`torch.Tensor`): index tensor 18 - x_shape (:obj:`torch.Tensor`): shape of x 19 """ 20 b, *_ = t.shape 21 out = a.gather(-1, t) 22 return out.reshape(b, *((1, ) * (len(x_shape) - 1))) 23 24 25def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32): 26 """ 27 Overview: 28 cosine schedule 29 as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 30 Arguments: 31 - timesteps (:obj:`int`): timesteps of diffusion step 32 - s (:obj:`float`): s 33 - dtype (:obj:`torch.dtype`): dtype of beta 34 Return: 35 Tensor of beta [timesteps,], computing by cosine. 36 """ 37 steps = timesteps + 1 38 x = np.linspace(0, steps, steps) 39 alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 40 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 41 betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 42 betas_clipped = np.clip(betas, a_min=0, a_max=0.999) 43 return torch.tensor(betas_clipped, dtype=dtype) 44 45 46def apply_conditioning(x, conditions, action_dim): 47 """ 48 Overview: 49 add condition into x 50 Arguments: 51 - x (:obj:`torch.Tensor`): input tensor 52 - conditions (:obj:`dict`): condition dict, key is timestep, value is condition 53 - action_dim (:obj:`int`): action dim 54 """ 55 for t, val in conditions.items(): 56 x[:, t, action_dim:] = val.clone() 57 return x 58 59 60class DiffusionConv1d(nn.Module): 61 """ 62 Overview: 63 Conv1d with activation and normalization for diffusion models. 64 Interfaces: 65 ``__init__``, ``forward`` 66 """ 67 68 def __init__( 69 self, 70 in_channels: int, 71 out_channels: int, 72 kernel_size: int, 73 padding: int, 74 activation: nn.Module = None, 75 n_groups: int = 8 76 ) -> None: 77 """ 78 Overview: 79 Create a 1-dim convlution layer with activation and normalization. This Conv1d have GropuNorm. 80 And need add 1-dim when compute norm 81 Arguments: 82 - in_channels (:obj:`int`): Number of channels in the input tensor 83 - out_channels (:obj:`int`): Number of channels in the output tensor 84 - kernel_size (:obj:`int`): Size of the convolving kernel 85 - padding (:obj:`int`): Zero-padding added to both sides of the input 86 - activation (:obj:`nn.Module`): the optional activation function 87 """ 88 super().__init__() 89 self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) 90 self.norm = nn.GroupNorm(n_groups, out_channels) 91 self.act = activation 92 93 def forward(self, inputs) -> torch.Tensor: 94 """ 95 Overview: 96 compute conv1d for inputs. 97 Arguments: 98 - inputs (:obj:`torch.Tensor`): input tensor 99 Return: 100 - out (:obj:`torch.Tensor`): output tensor 101 """ 102 x = self.conv1(inputs) 103 # [batch, channels, horizon] -> [batch, channels, 1, horizon] 104 x = x.unsqueeze(-2) 105 x = self.norm(x) 106 # [batch, channels, 1, horizon] -> [batch, channels, horizon] 107 x = x.squeeze(-2) 108 out = self.act(x) 109 return out 110 111 112class SinusoidalPosEmb(nn.Module): 113 """ 114 Overview: 115 class for computing sin position embeding 116 Interfaces: 117 ``__init__``, ``forward`` 118 """ 119 120 def __init__(self, dim: int) -> None: 121 """ 122 Overview: 123 Initialization of SinusoidalPosEmb class 124 Arguments: 125 - dim (:obj:`int`): dimension of embeding 126 """ 127 128 super().__init__() 129 self.dim = dim 130 131 def forward(self, x) -> torch.Tensor: 132 """ 133 Overview: 134 compute sin position embeding 135 Arguments: 136 - x (:obj:`torch.Tensor`): input tensor 137 Return: 138 - emb (:obj:`torch.Tensor`): output tensor 139 """ 140 141 device = x.device 142 half_dim = self.dim // 2 143 emb = math.log(10000) / (half_dim - 1) 144 emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 145 emb = x[:, None] * emb[None, :] 146 emb = torch.cat((emb.sin(), emb.cos()), dim=1) 147 return emb 148 149 150class Residual(nn.Module): 151 """ 152 Overview: 153 Basic Residual block 154 Interfaces: 155 ``__init__``, ``forward`` 156 """ 157 158 def __init__(self, fn): 159 """ 160 Overview: 161 Initialization of Residual class 162 Arguments: 163 - fn (:obj:`nn.Module`): function of residual block 164 """ 165 166 super().__init__() 167 self.fn = fn 168 169 def forward(self, x, *arg, **kwargs): 170 """ 171 Overview: 172 compute residual block 173 Arguments: 174 - x (:obj:`torch.Tensor`): input tensor 175 """ 176 177 return self.fn(x, *arg, **kwargs) + x 178 179 180class LayerNorm(nn.Module): 181 """ 182 Overview: 183 LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] 184 Interfaces: 185 ``__init__``, ``forward`` 186 """ 187 188 def __init__(self, dim, eps=1e-5) -> None: 189 """ 190 Overview: 191 Initialization of LayerNorm class 192 Arguments: 193 - dim (:obj:`int`): dimension of input 194 - eps (:obj:`float`): eps of LayerNorm 195 """ 196 197 super().__init__() 198 self.eps = eps 199 self.g = nn.Parameter(torch.ones(1, dim, 1)) 200 self.b = nn.Parameter(torch.zeros(1, dim, 1)) 201 202 def forward(self, x): 203 """ 204 Overview: 205 compute LayerNorm 206 Arguments: 207 - x (:obj:`torch.Tensor`): input tensor 208 """ 209 210 print('x.shape:', x.shape) 211 var = torch.var(x, dim=1, unbiased=False, keepdim=True) 212 mean = torch.mean(x, dim=1, keepdim=True) 213 return (x - mean) / (var + self.eps).sqrt() * self.g + self.b 214 215 216class PreNorm(nn.Module): 217 """ 218 Overview: 219 PreNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] 220 Interfaces: 221 ``__init__``, ``forward`` 222 """ 223 224 def __init__(self, dim, fn) -> None: 225 """ 226 Overview: 227 Initialization of PreNorm class 228 Arguments: 229 - dim (:obj:`int`): dimension of input 230 - fn (:obj:`nn.Module`): function of residual block 231 """ 232 233 super().__init__() 234 self.fn = fn 235 self.norm = LayerNorm(dim) 236 237 def forward(self, x): 238 """ 239 Overview: 240 compute PreNorm 241 Arguments: 242 - x (:obj:`torch.Tensor`): input tensor 243 """ 244 x = self.norm(x) 245 return self.fn(x) 246 247 248class LinearAttention(nn.Module): 249 """ 250 Overview: 251 Linear Attention head 252 Interfaces: 253 ``__init__``, ``forward`` 254 """ 255 256 def __init__(self, dim, heads=4, dim_head=32) -> None: 257 """ 258 Overview: 259 Initialization of LinearAttention class 260 Arguments: 261 - dim (:obj:`int`): dimension of input 262 - heads (:obj:`int`): heads of attention 263 - dim_head (:obj:`int`): dim of head 264 """ 265 super().__init__() 266 self.scale = dim_head ** -0.5 267 self.heads = heads 268 hidden_dim = dim_head * heads 269 self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False) 270 self.to_out = nn.Conv1d(hidden_dim, dim, 1) 271 272 def forward(self, x): 273 """ 274 Overview: 275 compute LinearAttention 276 Arguments: 277 - x (:obj:`torch.Tensor`): input tensor 278 """ 279 qkv = self.to_qkv(x).chunk(3, dim=1) 280 q, k, v = map(lambda t: t.reshape(t.shape[0], self.heads, -1, t.shape[-1]), qkv) 281 q = q * self.scale 282 k = k.softmax(dim=-1) 283 context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 284 285 out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 286 out = out.reshape(out.shape[0], -1, out.shape[-1]) 287 return self.to_out(out) 288 289 290class ResidualTemporalBlock(nn.Module): 291 """ 292 Overview: 293 Residual block of temporal 294 Interfaces: 295 ``__init__``, ``forward`` 296 """ 297 298 def __init__( 299 self, in_channels: int, out_channels: int, embed_dim: int, kernel_size: int = 5, mish: bool = True 300 ) -> None: 301 """ 302 Overview: 303 Initialization of ResidualTemporalBlock class 304 Arguments: 305 - in_channels (:obj:'int'): dim of in_channels 306 - out_channels (:obj:'int'): dim of out_channels 307 - embed_dim (:obj:'int'): dim of embeding layer 308 - kernel_size (:obj:'int'): kernel_size of conv1d 309 - mish (:obj:'bool'): whether use mish as a activate function 310 """ 311 super().__init__() 312 if mish: 313 act = nn.Mish() 314 else: 315 act = nn.SiLU() 316 self.blocks = nn.ModuleList( 317 [ 318 DiffusionConv1d(in_channels, out_channels, kernel_size, kernel_size // 2, act), 319 DiffusionConv1d(out_channels, out_channels, kernel_size, kernel_size // 2, act), 320 ] 321 ) 322 self.time_mlp = nn.Sequential( 323 act, 324 nn.Linear(embed_dim, out_channels), 325 ) 326 self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ 327 if in_channels != out_channels else nn.Identity() 328 329 def forward(self, x, t): 330 """ 331 Overview: 332 compute residual block 333 Arguments: 334 - x (:obj:'tensor'): input tensor 335 - t (:obj:'tensor'): time tensor 336 """ 337 out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1) 338 out = self.blocks[1](out) 339 return out + self.residual_conv(x) 340 341 342class DiffusionUNet1d(nn.Module): 343 """ 344 Overview: 345 Diffusion unet for 1d vector data 346 Interfaces: 347 ``__init__``, ``forward``, ``get_pred`` 348 """ 349 350 def __init__( 351 self, 352 transition_dim: int, 353 dim: int = 32, 354 dim_mults: SequenceType = [1, 2, 4, 8], 355 returns_condition: bool = False, 356 condition_dropout: float = 0.1, 357 calc_energy: bool = False, 358 kernel_size: int = 5, 359 attention: bool = False, 360 ) -> None: 361 """ 362 Overview: 363 Initialization of DiffusionUNet1d class 364 Arguments: 365 - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim 366 - dim (:obj:'int'): dim of layer 367 - dim_mults (:obj:'SequenceType'): mults of dim 368 - returns_condition (:obj:'bool'): whether use return as a condition 369 - condition_dropout (:obj:'float'): dropout of returns condition 370 - calc_energy (:obj:'bool'): whether use calc_energy 371 - kernel_size (:obj:'int'): kernel_size of conv1d 372 - attention (:obj:'bool'): whether use attention 373 """ 374 super().__init__() 375 dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] 376 in_out = list(zip(dims[:-1], dims[1:])) 377 378 if calc_energy: 379 mish = False 380 act = nn.SiLU() 381 else: 382 mish = True 383 act = nn.Mish() 384 385 self.time_dim = dim 386 self.returns_dim = dim 387 388 self.time_mlp = nn.Sequential( 389 SinusoidalPosEmb(dim), 390 nn.Linear(dim, dim * 4), 391 act, 392 nn.Linear(dim * 4, dim), 393 ) 394 395 self.returns_condition = returns_condition 396 self.condition_dropout = condition_dropout 397 self.cale_energy = calc_energy 398 399 if self.returns_condition: 400 self.returns_mlp = nn.Sequential( 401 nn.Linear(1, dim), 402 act, 403 nn.Linear(dim, dim * 4), 404 act, 405 nn.Linear(dim * 4, dim), 406 ) 407 self.mask_dist = torch.distributions.Bernoulli(probs=1 - self.condition_dropout) 408 embed_dim = 2 * dim 409 else: 410 embed_dim = dim 411 412 self.downs = nn.ModuleList([]) 413 self.ups = nn.ModuleList([]) 414 num_resolution = len(in_out) 415 416 for ind, (dim_in, dim_out) in enumerate(in_out): 417 is_last = ind >= (num_resolution - 1) 418 self.downs.append( 419 nn.ModuleList( 420 [ 421 ResidualTemporalBlock(dim_in, dim_out, embed_dim, kernel_size, mish=mish), 422 ResidualTemporalBlock(dim_out, dim_out, embed_dim, kernel_size, mish=mish), 423 Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(), 424 nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity() 425 ] 426 ) 427 ) 428 429 mid_dim = dims[-1] 430 self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish) 431 self.mid_atten = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity() 432 self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish) 433 434 for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 435 is_last = ind >= (num_resolution - 1) 436 self.ups.append( 437 nn.ModuleList( 438 [ 439 ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim, kernel_size, mish=mish), 440 ResidualTemporalBlock(dim_in, dim_in, embed_dim, kernel_size, mish=mish), 441 Residual(PreNorm(dim_in, LinearAttention(dim_in))) if attention else nn.Identity(), 442 nn.ConvTranspose1d(dim_in, dim_in, 4, 2, 1) if not is_last else nn.Identity() 443 ] 444 ) 445 ) 446 447 self.final_conv = nn.Sequential( 448 DiffusionConv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, activation=act), 449 nn.Conv1d(dim, transition_dim, 1), 450 ) 451 452 def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False): 453 """ 454 Overview: 455 compute diffusion unet forward 456 Arguments: 457 - x (:obj:'tensor'): noise trajectory 458 - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 459 - time (:obj:'int'): timestep of diffusion step 460 - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return 461 - use_dropout (:obj:'bool'): Whether use returns condition mask 462 - force_dropout (:obj:'bool'): Whether use returns condition 463 """ 464 if self.cale_energy: 465 x_inp = x 466 467 # [batch, horizon, transition ] -> [batch, transition , horizon] 468 x = x.transpose(1, 2) 469 t = self.time_mlp(time) 470 471 if self.returns_condition: 472 assert returns is not None 473 returns_embed = self.returns_mlp(returns) 474 if use_dropout: 475 mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device) 476 returns_embed = mask * returns_embed 477 if force_dropout: 478 returns_embed = 0 * returns_embed 479 t = torch.cat([t, returns_embed], dim=-1) 480 481 h = [] 482 483 for resnet, resnet2, atten, downsample in self.downs: 484 x = resnet(x, t) 485 x = resnet2(x, t) 486 x = atten(x) 487 h.append(x) 488 x = downsample(x) 489 490 x = self.mid_block1(x, t) 491 x = self.mid_atten(x) 492 x = self.mid_block2(x, t) 493 494 for resnet, resnet2, atten, upsample in self.ups: 495 x = torch.cat((x, h.pop()), dim=1) 496 x = resnet(x, t) 497 x = resnet2(x, t) 498 x = atten(x) 499 x = upsample(x) 500 501 x = self.final_conv(x) 502 # [batch, transition , horizon] -> [batch, horizon, transition ] 503 x = x.transpose(1, 2) 504 505 if self.cale_energy: 506 # Energy function 507 energy = ((x - x_inp) ** 2).mean() 508 grad = torch.autograd.grad(outputs=energy, inputs=x_inp, create_graph=True) 509 return grad[0] 510 else: 511 return x 512 513 def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False): 514 """ 515 Overview: 516 compute diffusion unet forward 517 Arguments: 518 - x (:obj:'tensor'): noise trajectory 519 - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 520 - time (:obj:'int'): timestep of diffusion step 521 - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return 522 - use_dropout (:obj:'bool'): Whether use returns condition mask 523 - force_dropout (:obj:'bool'): Whether use returns condition 524 """ 525 # [batch, horizon, transition ] -> [batch, transition , horizon] 526 x = x.transpose(1, 2) 527 t = self.time_mlp(time) 528 529 if self.returns_condition: 530 assert returns is not None 531 returns_embed = self.returns_mlp(returns) 532 if use_dropout: 533 mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device) 534 returns_embed = mask * returns_embed 535 if force_dropout: 536 returns_embed = 0 * returns_embed 537 t = torch.cat([t, returns_embed], dim=-1) 538 539 h = [] 540 541 for resnet, resnet2, downsample in self.downs: 542 x = resnet(x, t) 543 x = resnet2(x, t) 544 h.append(x) 545 x = downsample(x) 546 547 x = self.mid_block1(x, t) 548 x = self.mid_block2(x, t) 549 550 for resnet, resnet2, upsample in self.ups: 551 x = torch.cat((x, h.pop()), dim=1) 552 x = resnet(x, t) 553 x = resnet2(x, t) 554 x = upsample(x) 555 556 x = self.final_conv(x) 557 # [batch, transition , horizon] -> [batch, horizon, transition ] 558 x = x.transpose(1, 2) 559 return x 560 561 562class TemporalValue(nn.Module): 563 """ 564 Overview: 565 temporal net for value function 566 Interfaces: 567 ``__init__``, ``forward`` 568 """ 569 570 def __init__( 571 self, 572 horizon: int, 573 transition_dim: int, 574 dim: int = 32, 575 time_dim: int = None, 576 out_dim: int = 1, 577 kernel_size: int = 5, 578 dim_mults: SequenceType = [1, 2, 4, 8], 579 ) -> None: 580 """ 581 Overview: 582 Initialization of TemporalValue class 583 Arguments: 584 - horizon (:obj:'int'): horizon of trajectory 585 - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim 586 - dim (:obj:'int'): dim of layer 587 - time_dim (:obj:'int'): dim of time 588 - out_dim (:obj:'int'): dim of output 589 - kernel_size (:obj:'int'): kernel_size of conv1d 590 - dim_mults (:obj:'SequenceType'): mults of dim 591 """ 592 super().__init__() 593 dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] 594 in_out = list(zip(dims[:-1], dims[1:])) 595 596 time_dim = time_dim or dim 597 self.time_mlp = nn.Sequential( 598 SinusoidalPosEmb(dim), 599 nn.Linear(dim, dim * 4), 600 nn.Mish(), 601 nn.Linear(dim * 4, dim), 602 ) 603 self.blocks = nn.ModuleList([]) 604 605 for ind, (dim_in, dim_out) in enumerate(in_out): 606 self.blocks.append( 607 nn.ModuleList( 608 [ 609 ResidualTemporalBlock(dim_in, dim_out, kernel_size=kernel_size, embed_dim=time_dim), 610 ResidualTemporalBlock(dim_out, dim_out, kernel_size=kernel_size, embed_dim=time_dim), 611 nn.Conv1d(dim_out, dim_out, 3, 2, 1) 612 ] 613 ) 614 ) 615 616 horizon = horizon // 2 617 618 mid_dim = dims[-1] 619 mid_dim_2 = mid_dim // 2 620 mid_dim_3 = mid_dim // 4 621 622 self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim_2, kernel_size=kernel_size, embed_dim=time_dim) 623 self.mid_down1 = nn.Conv1d(mid_dim_2, mid_dim_2, 3, 2, 1) 624 625 horizon = horizon // 2 626 self.mid_block2 = ResidualTemporalBlock(mid_dim_2, mid_dim_3, kernel_size=kernel_size, embed_dim=time_dim) 627 self.mid_down2 = nn.Conv1d(mid_dim_3, mid_dim_3, 3, 2, 1) 628 horizon = horizon // 2 629 630 fc_dim = mid_dim_3 * max(horizon, 1) 631 self.final_block = nn.Sequential( 632 nn.Linear(fc_dim + time_dim, fc_dim // 2), 633 nn.Mish(), 634 nn.Linear(fc_dim // 2, out_dim), 635 ) 636 637 def forward(self, x, cond, time, *args): 638 """ 639 Overview: 640 compute temporal value forward 641 Arguments: 642 - x (:obj:'tensor'): noise trajectory 643 - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 644 - time (:obj:'int'): timestep of diffusion step 645 """ 646 # [batch, horizon, transition ] -> [batch, transition , horizon] 647 x = x.transpose(1, 2) 648 t = self.time_mlp(time) 649 for resnet, resnet2, downsample in self.blocks: 650 x = resnet(x, t) 651 x = resnet2(x, t) 652 x = downsample(x) 653 654 x = self.mid_block1(x, t) 655 x = self.mid_down1(x) 656 657 x = self.mid_block2(x, t) 658 x = self.mid_down2(x) 659 x = x.view(len(x), -1) 660 out = self.final_block(torch.cat([x, t], dim=-1)) 661 return out