ding.world_model.ddppo¶
ding.world_model.ddppo
¶
Full Source Code
../ding/world_model/ddppo.py
1from functools import partial 2from ditk import logging 3import itertools 4import copy 5import numpy as np 6import multiprocessing 7import torch 8import torch.nn as nn 9 10from ding.utils import WORLD_MODEL_REGISTRY 11from ding.utils.data import default_collate 12from ding.torch_utils import unsqueeze_repeat 13from ding.world_model.base_world_model import HybridWorldModel 14from ding.world_model.model.ensemble import EnsembleModel, StandardScaler 15 16 17#======================= Helper functions ======================= 18# tree_query = lambda datapoint: tree.query(datapoint, k=k+1)[1][1:] 19def tree_query(datapoint, tree, k): 20 return tree.query(datapoint, k=k + 1)[1][1:] 21 22 23def get_neighbor_index(data, k, serial=False): 24 """ 25 data: [B, N] 26 k: int 27 28 ret: [B, k] 29 """ 30 try: 31 from scipy.spatial import KDTree 32 except ImportError: 33 import sys 34 logging.warning("Please install scipy first, such as `pip3 install scipy`.") 35 sys.exit(1) 36 data = data.cpu().numpy() 37 tree = KDTree(data) 38 39 if serial: 40 nn_index = [torch.from_numpy(np.array(tree_query(d, tree, k))) for d in data] 41 nn_index = torch.stack(nn_index).long() 42 else: 43 # TODO: speed up multiprocessing 44 pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) 45 fn = partial(tree_query, tree=tree, k=k) 46 nn_index = torch.from_numpy(np.array(list(pool.map(fn, data)), dtype=np.int32)).to(torch.long) 47 pool.close() 48 return nn_index 49 50 51def get_batch_jacobian(net, x, noutputs): # x: b, in dim, noutpouts: out dim 52 x = x.unsqueeze(1) # b, 1 ,in_dim 53 n = x.size()[0] 54 x = x.repeat(1, noutputs, 1) # b, out_dim, in_dim 55 x.requires_grad_(True) 56 y = net(x) 57 upstream_gradient = torch.eye(noutputs).reshape(1, noutputs, noutputs).repeat(n, 1, 1).to(x.device) 58 re = torch.autograd.grad(y, x, upstream_gradient, create_graph=True)[0] 59 60 return re 61 62 63class EnsembleGradientModel(EnsembleModel): 64 65 def train(self, loss, loss_reg, reg): 66 self.optimizer.zero_grad() 67 68 loss += 0.01 * torch.sum(self.max_logvar) - 0.01 * torch.sum(self.min_logvar) 69 loss += reg * loss_reg 70 if self.use_decay: 71 loss += self.get_decay_loss() 72 73 loss.backward() 74 75 self.optimizer.step() 76 77 78# TODO: derive from MBPO instead of implementing from scratch 79@WORLD_MODEL_REGISTRY.register('ddppo') 80class DDPPOWorldMode(HybridWorldModel, nn.Module): 81 """rollout model + gradient model""" 82 config = dict( 83 model=dict( 84 ensemble_size=7, 85 elite_size=5, 86 state_size=None, # has to be specified 87 action_size=None, # has to be specified 88 reward_size=1, 89 hidden_size=200, 90 use_decay=False, 91 batch_size=256, 92 holdout_ratio=0.2, 93 max_epochs_since_update=5, 94 deterministic_rollout=True, 95 # parameters for DDPPO 96 gradient_model=True, 97 k=3, 98 reg=1, 99 neighbor_pool_size=10000, 100 train_freq_gradient_model=250 101 ), 102 ) 103 104 def __init__(self, cfg, env, tb_logger): 105 HybridWorldModel.__init__(self, cfg, env, tb_logger) 106 nn.Module.__init__(self) 107 108 cfg = cfg.model 109 self.ensemble_size = cfg.ensemble_size 110 self.elite_size = cfg.elite_size 111 self.state_size = cfg.state_size 112 self.action_size = cfg.action_size 113 self.reward_size = cfg.reward_size 114 self.hidden_size = cfg.hidden_size 115 self.use_decay = cfg.use_decay 116 self.batch_size = cfg.batch_size 117 self.holdout_ratio = cfg.holdout_ratio 118 self.max_epochs_since_update = cfg.max_epochs_since_update 119 self.deterministic_rollout = cfg.deterministic_rollout 120 # parameters for DDPPO 121 self.gradient_model = cfg.gradient_model 122 self.k = cfg.k 123 self.reg = cfg.reg 124 self.neighbor_pool_size = cfg.neighbor_pool_size 125 self.train_freq_gradient_model = cfg.train_freq_gradient_model 126 127 self.rollout_model = EnsembleModel( 128 self.state_size, 129 self.action_size, 130 self.reward_size, 131 self.ensemble_size, 132 self.hidden_size, 133 use_decay=self.use_decay 134 ) 135 self.scaler = StandardScaler(self.state_size + self.action_size) 136 137 self.ensemble_mse_losses = [] 138 self.model_variances = [] 139 self.elite_model_idxes = [] 140 141 if self.gradient_model: 142 self.gradient_model = EnsembleGradientModel( 143 self.state_size, 144 self.action_size, 145 self.reward_size, 146 self.ensemble_size, 147 self.hidden_size, 148 use_decay=self.use_decay 149 ) 150 self.elite_model_idxes_gradient_model = [] 151 152 self.last_train_step_gradient_model = 0 153 self.serial_calc_nn = False 154 155 if self._cuda: 156 self.cuda() 157 158 def step(self, obs, act, batch_size=8192): 159 160 class Predict(torch.autograd.Function): 161 # TODO: align rollout_model elites with gradient_model elites 162 # use different model for forward and backward 163 @staticmethod 164 def forward(ctx, x): 165 ctx.save_for_backward(x) 166 mean, var = self.rollout_model(x, ret_log_var=False) 167 return torch.cat([mean, var], dim=-1) 168 169 @staticmethod 170 def backward(ctx, grad_out): 171 x, = ctx.saved_tensors 172 with torch.enable_grad(): 173 x = x.detach() 174 x.requires_grad_(True) 175 mean, var = self.gradient_model(x, ret_log_var=False) 176 y = torch.cat([mean, var], dim=-1) 177 return torch.autograd.grad(y, x, grad_outputs=grad_out, create_graph=True) 178 179 if len(act.shape) == 1: 180 act = act.unsqueeze(1) 181 if self._cuda: 182 obs = obs.cuda() 183 act = act.cuda() 184 inputs = torch.cat([obs, act], dim=1) 185 inputs = self.scaler.transform(inputs) 186 # predict 187 ensemble_mean, ensemble_var = [], [] 188 for i in range(0, inputs.shape[0], batch_size): 189 input = unsqueeze_repeat(inputs[i:i + batch_size], self.ensemble_size) 190 if not torch.is_grad_enabled() or not self.gradient_model: 191 b_mean, b_var = self.rollout_model(input, ret_log_var=False) 192 else: 193 # use gradient model to compute gradients during backward pass 194 output = Predict.apply(input) 195 b_mean, b_var = output.chunk(2, dim=2) 196 ensemble_mean.append(b_mean) 197 ensemble_var.append(b_var) 198 ensemble_mean = torch.cat(ensemble_mean, 1) 199 ensemble_var = torch.cat(ensemble_var, 1) 200 ensemble_mean[:, :, 1:] += obs.unsqueeze(0) 201 ensemble_std = ensemble_var.sqrt() 202 # sample from the predicted distribution 203 if self.deterministic_rollout: 204 ensemble_sample = ensemble_mean 205 else: 206 ensemble_sample = ensemble_mean + torch.randn_like(ensemble_mean).to(ensemble_mean) * ensemble_std 207 # sample from ensemble 208 model_idxes = torch.from_numpy(np.random.choice(self.elite_model_idxes, size=len(obs))).to(inputs.device) 209 batch_idxes = torch.arange(len(obs)).to(inputs.device) 210 sample = ensemble_sample[model_idxes, batch_idxes] 211 rewards, next_obs = sample[:, 0], sample[:, 1:] 212 213 return rewards, next_obs, self.env.termination_fn(next_obs) 214 215 def eval(self, env_buffer, envstep, train_iter): 216 data = env_buffer.sample(self.eval_freq, train_iter) 217 data = default_collate(data) 218 data['done'] = data['done'].float() 219 data['weight'] = data.get('weight', None) 220 obs = data['obs'] 221 action = data['action'] 222 reward = data['reward'] 223 next_obs = data['next_obs'] 224 if len(reward.shape) == 1: 225 reward = reward.unsqueeze(1) 226 if len(action.shape) == 1: 227 action = action.unsqueeze(1) 228 229 # build eval samples 230 inputs = torch.cat([obs, action], dim=1) 231 labels = torch.cat([reward, next_obs - obs], dim=1) 232 if self._cuda: 233 inputs = inputs.cuda() 234 labels = labels.cuda() 235 236 # normalize 237 inputs = self.scaler.transform(inputs) 238 239 # repeat for ensemble 240 inputs = unsqueeze_repeat(inputs, self.ensemble_size) 241 labels = unsqueeze_repeat(labels, self.ensemble_size) 242 243 # eval 244 with torch.no_grad(): 245 mean, logvar = self.rollout_model(inputs, ret_log_var=True) 246 loss, mse_loss = self.rollout_model.loss(mean, logvar, labels) 247 ensemble_mse_loss = torch.pow(mean.mean(0) - labels[0], 2) 248 model_variance = mean.var(0) 249 self.tb_logger.add_scalar('env_model_step/eval_mse_loss', mse_loss.mean().item(), envstep) 250 self.tb_logger.add_scalar('env_model_step/eval_ensemble_mse_loss', ensemble_mse_loss.mean().item(), envstep) 251 self.tb_logger.add_scalar('env_model_step/eval_model_variances', model_variance.mean().item(), envstep) 252 253 self.last_eval_step = envstep 254 255 def train(self, env_buffer, envstep, train_iter): 256 257 def train_sample(data) -> tuple: 258 data = default_collate(data) 259 data['done'] = data['done'].float() 260 data['weight'] = data.get('weight', None) 261 obs = data['obs'] 262 action = data['action'] 263 reward = data['reward'] 264 next_obs = data['next_obs'] 265 if len(reward.shape) == 1: 266 reward = reward.unsqueeze(1) 267 if len(action.shape) == 1: 268 action = action.unsqueeze(1) 269 # build train samples 270 inputs = torch.cat([obs, action], dim=1) 271 labels = torch.cat([reward, next_obs - obs], dim=1) 272 if self._cuda: 273 inputs = inputs.cuda() 274 labels = labels.cuda() 275 return inputs, labels 276 277 logvar = dict() 278 279 data = env_buffer.sample(env_buffer.count(), train_iter) 280 inputs, labels = train_sample(data) 281 logvar.update(self._train_rollout_model(inputs, labels)) 282 283 if self.gradient_model: 284 # update neighbor pool 285 if (envstep - self.last_train_step_gradient_model) >= self.train_freq_gradient_model: 286 n = min(env_buffer.count(), self.neighbor_pool_size) 287 self.neighbor_pool = env_buffer.sample(n, train_iter, sample_range=slice(-n, None)) 288 inputs_reg, labels_reg = train_sample(self.neighbor_pool) 289 logvar.update(self._train_gradient_model(inputs, labels, inputs_reg, labels_reg)) 290 self.last_train_step_gradient_model = envstep 291 292 self.last_train_step = envstep 293 294 # log 295 if self.tb_logger is not None: 296 for k, v in logvar.items(): 297 self.tb_logger.add_scalar('env_model_step/' + k, v, envstep) 298 299 def _train_rollout_model(self, inputs, labels): 300 #split 301 num_holdout = int(inputs.shape[0] * self.holdout_ratio) 302 train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:] 303 holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout] 304 305 #normalize 306 self.scaler.fit(train_inputs) 307 train_inputs = self.scaler.transform(train_inputs) 308 holdout_inputs = self.scaler.transform(holdout_inputs) 309 310 #repeat for ensemble 311 holdout_inputs = unsqueeze_repeat(holdout_inputs, self.ensemble_size) 312 holdout_labels = unsqueeze_repeat(holdout_labels, self.ensemble_size) 313 314 self._epochs_since_update = 0 315 self._snapshots = {i: (-1, 1e10) for i in range(self.ensemble_size)} 316 self._save_states() 317 for epoch in itertools.count(): 318 319 train_idx = torch.stack([torch.randperm(train_inputs.shape[0]) 320 for _ in range(self.ensemble_size)]).to(train_inputs.device) 321 self.mse_loss = [] 322 for start_pos in range(0, train_inputs.shape[0], self.batch_size): 323 idx = train_idx[:, start_pos:start_pos + self.batch_size] 324 train_input = train_inputs[idx] 325 train_label = train_labels[idx] 326 mean, logvar = self.rollout_model(train_input, ret_log_var=True) 327 loss, mse_loss = self.rollout_model.loss(mean, logvar, train_label) 328 self.rollout_model.train(loss) 329 self.mse_loss.append(mse_loss.mean().item()) 330 self.mse_loss = sum(self.mse_loss) / len(self.mse_loss) 331 332 with torch.no_grad(): 333 holdout_mean, holdout_logvar = self.rollout_model(holdout_inputs, ret_log_var=True) 334 _, holdout_mse_loss = self.rollout_model.loss(holdout_mean, holdout_logvar, holdout_labels) 335 self.curr_holdout_mse_loss = holdout_mse_loss.mean().item() 336 break_train = self._save_best(epoch, holdout_mse_loss) 337 if break_train: 338 break 339 340 self._load_states() 341 with torch.no_grad(): 342 holdout_mean, holdout_logvar = self.rollout_model(holdout_inputs, ret_log_var=True) 343 _, holdout_mse_loss = self.rollout_model.loss(holdout_mean, holdout_logvar, holdout_labels) 344 sorted_loss, sorted_loss_idx = holdout_mse_loss.sort() 345 sorted_loss = sorted_loss.detach().cpu().numpy().tolist() 346 sorted_loss_idx = sorted_loss_idx.detach().cpu().numpy().tolist() 347 self.elite_model_idxes = sorted_loss_idx[:self.elite_size] 348 self.top_holdout_mse_loss = sorted_loss[0] 349 self.middle_holdout_mse_loss = sorted_loss[self.ensemble_size // 2] 350 self.bottom_holdout_mse_loss = sorted_loss[-1] 351 self.best_holdout_mse_loss = holdout_mse_loss.mean().item() 352 return { 353 'rollout_model/mse_loss': self.mse_loss, 354 'rollout_model/curr_holdout_mse_loss': self.curr_holdout_mse_loss, 355 'rollout_model/best_holdout_mse_loss': self.best_holdout_mse_loss, 356 'rollout_model/top_holdout_mse_loss': self.top_holdout_mse_loss, 357 'rollout_model/middle_holdout_mse_loss': self.middle_holdout_mse_loss, 358 'rollout_model/bottom_holdout_mse_loss': self.bottom_holdout_mse_loss, 359 } 360 361 def _get_jacobian(self, model, train_input_reg): 362 """ 363 train_input_reg: [ensemble_size, B, state_size+action_size] 364 365 ret: [ensemble_size, B, state_size+reward_size, state_size+action_size] 366 """ 367 368 def func(x): 369 x = x.view(self.ensemble_size, -1, self.state_size + self.action_size) 370 state = x[:, :, :self.state_size] 371 x = self.scaler.transform(x) 372 y, _ = model(x) 373 # y[:, :, self.reward_size:] += state, inplace operation leads to error 374 null = torch.zeros_like(y) 375 null[:, :, self.reward_size:] += state 376 y = y + null 377 378 return y.view(-1, self.state_size + self.reward_size, self.state_size + self.reward_size) 379 380 # reshape input 381 train_input_reg = train_input_reg.view(-1, self.state_size + self.action_size) 382 jacobian = get_batch_jacobian(func, train_input_reg, self.state_size + self.reward_size) 383 384 # reshape jacobian 385 return jacobian.view( 386 self.ensemble_size, -1, self.state_size + self.reward_size, self.state_size + self.action_size 387 ) 388 389 def _train_gradient_model(self, inputs, labels, inputs_reg, labels_reg): 390 #split 391 num_holdout = int(inputs.shape[0] * self.holdout_ratio) 392 train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:] 393 holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout] 394 395 #normalize 396 # self.scaler.fit(train_inputs) 397 train_inputs = self.scaler.transform(train_inputs) 398 holdout_inputs = self.scaler.transform(holdout_inputs) 399 400 #repeat for ensemble 401 holdout_inputs = unsqueeze_repeat(holdout_inputs, self.ensemble_size) 402 holdout_labels = unsqueeze_repeat(holdout_labels, self.ensemble_size) 403 404 #no split and normalization on regulation data 405 train_inputs_reg, train_labels_reg = inputs_reg, labels_reg 406 407 neighbor_index = get_neighbor_index(train_inputs_reg, self.k, serial=self.serial_calc_nn) 408 neighbor_inputs = train_inputs_reg[neighbor_index] # [N, k, state_size+action_size] 409 neighbor_labels = train_labels_reg[neighbor_index] # [N, k, state_size+reward_size] 410 neighbor_inputs_distance = (neighbor_inputs - train_inputs_reg.unsqueeze(1)) # [N, k, state_size+action_size] 411 neighbor_labels_distance = (neighbor_labels - train_labels_reg.unsqueeze(1)) # [N, k, state_size+reward_size] 412 413 self._epochs_since_update = 0 414 self._snapshots = {i: (-1, 1e10) for i in range(self.ensemble_size)} 415 self._save_states() 416 for epoch in itertools.count(): 417 418 train_idx = torch.stack([torch.randperm(train_inputs.shape[0]) 419 for _ in range(self.ensemble_size)]).to(train_inputs.device) 420 421 train_idx_reg = torch.stack([torch.randperm(train_inputs_reg.shape[0]) 422 for _ in range(self.ensemble_size)]).to(train_inputs_reg.device) 423 424 self.mse_loss = [] 425 self.grad_loss = [] 426 for start_pos in range(0, train_inputs.shape[0], self.batch_size): 427 idx = train_idx[:, start_pos:start_pos + self.batch_size] 428 train_input = train_inputs[idx] 429 train_label = train_labels[idx] 430 mean, logvar = self.gradient_model(train_input, ret_log_var=True) 431 loss, mse_loss = self.gradient_model.loss(mean, logvar, train_label) 432 433 # regulation loss 434 if start_pos % train_inputs_reg.shape[0] < (start_pos + self.batch_size) % train_inputs_reg.shape[0]: 435 idx_reg = train_idx_reg[:, start_pos % train_inputs_reg.shape[0]:(start_pos + self.batch_size) % 436 train_inputs_reg.shape[0]] 437 else: 438 idx_reg = train_idx_reg[:, 0:(start_pos + self.batch_size) % train_inputs_reg.shape[0]] 439 440 train_input_reg = train_inputs_reg[idx_reg] 441 neighbor_input_distance = neighbor_inputs_distance[idx_reg 442 ] # [ensemble_size, B, k, state_size+action_size] 443 neighbor_label_distance = neighbor_labels_distance[idx_reg 444 ] # [ensemble_size, B, k, state_size+reward_size] 445 446 jacobian = self._get_jacobian(self.gradient_model, train_input_reg).unsqueeze(2).repeat_interleave( 447 self.k, dim=2 448 ) # [ensemble_size, B, k(repeat), state_size+reward_size, state_size+action_size] 449 450 directional_derivative = (jacobian @ neighbor_input_distance.unsqueeze(-1)).squeeze( 451 -1 452 ) # [ensemble_size, B, k, state_size+reward_size] 453 454 loss_reg = torch.pow((neighbor_label_distance - directional_derivative), 455 2).sum(0).mean() # sumed over network 456 457 self.gradient_model.train(loss, loss_reg, self.reg) 458 self.mse_loss.append(mse_loss.mean().item()) 459 self.grad_loss.append(loss_reg.item()) 460 461 self.mse_loss = sum(self.mse_loss) / len(self.mse_loss) 462 self.grad_loss = sum(self.grad_loss) / len(self.grad_loss) 463 464 with torch.no_grad(): 465 holdout_mean, holdout_logvar = self.gradient_model(holdout_inputs, ret_log_var=True) 466 _, holdout_mse_loss = self.gradient_model.loss(holdout_mean, holdout_logvar, holdout_labels) 467 self.curr_holdout_mse_loss = holdout_mse_loss.mean().item() 468 break_train = self._save_best(epoch, holdout_mse_loss) 469 if break_train: 470 break 471 472 self._load_states() 473 with torch.no_grad(): 474 holdout_mean, holdout_logvar = self.gradient_model(holdout_inputs, ret_log_var=True) 475 _, holdout_mse_loss = self.gradient_model.loss(holdout_mean, holdout_logvar, holdout_labels) 476 sorted_loss, sorted_loss_idx = holdout_mse_loss.sort() 477 sorted_loss = sorted_loss.detach().cpu().numpy().tolist() 478 sorted_loss_idx = sorted_loss_idx.detach().cpu().numpy().tolist() 479 self.elite_model_idxes_gradient_model = sorted_loss_idx[:self.elite_size] 480 self.top_holdout_mse_loss = sorted_loss[0] 481 self.middle_holdout_mse_loss = sorted_loss[self.ensemble_size // 2] 482 self.bottom_holdout_mse_loss = sorted_loss[-1] 483 self.best_holdout_mse_loss = holdout_mse_loss.mean().item() 484 return { 485 'gradient_model/mse_loss': self.mse_loss, 486 'gradient_model/grad_loss': self.grad_loss, 487 'gradient_model/curr_holdout_mse_loss': self.curr_holdout_mse_loss, 488 'gradient_model/best_holdout_mse_loss': self.best_holdout_mse_loss, 489 'gradient_model/top_holdout_mse_loss': self.top_holdout_mse_loss, 490 'gradient_model/middle_holdout_mse_loss': self.middle_holdout_mse_loss, 491 'gradient_model/bottom_holdout_mse_loss': self.bottom_holdout_mse_loss, 492 } 493 494 def _save_states(self, ): 495 self._states = copy.deepcopy(self.state_dict()) 496 497 def _save_state(self, id): 498 state_dict = self.state_dict() 499 for k, v in state_dict.items(): 500 if 'weight' in k or 'bias' in k: 501 self._states[k].data[id] = copy.deepcopy(v.data[id]) 502 503 def _load_states(self): 504 self.load_state_dict(self._states) 505 506 def _save_best(self, epoch, holdout_losses): 507 updated = False 508 for i in range(len(holdout_losses)): 509 current = holdout_losses[i] 510 _, best = self._snapshots[i] 511 improvement = (best - current) / best 512 if improvement > 0.01: 513 self._snapshots[i] = (epoch, current) 514 self._save_state(i) 515 # self._save_state(i) 516 updated = True 517 # improvement = (best - current) / best 518 519 if updated: 520 self._epochs_since_update = 0 521 else: 522 self._epochs_since_update += 1 523 return self._epochs_since_update > self.max_epochs_since_update