Skip to content

ding.policy.vmpo

ding.policy.vmpo

VMPOPolicy

Bases: PPOPolicy

Overview

On-policy discrete VMPO policy built on PPO's data flow (collector, train-sample processing, and actor-critic model interface). The learn loss uses VMPO-style: - top-k advantage weighting with temperature dual variable (eta) - KL trust-region penalty with adaptive dual variable (alpha)

default_model()

Overview

VMPO for this setup defaults to a VAC-style actor-critic with GTrXL core.

Returns: - model_info (:obj:Tuple[str, List[str]]): The registered model name and import path.

Full Source Code

../ding/policy/vmpo.py

1from typing import Any, Dict, List, Tuple 2 3import torch 4import torch.nn.functional as F 5 6from ding.rl_utils import gae, gae_data 7from ding.torch_utils import to_device, to_dtype 8from ding.utils import POLICY_REGISTRY, split_data_generator 9 10from .common_utils import default_preprocess_learn 11from .ppo import PPOPolicy 12 13 14@POLICY_REGISTRY.register('vmpo') 15class VMPOPolicy(PPOPolicy): 16 """ 17 Overview: 18 On-policy discrete VMPO policy built on PPO's data flow (collector, train-sample processing, 19 and actor-critic model interface). The learn loss uses VMPO-style: 20 - top-k advantage weighting with temperature dual variable (eta) 21 - KL trust-region penalty with adaptive dual variable (alpha) 22 """ 23 24 config = dict( 25 type='vmpo', 26 cuda=False, 27 on_policy=True, 28 priority=False, 29 priority_IS_weight=False, 30 recompute_adv=True, 31 action_space='discrete', 32 nstep_return=False, 33 multi_agent=False, 34 transition_with_policy_data=True, 35 learn=dict( 36 epoch_per_collect=10, 37 batch_size=64, 38 learning_rate=3e-4, 39 lr_scheduler=None, 40 value_weight=0.5, 41 entropy_weight=0.001, 42 adv_norm=True, 43 value_norm=True, 44 ppo_param_init=False, 45 grad_clip_type='clip_norm', 46 grad_clip_value=0.5, 47 ignore_done=False, 48 topk_fraction=0.5, 49 epsilon_eta=0.1, 50 epsilon_kl=0.02, 51 temperature_init=1.0, 52 temperature_lr=1e-4, 53 alpha_init=1.0, 54 alpha_lr=1e-4, 55 ), 56 collect=dict( 57 unroll_len=1, 58 discount_factor=0.99, 59 gae_lambda=0.95, 60 ), 61 eval=dict(), 62 ) 63 64 def default_model(self) -> Tuple[str, List[str]]: 65 """ 66 Overview: 67 VMPO for this setup defaults to a VAC-style actor-critic with GTrXL core. 68 Returns: 69 - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and import path. 70 """ 71 return 'gtrxl_vac', ['ding.model.template.vac'] 72 73 def _init_learn(self) -> None: 74 super()._init_learn() 75 assert self._action_space == 'discrete', "Current VMPOPolicy implementation supports discrete action only." 76 77 self._topk_fraction = float(self._cfg.learn.topk_fraction) 78 if not 0.0 < self._topk_fraction <= 1.0: 79 raise ValueError(f"`topk_fraction` must be in (0, 1], got {self._topk_fraction}.") 80 81 self._epsilon_eta = float(self._cfg.learn.epsilon_eta) 82 self._epsilon_kl = float(self._cfg.learn.epsilon_kl) 83 84 eta_init = torch.tensor(float(self._cfg.learn.temperature_init), dtype=torch.float32, device=self._device) 85 eta_init = torch.clamp(eta_init, min=1e-6) 86 self._log_eta = torch.nn.Parameter(torch.log(torch.expm1(eta_init))) 87 self._eta_optimizer = torch.optim.Adam([self._log_eta], lr=float(self._cfg.learn.temperature_lr)) 88 89 alpha_init = torch.tensor(float(self._cfg.learn.alpha_init), dtype=torch.float32, device=self._device) 90 alpha_init = torch.clamp(alpha_init, min=1e-6) 91 self._log_alpha = torch.nn.Parameter(torch.log(torch.expm1(alpha_init))) 92 self._alpha_optimizer = torch.optim.Adam([self._log_alpha], lr=float(self._cfg.learn.alpha_lr)) 93 94 @staticmethod 95 def _positive(log_param: torch.Tensor) -> torch.Tensor: 96 return F.softplus(log_param) + 1e-8 97 98 def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 99 data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) 100 if self._cuda: 101 data = to_device(data, self._device) 102 data['obs'] = to_dtype(data['obs'], torch.float32) 103 if 'next_obs' in data: 104 data['next_obs'] = to_dtype(data['next_obs'], torch.float32) 105 106 return_infos: List[Dict[str, Any]] = [] 107 self._learn_model.train() 108 109 for _ in range(self._cfg.learn.epoch_per_collect): 110 if self._recompute_adv: 111 with torch.no_grad(): 112 value = self._learn_model.forward(data['obs'], mode='compute_critic')['value'] 113 next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value'] 114 if self._value_norm: 115 value *= self._running_mean_std.std 116 next_value *= self._running_mean_std.std 117 118 traj_flag = data.get('traj_flag', None) 119 compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], traj_flag) 120 data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda) 121 unnormalized_returns = value + data['adv'] 122 123 if self._value_norm: 124 data['value'] = value / self._running_mean_std.std 125 data['return'] = unnormalized_returns / self._running_mean_std.std 126 self._running_mean_std.update(unnormalized_returns.cpu().numpy()) 127 else: 128 data['value'] = value 129 data['return'] = unnormalized_returns 130 else: 131 if self._value_norm: 132 unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std 133 data['return'] = unnormalized_return / self._running_mean_std.std 134 self._running_mean_std.update(unnormalized_return.cpu().numpy()) 135 else: 136 data['return'] = data['adv'] + data['value'] 137 138 for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): 139 output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') 140 logits = output['logit'] 141 old_logits = batch['logit'] 142 143 adv = batch['adv'].reshape(-1) 144 if self._adv_norm: 145 adv = (adv - adv.mean()) / (adv.std(unbiased=False) + 1e-8) 146 147 sample_weight = batch.get('weight', None) 148 if sample_weight is None: 149 sample_weight = torch.ones_like(adv) 150 else: 151 sample_weight = sample_weight.float().reshape(-1) 152 153 # Keep optimization numerically safe when all sample weights are 0. 154 valid_mask = sample_weight > 0 155 if not bool(valid_mask.any()): 156 valid_mask = torch.ones_like(valid_mask, dtype=torch.bool) 157 sample_weight = torch.ones_like(sample_weight) 158 159 \(\eta = \mathrm{softplus}(\theta_\eta), \quad \eta > 0\) 160 eta = self._positive(self._log_eta) 161 \(\tilde{A}_i = A_i / \eta\) 162 scaled_adv = adv / eta 163 164 with torch.no_grad(): 165 valid_scaled_adv = scaled_adv.detach()[valid_mask] 166 valid_num = int(valid_scaled_adv.numel()) 167 k = max(1, int(self._topk_fraction * valid_num)) 168 k = min(k, valid_num) 169 topk_vals, _ = torch.topk(valid_scaled_adv, k) 170 threshold = topk_vals.min() 171 selected_mask = valid_mask & (scaled_adv.detach() >= threshold) 172 if not bool(selected_mask.any()): 173 selected_mask = valid_mask 174 175 selected_weight = sample_weight[selected_mask] 176 177 # E-step dual update for eta. 178 \(\mathcal{L}_\eta = \eta \left( \epsilon_\eta + \log\left(\frac{1}{\sum_i w_i}\sum_i w_i \exp(A_i/\eta)\right)\right)\) 179 selected_scaled_det = (adv.detach() / eta)[selected_mask] 180 max_scaled = selected_scaled_det.max().detach() 181 exp_scaled = torch.exp(selected_scaled_det - max_scaled) * selected_weight 182 sum_exp_scaled = exp_scaled.sum() + 1e-8 183 sum_selected_weight = selected_weight.sum() + 1e-8 184 log_mean_exp = torch.log(sum_exp_scaled / sum_selected_weight) + max_scaled 185 eta_loss = eta * (self._epsilon_eta + log_mean_exp) 186 187 self._eta_optimizer.zero_grad() 188 eta_loss.backward() 189 self._eta_optimizer.step() 190 191 with torch.no_grad(): 192 eta_det = self._positive(self._log_eta).detach() 193 194 # Fixed VMPO policy weights for selected samples. 195 \(q_i = \frac{w_i \exp(A_i/\eta)}{\sum_j w_j \exp(A_j/\eta)}\) 196 selected_scaled = (adv.detach() / eta_det)[selected_mask] 197 max_selected_scaled = selected_scaled.max() 198 unnormalized_w = torch.exp(selected_scaled - max_selected_scaled) * selected_weight 199 weights = unnormalized_w / (unnormalized_w.sum() + 1e-8) 200 201 action = batch['action'].long().reshape(-1) 202 new_log_prob_all = F.log_softmax(logits, dim=-1) 203 new_log_prob_action = new_log_prob_all.gather(1, action.unsqueeze(-1)).squeeze(-1) 204 \(\mathcal{L}_{\pi,\text{NLL}} = -\sum_i q_i \log \pi_\theta(a_i \mid s_i)\) 205 policy_nll = -(weights * new_log_prob_action[selected_mask]).sum() 206 207 old_log_prob_all = F.log_softmax(old_logits, dim=-1) 208 old_prob_all = old_log_prob_all.exp() 209 \(D_{\mathrm{KL}}(\pi_{\text{old}} \| \pi_\theta) = \sum_a \pi_{\text{old}}(a \mid s)\left[\log \pi_{\text{old}}(a \mid s)-\log \pi_\theta(a \mid s)\right]\) 210 kl_all = (old_prob_all * (old_log_prob_all - new_log_prob_all)).sum(dim=-1) 211 kl_selected = (kl_all[selected_mask] * selected_weight).sum() / sum_selected_weight 212 213 \(\alpha = \mathrm{softplus}(\theta_\alpha), \quad \alpha > 0\) 214 alpha = self._positive(self._log_alpha) 215 \(\mathcal{L}_\alpha = \alpha \left(\epsilon_{\mathrm{KL}} - \bar{D}_{\mathrm{KL}}\right)\) 216 alpha_loss = alpha * (self._epsilon_kl - kl_selected.detach()) 217 self._alpha_optimizer.zero_grad() 218 alpha_loss.backward() 219 self._alpha_optimizer.step() 220 221 with torch.no_grad(): 222 alpha_det = self._positive(self._log_alpha).detach() 223 224 \(\mathcal{L}_\pi = \mathcal{L}_{\pi,\text{NLL}} + \alpha \bar{D}_{\mathrm{KL}}\) 225 policy_loss = policy_nll + alpha_det * kl_selected 226 227 value_pred = output['value'].reshape(-1) 228 value_target = batch['return'].reshape(-1).detach() 229 \(\mathcal{L}_V = \frac{1}{2}\frac{\sum_i w_i\left(V_\phi(s_i)-R_i\right)^2}{\sum_i w_i}\) 230 value_loss = 0.5 * ((value_pred - value_target).pow(2) * sample_weight).sum() / ( 231 sample_weight.sum() + 1e-8 232 ) 233 234 \(\mathcal{H}(\pi_\theta(\cdot \mid s)) = -\sum_a \pi_\theta(a \mid s)\log \pi_\theta(a \mid s)\) 235 entropy_all = -(new_log_prob_all.exp() * new_log_prob_all).sum(dim=-1) 236 \(\mathcal{L}_H = \frac{\sum_i w_i \mathcal{H}(\pi_\theta(\cdot \mid s_i))}{\sum_i w_i}\) 237 entropy_loss = (entropy_all * sample_weight).sum() / (sample_weight.sum() + 1e-8) 238 239 \(\mathcal{L}_{\text{total}} = \mathcal{L}_\pi + c_v \mathcal{L}_V - c_e \mathcal{L}_H\) 240 total_loss = policy_loss + self._value_weight * value_loss - self._entropy_weight * entropy_loss 241 242 self._optimizer.zero_grad() 243 total_loss.backward() 244 self._optimizer.step() 245 246 if self._cfg.learn.lr_scheduler is not None: 247 cur_lr = sum(self._lr_scheduler.get_last_lr()) / len(self._lr_scheduler.get_last_lr()) 248 else: 249 cur_lr = self._optimizer.defaults['lr'] 250 251 return_infos.append( 252 { 253 'cur_lr': float(cur_lr), 254 'total_loss': float(total_loss.item()), 255 'policy_loss': float(policy_loss.item()), 256 'value_loss': float(value_loss.item()), 257 'entropy_loss': float(entropy_loss.item()), 258 'adv_max': float(adv.max().item()), 259 'adv_mean': float(adv.mean().item()), 260 'value_mean': float(value_pred.mean().item()), 261 'value_max': float(value_pred.max().item()), 262 'approx_kl': float(kl_all.mean().item()), 263 'clipfrac': 0.0, 264 'dual_eta': float(eta_det.item()), 265 'dual_alpha': float(alpha_det.item()), 266 'dual_eta_loss': float(eta_loss.item()), 267 'dual_alpha_loss': float(alpha_loss.item()), 268 'selected_frac': float(selected_mask.float().mean().item()), 269 'policy_nll': float(policy_nll.item()), 270 'kl_selected': float(kl_selected.item()), 271 } 272 ) 273 274 if self._cfg.learn.lr_scheduler is not None: 275 self._lr_scheduler.step() 276 277 return return_infos 278 279 def _state_dict_learn(self) -> Dict[str, Any]: 280 state = super()._state_dict_learn() 281 state.update( 282 { 283 'log_eta': self._log_eta.detach(), 284 'log_alpha': self._log_alpha.detach(), 285 'eta_optimizer': self._eta_optimizer.state_dict(), 286 'alpha_optimizer': self._alpha_optimizer.state_dict(), 287 } 288 ) 289 return state 290 291 def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: 292 super()._load_state_dict_learn(state_dict) 293 if 'log_eta' in state_dict: 294 self._log_eta.data.copy_(state_dict['log_eta'].to(self._device)) 295 if 'log_alpha' in state_dict: 296 self._log_alpha.data.copy_(state_dict['log_alpha'].to(self._device)) 297 if 'eta_optimizer' in state_dict: 298 self._eta_optimizer.load_state_dict(state_dict['eta_optimizer']) 299 if 'alpha_optimizer' in state_dict: 300 self._alpha_optimizer.load_state_dict(state_dict['alpha_optimizer']) 301 302 def _monitor_vars_learn(self) -> List[str]: 303 return super()._monitor_vars_learn() + [ 304 'dual_eta', 305 'dual_alpha', 306 'dual_eta_loss', 307 'dual_alpha_loss', 308 'selected_frac', 309 'policy_nll', 310 'kl_selected', 311 ]