ding.world_model.mbpo¶
ding.world_model.mbpo
¶
Full Source Code
../ding/world_model/mbpo.py
1import itertools 2import numpy as np 3import copy 4import torch 5from torch import nn 6 7from ding.utils import WORLD_MODEL_REGISTRY 8from ding.utils.data import default_collate 9from ding.world_model.base_world_model import HybridWorldModel 10from ding.world_model.model.ensemble import EnsembleModel, StandardScaler 11from ding.torch_utils import fold_batch, unfold_batch, unsqueeze_repeat 12 13 14@WORLD_MODEL_REGISTRY.register('mbpo') 15class MBPOWorldModel(HybridWorldModel, nn.Module): 16 config = dict( 17 model=dict( 18 ensemble_size=7, 19 elite_size=5, 20 state_size=None, 21 action_size=None, 22 reward_size=1, 23 hidden_size=200, 24 use_decay=False, 25 batch_size=256, 26 holdout_ratio=0.2, 27 max_epochs_since_update=5, 28 deterministic_rollout=True, 29 ), 30 ) 31 32 def __init__(self, cfg, env, tb_logger): 33 HybridWorldModel.__init__(self, cfg, env, tb_logger) 34 nn.Module.__init__(self) 35 36 cfg = cfg.model 37 self.ensemble_size = cfg.ensemble_size 38 self.elite_size = cfg.elite_size 39 self.state_size = cfg.state_size 40 self.action_size = cfg.action_size 41 self.reward_size = cfg.reward_size 42 self.hidden_size = cfg.hidden_size 43 self.use_decay = cfg.use_decay 44 self.batch_size = cfg.batch_size 45 self.holdout_ratio = cfg.holdout_ratio 46 self.max_epochs_since_update = cfg.max_epochs_since_update 47 self.deterministic_rollout = cfg.deterministic_rollout 48 49 self.ensemble_model = EnsembleModel( 50 self.state_size, 51 self.action_size, 52 self.reward_size, 53 self.ensemble_size, 54 self.hidden_size, 55 use_decay=self.use_decay 56 ) 57 self.scaler = StandardScaler(self.state_size + self.action_size) 58 59 if self._cuda: 60 self.cuda() 61 62 self.ensemble_mse_losses = [] 63 self.model_variances = [] 64 self.elite_model_idxes = [] 65 66 def step(self, obs, act, batch_size=8192, keep_ensemble=False): 67 if len(act.shape) == 1: 68 act = act.unsqueeze(1) 69 if self._cuda: 70 obs = obs.cuda() 71 act = act.cuda() 72 inputs = torch.cat([obs, act], dim=-1) 73 if keep_ensemble: 74 inputs, dim = fold_batch(inputs, 1) 75 inputs = self.scaler.transform(inputs) 76 inputs = unfold_batch(inputs, dim) 77 else: 78 inputs = self.scaler.transform(inputs) 79 # predict 80 ensemble_mean, ensemble_var = [], [] 81 batch_dim = 0 if len(inputs.shape) == 2 else 1 82 for i in range(0, inputs.shape[batch_dim], batch_size): 83 if keep_ensemble: 84 # inputs: [E, B, D] 85 input = inputs[:, i:i + batch_size] 86 else: 87 # input: [B, D] 88 input = unsqueeze_repeat(inputs[i:i + batch_size], self.ensemble_size) 89 b_mean, b_var = self.ensemble_model(input, ret_log_var=False) 90 ensemble_mean.append(b_mean) 91 ensemble_var.append(b_var) 92 ensemble_mean = torch.cat(ensemble_mean, 1) 93 ensemble_var = torch.cat(ensemble_var, 1) 94 if keep_ensemble: 95 ensemble_mean[:, :, 1:] += obs 96 else: 97 ensemble_mean[:, :, 1:] += obs.unsqueeze(0) 98 ensemble_std = ensemble_var.sqrt() 99 # sample from the predicted distribution 100 if self.deterministic_rollout: 101 ensemble_sample = ensemble_mean 102 else: 103 ensemble_sample = ensemble_mean + torch.randn_like(ensemble_mean).to(ensemble_mean) * ensemble_std 104 if keep_ensemble: 105 # [E, B, D] 106 rewards, next_obs = ensemble_sample[:, :, 0], ensemble_sample[:, :, 1:] 107 next_obs_flatten, dim = fold_batch(next_obs) 108 done = unfold_batch(self.env.termination_fn(next_obs_flatten), dim) 109 return rewards, next_obs, done 110 # sample from ensemble 111 model_idxes = torch.from_numpy(np.random.choice(self.elite_model_idxes, size=len(obs))).to(inputs.device) 112 batch_idxes = torch.arange(len(obs)).to(inputs.device) 113 sample = ensemble_sample[model_idxes, batch_idxes] 114 rewards, next_obs = sample[:, 0], sample[:, 1:] 115 116 return rewards, next_obs, self.env.termination_fn(next_obs) 117 118 def eval(self, env_buffer, envstep, train_iter): 119 data = env_buffer.sample(self.eval_freq, train_iter) 120 data = default_collate(data) 121 data['done'] = data['done'].float() 122 data['weight'] = data.get('weight', None) 123 obs = data['obs'] 124 action = data['action'] 125 reward = data['reward'] 126 next_obs = data['next_obs'] 127 if len(reward.shape) == 1: 128 reward = reward.unsqueeze(1) 129 if len(action.shape) == 1: 130 action = action.unsqueeze(1) 131 132 # build eval samples 133 inputs = torch.cat([obs, action], dim=1) 134 labels = torch.cat([reward, next_obs - obs], dim=1) 135 if self._cuda: 136 inputs = inputs.cuda() 137 labels = labels.cuda() 138 139 # normalize 140 inputs = self.scaler.transform(inputs) 141 142 # repeat for ensemble 143 inputs = unsqueeze_repeat(inputs, self.ensemble_size) 144 labels = unsqueeze_repeat(labels, self.ensemble_size) 145 146 # eval 147 with torch.no_grad(): 148 mean, logvar = self.ensemble_model(inputs, ret_log_var=True) 149 loss, mse_loss = self.ensemble_model.loss(mean, logvar, labels) 150 ensemble_mse_loss = torch.pow(mean.mean(0) - labels[0], 2) 151 model_variance = mean.var(0) 152 self.tb_logger.add_scalar('env_model_step/eval_mse_loss', mse_loss.mean().item(), envstep) 153 self.tb_logger.add_scalar('env_model_step/eval_ensemble_mse_loss', ensemble_mse_loss.mean().item(), envstep) 154 self.tb_logger.add_scalar('env_model_step/eval_model_variances', model_variance.mean().item(), envstep) 155 156 self.last_eval_step = envstep 157 158 def train(self, env_buffer, envstep, train_iter): 159 data = env_buffer.sample(env_buffer.count(), train_iter) 160 data = default_collate(data) 161 data['done'] = data['done'].float() 162 data['weight'] = data.get('weight', None) 163 obs = data['obs'] 164 action = data['action'] 165 reward = data['reward'] 166 next_obs = data['next_obs'] 167 if len(reward.shape) == 1: 168 reward = reward.unsqueeze(1) 169 if len(action.shape) == 1: 170 action = action.unsqueeze(1) 171 # build train samples 172 inputs = torch.cat([obs, action], dim=1) 173 labels = torch.cat([reward, next_obs - obs], dim=1) 174 if self._cuda: 175 inputs = inputs.cuda() 176 labels = labels.cuda() 177 # train 178 logvar = self._train(inputs, labels) 179 self.last_train_step = envstep 180 # log 181 if self.tb_logger is not None: 182 for k, v in logvar.items(): 183 self.tb_logger.add_scalar('env_model_step/' + k, v, envstep) 184 185 def _train(self, inputs, labels): 186 #split 187 num_holdout = int(inputs.shape[0] * self.holdout_ratio) 188 train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:] 189 holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout] 190 191 #normalize 192 self.scaler.fit(train_inputs) 193 train_inputs = self.scaler.transform(train_inputs) 194 holdout_inputs = self.scaler.transform(holdout_inputs) 195 196 #repeat for ensemble 197 holdout_inputs = unsqueeze_repeat(holdout_inputs, self.ensemble_size) 198 holdout_labels = unsqueeze_repeat(holdout_labels, self.ensemble_size) 199 200 self._epochs_since_update = 0 201 self._snapshots = {i: (-1, 1e10) for i in range(self.ensemble_size)} 202 self._save_states() 203 for epoch in itertools.count(): 204 205 train_idx = torch.stack([torch.randperm(train_inputs.shape[0]) 206 for _ in range(self.ensemble_size)]).to(train_inputs.device) 207 self.mse_loss = [] 208 for start_pos in range(0, train_inputs.shape[0], self.batch_size): 209 idx = train_idx[:, start_pos:start_pos + self.batch_size] 210 train_input = train_inputs[idx] 211 train_label = train_labels[idx] 212 mean, logvar = self.ensemble_model(train_input, ret_log_var=True) 213 loss, mse_loss = self.ensemble_model.loss(mean, logvar, train_label) 214 self.ensemble_model.train(loss) 215 self.mse_loss.append(mse_loss.mean().item()) 216 self.mse_loss = sum(self.mse_loss) / len(self.mse_loss) 217 218 with torch.no_grad(): 219 holdout_mean, holdout_logvar = self.ensemble_model(holdout_inputs, ret_log_var=True) 220 _, holdout_mse_loss = self.ensemble_model.loss(holdout_mean, holdout_logvar, holdout_labels) 221 self.curr_holdout_mse_loss = holdout_mse_loss.mean().item() 222 break_train = self._save_best(epoch, holdout_mse_loss) 223 if break_train: 224 break 225 226 self._load_states() 227 with torch.no_grad(): 228 holdout_mean, holdout_logvar = self.ensemble_model(holdout_inputs, ret_log_var=True) 229 _, holdout_mse_loss = self.ensemble_model.loss(holdout_mean, holdout_logvar, holdout_labels) 230 sorted_loss, sorted_loss_idx = holdout_mse_loss.sort() 231 sorted_loss = sorted_loss.detach().cpu().numpy().tolist() 232 sorted_loss_idx = sorted_loss_idx.detach().cpu().numpy().tolist() 233 self.elite_model_idxes = sorted_loss_idx[:self.elite_size] 234 self.top_holdout_mse_loss = sorted_loss[0] 235 self.middle_holdout_mse_loss = sorted_loss[self.ensemble_size // 2] 236 self.bottom_holdout_mse_loss = sorted_loss[-1] 237 self.best_holdout_mse_loss = holdout_mse_loss.mean().item() 238 return { 239 'mse_loss': self.mse_loss, 240 'curr_holdout_mse_loss': self.curr_holdout_mse_loss, 241 'best_holdout_mse_loss': self.best_holdout_mse_loss, 242 'top_holdout_mse_loss': self.top_holdout_mse_loss, 243 'middle_holdout_mse_loss': self.middle_holdout_mse_loss, 244 'bottom_holdout_mse_loss': self.bottom_holdout_mse_loss, 245 } 246 247 def _save_states(self, ): 248 self._states = copy.deepcopy(self.state_dict()) 249 250 def _save_state(self, id): 251 state_dict = self.state_dict() 252 for k, v in state_dict.items(): 253 if 'weight' in k or 'bias' in k: 254 self._states[k].data[id] = copy.deepcopy(v.data[id]) 255 256 def _load_states(self): 257 self.load_state_dict(self._states) 258 259 def _save_best(self, epoch, holdout_losses): 260 updated = False 261 for i in range(len(holdout_losses)): 262 current = holdout_losses[i] 263 _, best = self._snapshots[i] 264 improvement = (best - current) / best 265 if improvement > 0.01: 266 self._snapshots[i] = (epoch, current) 267 self._save_state(i) 268 # self._save_state(i) 269 updated = True 270 # improvement = (best - current) / best 271 272 if updated: 273 self._epochs_since_update = 0 274 else: 275 self._epochs_since_update += 1 276 return self._epochs_since_update > self.max_epochs_since_update