ding.world_model.model.ensemble¶
ding.world_model.model.ensemble
¶
Full Source Code
../ding/world_model/model/ensemble.py
1import numpy as np 2import torch 3import torch.nn as nn 4import torch.nn.functional as F 5 6from ding.torch_utils import Swish 7 8 9class StandardScaler(nn.Module): 10 11 def __init__(self, input_size: int): 12 super(StandardScaler, self).__init__() 13 self.register_buffer('std', torch.ones(1, input_size)) 14 self.register_buffer('mu', torch.zeros(1, input_size)) 15 16 def fit(self, data: torch.Tensor): 17 std, mu = torch.std_mean(data, dim=0, keepdim=True) 18 std[std < 1e-12] = 1 19 self.std.data.mul_(0.0).add_(std) 20 self.mu.data.mul_(0.0).add_(mu) 21 22 def transform(self, data: torch.Tensor): 23 return (data - self.mu) / self.std 24 25 def inverse_transform(self, data: torch.Tensor): 26 return self.std * data + self.mu 27 28 29class EnsembleFC(nn.Module): 30 __constants__ = ['in_features', 'out_features'] 31 in_features: int 32 out_features: int 33 ensemble_size: int 34 weight: torch.Tensor 35 36 def __init__(self, in_features: int, out_features: int, ensemble_size: int, weight_decay: float = 0.) -> None: 37 super(EnsembleFC, self).__init__() 38 self.in_features = in_features 39 self.out_features = out_features 40 self.ensemble_size = ensemble_size 41 self.weight = nn.Parameter(torch.zeros(ensemble_size, in_features, out_features)) 42 self.weight_decay = weight_decay 43 self.bias = nn.Parameter(torch.zeros(ensemble_size, 1, out_features)) 44 45 def forward(self, input: torch.Tensor) -> torch.Tensor: 46 assert input.shape[0] == self.ensemble_size and len(input.shape) == 3 47 return torch.bmm(input, self.weight) + self.bias # w times x + b 48 49 def extra_repr(self) -> str: 50 return 'in_features={}, out_features={}, ensemble_size={}, weight_decay={}'.format( 51 self.in_features, self.out_features, self.ensemble_size, self.weight_decay 52 ) 53 54 55class EnsembleModel(nn.Module): 56 57 def __init__( 58 self, 59 state_size, 60 action_size, 61 reward_size, 62 ensemble_size, 63 hidden_size=200, 64 learning_rate=1e-3, 65 use_decay=False 66 ): 67 super(EnsembleModel, self).__init__() 68 69 self.use_decay = use_decay 70 self.hidden_size = hidden_size 71 self.output_dim = state_size + reward_size 72 73 self.nn1 = EnsembleFC(state_size + action_size, hidden_size, ensemble_size, weight_decay=0.000025) 74 self.nn2 = EnsembleFC(hidden_size, hidden_size, ensemble_size, weight_decay=0.00005) 75 self.nn3 = EnsembleFC(hidden_size, hidden_size, ensemble_size, weight_decay=0.000075) 76 self.nn4 = EnsembleFC(hidden_size, hidden_size, ensemble_size, weight_decay=0.000075) 77 self.nn5 = EnsembleFC(hidden_size, self.output_dim * 2, ensemble_size, weight_decay=0.0001) 78 self.max_logvar = nn.Parameter(torch.ones(1, self.output_dim).float() * 0.5, requires_grad=False) 79 self.min_logvar = nn.Parameter(torch.ones(1, self.output_dim).float() * -10, requires_grad=False) 80 self.swish = Swish() 81 82 def init_weights(m: nn.Module): 83 84 def truncated_normal_init(t, mean: float = 0.0, std: float = 0.01): 85 torch.nn.init.normal_(t, mean=mean, std=std) 86 while True: 87 cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std) 88 if not torch.sum(cond): 89 break 90 t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t) 91 return t 92 93 if isinstance(m, nn.Linear) or isinstance(m, EnsembleFC): 94 input_dim = m.in_features 95 truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(input_dim))) 96 m.bias.data.fill_(0.0) 97 98 self.apply(init_weights) 99 100 self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate) 101 102 def forward(self, x: torch.Tensor, ret_log_var: bool = False): 103 x = self.swish(self.nn1(x)) 104 x = self.swish(self.nn2(x)) 105 x = self.swish(self.nn3(x)) 106 x = self.swish(self.nn4(x)) 107 x = self.nn5(x) 108 109 mean, logvar = x.chunk(2, dim=2) 110 logvar = self.max_logvar - F.softplus(self.max_logvar - logvar) 111 logvar = self.min_logvar + F.softplus(logvar - self.min_logvar) 112 113 if ret_log_var: 114 return mean, logvar 115 else: 116 return mean, torch.exp(logvar) 117 118 def get_decay_loss(self): 119 decay_loss = 0. 120 for m in self.modules(): 121 if isinstance(m, EnsembleFC): 122 decay_loss += m.weight_decay * torch.sum(torch.square(m.weight)) / 2. 123 return decay_loss 124 125 def loss(self, mean: torch.Tensor, logvar: torch.Tensor, labels: torch.Tensor): 126 """ 127 mean, logvar: Ensemble_size x N x dim 128 labels: Ensemble_size x N x dim 129 """ 130 assert len(mean.shape) == len(logvar.shape) == len(labels.shape) == 3 131 inv_var = torch.exp(-logvar) 132 # Average over batch and dim, sum over ensembles. 133 mse_loss_inv = (torch.pow(mean - labels, 2) * inv_var).mean(dim=(1, 2)) 134 var_loss = logvar.mean(dim=(1, 2)) 135 with torch.no_grad(): 136 # Used only for logging. 137 mse_loss = torch.pow(mean - labels, 2).mean(dim=(1, 2)) 138 total_loss = mse_loss_inv.sum() + var_loss.sum() 139 return total_loss, mse_loss 140 141 def train(self, loss: torch.Tensor): 142 self.optimizer.zero_grad() 143 144 loss += 0.01 * torch.sum(self.max_logvar) - 0.01 * torch.sum(self.min_logvar) 145 if self.use_decay: 146 loss += self.get_decay_loss() 147 148 loss.backward() 149 150 self.optimizer.step()