Skip to content

ding.world_model.model.networks

ding.world_model.model.networks

Full Source Code

../ding/world_model/model/networks.py

1import math 2import numpy as np 3from typing import Optional, Dict, Union, List 4 5import torch 6from torch import nn 7import torch.nn.functional as F 8from torch import distributions as torchd 9from ding.utils import SequenceType 10from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, static_scan, \ 11 OneHotDist, ContDist, SymlogDist, DreamerLayerNorm 12 13 14class RSSM(nn.Module): 15 16 def __init__( 17 self, 18 stoch=30, 19 deter=200, 20 hidden=200, 21 action_type=None, 22 layers_input=1, 23 layers_output=1, 24 rec_depth=1, 25 shared=False, 26 discrete=False, 27 act=nn.ELU, 28 norm=nn.LayerNorm, 29 mean_act="none", 30 std_act="softplus", 31 temp_post=True, 32 min_std=0.1, 33 cell="gru", 34 unimix_ratio=0.01, 35 num_actions=None, 36 embed=None, 37 device=None, 38 ): 39 super(RSSM, self).__init__() 40 self._stoch = stoch 41 self._deter = deter 42 self._hidden = hidden 43 self._action_type = action_type 44 self._min_std = min_std 45 self._layers_input = layers_input 46 self._layers_output = layers_output 47 self._rec_depth = rec_depth 48 self._shared = shared 49 self._discrete = discrete 50 self._act = act 51 self._norm = norm 52 self._mean_act = mean_act 53 self._std_act = std_act 54 self._temp_post = temp_post 55 self._unimix_ratio = unimix_ratio 56 self._embed = embed 57 self._device = device 58 59 inp_layers = [] 60 if self._discrete: 61 inp_dim = self._stoch * self._discrete + num_actions 62 else: 63 inp_dim = self._stoch + num_actions 64 if self._shared: 65 inp_dim += self._embed 66 for i in range(self._layers_input): 67 inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) 68 inp_layers.append(self._norm(self._hidden, eps=1e-03)) 69 inp_layers.append(self._act()) 70 if i == 0: 71 inp_dim = self._hidden 72 self._inp_layers = nn.Sequential(*inp_layers) 73 self._inp_layers.apply(weight_init) 74 75 if cell == "gru": 76 self._cell = GRUCell(self._hidden, self._deter) 77 self._cell.apply(weight_init) 78 elif cell == "gru_layer_norm": 79 self._cell = GRUCell(self._hidden, self._deter, norm=True) 80 self._cell.apply(weight_init) 81 else: 82 raise NotImplementedError(cell) 83 84 img_out_layers = [] 85 inp_dim = self._deter 86 for i in range(self._layers_output): 87 img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) 88 img_out_layers.append(self._norm(self._hidden, eps=1e-03)) 89 img_out_layers.append(self._act()) 90 if i == 0: 91 inp_dim = self._hidden 92 self._img_out_layers = nn.Sequential(*img_out_layers) 93 self._img_out_layers.apply(weight_init) 94 95 obs_out_layers = [] 96 if self._temp_post: 97 inp_dim = self._deter + self._embed 98 else: 99 inp_dim = self._embed 100 for i in range(self._layers_output): 101 obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False)) 102 obs_out_layers.append(self._norm(self._hidden, eps=1e-03)) 103 obs_out_layers.append(self._act()) 104 if i == 0: 105 inp_dim = self._hidden 106 self._obs_out_layers = nn.Sequential(*obs_out_layers) 107 self._obs_out_layers.apply(weight_init) 108 109 if self._discrete: 110 self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) 111 self._ims_stat_layer.apply(weight_init) 112 self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete) 113 self._obs_stat_layer.apply(weight_init) 114 else: 115 self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) 116 self._ims_stat_layer.apply(weight_init) 117 self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) 118 self._obs_stat_layer.apply(weight_init) 119 120 def initial(self, batch_size): 121 deter = torch.zeros(batch_size, self._deter).to(self._device) 122 if self._discrete: 123 state = dict( 124 logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), 125 stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device), 126 deter=deter, 127 ) 128 else: 129 state = dict( 130 mean=torch.zeros([batch_size, self._stoch]).to(self._device), 131 std=torch.zeros([batch_size, self._stoch]).to(self._device), 132 stoch=torch.zeros([batch_size, self._stoch]).to(self._device), 133 deter=deter, 134 ) 135 return state 136 137 def observe(self, embed, action, state=None): 138 swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) # 交换前两维 139 if state is None: 140 state = self.initial(action.shape[0]) # {logit, stoch, deter} 141 # (batch, time, ch) -> (time, batch, ch) 142 embed, action = swap(embed), swap(action) 143 post, prior = static_scan( 144 lambda prev_state, prev_act, embed: self.obs_step(prev_state[0], prev_act, embed), 145 (action, embed), 146 (state, state), 147 ) 148 149 # (time, batch, stoch, discrete_num) -> (batch, time, stoch, discrete_num) 150 post = {k: swap(v) for k, v in post.items()} 151 prior = {k: swap(v) for k, v in prior.items()} 152 return post, prior 153 154 def imagine(self, action, state=None): 155 swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) 156 if state is None: 157 state = self.initial(action.shape[0]) 158 assert isinstance(state, dict), state 159 action = action 160 action = swap(action) 161 prior = static_scan(self.img_step, [action], state) 162 prior = prior[0] 163 prior = {k: swap(v) for k, v in prior.items()} 164 return prior 165 166 def get_feat(self, state): 167 stoch = state["stoch"] 168 if self._discrete: 169 shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete] 170 stoch = stoch.reshape(shape) 171 return torch.cat([stoch, state["deter"]], -1) 172 173 def get_dist(self, state, dtype=None): 174 if self._discrete: 175 logit = state["logit"] 176 dist = torchd.independent.Independent(OneHotDist(logit, unimix_ratio=self._unimix_ratio), 1) 177 else: 178 mean, std = state["mean"], state["std"] 179 dist = ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), 1)) 180 return dist 181 182 def obs_step(self, prev_state, prev_action, embed, sample=True): 183 # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) 184 # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs 185 if self._action_type == 'continuous': 186 prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() 187 prior = self.img_step(prev_state, prev_action, None, sample) 188 if self._shared: 189 post = self.img_step(prev_state, prev_action, embed, sample) 190 else: 191 if self._temp_post: 192 x = torch.cat([prior["deter"], embed], -1) 193 else: 194 x = embed 195 # (batch_size, prior_deter + embed) -> (batch_size, hidden) 196 x = self._obs_out_layers(x) 197 # (batch_size, hidden) -> (batch_size, stoch, discrete_num) 198 stats = self._suff_stats_layer("obs", x) 199 if sample: 200 stoch = self.get_dist(stats).sample() 201 else: 202 stoch = self.get_dist(stats).mode() 203 post = {"stoch": stoch, "deter": prior["deter"], **stats} 204 return post, prior 205 206 # this is used for making future image 207 def img_step(self, prev_state, prev_action, embed=None, sample=True): 208 # (batch, stoch, discrete_num) 209 if self._action_type == 'continuous': 210 prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() 211 prev_stoch = prev_state["stoch"] 212 if self._discrete: 213 shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] 214 # (batch, stoch, discrete_num) -> (batch, stoch * discrete_num) 215 prev_stoch = prev_stoch.reshape(shape) 216 if self._shared: 217 if embed is None: 218 shape = list(prev_action.shape[:-1]) + [self._embed] 219 embed = torch.zeros(shape) 220 # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed) 221 x = torch.cat([prev_stoch, prev_action, embed], -1) 222 else: 223 x = torch.cat([prev_stoch, prev_action], -1) 224 # (batch, stoch * discrete_num + action, embed) -> (batch, hidden) 225 x = self._inp_layers(x) 226 for _ in range(self._rec_depth): # rec depth is not correctly implemented 227 deter = prev_state["deter"] 228 # (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter) 229 x, deter = self._cell(x, [deter]) 230 deter = deter[0] # Keras wraps the state in a list. 231 # (batch, deter) -> (batch, hidden) 232 x = self._img_out_layers(x) 233 # (batch, hidden) -> (batch_size, stoch, discrete_num) 234 stats = self._suff_stats_layer("ims", x) 235 if sample: 236 stoch = self.get_dist(stats).sample() 237 else: 238 stoch = self.get_dist(stats).mode() 239 prior = {"stoch": stoch, "deter": deter, **stats} # {stoch, deter, logit} 240 return prior 241 242 def _suff_stats_layer(self, name, x): 243 if self._discrete: 244 if name == "ims": 245 x = self._ims_stat_layer(x) 246 elif name == "obs": 247 x = self._obs_stat_layer(x) 248 else: 249 raise NotImplementedError 250 logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete]) 251 return {"logit": logit} 252 else: 253 if name == "ims": 254 x = self._ims_stat_layer(x) 255 elif name == "obs": 256 x = self._obs_stat_layer(x) 257 else: 258 raise NotImplementedError 259 mean, std = torch.split(x, [self._stoch] * 2, -1) 260 mean = { 261 "none": lambda: mean, 262 "tanh5": lambda: 5.0 * torch.tanh(mean / 5.0), 263 }[self._mean_act]() 264 std = { 265 "softplus": lambda: torch.softplus(std), 266 "abs": lambda: torch.abs(std + 1), 267 "sigmoid": lambda: torch.sigmoid(std), 268 "sigmoid2": lambda: 2 * torch.sigmoid(std / 2), 269 }[self._std_act]() 270 std = std + self._min_std 271 return {"mean": mean, "std": std} 272 273 def kl_loss(self, post, prior, forward, free, lscale, rscale): 274 kld = torchd.kl.kl_divergence 275 dist = lambda x: self.get_dist(x) 276 sg = lambda x: {k: v.detach() for k, v in x.items()} 277 # forward == false -> (post, prior) 278 lhs, rhs = (prior, post) if forward else (post, prior) 279 280 # forward == false -> Lrep 281 value_lhs = value = kld( 282 dist(lhs) if self._discrete else dist(lhs)._dist, 283 dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist, 284 ) 285 # forward == false -> Ldyn 286 value_rhs = kld( 287 dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist, 288 dist(rhs) if self._discrete else dist(rhs)._dist, 289 ) 290 # free bits 291 loss_lhs = torch.mean(torch.clip(value_lhs, min=free)) 292 loss_rhs = torch.mean(torch.clip(value_rhs, min=free)) 293 loss = lscale * loss_lhs + rscale * loss_rhs 294 295 return loss, value, loss_lhs, loss_rhs 296 297 298class ConvDecoder(nn.Module): 299 300 def __init__( 301 self, 302 inp_depth, # config.dyn_stoch * config.dyn_discrete + config.dyn_deter 303 depth=32, 304 act=nn.ELU, 305 norm=nn.LayerNorm, 306 shape=(3, 64, 64), 307 kernels=(3, 3, 3, 3), 308 outscale=1.0, 309 ): 310 super(ConvDecoder, self).__init__() 311 self._inp_depth = inp_depth 312 self._act = act 313 self._norm = norm 314 self._depth = depth 315 self._shape = shape 316 self._kernels = kernels 317 self._embed_size = ((64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1)) 318 319 self._linear_layer = nn.Linear(inp_depth, self._embed_size) 320 inp_dim = self._embed_size // 16 # 除以最后的4*4 feature map来得到channel数 321 322 layers = [] 323 h, w = 4, 4 324 for i, kernel in enumerate(self._kernels): 325 depth = self._embed_size // 16 // (2 ** (i + 1)) 326 act = self._act 327 bias = False 328 initializer = weight_init 329 if i == len(self._kernels) - 1: 330 depth = self._shape[0] 331 act = False 332 bias = True 333 norm = False 334 initializer = uniform_weight_init(outscale) 335 336 if i != 0: 337 inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth 338 pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1) 339 pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1) 340 layers.append( 341 nn.ConvTranspose2d( 342 inp_dim, 343 depth, 344 kernel, 345 2, 346 padding=(pad_h, pad_w), 347 output_padding=(outpad_h, outpad_w), 348 bias=bias, 349 ) 350 ) 351 if norm: 352 layers.append(DreamerLayerNorm(depth)) 353 if act: 354 layers.append(act()) 355 [m.apply(initializer) for m in layers[-3:]] 356 h, w = h * 2, w * 2 357 358 self.layers = nn.Sequential(*layers) 359 360 def calc_same_pad(self, k, s, d): 361 val = d * (k - 1) - s + 1 362 pad = math.ceil(val / 2) 363 outpad = pad * 2 - val 364 return pad, outpad 365 366 def __call__(self, features): 367 x = self._linear_layer(features) # feature:[batch, time, stoch*discrete + deter] 368 x = x.reshape([-1, 4, 4, self._embed_size // 16]) 369 x = x.permute(0, 3, 1, 2) 370 x = self.layers(x) 371 mean = x.reshape(list(features.shape[:-1]) + self._shape) 372 #mean = mean.permute(0, 1, 3, 4, 2) 373 return SymlogDist(mean) 374 375 376class GRUCell(nn.Module): 377 378 def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1): 379 super(GRUCell, self).__init__() 380 self._inp_size = inp_size # hidden 381 self._size = size # deter 382 self._act = act 383 self._norm = norm 384 self._update_bias = update_bias 385 self._layer = nn.Linear(inp_size + size, 3 * size, bias=False) 386 if norm: 387 self._norm = nn.LayerNorm(3 * size, eps=1e-03) 388 389 @property 390 def state_size(self): 391 return self._size 392 393 def forward(self, inputs, state): 394 state = state[0] # Keras wraps the state in a list. 395 parts = self._layer(torch.cat([inputs, state], -1)) 396 if self._norm: 397 parts = self._norm(parts) 398 reset, cand, update = torch.split(parts, [self._size] * 3, -1) 399 reset = torch.sigmoid(reset) 400 cand = self._act(reset * cand) 401 update = torch.sigmoid(update + self._update_bias) 402 output = update * cand + (1 - update) * state 403 return output, [output]