Skip to content

ding.model.template.qgpo

ding.model.template.qgpo

TwinQ

Bases: Module

Overview

Twin Q network for QGPO, which has two Q networks.

Interfaces: __init__, forward, both

__init__(action_dim, state_dim)

Overview

Initialization of Twin Q.

Arguments: - action_dim (:obj:int): The dimension of action. - state_dim (:obj:int): The dimension of state.

both(action, condition=None)

Overview

Return the output of two Q networks.

Arguments: - action (:obj:torch.Tensor): The input action. - condition (:obj:torch.Tensor): The input condition.

forward(action, condition=None)

Overview

Return the minimum output of two Q networks.

Arguments: - action (:obj:torch.Tensor): The input action. - condition (:obj:torch.Tensor): The input condition.

GuidanceQt

Bases: Module

Overview

Energy Guidance Qt network for QGPO. In the origin paper, the energy guidance is trained by CEP method.

Interfaces: __init__, forward

__init__(action_dim, state_dim, time_embed_dim=32)

Overview

Initialization of Guidance Qt.

Arguments: - action_dim (:obj:int): The dimension of action. - state_dim (:obj:int): The dimension of state. - time_embed_dim (:obj:int): The dimension of time embedding. The time embedding is a Gaussian Fourier Feature tensor.

forward(action, t, condition=None)

Overview

Return the output of Guidance Qt.

Arguments: - action (:obj:torch.Tensor): The input action. - t (:obj:torch.Tensor): The input time. - condition (:obj:torch.Tensor): The input condition.

QGPOCritic

Bases: Module

Overview

QGPO critic network.

Interfaces: __init__, forward, calculateQ, calculate_guidance

__init__(device, cfg, action_dim, state_dim)

Overview

Initialization of QGPO critic.

Arguments: - device (:obj:torch.device): The device to use. - cfg (:obj:EasyDict): The config dict. - action_dim (:obj:int): The dimension of action. - state_dim (:obj:int): The dimension of state.

calculate_guidance(a, t, condition=None, guidance_scale=1.0)

Overview

Calculate the guidance for conditional sampling.

Arguments: - a (:obj:torch.Tensor): The input action. - t (:obj:torch.Tensor): The input time. - condition (:obj:torch.Tensor): The input condition. - guidance_scale (:obj:float): The scale of guidance.

forward(a, condition=None)

Overview

Return the output of QGPO critic.

Arguments: - a (:obj:torch.Tensor): The input action. - condition (:obj:torch.Tensor): The input condition.

calculateQ(a, condition=None)

Overview

Return the output of QGPO critic.

Arguments: - a (:obj:torch.Tensor): The input action. - condition (:obj:torch.Tensor): The input condition.

ScoreNet

Bases: Module

Overview

Score-based generative model for QGPO.

Interfaces: __init__, forward

__init__(device, input_dim, output_dim, embed_dim=32)

Overview

Initialization of ScoreNet.

Arguments: - device (:obj:torch.device): The device to use. - input_dim (:obj:int): The dimension of input. - output_dim (:obj:int): The dimension of output. - embed_dim (:obj:int): The dimension of time embedding. The time embedding is a Gaussian Fourier Feature tensor.

forward(x, t, condition)

Overview

Return the output of ScoreNet.

Arguments: - x (:obj:torch.Tensor): The input tensor. - t (:obj:torch.Tensor): The input time. - condition (:obj:torch.Tensor): The input condition.

QGPO

Bases: Module

Overview

Model of QGPO algorithm.

Interfaces: __init__, calculateQ, select_actions, sample, score_model_loss_fn, q_loss_fn, qt_loss_fn

__init__(cfg)

Overview

Initialization of QGPO.

Arguments: - cfg (:obj:EasyDict): The config dict.

calculateQ(s, a)

Overview

Calculate the Q value.

Arguments: - s (:obj:torch.Tensor): The input state. - a (:obj:torch.Tensor): The input action.

select_actions(states, diffusion_steps=15, guidance_scale=1.0)

Overview

Select actions for conditional sampling.

Arguments: - states (:obj:list): The input states. - diffusion_steps (:obj:int): The diffusion steps. - guidance_scale (:obj:float): The scale of guidance.

sample(states, sample_per_state=16, diffusion_steps=15, guidance_scale=1.0)

Overview

Sample actions for conditional sampling.

Arguments: - states (:obj:list): The input states. - sample_per_state (:obj:int): The number of samples per state. - diffusion_steps (:obj:int): The diffusion steps. - guidance_scale (:obj:float): The scale of guidance.

score_model_loss_fn(x, s, eps=0.001)

Overview

The loss function for training score-based generative models.

Arguments: model: A PyTorch model instance that represents a time-dependent score-based model. x: A mini-batch of training data. eps: A tolerance value for numerical stability.

q_loss_fn(a, s, r, s_, d, fake_a_, discount=0.99)

Overview

The loss function for training Q function.

Arguments: - a (:obj:torch.Tensor): The input action. - s (:obj:torch.Tensor): The input state. - r (:obj:torch.Tensor): The input reward. - s_ (:obj:torch.Tensor): The input next state. - d (:obj:torch.Tensor): The input done. - fake_a (:obj:torch.Tensor): The input fake action. - discount (:obj:float): The discount factor.

qt_loss_fn(s, fake_a)

Overview

The loss function for training Guidance Qt.

Arguments: - s (:obj:torch.Tensor): The input state. - fake_a (:obj:torch.Tensor): The input fake action.

marginal_prob_std(t, device)

Overview

Compute the mean and standard deviation of \(p_{0t}(x(t) | x(0))\).

Arguments: - t (:obj:torch.Tensor): The input time. - device (:obj:torch.device): The device to use.

Full Source Code

../ding/model/template/qgpo.py

1############################################################# 2# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion 3############################################################# 4 5from easydict import EasyDict 6import torch 7import torch.nn as nn 8import torch.nn.functional as F 9import copy 10from ding.torch_utils import MLP 11from ding.torch_utils.diffusion_SDE.dpm_solver_pytorch import DPM_Solver, NoiseScheduleVP 12from ding.model.common.encoder import GaussianFourierProjectionTimeEncoder 13from ding.torch_utils.network.res_block import TemporalSpatialResBlock 14 15 16def marginal_prob_std(t, device): 17 """ 18 Overview: 19 Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. 20 Arguments: 21 - t (:obj:`torch.Tensor`): The input time. 22 - device (:obj:`torch.device`): The device to use. 23 """ 24 25 t = torch.tensor(t, device=device) 26 beta_1 = 20.0 27 beta_0 = 0.1 28 log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 29 alpha_t = torch.exp(log_mean_coeff) 30 std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) 31 return alpha_t, std 32 33 34class TwinQ(nn.Module): 35 """ 36 Overview: 37 Twin Q network for QGPO, which has two Q networks. 38 Interfaces: 39 ``__init__``, ``forward``, ``both`` 40 """ 41 42 def __init__(self, action_dim, state_dim): 43 """ 44 Overview: 45 Initialization of Twin Q. 46 Arguments: 47 - action_dim (:obj:`int`): The dimension of action. 48 - state_dim (:obj:`int`): The dimension of state. 49 """ 50 super().__init__() 51 self.q1 = MLP( 52 in_channels=state_dim + action_dim, 53 hidden_channels=256, 54 out_channels=1, 55 activation=nn.ReLU(), 56 layer_num=4, 57 output_activation=False 58 ) 59 self.q2 = MLP( 60 in_channels=state_dim + action_dim, 61 hidden_channels=256, 62 out_channels=1, 63 activation=nn.ReLU(), 64 layer_num=4, 65 output_activation=False 66 ) 67 68 def both(self, action, condition=None): 69 """ 70 Overview: 71 Return the output of two Q networks. 72 Arguments: 73 - action (:obj:`torch.Tensor`): The input action. 74 - condition (:obj:`torch.Tensor`): The input condition. 75 """ 76 as_ = torch.cat([action, condition], -1) if condition is not None else action 77 return self.q1(as_), self.q2(as_) 78 79 def forward(self, action, condition=None): 80 """ 81 Overview: 82 Return the minimum output of two Q networks. 83 Arguments: 84 - action (:obj:`torch.Tensor`): The input action. 85 - condition (:obj:`torch.Tensor`): The input condition. 86 """ 87 return torch.min(*self.both(action, condition)) 88 89 90class GuidanceQt(nn.Module): 91 """ 92 Overview: 93 Energy Guidance Qt network for QGPO. \ 94 In the origin paper, the energy guidance is trained by CEP method. 95 Interfaces: 96 ``__init__``, ``forward`` 97 """ 98 99 def __init__(self, action_dim, state_dim, time_embed_dim=32): 100 """ 101 Overview: 102 Initialization of Guidance Qt. 103 Arguments: 104 - action_dim (:obj:`int`): The dimension of action. 105 - state_dim (:obj:`int`): The dimension of state. 106 - time_embed_dim (:obj:`int`): The dimension of time embedding. \ 107 The time embedding is a Gaussian Fourier Feature tensor. 108 """ 109 super().__init__() 110 self.qt = MLP( 111 in_channels=action_dim + time_embed_dim + state_dim, 112 hidden_channels=256, 113 out_channels=1, 114 activation=torch.nn.SiLU(), 115 layer_num=4, 116 output_activation=False 117 ) 118 self.embed = nn.Sequential( 119 GaussianFourierProjectionTimeEncoder(embed_dim=time_embed_dim), nn.Linear(time_embed_dim, time_embed_dim) 120 ) 121 122 def forward(self, action, t, condition=None): 123 """ 124 Overview: 125 Return the output of Guidance Qt. 126 Arguments: 127 - action (:obj:`torch.Tensor`): The input action. 128 - t (:obj:`torch.Tensor`): The input time. 129 - condition (:obj:`torch.Tensor`): The input condition. 130 """ 131 embed = self.embed(t) 132 ats = torch.cat([action, embed, condition], -1) if condition is not None else torch.cat([action, embed], -1) 133 return self.qt(ats) 134 135 136class QGPOCritic(nn.Module): 137 """ 138 Overview: 139 QGPO critic network. 140 Interfaces: 141 ``__init__``, ``forward``, ``calculateQ``, ``calculate_guidance`` 142 """ 143 144 def __init__(self, device, cfg, action_dim, state_dim) -> None: 145 """ 146 Overview: 147 Initialization of QGPO critic. 148 Arguments: 149 - device (:obj:`torch.device`): The device to use. 150 - cfg (:obj:`EasyDict`): The config dict. 151 - action_dim (:obj:`int`): The dimension of action. 152 - state_dim (:obj:`int`): The dimension of state. 153 """ 154 155 super().__init__() 156 # is state_dim is 0 means unconditional guidance 157 assert state_dim > 0 158 # only apply to conditional sampling here 159 self.device = device 160 self.q0 = TwinQ(action_dim, state_dim).to(self.device) 161 self.q0_target = copy.deepcopy(self.q0).requires_grad_(False).to(self.device) 162 self.qt = GuidanceQt(action_dim, state_dim).to(self.device) 163 164 self.alpha = cfg.alpha 165 self.q_alpha = cfg.q_alpha 166 167 def calculate_guidance(self, a, t, condition=None, guidance_scale=1.0): 168 """ 169 Overview: 170 Calculate the guidance for conditional sampling. 171 Arguments: 172 - a (:obj:`torch.Tensor`): The input action. 173 - t (:obj:`torch.Tensor`): The input time. 174 - condition (:obj:`torch.Tensor`): The input condition. 175 - guidance_scale (:obj:`float`): The scale of guidance. 176 """ 177 178 with torch.enable_grad(): 179 a.requires_grad_(True) 180 Q_t = self.qt(a, t, condition) 181 guidance = guidance_scale * torch.autograd.grad(torch.sum(Q_t), a)[0] 182 return guidance.detach() 183 184 def forward(self, a, condition=None): 185 """ 186 Overview: 187 Return the output of QGPO critic. 188 Arguments: 189 - a (:obj:`torch.Tensor`): The input action. 190 - condition (:obj:`torch.Tensor`): The input condition. 191 """ 192 193 return self.q0(a, condition) 194 195 def calculateQ(self, a, condition=None): 196 """ 197 Overview: 198 Return the output of QGPO critic. 199 Arguments: 200 - a (:obj:`torch.Tensor`): The input action. 201 - condition (:obj:`torch.Tensor`): The input condition. 202 """ 203 204 return self(a, condition) 205 206 207class ScoreNet(nn.Module): 208 """ 209 Overview: 210 Score-based generative model for QGPO. 211 Interfaces: 212 ``__init__``, ``forward`` 213 """ 214 215 def __init__(self, device, input_dim, output_dim, embed_dim=32): 216 """ 217 Overview: 218 Initialization of ScoreNet. 219 Arguments: 220 - device (:obj:`torch.device`): The device to use. 221 - input_dim (:obj:`int`): The dimension of input. 222 - output_dim (:obj:`int`): The dimension of output. 223 - embed_dim (:obj:`int`): The dimension of time embedding. \ 224 The time embedding is a Gaussian Fourier Feature tensor. 225 """ 226 227 super().__init__() 228 229 # origin score base 230 self.output_dim = output_dim 231 self.embed = nn.Sequential( 232 GaussianFourierProjectionTimeEncoder(embed_dim=embed_dim), nn.Linear(embed_dim, embed_dim) 233 ) 234 235 self.device = device 236 self.pre_sort_condition = nn.Sequential(nn.Linear(input_dim - output_dim, 32), torch.nn.SiLU()) 237 self.sort_t = nn.Sequential( 238 nn.Linear(64, 128), 239 torch.nn.SiLU(), 240 nn.Linear(128, 128), 241 ) 242 self.down_block1 = TemporalSpatialResBlock(output_dim, 512) 243 self.down_block2 = TemporalSpatialResBlock(512, 256) 244 self.down_block3 = TemporalSpatialResBlock(256, 128) 245 self.middle1 = TemporalSpatialResBlock(128, 128) 246 self.up_block3 = TemporalSpatialResBlock(256, 256) 247 self.up_block2 = TemporalSpatialResBlock(512, 512) 248 self.last = nn.Linear(1024, output_dim) 249 250 def forward(self, x, t, condition): 251 """ 252 Overview: 253 Return the output of ScoreNet. 254 Arguments: 255 - x (:obj:`torch.Tensor`): The input tensor. 256 - t (:obj:`torch.Tensor`): The input time. 257 - condition (:obj:`torch.Tensor`): The input condition. 258 """ 259 260 embed = self.embed(t) 261 embed = torch.cat([self.pre_sort_condition(condition), embed], dim=-1) 262 embed = self.sort_t(embed) 263 d1 = self.down_block1(x, embed) 264 d2 = self.down_block2(d1, embed) 265 d3 = self.down_block3(d2, embed) 266 u3 = self.middle1(d3, embed) 267 u2 = self.up_block3(torch.cat([d3, u3], dim=-1), embed) 268 u1 = self.up_block2(torch.cat([d2, u2], dim=-1), embed) 269 u0 = torch.cat([d1, u1], dim=-1) 270 h = self.last(u0) 271 self.h = h 272 # Normalize output 273 return h / marginal_prob_std(t, device=self.device)[1][..., None] 274 275 276class QGPO(nn.Module): 277 """ 278 Overview: 279 Model of QGPO algorithm. 280 Interfaces: 281 ``__init__``, ``calculateQ``, ``select_actions``, ``sample``, ``score_model_loss_fn``, ``q_loss_fn``, \ 282 ``qt_loss_fn`` 283 """ 284 285 def __init__(self, cfg: EasyDict) -> None: 286 """ 287 Overview: 288 Initialization of QGPO. 289 Arguments: 290 - cfg (:obj:`EasyDict`): The config dict. 291 """ 292 293 super(QGPO, self).__init__() 294 self.device = cfg.device 295 self.obs_dim = cfg.obs_dim 296 self.action_dim = cfg.action_dim 297 298 self.noise_schedule = NoiseScheduleVP(schedule='linear') 299 300 self.score_model = ScoreNet( 301 device=self.device, 302 input_dim=self.obs_dim + self.action_dim, 303 output_dim=self.action_dim, 304 ) 305 306 self.q = QGPOCritic(self.device, cfg.qgpo_critic, action_dim=self.action_dim, state_dim=self.obs_dim) 307 308 def calculateQ(self, s, a): 309 """ 310 Overview: 311 Calculate the Q value. 312 Arguments: 313 - s (:obj:`torch.Tensor`): The input state. 314 - a (:obj:`torch.Tensor`): The input action. 315 """ 316 317 return self.q(a, s) 318 319 def select_actions(self, states, diffusion_steps=15, guidance_scale=1.0): 320 """ 321 Overview: 322 Select actions for conditional sampling. 323 Arguments: 324 - states (:obj:`list`): The input states. 325 - diffusion_steps (:obj:`int`): The diffusion steps. 326 - guidance_scale (:obj:`float`): The scale of guidance. 327 """ 328 329 def forward_dpm_wrapper_fn(x, t): 330 score = self.score_model(x, t, condition=states) 331 result = -(score + 332 self.q.calculate_guidance(x, t, states, guidance_scale=guidance_scale)) * marginal_prob_std( 333 t, device=self.device 334 )[1][..., None] 335 return result 336 337 self.eval() 338 multiple_input = True 339 with torch.no_grad(): 340 states = torch.FloatTensor(states).to(self.device) 341 if states.dim == 1: 342 states = states.unsqueeze(0) 343 multiple_input = False 344 num_states = states.shape[0] 345 346 init_x = torch.randn(states.shape[0], self.action_dim, device=self.device) 347 results = DPM_Solver( 348 forward_dpm_wrapper_fn, self.noise_schedule, predict_x0=True 349 ).sample( 350 init_x, steps=diffusion_steps, order=2 351 ).cpu().numpy() 352 353 actions = results.reshape(num_states, self.action_dim).copy() # <bz, A> 354 355 out_actions = [actions[i] for i in range(actions.shape[0])] if multiple_input else actions[0] 356 self.train() 357 return out_actions 358 359 def sample(self, states, sample_per_state=16, diffusion_steps=15, guidance_scale=1.0): 360 """ 361 Overview: 362 Sample actions for conditional sampling. 363 Arguments: 364 - states (:obj:`list`): The input states. 365 - sample_per_state (:obj:`int`): The number of samples per state. 366 - diffusion_steps (:obj:`int`): The diffusion steps. 367 - guidance_scale (:obj:`float`): The scale of guidance. 368 """ 369 370 def forward_dpm_wrapper_fn(x, t): 371 score = self.score_model(x, t, condition=states) 372 result = -(score + self.q.calculate_guidance(x, t, states, guidance_scale=guidance_scale)) \ 373 * marginal_prob_std(t, device=self.device)[1][..., None] 374 return result 375 376 self.eval() 377 num_states = states.shape[0] 378 with torch.no_grad(): 379 states = torch.FloatTensor(states).to(self.device) 380 states = torch.repeat_interleave(states, sample_per_state, dim=0) 381 382 init_x = torch.randn(states.shape[0], self.action_dim, device=self.device) 383 results = DPM_Solver( 384 forward_dpm_wrapper_fn, self.noise_schedule, predict_x0=True 385 ).sample( 386 init_x, steps=diffusion_steps, order=2 387 ).cpu().numpy() 388 389 actions = results[:, :].reshape(num_states, sample_per_state, self.action_dim).copy() 390 391 self.train() 392 return actions 393 394 def score_model_loss_fn(self, x, s, eps=1e-3): 395 """ 396 Overview: 397 The loss function for training score-based generative models. 398 Arguments: 399 model: A PyTorch model instance that represents a \ 400 time-dependent score-based model. 401 x: A mini-batch of training data. 402 eps: A tolerance value for numerical stability. 403 """ 404 405 random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps 406 z = torch.randn_like(x) 407 alpha_t, std = marginal_prob_std(random_t, device=x.device) 408 perturbed_x = x * alpha_t[:, None] + z * std[:, None] 409 score = self.score_model(perturbed_x, random_t, condition=s) 410 loss = torch.mean(torch.sum((score * std[:, None] + z) ** 2, dim=(1, ))) 411 return loss 412 413 def q_loss_fn(self, a, s, r, s_, d, fake_a_, discount=0.99): 414 """ 415 Overview: 416 The loss function for training Q function. 417 Arguments: 418 - a (:obj:`torch.Tensor`): The input action. 419 - s (:obj:`torch.Tensor`): The input state. 420 - r (:obj:`torch.Tensor`): The input reward. 421 - s\_ (:obj:`torch.Tensor`): The input next state. 422 - d (:obj:`torch.Tensor`): The input done. 423 - fake_a (:obj:`torch.Tensor`): The input fake action. 424 - discount (:obj:`float`): The discount factor. 425 """ 426 427 with torch.no_grad(): 428 softmax = nn.Softmax(dim=1) 429 next_energy = self.q.q0_target(fake_a_, torch.stack([s_] * fake_a_.shape[1], axis=1)).detach().squeeze() 430 next_v = torch.sum(softmax(self.q.q_alpha * next_energy) * next_energy, dim=-1, keepdim=True) 431 # Update Q function 432 targets = r + (1. - d.float()) * discount * next_v.detach() 433 qs = self.q.q0.both(a, s) 434 q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs) 435 436 return q_loss 437 438 def qt_loss_fn(self, s, fake_a): 439 """ 440 Overview: 441 The loss function for training Guidance Qt. 442 Arguments: 443 - s (:obj:`torch.Tensor`): The input state. 444 - fake_a (:obj:`torch.Tensor`): The input fake action. 445 """ 446 447 # input many s <bz, S> anction <bz, M, A>, 448 energy = self.q.q0_target(fake_a, torch.stack([s] * fake_a.shape[1], axis=1)).detach().squeeze() 449 450 # CEP guidance method, as proposed in the paper 451 logsoftmax = nn.LogSoftmax(dim=1) 452 softmax = nn.Softmax(dim=1) 453 454 x0_data_energy = energy * self.q.alpha 455 random_t = torch.rand((fake_a.shape[0], ), device=self.device) * (1. - 1e-3) + 1e-3 456 random_t = torch.stack([random_t] * fake_a.shape[1], dim=1) 457 z = torch.randn_like(fake_a) 458 alpha_t, std = marginal_prob_std(random_t, device=self.device) 459 perturbed_fake_a = fake_a * alpha_t[..., None] + z * std[..., None] 460 xt_model_energy = self.q.qt(perturbed_fake_a, random_t, torch.stack([s] * fake_a.shape[1], axis=1)).squeeze() 461 p_label = softmax(x0_data_energy) 462 463 # <bz,M> 464 qt_loss = -torch.mean(torch.sum(p_label * logsoftmax(xt_model_energy), axis=-1)) 465 return qt_loss