ding.world_model.dreamer¶
ding.world_model.dreamer
¶
Full Source Code
../ding/world_model/dreamer.py
1import numpy as np 2import copy 3import torch 4from torch import nn 5 6from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts 7from ding.utils.data import default_collate 8from ding.model import ConvEncoder, FCEncoder 9from ding.world_model.base_world_model import WorldModel 10from ding.world_model.model.networks import RSSM, ConvDecoder 11from ding.torch_utils import to_device, one_hot 12from ding.torch_utils.network.dreamer import DenseHead 13 14 15@WORLD_MODEL_REGISTRY.register('dreamer') 16class DREAMERWorldModel(WorldModel, nn.Module): 17 config = dict( 18 pretrain=100, 19 train_freq=2, 20 model=dict( 21 state_size=None, 22 action_size=None, 23 model_lr=1e-4, 24 reward_size=1, 25 hidden_size=200, 26 batch_size=256, 27 max_epochs_since_update=5, 28 dyn_stoch=32, 29 dyn_deter=512, 30 dyn_hidden=512, 31 dyn_input_layers=1, 32 dyn_output_layers=1, 33 dyn_rec_depth=1, 34 dyn_shared=False, 35 dyn_discrete=32, 36 act='SiLU', 37 norm='LayerNorm', 38 grad_heads=['image', 'reward', 'discount'], 39 units=512, 40 image_dec_layers=2, 41 reward_layers=2, 42 discount_layers=2, 43 value_layers=2, 44 actor_layers=2, 45 cnn_depth=32, 46 encoder_kernels=[4, 4, 4, 4], 47 decoder_kernels=[4, 4, 4, 4], 48 reward_head='twohot_symlog', 49 kl_lscale=0.1, 50 kl_rscale=0.5, 51 kl_free=1.0, 52 kl_forward=False, 53 pred_discount=True, 54 dyn_mean_act='none', 55 dyn_std_act='sigmoid2', 56 dyn_temp_post=True, 57 dyn_min_std=0.1, 58 dyn_cell='gru_layer_norm', 59 unimix_ratio=0.01, 60 device='cuda' if torch.cuda.is_available() else 'cpu', 61 obs_type='RGB', 62 action_type='continuous', 63 encoder_hidden_size_list=[64, 128, 128], 64 ), 65 ) 66 67 def __init__(self, cfg, env, tb_logger): 68 WorldModel.__init__(self, cfg, env, tb_logger) 69 nn.Module.__init__(self) 70 71 self.pretrain_flag = True 72 self._cfg = cfg.model 73 #self._cfg.act = getattr(torch.nn, self._cfg.act), 74 #self._cfg.norm = getattr(torch.nn, self._cfg.norm), 75 self._cfg.act = nn.modules.activation.SiLU # nn.SiLU 76 self._cfg.norm = nn.modules.normalization.LayerNorm # nn.LayerNorm 77 self.state_size = self._cfg.state_size 78 self.obs_type = self._cfg.obs_type 79 self.action_size = self._cfg.action_size 80 self.action_type = self._cfg.action_type 81 self.reward_size = self._cfg.reward_size 82 self.hidden_size = self._cfg.hidden_size 83 self.batch_size = self._cfg.batch_size 84 if self.obs_type == 'vector': 85 self.encoder = FCEncoder(self.state_size, self._cfg.encoder_hidden_size_list, activation=torch.nn.SiLU()) 86 self.embed_size = self._cfg.encoder_hidden_size_list[-1] 87 elif self.obs_type == 'RGB': 88 self.encoder = ConvEncoder( 89 self.state_size, 90 hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128? 91 activation=torch.nn.SiLU(), 92 kernel_size=self._cfg.encoder_kernels, 93 layer_norm=True 94 ) 95 self.embed_size = ( 96 (self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth * 97 2 ** (len(self._cfg.encoder_kernels) - 1) 98 ) 99 100 self.dynamics = RSSM( 101 self._cfg.dyn_stoch, 102 self._cfg.dyn_deter, 103 self._cfg.dyn_hidden, 104 self._cfg.action_type, 105 self._cfg.dyn_input_layers, 106 self._cfg.dyn_output_layers, 107 self._cfg.dyn_rec_depth, 108 self._cfg.dyn_shared, 109 self._cfg.dyn_discrete, 110 self._cfg.act, 111 self._cfg.norm, 112 self._cfg.dyn_mean_act, 113 self._cfg.dyn_std_act, 114 self._cfg.dyn_temp_post, 115 self._cfg.dyn_min_std, 116 self._cfg.dyn_cell, 117 self._cfg.unimix_ratio, 118 self._cfg.action_size, 119 self.embed_size, 120 self._cfg.device, 121 ) 122 self.heads = nn.ModuleDict() 123 if self._cfg.dyn_discrete: 124 feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter 125 else: 126 feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter 127 128 if isinstance(self.state_size, int): 129 self.heads['image'] = DenseHead( 130 feat_size, 131 (self.state_size, ), 132 self._cfg.image_dec_layers, 133 self._cfg.units, 134 'SiLU', # self._cfg.act 135 'LN', # self._cfg.norm 136 dist='binary', 137 outscale=0.0, 138 device=self._cfg.device, 139 ) 140 elif len(self.state_size) == 3: 141 self.heads["image"] = ConvDecoder( 142 feat_size, # pytorch version 143 self._cfg.cnn_depth, 144 self._cfg.act, 145 self._cfg.norm, 146 self.state_size, 147 self._cfg.decoder_kernels, 148 ) 149 self.heads["reward"] = DenseHead( 150 feat_size, # dyn_stoch * dyn_discrete + dyn_deter 151 (255, ), 152 self._cfg.reward_layers, 153 self._cfg.units, 154 'SiLU', # self._cfg.act 155 'LN', # self._cfg.norm 156 dist=self._cfg.reward_head, 157 outscale=0.0, 158 device=self._cfg.device, 159 ) 160 if self._cfg.pred_discount: 161 self.heads["discount"] = DenseHead( 162 feat_size, # pytorch version 163 [], 164 self._cfg.discount_layers, 165 self._cfg.units, 166 'SiLU', # self._cfg.act 167 'LN', # self._cfg.norm 168 dist="binary", 169 device=self._cfg.device, 170 ) 171 172 if self._cuda: 173 self.cuda() 174 # to do 175 # grad_clip, weight_decay 176 self.optimizer = torch.optim.Adam(self.parameters(), lr=self._cfg.model_lr) 177 178 def step(self, obs, act): 179 pass 180 181 def eval(self, env_buffer, envstep, train_iter): 182 pass 183 184 def should_pretrain(self): 185 if self.pretrain_flag: 186 self.pretrain_flag = False 187 return True 188 return False 189 190 def train(self, env_buffer, envstep, train_iter, batch_size, batch_length): 191 self.last_train_step = envstep 192 data = env_buffer.sample( 193 batch_size, batch_length, train_iter 194 ) # [len=B, ele=[len=T, ele={dict_key: Tensor(any_dims)}]] 195 data = default_collate(data) # -> [len=T, ele={dict_key: Tensor(B, any_dims)}] 196 data = lists_to_dicts(data, recursive=True) # -> {some_key: T lists}, each list is [B, some_dim] 197 data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])} 198 199 data['discount'] = data.get('discount', 1.0 - data['done'].float()) 200 data['weight'] = data.get('weight', None) 201 if self.obs_type == 'RGB': 202 data['image'] = data['obs'] - 0.5 203 else: 204 data['image'] = data['obs'] 205 if self.action_type == 'continuous': 206 data['action'] *= (1.0 / torch.clip(torch.abs(data['action']), min=1.0)) 207 else: 208 data['action'] = one_hot(data['action'], self.action_size) 209 data = to_device(data, self._cfg.device) 210 if len(data['reward'].shape) == 2: 211 data['reward'] = data['reward'].unsqueeze(-1) 212 if len(data['action'].shape) == 2: 213 data['action'] = data['action'].unsqueeze(-1) 214 if len(data['discount'].shape) == 2: 215 data['discount'] = data['discount'].unsqueeze(-1) 216 217 self.requires_grad_(requires_grad=True) 218 219 image = data['image'].reshape([-1] + list(data['image'].shape[2:])) 220 embed = self.encoder(image) 221 embed = embed.reshape(list(data['image'].shape[:2]) + [embed.shape[-1]]) 222 223 post, prior = self.dynamics.observe(embed, data["action"]) 224 kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( 225 post, prior, self._cfg.kl_forward, self._cfg.kl_free, self._cfg.kl_lscale, self._cfg.kl_rscale 226 ) 227 losses = {} 228 likes = {} 229 for name, head in self.heads.items(): 230 grad_head = name in self._cfg.grad_heads 231 feat = self.dynamics.get_feat(post) 232 feat = feat if grad_head else feat.detach() 233 pred = head(feat) 234 like = pred.log_prob(data[name]) 235 likes[name] = like 236 losses[name] = -torch.mean(like) 237 model_loss = sum(losses.values()) + kl_loss 238 239 # ==================== 240 # world model update 241 # ==================== 242 self.optimizer.zero_grad() 243 model_loss.backward() 244 self.optimizer.step() 245 246 self.requires_grad_(requires_grad=False) 247 # log 248 if self.tb_logger is not None: 249 for name, loss in losses.items(): 250 self.tb_logger.add_scalar(name + '_loss', loss.detach().cpu().numpy().item(), envstep) 251 self.tb_logger.add_scalar('kl_free', self._cfg.kl_free, envstep) 252 self.tb_logger.add_scalar('kl_lscale', self._cfg.kl_lscale, envstep) 253 self.tb_logger.add_scalar('kl_rscale', self._cfg.kl_rscale, envstep) 254 self.tb_logger.add_scalar('loss_lhs', loss_lhs.detach().cpu().numpy().item(), envstep) 255 self.tb_logger.add_scalar('loss_rhs', loss_rhs.detach().cpu().numpy().item(), envstep) 256 self.tb_logger.add_scalar('kl', torch.mean(kl_value).detach().cpu().numpy().item(), envstep) 257 258 prior_ent = torch.mean(self.dynamics.get_dist(prior).entropy()).detach().cpu().numpy() 259 post_ent = torch.mean(self.dynamics.get_dist(post).entropy()).detach().cpu().numpy() 260 261 self.tb_logger.add_scalar('prior_ent', prior_ent.item(), envstep) 262 self.tb_logger.add_scalar('post_ent', post_ent.item(), envstep) 263 264 context = dict( 265 embed=embed, 266 feat=self.dynamics.get_feat(post), 267 kl=kl_value, 268 postent=self.dynamics.get_dist(post).entropy(), 269 ) 270 post = {k: v.detach() for k, v in post.items()} 271 return post, context 272 273 def _save_states(self, ): 274 self._states = copy.deepcopy(self.state_dict()) 275 276 def _save_state(self, id): 277 state_dict = self.state_dict() 278 for k, v in state_dict.items(): 279 if 'weight' in k or 'bias' in k: 280 self._states[k].data[id] = copy.deepcopy(v.data[id]) 281 282 def _load_states(self): 283 self.load_state_dict(self._states) 284 285 def _save_best(self, epoch, holdout_losses): 286 updated = False 287 for i in range(len(holdout_losses)): 288 current = holdout_losses[i] 289 _, best = self._snapshots[i] 290 improvement = (best - current) / best 291 if improvement > 0.01: 292 self._snapshots[i] = (epoch, current) 293 self._save_state(i) 294 # self._save_state(i) 295 updated = True 296 # improvement = (best - current) / best 297 298 if updated: 299 self._epochs_since_update = 0 300 else: 301 self._epochs_since_update += 1 302 return self._epochs_since_update > self.max_epochs_since_update