1from typing import Union, List, Dict 2from collections import namedtuple 3import numpy as np 4import math 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType 9from ding.torch_utils.network.diffusion import extract, cosine_beta_schedule, apply_conditioning, \ 10 DiffusionUNet1d, TemporalValue 11 12Sample = namedtuple('Sample', 'trajectories values chains') 13 14 15def default_sample_fn(model, x, cond, t): 16 b, *_, device = *x.shape, x.device 17 model_mean, _, model_log_variance = model.p_mean_variance( 18 x=x, 19 cond=cond, 20 t=t, 21 ) 22 noise = 0.5 * torch.randn_like(x) 23 # no noise when t == 0 24 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1, ) * (len(x.shape) - 1))) 25 values = torch.zeros(len(x), device=device) 26 return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, values 27 28 29def get_guide_output(guide, x, cond, t): 30 x.requires_grad_() 31 y = guide(x, cond, t).squeeze(dim=-1) 32 grad = torch.autograd.grad([y.sum()], [x])[0] 33 x.detach() 34 return y, grad 35 36 37def n_step_guided_p_sample( 38 model, 39 x, 40 cond, 41 t, 42 guide, 43 scale=0.001, 44 t_stopgrad=0, 45 n_guide_steps=1, 46 scale_grad_by_std=True, 47): 48 model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) 49 model_std = torch.exp(0.5 * model_log_variance) 50 model_var = torch.exp(model_log_variance) 51 52 for _ in range(n_guide_steps): 53 with torch.enable_grad(): 54 y, grad = get_guide_output(guide, x, cond, t) 55 56 if scale_grad_by_std: 57 grad = model_var * grad 58 59 grad[t < t_stopgrad] = 0 60 61 x = x + scale * grad 62 x = apply_conditioning(x, cond, model.action_dim) 63 64 model_mean, _, model_log_variance = model.p_mean_variance(x=x, cond=cond, t=t) 65 66 # no noise when t == 0 67 noise = torch.randn_like(x) 68 noise[t == 0] = 0 69 70 return model_mean + model_std * noise, y 71 72 73class GaussianDiffusion(nn.Module): 74 """ 75 Overview: 76 Gaussian diffusion model 77 Arguments: 78 - model (:obj:`str`): type of model 79 - model_cfg (:obj:'dict') config of model 80 - horizon (:obj:`int`): horizon of trajectory 81 - obs_dim (:obj:`int`): Dim of the ovservation 82 - action_dim (:obj:`int`): Dim of the ation 83 - n_timesteps (:obj:`int`): Number of timesteps 84 - predict_epsilon (:obj:'bool'): Whether predict epsilon 85 - loss_discount (:obj:'float'): discount of loss 86 - clip_denoised (:obj:'bool'): Whether use clip_denoised 87 - action_weight (:obj:'float'): weight of action 88 - loss_weights (:obj:'dict'): weight of loss 89 """ 90 91 def __init__( 92 self, 93 model: str, 94 model_cfg: dict, 95 horizon: int, 96 obs_dim: Union[int, SequenceType], 97 action_dim: Union[int, SequenceType], 98 n_timesteps: int = 1000, 99 predict_epsilon: bool = True, 100 loss_discount: float = 1.0, 101 clip_denoised: bool = False, 102 action_weight: float = 1.0, 103 loss_weights: dict = None, 104 ) -> None: 105 super().__init__() 106 self.horizon = horizon 107 self.obs_dim = obs_dim 108 self.action_dim = action_dim 109 self.transition_dim = obs_dim + action_dim 110 if type(model) == str: 111 model = eval(model) 112 self.model = model(**model_cfg) 113 self.predict_epsilon = predict_epsilon 114 self.clip_denoised = clip_denoised 115 116 betas = cosine_beta_schedule(n_timesteps) 117 alphas = 1. - betas 118 alphas_cumprod = torch.cumprod(alphas, axis=0) 119 alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) 120 self.n_timesteps = int(n_timesteps) 121 122 self.register_buffer('betas', betas) 123 self.register_buffer('alphas_cumprod', alphas_cumprod) 124 self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 125 126 # calculations for diffusion q(x_t | x_{t-1}) and others 127 self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 128 self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 129 self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 130 self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 131 self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 132 133 # calculations for posterior q(x_{t-1} | x_t, x_0) 134 posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 135 self.register_buffer('posterior_variance', posterior_variance) 136 137 # log calculation clipped because the posterior variance 138 # is 0 at the beginning of the diffusion chain 139 self.register_buffer('posterior_log_variance_clipped', torch.log(torch.clamp(posterior_variance, min=1e-20))) 140 self.register_buffer('posterior_mean_coef1', betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 141 self.register_buffer( 142 'posterior_mean_coef2', (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod) 143 ) 144 145 self.loss_weights = self.get_loss_weights(action_weight, loss_discount, loss_weights) 146 147 def get_loss_weights(self, action_weight: float, discount: float, weights_dict: dict): 148 """ 149 Overview: 150 sets loss coefficients for trajectory 151 Arguments: 152 - action_weight (:obj:'float') coefficient on first action loss 153 - discount (:obj:'float') multiplies t^th timestep of trajectory loss by discount**t 154 - weights_dict (:obj:'dict') { i: c } multiplies dimension i of observation loss by c 155 """ 156 self.action_weight = action_weight 157 dim_weights = torch.ones(self.transition_dim, dtype=torch.float32) 158 159 # set loss coefficients for dimensions of observation 160 if weights_dict is None: 161 weights_dict = {} 162 for ind, w in weights_dict.items(): 163 dim_weights[self.action_dim + ind] *= w 164 165 # decay loss with trajectory timestep: discount**t 166 discounts = discount ** torch.arange(self.horizon, dtype=torch.float) 167 discounts = discounts / discounts.mean() 168 loss_weights = torch.einsum('h,t->ht', discounts, dim_weights) 169 170 # manually set a0 weight 171 loss_weights[0, :self.action_dim] = action_weight 172 return loss_weights 173 174 def predict_start_from_noise(self, x_t, t, noise): 175 """ 176 if self.predict_epsilon, model output is (scaled) noise; 177 otherwise, model predicts x0 directly 178 """ 179 if self.predict_epsilon: 180 return ( 181 extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 182 extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 183 ) 184 else: 185 return noise 186 187 def q_posterior(self, x_start, x_t, t): 188 """ 189 Overview: 190 give noise and step, compute mean, variance. 191 Arguments: 192 x_start (:obj:'tensor') noise trajectory in timestep 0 193 x_t (:obj:'tuple') noise trajectory in timestep t 194 t (:obj:'int') timestep of diffusion step 195 """ 196 posterior_mean = ( 197 extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 198 extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 199 ) 200 posterior_variance = extract(self.posterior_variance, t, x_t.shape) 201 posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 202 return posterior_mean, posterior_variance, posterior_log_variance_clipped 203 204 def p_mean_variance(self, x, cond, t): 205 x_recon = self.predict_start_from_noise(x, t=t, noise=self.model(x, cond, t)) 206 207 if self.clip_denoised: 208 x_recon.clamp_(-1., 1.) 209 else: 210 assert RuntimeError() 211 212 model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 213 return model_mean, posterior_variance, posterior_log_variance 214 215 @torch.no_grad() 216 def p_sample_loop(self, shape, cond, return_chain=False, sample_fn=default_sample_fn, plan_size=1, **sample_kwargs): 217 device = self.betas.device 218 219 batch_size = shape[0] 220 x = torch.randn(shape, device=device) 221 x = apply_conditioning(x, cond, self.action_dim) 222 223 chain = [x] if return_chain else None 224 225 for i in reversed(range(0, self.n_timesteps)): 226 t = torch.full((batch_size, ), i, device=device, dtype=torch.long) 227 x, values = sample_fn(self, x, cond, t, **sample_kwargs) 228 x = apply_conditioning(x, cond, self.action_dim) 229 230 if return_chain: 231 chain.append(x) 232 values = values.reshape(-1, plan_size, *values.shape[1:]) 233 x = x.reshape(-1, plan_size, *x.shape[1:]) 234 if plan_size > 1: 235 inds = torch.argsort(values, dim=1, descending=True) 236 x = x[torch.arange(x.size(0)).unsqueeze(1), inds] 237 values = values[torch.arange(values.size(0)).unsqueeze(1), inds] 238 if return_chain: 239 chain = torch.stack(chain, dim=1) 240 return Sample(x, values, chain) 241 242 @torch.no_grad() 243 def conditional_sample(self, cond, horizon=None, **sample_kwargs): 244 """ 245 conditions : [ (time, state), ... ] 246 """ 247 device = self.betas.device 248 batch_size = len(cond[0]) 249 horizon = horizon or self.horizon 250 shape = (batch_size, horizon, self.transition_dim) 251 252 return self.p_sample_loop(shape, cond, **sample_kwargs) 253 254 def q_sample(self, x_start, t, noise=None): 255 """ 256 Arguments: 257 conditions (:obj:'tuple') [ (time, state), ... ] conditions of diffusion 258 t (:obj:'int') timestep of diffusion 259 noise (:obj:'tensor.float') timestep's noise of diffusion 260 """ 261 if noise is None: 262 noise = torch.randn_like(x_start) 263 264 sample = ( 265 extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 266 extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 267 ) 268 269 return sample 270 271 def p_losses(self, x_start, cond, t): 272 noise = torch.randn_like(x_start) 273 274 x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 275 x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) 276 277 x_recon = self.model(x_noisy, cond, t) 278 x_recon = apply_conditioning(x_recon, cond, self.action_dim) 279 280 assert noise.shape == x_recon.shape 281 282 if self.predict_epsilon: 283 loss = F.mse_loss(x_recon, noise, reduction='none') 284 a0_loss = (loss[:, 0, :self.action_dim] / self.loss_weights[0, :self.action_dim].to(loss.device)).mean() 285 loss = (loss * self.loss_weights.to(loss.device)).mean() 286 else: 287 loss = F.mse_loss(x_recon, x_start, reduction='none') 288 a0_loss = (loss[:, 0, :self.action_dim] / self.loss_weights[0, :self.action_dim].to(loss.device)).mean() 289 loss = (loss * self.loss_weights.to(loss.device)).mean() 290 return loss, a0_loss 291 292 def forward(self, cond, *args, **kwargs): 293 return self.conditional_sample(cond, *args, **kwargs) 294 295 296class ValueDiffusion(GaussianDiffusion): 297 """ 298 Overview: 299 Gaussian diffusion model for value function. 300 """ 301 302 def p_losses(self, x_start, cond, target, t): 303 noise = torch.randn_like(x_start) 304 x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 305 x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) 306 307 pred = self.model(x_noisy, cond, t) 308 loss = F.mse_loss(pred, target, reduction='none').mean() 309 log = { 310 'mean_pred': pred.mean().item(), 311 'max_pred': pred.max().item(), 312 'min_pred': pred.min().item(), 313 } 314 315 return loss, log 316 317 def forward(self, x, cond, t): 318 return self.model(x, cond, t) 319 320 321@MODEL_REGISTRY.register('pd') 322class PlanDiffuser(nn.Module): 323 """ 324 Overview: 325 Diffuser model for plan. 326 Arguments: 327 - diffuser_model (:obj:`str`): type of plan model 328 - diffuser_model_cfg (:obj:'dict') config of diffuser_model 329 - value_model (:obj:`str`): type of value model, if haven't use, set it as None 330 - value_model_cfg (:obj:`int`): config of value_model 331 - sample_kwargs : config of sample function 332 """ 333 334 def __init__( 335 self, diffuser_model: str, diffuser_model_cfg: dict, value_model: str, value_model_cfg: dict, **sample_kwargs 336 ): 337 super().__init__() 338 diffuser_model = eval(diffuser_model) 339 self.diffuser = diffuser_model(**diffuser_model_cfg) 340 self.value = None 341 if value_model: 342 value_model = eval(value_model) 343 self.value = value_model(**value_model_cfg) 344 self.sample_kwargs = sample_kwargs 345 346 def diffuser_loss(self, x_start, cond, t): 347 return self.diffuser.p_losses(x_start, cond, t) 348 349 def value_loss(self, x_start, cond, target, t): 350 return self.value.p_losses(x_start, cond, target, t) 351 352 def get_eval(self, cond, batch_size=1): 353 cond = self.repeat_cond(cond, batch_size) 354 if self.value: 355 samples = self.diffuser( 356 cond, sample_fn=n_step_guided_p_sample, plan_size=batch_size, guide=self.value, **self.sample_kwargs 357 ) 358 # extract action [eval_num, batch_size, horizon, transition_dim] 359 actions = samples.trajectories[:, :, :, :self.diffuser.action_dim] 360 action = actions[:, 0, 0] 361 return action 362 else: 363 samples = self.diffuser(cond, plan_size=batch_size) 364 return samples.trajectories[:, :, :, self.diffuser.action_dim:].squeeze(1) 365 366 def repeat_cond(self, cond, batch_size): 367 for k, v in cond.items(): 368 cond[k] = v.repeat_interleave(batch_size, dim=0) 369 return cond 370 371 372@MODEL_REGISTRY.register('dd') 373class GaussianInvDynDiffusion(nn.Module): 374 """ 375 Overview: 376 Gaussian diffusion model with Invdyn action model. 377 Arguments: 378 - model (:obj:`str`): type of model 379 - model_cfg (:obj:'dict') config of model 380 - horizon (:obj:`int`): horizon of trajectory 381 - obs_dim (:obj:`int`): Dim of the ovservation 382 - action_dim (:obj:`int`): Dim of the ation 383 - n_timesteps (:obj:`int`): Number of timesteps 384 - hidden_dim (:obj:'int'): hidden dim of inv_model 385 - returns_condition (:obj:'bool'): Whether use returns condition 386 - ar_inv (:obj:'bool'): Whether use inverse action learning 387 - train_only_inv (:obj:'bool'): Whether train inverse action model only 388 - predict_epsilon (:obj:'bool'): Whether predict epsilon 389 - condition_guidance_w (:obj:'float'): weight of condition guidance 390 - loss_discount (:obj:'float'): discount of loss 391 """ 392 393 def __init__( 394 self, 395 model: str, 396 model_cfg: dict, 397 horizon: int, 398 obs_dim: Union[int, SequenceType], 399 action_dim: Union[int, SequenceType], 400 n_timesteps: int = 1000, 401 hidden_dim: int = 256, 402 returns_condition: bool = False, 403 ar_inv: bool = False, 404 train_only_inv: bool = False, 405 predict_epsilon: bool = True, 406 condition_guidance_w: float = 0.1, 407 loss_discount: float = 1.0, 408 clip_denoised: bool = False, 409 ) -> None: 410 super().__init__() 411 self.horizon = horizon 412 self.obs_dim = obs_dim 413 self.action_dim = action_dim 414 self.transition_dim = obs_dim + action_dim 415 if type(model) == str: 416 model = eval(model) 417 self.model = model(**model_cfg) 418 self.ar_inv = ar_inv 419 self.train_only_inv = train_only_inv 420 self.predict_epsilon = predict_epsilon 421 self.condition_guidance_w = condition_guidance_w 422 423 self.inv_model = nn.Sequential( 424 nn.Linear(2 * self.obs_dim, hidden_dim), 425 nn.ReLU(), 426 nn.Linear(hidden_dim, hidden_dim), 427 nn.ReLU(), 428 nn.Linear(hidden_dim, self.action_dim), 429 ) 430 431 self.returns_condition = returns_condition 432 self.clip_denoised = clip_denoised 433 434 betas = cosine_beta_schedule(n_timesteps) 435 alphas = 1. - betas 436 alphas_cumprod = torch.cumprod(alphas, axis=0) 437 alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) 438 self.n_timesteps = int(n_timesteps) 439 440 self.register_buffer('betas', betas) 441 self.register_buffer('alphas_cumprod', alphas_cumprod) 442 self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 443 444 # calculations for diffusion q(x_t | x_{t-1}) and others 445 self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 446 self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 447 self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 448 self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 449 self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 450 451 # calculations for posterior q(x_{t-1} | x_t, x_0) 452 posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 453 self.register_buffer('posterior_variance', posterior_variance) 454 455 # log calculation clipped because the posterior variance 456 # is 0 at the beginning of the diffusion chain 457 self.register_buffer('posterior_log_variance_clipped', torch.log(torch.clamp(posterior_variance, min=1e-20))) 458 self.register_buffer('posterior_mean_coef1', betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 459 self.register_buffer( 460 'posterior_mean_coef2', (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod) 461 ) 462 463 self.loss_weights = self.get_loss_weights(loss_discount) 464 465 def get_loss_weights(self, discount: int): 466 self.action_weight = 1 467 dim_weights = torch.ones(self.obs_dim, dtype=torch.float32) 468 469 # decay loss with trajectory timestep: discount**t 470 discounts = discount ** torch.arange(self.horizon, dtype=torch.float) 471 discounts = discounts / discounts.mean() 472 loss_weights = torch.einsum('h,t->ht', discounts, dim_weights) 473 # Cause things are conditioned on t=0 474 if self.predict_epsilon: 475 loss_weights[0, :] = 0 476 477 return loss_weights 478 479 def predict_start_from_noise(self, x_t, t, noise): 480 """ 481 if self.predict_epsilon, model output is (scaled) noise; 482 otherwise, model predicts x0 directly 483 """ 484 if self.predict_epsilon: 485 return ( 486 extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 487 extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 488 ) 489 else: 490 return noise 491 492 def q_posterior(self, x_start, x_t, t): 493 """ 494 Arguments: 495 x_start (:obj:'tensor') noise trajectory in timestep 0 496 x_t (:obj:'tuple') noise trajectory in timestep t 497 t (:obj:'int') timestep of diffusion step 498 """ 499 posterior_mean = ( 500 extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 501 extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 502 ) 503 posterior_variance = extract(self.posterior_variance, t, x_t.shape) 504 posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 505 return posterior_mean, posterior_variance, posterior_log_variance_clipped 506 507 def p_mean_variance(self, x, cond, t, returns=None): 508 """ 509 Arguments: 510 x (:obj:'tensor') noise trajectory in timestep t 511 cond (:obj:'tuple') [ (time, state), ... ] state is init state of env, time = 0 512 t (:obj:'int') timestep of diffusion step 513 returns (:obj:'tensor') condition returns of trajectory, returns is normal return 514 returns: 515 model_mean (:obj:'tensor.float') 516 posterior_variance (:obj:'float') 517 posterior_log_variance (:obj:'float') 518 """ 519 if self.returns_condition: 520 # epsilon could be epsilon or x0 itself 521 epsilon_cond = self.model(x, cond, t, returns, use_dropout=False) 522 epsilon_uncond = self.model(x, cond, t, returns, force_dropout=True) 523 epsilon = epsilon_uncond + self.condition_guidance_w * (epsilon_cond - epsilon_uncond) 524 else: 525 epsilon = self.model(x, cond, t) 526 527 t = t.detach().to(torch.int64) 528 x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon) 529 530 if self.clip_denoised: 531 x_recon.clamp_(-1., 1.) 532 else: 533 assert RuntimeError() 534 535 model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 536 return model_mean, posterior_variance, posterior_log_variance 537 538 @torch.no_grad() 539 def p_sample(self, x, cond, t, returns=None): 540 """ 541 Arguments: 542 x (:obj:'tensor') noise trajectory in timestep t 543 cond (:obj:'tuple') [ (time, state), ... ] state is init state of env, time = 0 544 t (:obj:'int') timestep of diffusion step 545 returns (:obj:'tensor') condition returns of trajectory, returns is normal return 546 """ 547 b, *_, device = *x.shape, x.device 548 model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t, returns=returns) 549 noise = 0.5 * torch.randn_like(x) 550 # no noise when t == 0 551 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1, ) * (len(x.shape) - 1))) 552 return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 553 554 @torch.no_grad() 555 def p_sample_loop(self, shape, cond, returns=None, verbose=True, return_diffusion=False): 556 """ 557 Arguments: 558 shape (:obj:'tuple') (batch_size, horizon, self.obs_dim) 559 cond (:obj:'tuple') [ (time, state), ... ] state is init state of env, time = 0 560 returns (:obj:'tensor') condition returns of trajectory, returns is normal return 561 horizon (:obj:'int') horizon of trajectory 562 verbose (:obj:'bool') whether log diffusion progress 563 return_diffusion (:obj:'bool') whether use return diffusion 564 """ 565 device = self.betas.device 566 567 batch_size = shape[0] 568 x = 0.5 * torch.randn(shape, device=device) 569 # In this model, init state must be given by the env and without noise. 570 x = apply_conditioning(x, cond, 0) 571 572 if return_diffusion: 573 diffusion = [x] 574 575 for i in reversed(range(0, self.n_timesteps)): 576 timesteps = torch.full((batch_size, ), i, device=device, dtype=torch.long) 577 x = self.p_sample(x, cond, timesteps, returns) 578 x = apply_conditioning(x, cond, 0) 579 580 if return_diffusion: 581 diffusion.append(x) 582 583 if return_diffusion: 584 return x, torch.stack(diffusion, dim=1) 585 else: 586 return x 587 588 @torch.no_grad() 589 def conditional_sample(self, cond, returns=None, horizon=None, *args, **kwargs): 590 """ 591 Arguments: 592 conditions (:obj:'tuple') [ (time, state), ... ] state is init state of env, time is timestep of trajectory 593 returns (:obj:'tensor') condition returns of trajectory, returns is normal return 594 horizon (:obj:'int') horizon of trajectory 595 returns: 596 x (:obj:'tensor') tarjctory of env 597 """ 598 device = self.betas.device 599 batch_size = len(cond[0]) 600 horizon = horizon or self.horizon 601 shape = (batch_size, horizon, self.obs_dim) 602 603 return self.p_sample_loop(shape, cond, returns, *args, **kwargs) 604 605 def q_sample(self, x_start, t, noise=None): 606 """ 607 Arguments: 608 conditions (:obj:'tuple') [ (time, state), ... ] conditions of diffusion 609 t (:obj:'int') timestep of diffusion 610 noise (:obj:'tensor.float') timestep's noise of diffusion 611 """ 612 if noise is None: 613 noise = torch.randn_like(x_start) 614 615 sample = ( 616 extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 617 extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 618 ) 619 620 return sample 621 622 def p_losses(self, x_start, cond, t, returns=None): 623 noise = torch.randn_like(x_start) 624 625 x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 626 x_noisy = apply_conditioning(x_noisy, cond, 0) 627 628 x_recon = self.model(x_noisy, cond, t, returns) 629 630 if not self.predict_epsilon: 631 x_recon = apply_conditioning(x_recon, cond, 0) 632 633 assert noise.shape == x_recon.shape 634 635 if self.predict_epsilon: 636 loss = F.mse_loss(x_recon, noise, reduction='none') 637 loss = (loss * self.loss_weights.to(loss.device)).mean() 638 else: 639 loss = F.mse_loss(x_recon, x_start, reduction='none') 640 loss = (loss * self.loss_weights.to(loss.device)).mean() 641 642 return loss 643 644 def forward(self, cond, *args, **kwargs): 645 return self.conditional_sample(cond=cond, *args, **kwargs)