ding.policy.ppof¶
ding.policy.ppof
¶
Full Source Code
../ding/policy/ppof.py
1from typing import List, Dict, Any, Tuple, Union, Callable, Optional 2from collections import namedtuple 3from easydict import EasyDict 4import copy 5import random 6import numpy as np 7import torch 8import treetensor.torch as ttorch 9from torch.optim import AdamW 10 11from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \ 12 get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \ 13 HybridStochasticSampler, HybridDeterminsticSampler, value_transform, value_inv_transform, symlog, inv_symlog 14from ding.utils import POLICY_REGISTRY, RunningMeanStd 15 16 17@POLICY_REGISTRY.register('ppof') 18class PPOFPolicy: 19 config = dict( 20 type='ppo', 21 on_policy=True, 22 cuda=True, 23 action_space='discrete', 24 discount_factor=0.99, 25 gae_lambda=0.95, 26 # learn 27 epoch_per_collect=10, 28 batch_size=64, 29 learning_rate=3e-4, 30 # learningrate scheduler, which the format is (10000, 0.1) 31 lr_scheduler=None, 32 weight_decay=0, 33 value_weight=0.5, 34 entropy_weight=0.01, 35 clip_ratio=0.2, 36 adv_norm=True, 37 value_norm='baseline', 38 ppo_param_init=True, 39 grad_norm=0.5, 40 # collect 41 n_sample=128, 42 unroll_len=1, 43 # eval 44 deterministic_eval=True, 45 # model 46 model=dict(), 47 ) 48 mode = ['learn', 'collect', 'eval'] 49 50 @classmethod 51 def default_config(cls: type) -> EasyDict: 52 cfg = EasyDict(copy.deepcopy(cls.config)) 53 cfg.cfg_type = cls.__name__ + 'Dict' 54 return cfg 55 56 @classmethod 57 def default_model(cls: type) -> Callable: 58 from .model import PPOFModel 59 return PPOFModel 60 61 def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None) -> None: 62 self._cfg = cfg 63 if model is None: 64 self._model = self.default_model() 65 else: 66 self._model = model 67 if self._cfg.cuda and torch.cuda.is_available(): 68 self._device = 'cuda' 69 self._model.cuda() 70 else: 71 self._device = 'cpu' 72 assert self._cfg.action_space in ["continuous", "discrete", "hybrid", 'multi_discrete'] 73 self._action_space = self._cfg.action_space 74 if self._cfg.ppo_param_init: 75 self._model_param_init() 76 77 if enable_mode is None: 78 enable_mode = self.mode 79 self.enable_mode = enable_mode 80 if 'learn' in enable_mode: 81 self._optimizer = AdamW( 82 self._model.parameters(), 83 lr=self._cfg.learning_rate, 84 weight_decay=self._cfg.weight_decay, 85 ) 86 # define linear lr scheduler 87 if self._cfg.lr_scheduler is not None: 88 epoch_num, min_lr_lambda = self._cfg.lr_scheduler 89 90 self._lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 91 self._optimizer, 92 lr_lambda=lambda epoch: max(1.0 - epoch * (1.0 - min_lr_lambda) / epoch_num, min_lr_lambda) 93 ) 94 95 if self._cfg.value_norm: 96 self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device) 97 if 'collect' in enable_mode: 98 if self._action_space == 'discrete': 99 self._collect_sampler = MultinomialSampler() 100 elif self._action_space == 'continuous': 101 self._collect_sampler = ReparameterizationSampler() 102 elif self._action_space == 'hybrid': 103 self._collect_sampler = HybridStochasticSampler() 104 if 'eval' in enable_mode: 105 if self._action_space == 'discrete': 106 if self._cfg.deterministic_eval: 107 self._eval_sampler = ArgmaxSampler() 108 else: 109 self._eval_sampler = MultinomialSampler() 110 elif self._action_space == 'continuous': 111 if self._cfg.deterministic_eval: 112 self._eval_sampler = MuSampler() 113 else: 114 self._eval_sampler = ReparameterizationSampler() 115 elif self._action_space == 'hybrid': 116 if self._cfg.deterministic_eval: 117 self._eval_sampler = HybridDeterminsticSampler() 118 else: 119 self._eval_sampler = HybridStochasticSampler() 120 # for compatibility 121 self.learn_mode = self 122 self.collect_mode = self 123 self.eval_mode = self 124 125 def _model_param_init(self): 126 for n, m in self._model.named_modules(): 127 if isinstance(m, torch.nn.Linear): 128 torch.nn.init.orthogonal_(m.weight) 129 torch.nn.init.zeros_(m.bias) 130 if self._action_space in ['continuous', 'hybrid']: 131 for m in list(self._model.critic.modules()) + list(self._model.actor.modules()): 132 if isinstance(m, torch.nn.Linear): 133 # orthogonal initialization 134 torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) 135 torch.nn.init.zeros_(m.bias) 136 # init log sigma 137 if self._action_space == 'continuous': 138 torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5) 139 for m in self._model.actor_head.mu.modules(): 140 if isinstance(m, torch.nn.Linear): 141 torch.nn.init.zeros_(m.bias) 142 m.weight.data.copy_(0.01 * m.weight.data) 143 elif self._action_space == 'hybrid': # actor_head[1]: ReparameterizationHead, for action_args 144 if hasattr(self._model.actor_head[1], 'log_sigma_param'): 145 torch.nn.init.constant_(self._model.actor_head[1].log_sigma_param, -0.5) 146 for m in self._model.actor_head[1].mu.modules(): 147 if isinstance(m, torch.nn.Linear): 148 torch.nn.init.zeros_(m.bias) 149 m.weight.data.copy_(0.01 * m.weight.data) 150 151 def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: 152 return_infos = [] 153 self._model.train() 154 bs = self._cfg.batch_size 155 data = data[:self._cfg.n_sample // bs * bs] # rounding 156 157 # outer training loop 158 for epoch in range(self._cfg.epoch_per_collect): 159 # recompute adv 160 with torch.no_grad(): 161 # get the value dictionary 162 # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred' 163 value = self._model.compute_critic(data.obs) 164 next_value = self._model.compute_critic(data.next_obs) 165 reward = data.reward 166 167 assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'],\ 168 'Not supported value normalization! Value normalization supported: \ 169 popart, value rescale, symlog, baseline' 170 171 if self._cfg.value_norm == 'popart': 172 unnormalized_value = value['unnormalized_pred'] 173 unnormalized_next_value = value['unnormalized_pred'] 174 175 mu = self._model.critic_head.popart.mu 176 sigma = self._model.critic_head.popart.sigma 177 reward = (reward - mu) / sigma 178 179 value = value['pred'] 180 next_value = next_value['pred'] 181 elif self._cfg.value_norm == 'value_rescale': 182 value = value_inv_transform(value['pred']) 183 next_value = value_inv_transform(next_value['pred']) 184 elif self._cfg.value_norm == 'symlog': 185 value = inv_symlog(value['pred']) 186 next_value = inv_symlog(next_value['pred']) 187 elif self._cfg.value_norm == 'baseline': 188 value = value['pred'] * self._running_mean_std.std 189 next_value = next_value['pred'] * self._running_mean_std.std 190 191 traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory 192 adv_data = gae_data(value, next_value, reward, data.done, traj_flag) 193 data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda) 194 195 unnormalized_returns = value + data.adv # In popart, this return is normalized 196 197 if self._cfg.value_norm == 'popart': 198 self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1)) 199 elif self._cfg.value_norm == 'value_rescale': 200 value = value_transform(value) 201 unnormalized_returns = value_transform(unnormalized_returns) 202 elif self._cfg.value_norm == 'symlog': 203 value = symlog(value) 204 unnormalized_returns = symlog(unnormalized_returns) 205 elif self._cfg.value_norm == 'baseline': 206 value /= self._running_mean_std.std 207 unnormalized_returns /= self._running_mean_std.std 208 self._running_mean_std.update(unnormalized_returns.cpu().numpy()) 209 data.value = value 210 data.return_ = unnormalized_returns 211 212 # inner training loop 213 split_data = ttorch.split(data, self._cfg.batch_size) 214 random.shuffle(list(split_data)) 215 for batch in split_data: 216 output = self._model.compute_actor_critic(batch.obs) 217 adv = batch.adv 218 if self._cfg.adv_norm: 219 # Normalize advantage in a train_batch 220 adv = (adv - adv.mean()) / (adv.std() + 1e-8) 221 222 # Calculate ppo error 223 if self._action_space == 'continuous': 224 ppo_batch = ppo_data( 225 output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None 226 ) 227 ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._cfg.clip_ratio) 228 elif self._action_space == 'discrete': 229 ppo_batch = ppo_data( 230 output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None 231 ) 232 ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) 233 elif self._action_space == 'hybrid': 234 # discrete part (discrete policy loss and entropy loss) 235 ppo_discrete_batch = ppo_policy_data( 236 output.logit.action_type, batch.logit.action_type, batch.action.action_type, adv, None 237 ) 238 ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._cfg.clip_ratio) 239 # continuous part (continuous policy loss and entropy loss, value loss) 240 ppo_continuous_batch = ppo_data( 241 output.logit.action_args, batch.logit.action_args, batch.action.action_args, output.value, 242 batch.value, adv, batch.return_, None 243 ) 244 ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous( 245 ppo_continuous_batch, self._cfg.clip_ratio 246 ) 247 # sum discrete and continuous loss 248 ppo_loss = type(ppo_continuous_loss)( 249 ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss, 250 ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss 251 ) 252 ppo_info = type(ppo_continuous_info)( 253 max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl), 254 max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac) 255 ) 256 wv, we = self._cfg.value_weight, self._cfg.entropy_weight 257 total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss 258 259 self._optimizer.zero_grad() 260 total_loss.backward() 261 torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm) 262 self._optimizer.step() 263 264 return_info = { 265 'cur_lr': self._optimizer.defaults['lr'], 266 'total_loss': total_loss.item(), 267 'policy_loss': ppo_loss.policy_loss.item(), 268 'value_loss': ppo_loss.value_loss.item(), 269 'entropy_loss': ppo_loss.entropy_loss.item(), 270 'adv_max': adv.max().item(), 271 'adv_mean': adv.mean().item(), 272 'value_mean': output.value.mean().item(), 273 'value_max': output.value.max().item(), 274 'approx_kl': ppo_info.approx_kl, 275 'clipfrac': ppo_info.clipfrac, 276 } 277 if self._action_space == 'continuous': 278 return_info.update( 279 { 280 'action': batch.action.float().mean().item(), 281 'mu_mean': output.logit.mu.mean().item(), 282 'sigma_mean': output.logit.sigma.mean().item(), 283 } 284 ) 285 elif self._action_space == 'hybrid': 286 return_info.update( 287 { 288 'action': batch.action.action_args.float().mean().item(), 289 'mu_mean': output.logit.action_args.mu.mean().item(), 290 'sigma_mean': output.logit.action_args.sigma.mean().item(), 291 } 292 ) 293 return_infos.append(return_info) 294 295 if self._cfg.lr_scheduler is not None: 296 self._lr_scheduler.step() 297 298 return return_infos 299 300 def state_dict(self) -> Dict[str, Any]: 301 state_dict = { 302 'model': self._model.state_dict(), 303 } 304 if 'learn' in self.enable_mode: 305 state_dict['optimizer'] = self._optimizer.state_dict() 306 return state_dict 307 308 def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 309 self._model.load_state_dict(state_dict['model']) 310 if 'learn' in self.enable_mode: 311 self._optimizer.load_state_dict(state_dict['optimizer']) 312 313 def collect(self, data: ttorch.Tensor) -> ttorch.Tensor: 314 self._model.eval() 315 with torch.no_grad(): 316 output = self._model.compute_actor_critic(data) 317 action = self._collect_sampler(output.logit) 318 output.action = action 319 return output 320 321 def process_transition(self, obs: ttorch.Tensor, inference_output: dict, timestep: namedtuple) -> ttorch.Tensor: 322 return ttorch.as_tensor( 323 { 324 'obs': obs, 325 'next_obs': timestep.obs, 326 'action': inference_output.action, 327 'logit': inference_output.logit, 328 'value': inference_output.value, 329 'reward': timestep.reward, 330 'done': timestep.done, 331 } 332 ) 333 334 def eval(self, data: ttorch.Tensor) -> ttorch.Tensor: 335 self._model.eval() 336 with torch.no_grad(): 337 logit = self._model.compute_actor(data) 338 action = self._eval_sampler(logit) 339 return ttorch.as_tensor({'logit': logit, 'action': action}) 340 341 def monitor_vars(self) -> List[str]: 342 variables = [ 343 'cur_lr', 344 'policy_loss', 345 'value_loss', 346 'entropy_loss', 347 'adv_max', 348 'adv_mean', 349 'approx_kl', 350 'clipfrac', 351 'value_max', 352 'value_mean', 353 ] 354 if self._action_space in ['action', 'mu_mean', 'sigma_mean']: 355 variables += ['mu_mean', 'sigma_mean', 'action'] 356 return variables 357 358 def reset(self, env_id_list: Optional[List[int]] = None) -> None: 359 pass