Skip to content

ding.policy.qgpo

ding.policy.qgpo

QGPOPolicy

Bases: Policy

Overview

Policy class of QGPO algorithm (https://arxiv.org/abs/2304.12824). Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning

Interfaces: __init__, forward, learn, eval, state_dict, load_state_dict

Full Source Code

../ding/policy/qgpo.py

1############################################################# 2# This QGPO model is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion 3############################################################# 4 5from typing import List, Dict, Any 6import torch 7from ding.utils import POLICY_REGISTRY 8from ding.utils.data import default_collate 9from ding.torch_utils import to_device 10from .base_policy import Policy 11 12 13@POLICY_REGISTRY.register('qgpo') 14class QGPOPolicy(Policy): 15 """ 16 Overview: 17 Policy class of QGPO algorithm (https://arxiv.org/abs/2304.12824). 18 Contrastive Energy Prediction for Exact Energy-Guided Diffusion Sampling in Offline Reinforcement Learning 19 Interfaces: 20 ``__init__``, ``forward``, ``learn``, ``eval``, ``state_dict``, ``load_state_dict`` 21 """ 22 23 config = dict( 24 # (str) RL policy register name (refer to function "POLICY_REGISTRY"). 25 type='qgpo', 26 # (bool) Whether to use cuda for network. 27 cuda=False, 28 # (bool type) on_policy: Determine whether on-policy or off-policy. 29 # on-policy setting influences the behaviour of buffer. 30 # Default False in QGPO. 31 on_policy=False, 32 multi_agent=False, 33 model=dict( 34 qgpo_critic=dict( 35 # (float) The scale of the energy guidance when training qt. 36 # \pi_{behavior}\exp(f(s,a)) \propto \pi_{behavior}\exp(alpha * Q(s,a)) 37 alpha=3, 38 # (float) The scale of the energy guidance when training q0. 39 # \mathcal{T}Q(s,a)=r(s,a)+\mathbb{E}_{s'\sim P(s'|s,a),a'\sim\pi_{support}(a'|s')}Q(s',a') 40 # \pi_{support} \propto \pi_{behavior}\exp(q_alpha * Q(s,a)) 41 q_alpha=1, 42 ), 43 device='cuda', 44 # obs_dim 45 # action_dim 46 ), 47 learn=dict( 48 # learning rate for behavior model training 49 learning_rate=1e-4, 50 # batch size during the training of behavior model 51 batch_size=4096, 52 # batch size during the training of q value 53 batch_size_q=256, 54 # number of fake action support 55 M=16, 56 # number of diffusion time steps 57 diffusion_steps=15, 58 # training iterations when behavior model is fixed 59 behavior_policy_stop_training_iter=600000, 60 # training iterations when energy-guided policy begin training 61 energy_guided_policy_begin_training_iter=600000, 62 # training iterations when q value stop training, default None means no limit 63 q_value_stop_training_iter=1100000, 64 ), 65 eval=dict( 66 # energy guidance scale for policy in evaluation 67 # \pi_{evaluation} \propto \pi_{behavior}\exp(guidance_scale * alpha * Q(s,a)) 68 guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], 69 ), 70 ) 71 72 def _init_learn(self) -> None: 73 """ 74 Overview: 75 Learn mode initialization method. For QGPO, it mainly contains the optimizer, \ 76 algorithm-specific arguments such as qt_update_momentum, discount, behavior_policy_stop_training_iter, \ 77 energy_guided_policy_begin_training_iter and q_value_stop_training_iter, etc. 78 This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. 79 """ 80 self.cuda = self._cfg.cuda 81 82 self.behavior_model_optimizer = torch.optim.Adam( 83 self._model.score_model.parameters(), lr=self._cfg.learn.learning_rate 84 ) 85 self.q_optimizer = torch.optim.Adam(self._model.q.q0.parameters(), lr=3e-4) 86 self.qt_optimizer = torch.optim.Adam(self._model.q.qt.parameters(), lr=3e-4) 87 88 self.qt_update_momentum = 0.005 89 self.discount = 0.99 90 91 self.behavior_policy_stop_training_iter = self._cfg.learn.behavior_policy_stop_training_iter 92 self.energy_guided_policy_begin_training_iter = self._cfg.learn.energy_guided_policy_begin_training_iter 93 self.q_value_stop_training_iter = self._cfg.learn.q_value_stop_training_iter 94 95 def _forward_learn(self, data: dict) -> Dict[str, Any]: 96 """ 97 Overview: 98 Forward function for learning mode. 99 The training of QGPO algorithm is based on contrastive energy prediction, \ 100 which needs true action and fake action. The true action is sampled from the dataset, and the fake action \ 101 is sampled from the action support generated by the behavior policy. 102 The training process is divided into two stages: 103 1. Train the behavior model, which is modeled as a diffusion model by parameterizing the score function. 104 2. Train the Q function by fake action support generated by the behavior model. 105 3. Train the energy-guided policy by the Q function. 106 Arguments: 107 - data (:obj:`dict`): Dict type data. 108 Returns: 109 - result (:obj:`dict`): Dict type data of algorithm results. 110 """ 111 112 if self.cuda: 113 data = to_device(data, self._device) 114 115 s = data['s'] 116 a = data['a'] 117 r = data['r'] 118 s_ = data['s_'] 119 d = data['d'] 120 fake_a = data['fake_a'] 121 fake_a_ = data['fake_a_'] 122 123 # training behavior model 124 if self.behavior_policy_stop_training_iter > 0: 125 126 behavior_model_training_loss = self._model.score_model_loss_fn(a, s) 127 128 self.behavior_model_optimizer.zero_grad() 129 behavior_model_training_loss.backward() 130 self.behavior_model_optimizer.step() 131 132 self.behavior_policy_stop_training_iter -= 1 133 behavior_model_training_loss = behavior_model_training_loss.item() 134 else: 135 behavior_model_training_loss = 0 136 137 # training Q function 138 self.energy_guided_policy_begin_training_iter -= 1 139 self.q_value_stop_training_iter -= 1 140 if self.energy_guided_policy_begin_training_iter < 0: 141 if self.q_value_stop_training_iter > 0: 142 q0_loss = self._model.q_loss_fn(a, s, r, s_, d, fake_a_, discount=self.discount) 143 144 self.q_optimizer.zero_grad() 145 q0_loss.backward() 146 self.q_optimizer.step() 147 148 # Update target 149 for param, target_param in zip(self._model.q.q0.parameters(), self._model.q.q0_target.parameters()): 150 target_param.data.copy_( 151 self.qt_update_momentum * param.data + (1 - self.qt_update_momentum) * target_param.data 152 ) 153 154 q0_loss = q0_loss.item() 155 156 else: 157 q0_loss = 0 158 qt_loss = self._model.qt_loss_fn(s, fake_a) 159 160 self.qt_optimizer.zero_grad() 161 qt_loss.backward() 162 self.qt_optimizer.step() 163 164 qt_loss = qt_loss.item() 165 166 else: 167 q0_loss = 0 168 qt_loss = 0 169 170 total_loss = behavior_model_training_loss + q0_loss + qt_loss 171 172 return dict( 173 total_loss=total_loss, 174 behavior_model_training_loss=behavior_model_training_loss, 175 q0_loss=q0_loss, 176 qt_loss=qt_loss, 177 ) 178 179 def _init_collect(self) -> None: 180 """ 181 Overview: 182 Collect mode initialization method. Not supported for QGPO. 183 """ 184 pass 185 186 def _forward_collect(self) -> None: 187 """ 188 Overview: 189 Forward function for collect mode. Not supported for QGPO. 190 """ 191 pass 192 193 def _init_eval(self) -> None: 194 """ 195 Overview: 196 Eval mode initialization method. For QGPO, it mainly contains the guidance_scale and diffusion_steps, etc. 197 This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. 198 """ 199 200 self.diffusion_steps = self._cfg.eval.diffusion_steps 201 202 def _forward_eval(self, data: dict, guidance_scale: float) -> dict: 203 """ 204 Overview: 205 Forward function for eval mode. The eval process is based on the energy-guided policy, \ 206 which is modeled as a diffusion model by parameterizing the score function. 207 Arguments: 208 - data (:obj:`dict`): Dict type data. 209 - guidance_scale (:obj:`float`): The scale of the energy guidance. 210 Returns: 211 - output (:obj:`dict`): Dict type data of algorithm output. 212 """ 213 214 data_id = list(data.keys()) 215 states = default_collate(list(data.values())) 216 actions = self._model.select_actions( 217 states, diffusion_steps=self.diffusion_steps, guidance_scale=guidance_scale 218 ) 219 output = actions 220 221 return {i: {"action": d} for i, d in zip(data_id, output)} 222 223 def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 224 """ 225 Overview: 226 Get the train sample from the replay buffer, currently not supported for QGPO. 227 Arguments: 228 - transitions (:obj:`List[Dict[str, Any]]`): The data from replay buffer. 229 Returns: 230 - samples (:obj:`List[Dict[str, Any]]`): The data for training. 231 """ 232 pass 233 234 def _process_transition(self) -> None: 235 """ 236 Overview: 237 Process the transition data, currently not supported for QGPO. 238 """ 239 pass 240 241 def _state_dict_learn(self) -> Dict[str, Any]: 242 """ 243 Overview: 244 Return the state dict for saving. 245 Returns: 246 - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict. 247 """ 248 return { 249 'model': self._model.state_dict(), 250 'behavior_model_optimizer': self.behavior_model_optimizer.state_dict(), 251 } 252 253 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 254 """ 255 Overview: 256 Load the state dict. 257 Arguments: 258 - state_dict (:obj:`Dict[str, Any]`): Dict type data of state dict. 259 """ 260 self._model.load_state_dict(state_dict['model']) 261 self.behavior_model_optimizer.load_state_dict(state_dict['behavior_model_optimizer']) 262 263 def _monitor_vars_learn(self) -> List[str]: 264 """ 265 Overview: 266 Return the variables names to be monitored. 267 """ 268 return ['total_loss', 'behavior_model_training_loss', 'q0_loss', 'qt_loss']