Skip to content

ding.policy.plan_diffuser

ding.policy.plan_diffuser

PDPolicy

Bases: Policy

Overview

Implicit Plan Diffuser https://arxiv.org/pdf/2205.09991.pdf

Full Source Code

../ding/policy/plan_diffuser.py

1from typing import List, Dict, Any, Optional, Tuple, Union 2from collections import namedtuple, defaultdict 3import copy 4import numpy as np 5import torch 6import torch.nn.functional as F 7from torch.distributions import Normal, Independent 8 9from ding.torch_utils import Adam, to_device 10from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ 11 qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data 12from ding.policy import Policy 13from ding.model import model_wrap 14from ding.utils import POLICY_REGISTRY, DatasetNormalizer 15from ding.utils.data import default_collate, default_decollate 16from .common_utils import default_preprocess_learn 17 18 19@POLICY_REGISTRY.register('pd') 20class PDPolicy(Policy): 21 r""" 22 Overview: 23 Implicit Plan Diffuser 24 https://arxiv.org/pdf/2205.09991.pdf 25 26 """ 27 config = dict( 28 type='pd', 29 # (bool) Whether to use cuda for network. 30 cuda=False, 31 # (bool type) priority: Determine whether to use priority in buffer sample. 32 # Default False in SAC. 33 priority=False, 34 # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. 35 priority_IS_weight=False, 36 # (int) Number of training samples(randomly collected) in replay buffer when training starts. 37 # Default 10000 in SAC. 38 random_collect_size=10000, 39 nstep=1, 40 # normalizer type 41 normalizer='GaussianNormalizer', 42 model=dict( 43 diffuser_model='GaussianDiffusion', 44 diffuser_model_cfg=dict( 45 # the type of model 46 model='TemporalUnet', 47 # config of model 48 model_cfg=dict( 49 # model dim, In GaussianInvDynDiffusion, it is obs_dim. In others, it is obs_dim + action_dim 50 transition_dim=23, 51 dim=32, 52 dim_mults=[1, 2, 4, 8], 53 # whether use return as a condition 54 returns_condition=False, 55 condition_dropout=0.1, 56 # whether use calc energy 57 calc_energy=False, 58 kernel_size=5, 59 # whether use attention 60 attention=False, 61 ), 62 # horizon of tarjectory which generated by model 63 horizon=80, 64 # timesteps of diffusion 65 n_timesteps=1000, 66 # hidden dim of action model 67 # Whether predict epsilon 68 predict_epsilon=True, 69 # discount of loss 70 loss_discount=1.0, 71 # whether clip denoise 72 clip_denoised=False, 73 action_weight=10, 74 ), 75 value_model='ValueDiffusion', 76 value_model_cfg=dict( 77 # the type of model 78 model='TemporalValue', 79 # config of model 80 model_cfg=dict( 81 horizon=4, 82 # model dim, In GaussianInvDynDiffusion, it is obs_dim. In others, it is obs_dim + action_dim 83 transition_dim=23, 84 dim=32, 85 dim_mults=[1, 2, 4, 8], 86 # whether use calc energy 87 kernel_size=5, 88 ), 89 # horizon of tarjectory which generated by model 90 horizon=80, 91 # timesteps of diffusion 92 n_timesteps=1000, 93 # hidden dim of action model 94 predict_epsilon=True, 95 # discount of loss 96 loss_discount=1.0, 97 # whether clip denoise 98 clip_denoised=False, 99 action_weight=1.0, 100 ), 101 # guide_steps for p sample 102 n_guide_steps=2, 103 # scale of grad for p sample 104 scale=0.1, 105 # t of stopgrad for p sample 106 t_stopgrad=2, 107 # whether use std as a scale for grad 108 scale_grad_by_std=True, 109 ), 110 learn=dict( 111 112 # How many updates(iterations) to train after collector's one collection. 113 # Bigger "update_per_collect" means bigger off-policy. 114 # collect data -> update policy-> collect data -> ... 115 update_per_collect=1, 116 # (int) Minibatch size for gradient descent. 117 batch_size=100, 118 119 # (float type) learning_rate_q: Learning rate for model. 120 # Default to 3e-4. 121 # Please set to 1e-3, when model.value_network is True. 122 learning_rate=3e-4, 123 # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) 124 # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. 125 # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. 126 # However, interaction with HalfCheetah always gets done with done is False, 127 # Since we inplace done==True with done==False to keep 128 # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), 129 # when the episode step is greater than max episode step. 130 ignore_done=False, 131 132 # (float type) target_theta: Used for soft update of the target network, 133 # aka. Interpolation factor in polyak averaging for target networks. 134 # Default to 0.005. 135 target_theta=0.005, 136 # (float) discount factor for the discounted sum of rewards, aka. gamma. 137 discount_factor=0.99, 138 gradient_accumulate_every=2, 139 # train_epoch = train_epoch * gradient_accumulate_every 140 train_epoch=60000, 141 # batch_size of every env when eval 142 plan_batch_size=64, 143 144 # step start update target model and frequence 145 step_start_update_target=2000, 146 update_target_freq=10, 147 # update weight of target net 148 target_weight=0.995, 149 value_step=200e3, 150 151 # dataset weight include returns 152 include_returns=True, 153 154 # (float) Weight uniform initialization range in the last output layer 155 init_w=3e-3, 156 ), 157 ) 158 159 def default_model(self) -> Tuple[str, List[str]]: 160 return 'pd', ['ding.model.template.diffusion'] 161 162 def _init_learn(self) -> None: 163 r""" 164 Overview: 165 Learn mode init method. Called by ``self.__init__``. 166 Init q, value and policy's optimizers, algorithm config, main and target models. 167 """ 168 # Init 169 self._priority = self._cfg.priority 170 self._priority_IS_weight = self._cfg.priority_IS_weight 171 self.action_dim = self._cfg.model.diffuser_model_cfg.action_dim 172 self.obs_dim = self._cfg.model.diffuser_model_cfg.obs_dim 173 self.n_timesteps = self._cfg.model.diffuser_model_cfg.n_timesteps 174 self.gradient_accumulate_every = self._cfg.learn.gradient_accumulate_every 175 self.plan_batch_size = self._cfg.learn.plan_batch_size 176 self.gradient_steps = 1 177 self.update_target_freq = self._cfg.learn.update_target_freq 178 self.step_start_update_target = self._cfg.learn.step_start_update_target 179 self.target_weight = self._cfg.learn.target_weight 180 self.value_step = self._cfg.learn.value_step 181 self.use_target = False 182 self.horizon = self._cfg.model.diffuser_model_cfg.horizon 183 self.include_returns = self._cfg.learn.include_returns 184 185 # Optimizers 186 self._plan_optimizer = Adam( 187 self._model.diffuser.model.parameters(), 188 lr=self._cfg.learn.learning_rate, 189 ) 190 if self._model.value: 191 self._value_optimizer = Adam( 192 self._model.value.model.parameters(), 193 lr=self._cfg.learn.learning_rate, 194 ) 195 196 # Algorithm config 197 self._gamma = self._cfg.learn.discount_factor 198 199 # Main and target models 200 self._target_model = copy.deepcopy(self._model) 201 # self._target_model = model_wrap( 202 # self._target_model, 203 # wrapper_name='target', 204 # update_type='momentum', 205 # update_kwargs={'theta': self._cfg.learn.target_theta} 206 # ) 207 self._learn_model = model_wrap(self._model, wrapper_name='base') 208 self._learn_model.reset() 209 # self._target_model.reset() 210 211 self._forward_learn_cnt = 0 212 213 def _forward_learn(self, data: dict) -> Dict[str, Any]: 214 loss_dict = {} 215 216 data = default_preprocess_learn( 217 data, 218 use_priority=self._priority, 219 use_priority_IS_weight=self._cfg.priority_IS_weight, 220 ignore_done=self._cfg.learn.ignore_done, 221 use_nstep=False 222 ) 223 224 conds = {} 225 vals = data['condition_val'] 226 ids = data['condition_id'] 227 for i in range(len(ids)): 228 conds[ids[i][0].item()] = vals[i] 229 if len(ids) > 1: 230 self.use_target = True 231 data['conditions'] = conds 232 if 'returns' in data.keys(): 233 data['returns'] = data['returns'].unsqueeze(-1) 234 if self._cuda: 235 data = to_device(data, self._device) 236 237 self._learn_model.train() 238 # self._target_model.train() 239 x = data['trajectories'] 240 241 batch_size = len(x) 242 t = torch.randint(0, self.n_timesteps, (batch_size, ), device=x.device).long() 243 cond = data['conditions'] 244 if 'returns' in data.keys(): 245 target = data['returns'] 246 loss_dict['diffuse_loss'], loss_dict['a0_loss'] = self._model.diffuser_loss(x, cond, t) 247 loss_dict['diffuse_loss'] = loss_dict['diffuse_loss'] / self.gradient_accumulate_every 248 loss_dict['diffuse_loss'].backward() 249 if self._forward_learn_cnt < self.value_step and self._model.value: 250 loss_dict['value_loss'], logs = self._model.value_loss(x, cond, target, t) 251 loss_dict['value_loss'] = loss_dict['value_loss'] / self.gradient_accumulate_every 252 loss_dict['value_loss'].backward() 253 loss_dict.update(logs) 254 255 if self.gradient_steps >= self.gradient_accumulate_every: 256 self._plan_optimizer.step() 257 self._plan_optimizer.zero_grad() 258 if self._forward_learn_cnt < self.value_step and self._model.value: 259 self._value_optimizer.step() 260 self._value_optimizer.zero_grad() 261 self.gradient_steps = 1 262 else: 263 self.gradient_steps += 1 264 self._forward_learn_cnt += 1 265 if self._forward_learn_cnt % self.update_target_freq == 0: 266 if self._forward_learn_cnt < self.step_start_update_target: 267 self._target_model.load_state_dict(self._model.state_dict()) 268 else: 269 self.update_model_average(self._target_model, self._learn_model) 270 271 if 'returns' in data.keys(): 272 loss_dict['max_return'] = target.max().item() 273 loss_dict['min_return'] = target.min().item() 274 loss_dict['mean_return'] = target.mean().item() 275 loss_dict['max_traj'] = x.max().item() 276 loss_dict['min_traj'] = x.min().item() 277 loss_dict['mean_traj'] = x.mean().item() 278 return loss_dict 279 280 def update_model_average(self, ma_model, current_model): 281 for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 282 old_weight, up_weight = ma_params.data, current_params.data 283 if old_weight is None: 284 ma_params.data = up_weight 285 else: 286 old_weight * self.target_weight + (1 - self.target_weight) * up_weight 287 288 def _monitor_vars_learn(self) -> List[str]: 289 return [ 290 'diffuse_loss', 291 'value_loss', 292 'max_return', 293 'min_return', 294 'mean_return', 295 'max_traj', 296 'min_traj', 297 'mean_traj', 298 'mean_pred', 299 'max_pred', 300 'min_pred', 301 'a0_loss', 302 ] 303 304 def _state_dict_learn(self) -> Dict[str, Any]: 305 if self._model.value: 306 return { 307 'model': self._learn_model.state_dict(), 308 'target_model': self._target_model.state_dict(), 309 'plan_optimizer': self._plan_optimizer.state_dict(), 310 'value_optimizer': self._value_optimizer.state_dict(), 311 } 312 else: 313 return { 314 'model': self._learn_model.state_dict(), 315 'target_model': self._target_model.state_dict(), 316 'plan_optimizer': self._plan_optimizer.state_dict(), 317 } 318 319 def _init_eval(self): 320 self._eval_model = model_wrap(self._target_model, wrapper_name='base') 321 self._eval_model.reset() 322 if self.use_target: 323 self._plan_seq = [] 324 325 def init_data_normalizer(self, normalizer: DatasetNormalizer = None): 326 self.normalizer = normalizer 327 328 def _forward_eval(self, data: dict) -> Dict[str, Any]: 329 data_id = list(data.keys()) 330 data = default_collate(list(data.values())) 331 332 self._eval_model.eval() 333 if self.use_target: 334 cur_obs = self.normalizer.normalize(data[:, :self.obs_dim], 'observations') 335 target_obs = self.normalizer.normalize(data[:, self.obs_dim:], 'observations') 336 else: 337 obs = self.normalizer.normalize(data, 'observations') 338 with torch.no_grad(): 339 if self.use_target: 340 cur_obs = torch.tensor(cur_obs) 341 target_obs = torch.tensor(target_obs) 342 if self._cuda: 343 cur_obs = to_device(cur_obs, self._device) 344 target_obs = to_device(target_obs, self._device) 345 conditions = {0: cur_obs, self.horizon - 1: target_obs} 346 else: 347 obs = torch.tensor(obs) 348 if self._cuda: 349 obs = to_device(obs, self._device) 350 conditions = {0: obs} 351 352 if self.use_target: 353 if self._plan_seq == [] or 0 in self._eval_t: 354 plan_traj = self._eval_model.get_eval(conditions, self.plan_batch_size) 355 plan_traj = to_device(plan_traj, 'cpu').numpy() 356 if self._plan_seq == []: 357 self._plan_seq = plan_traj 358 self._eval_t = [0] * len(data_id) 359 else: 360 for id in data_id: 361 if self._eval_t[id] == 0: 362 self._plan_seq[id] = plan_traj[id] 363 action = [] 364 for id in data_id: 365 if self._eval_t[id] < len(self._plan_seq[id]) - 1: 366 next_waypoint = self._plan_seq[id][self._eval_t[id] + 1] 367 else: 368 next_waypoint = self._plan_seq[id][-1].copy() 369 next_waypoint[2:] = 0 370 cur_ob = cur_obs[id] 371 cur_ob = to_device(cur_ob, 'cpu').numpy() 372 act = next_waypoint[:2] - cur_ob[:2] + (next_waypoint[2:] - cur_ob[2:]) 373 action.append(act) 374 self._eval_t[id] += 1 375 else: 376 action = self._eval_model.get_eval(conditions, self.plan_batch_size) 377 if self._cuda: 378 action = to_device(action, 'cpu') 379 action = self.normalizer.unnormalize(action, 'actions') 380 action = torch.tensor(action).to('cpu') 381 output = {'action': action} 382 output = default_decollate(output) 383 return {i: d for i, d in zip(data_id, output)} 384 385 def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: 386 if self.use_target and data_id: 387 for id in data_id: 388 self._eval_t[id] = 0 389 390 def _init_collect(self) -> None: 391 pass 392 393 def _forward_collect(self, data: dict, **kwargs) -> dict: 394 pass 395 396 def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: 397 pass 398 399 def _get_train_sample(self, data: list) -> Union[None, List[Any]]: 400 pass