Skip to content

ding.model.template.vae

ding.model.template.vae

Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE

VanillaVAE

Bases: Module

Overview

Implementation of Vanilla variational autoencoder for action reconstruction.

Interfaces: __init__, encode, decode, decode_with_obs, reparameterize, forward, loss_function .

encode(input)

Overview

Encodes the input by passing through the encoder network and returns the latent codes.

Arguments: - input (:obj:Dict): Dict containing keywords obs (:obj:torch.Tensor) and action (:obj:torch.Tensor), representing the observation and agent's action respectively. Returns: - outputs (:obj:Dict): Dict containing keywords mu (:obj:torch.Tensor), log_var (:obj:torch.Tensor) and obs_encoding (:obj:torch.Tensor) representing latent codes. Shapes: - obs (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is observation dim. - action (:obj:torch.Tensor): :math:(B, A), where B is batch size and A is action dim. - mu (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size. - log_var (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size. - obs_encoding (:obj:torch.Tensor): :math:(B, H), where B is batch size and H is hidden dim.

decode(z, obs_encoding)

Overview

Maps the given latent action and obs_encoding onto the original action space.

Arguments: - z (:obj:torch.Tensor): the sampled latent action - obs_encoding (:obj:torch.Tensor): observation encoding Returns: - outputs (:obj:Dict): DQN forward outputs, such as q_value. ReturnsKeys: - reconstruction_action (:obj:torch.Tensor): reconstruction_action. - predition_residual (:obj:torch.Tensor): predition_residual. Shapes: - z (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent_size - obs_encoding (:obj:torch.Tensor): :math:(B, H), where B is batch size and H is hidden dim

decode_with_obs(z, obs)

Overview

Maps the given latent action and obs onto the original action space. Using the method self.encode_obs_head(obs) to get the obs_encoding.

Arguments: - z (:obj:torch.Tensor): the sampled latent action - obs (:obj:torch.Tensor): observation Returns: - outputs (:obj:Dict): DQN forward outputs, such as q_value. ReturnsKeys: - reconstruction_action (:obj:torch.Tensor): the action reconstructed by VAE . - predition_residual (:obj:torch.Tensor): the observation predicted by VAE. Shapes: - z (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent_size - obs (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is obs_shape

reparameterize(mu, logvar)

Overview

Reparameterization trick to sample from N(mu, var) from N(0,1).

Arguments: - mu (:obj:torch.Tensor): Mean of the latent Gaussian - logvar (:obj:torch.Tensor): Standard deviation of the latent Gaussian Shapes: - mu (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latnet_size - logvar (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latnet_size

forward(input, **kwargs)

Overview

Encode the input, reparameterize mu and log_var, decode obs_encoding.

Argumens: - input (:obj:Dict): Dict containing keywords obs (:obj:torch.Tensor) and action (:obj:torch.Tensor), representing the observation and agent's action respectively. Returns: - outputs (:obj:Dict): Dict containing keywords recons_action (:obj:torch.Tensor), prediction_residual (:obj:torch.Tensor), input (:obj:torch.Tensor), mu (:obj:torch.Tensor), log_var (:obj:torch.Tensor) and z (:obj:torch.Tensor). Shapes: - recons_action (:obj:torch.Tensor): :math:(B, A), where B is batch size and A is action dim. - prediction_residual (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is observation dim. - mu (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size. - log_var (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size. - z (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent_size

loss_function(args, **kwargs)

Overview

Computes the VAE loss function.

Arguments: - args (:obj:Dict[str, Tensor]): Dict containing keywords recons_action, prediction_residual original_action, mu, log_var and true_residual. - kwargs (:obj:Dict): Dict containing keywords kld_weight and predict_weight. Returns: - outputs (:obj:Dict[str, Tensor]): Dict containing different loss results, including loss, reconstruction_loss, kld_loss, predict_loss. Shapes: - recons_action (:obj:torch.Tensor): :math:(B, A), where B is batch size and A is action dim. - prediction_residual (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is observation dim. - original_action (:obj:torch.Tensor): :math:(B, A), where B is batch size and A is action dim. - mu (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size. - log_var (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size. - true_residual (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is observation dim.

Full Source Code

../ding/model/template/vae.py

1"""Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE""" 2 3import torch 4from torch.nn import functional as F 5from torch import nn 6from abc import abstractmethod 7from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple, Optional 8from ding.utils.type_helper import Tensor 9 10 11class VanillaVAE(nn.Module): 12 """ 13 Overview: 14 Implementation of Vanilla variational autoencoder for action reconstruction. 15 Interfaces: 16 ``__init__``, ``encode``, ``decode``, ``decode_with_obs``, ``reparameterize``, \ 17 ``forward``, ``loss_function`` . 18 """ 19 20 def __init__( 21 self, 22 action_shape: int, 23 obs_shape: int, 24 latent_size: int, 25 hidden_dims: List = [256, 256], 26 **kwargs 27 ) -> None: 28 super(VanillaVAE, self).__init__() 29 self.action_shape = action_shape 30 self.obs_shape = obs_shape 31 self.latent_size = latent_size 32 self.hidden_dims = hidden_dims 33 34 # Build Encoder 35 self.encode_action_head = nn.Sequential(nn.Linear(self.action_shape, hidden_dims[0]), nn.ReLU()) 36 self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_shape, hidden_dims[0]), nn.ReLU()) 37 38 self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[1]), nn.ReLU()) 39 self.encode_mu_head = nn.Linear(hidden_dims[1], latent_size) 40 self.encode_logvar_head = nn.Linear(hidden_dims[1], latent_size) 41 42 # Build Decoder 43 self.decode_action_head = nn.Sequential(nn.Linear(latent_size, hidden_dims[-1]), nn.ReLU()) 44 self.decode_common = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-2]), nn.ReLU()) 45 # TODO(pu): tanh 46 self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[-2], self.action_shape), nn.Tanh()) 47 48 # residual prediction 49 self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[-2], hidden_dims[-2]), nn.ReLU()) 50 self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[-2], self.obs_shape) 51 52 self.obs_encoding = None 53 54 def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]: 55 """ 56 Overview: 57 Encodes the input by passing through the encoder network and returns the latent codes. 58 Arguments: 59 - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) and \ 60 `action` (:obj:`torch.Tensor`), representing the observation and agent's action respectively. 61 Returns: 62 - outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`), \ 63 ``log_var`` (:obj:`torch.Tensor`) and ``obs_encoding`` (:obj:`torch.Tensor`) \ 64 representing latent codes. 65 Shapes: 66 - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. 67 - action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. 68 - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. 69 - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. 70 - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``. 71 """ 72 action_encoding = self.encode_action_head(input['action']) 73 obs_encoding = self.encode_obs_head(input['obs']) 74 # obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network 75 input = obs_encoding * action_encoding # TODO(pu): what about add, cat? 76 result = self.encode_common(input) 77 78 # Split the result into mu and var components 79 # of the latent Gaussian distribution 80 mu = self.encode_mu_head(result) 81 log_var = self.encode_logvar_head(result) 82 83 return {'mu': mu, 'log_var': log_var, 'obs_encoding': obs_encoding} 84 85 def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]: 86 """ 87 Overview: 88 Maps the given latent action and obs_encoding onto the original action space. 89 Arguments: 90 - z (:obj:`torch.Tensor`): the sampled latent action 91 - obs_encoding (:obj:`torch.Tensor`): observation encoding 92 Returns: 93 - outputs (:obj:`Dict`): DQN forward outputs, such as q_value. 94 ReturnsKeys: 95 - reconstruction_action (:obj:`torch.Tensor`): reconstruction_action. 96 - predition_residual (:obj:`torch.Tensor`): predition_residual. 97 Shapes: 98 - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` 99 - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim`` 100 """ 101 action_decoding = self.decode_action_head(torch.tanh(z)) # NOTE: tanh, here z is not bounded 102 action_obs_decoding = action_decoding * obs_encoding 103 action_obs_decoding_tmp = self.decode_common(action_obs_decoding) 104 105 reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) 106 predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) 107 predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) 108 return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} 109 110 def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]: 111 """ 112 Overview: 113 Maps the given latent action and obs onto the original action space. 114 Using the method self.encode_obs_head(obs) to get the obs_encoding. 115 Arguments: 116 - z (:obj:`torch.Tensor`): the sampled latent action 117 - obs (:obj:`torch.Tensor`): observation 118 Returns: 119 - outputs (:obj:`Dict`): DQN forward outputs, such as q_value. 120 ReturnsKeys: 121 - reconstruction_action (:obj:`torch.Tensor`): the action reconstructed by VAE . 122 - predition_residual (:obj:`torch.Tensor`): the observation predicted by VAE. 123 Shapes: 124 - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` 125 - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``obs_shape`` 126 """ 127 obs_encoding = self.encode_obs_head(obs) 128 # TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh 129 action_decoding = self.decode_action_head(z) 130 action_obs_decoding = action_decoding * obs_encoding 131 action_obs_decoding_tmp = self.decode_common(action_obs_decoding) 132 reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) 133 predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) 134 predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) 135 136 return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} 137 138 def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 139 """ 140 Overview: 141 Reparameterization trick to sample from N(mu, var) from N(0,1). 142 Arguments: 143 - mu (:obj:`torch.Tensor`): Mean of the latent Gaussian 144 - logvar (:obj:`torch.Tensor`): Standard deviation of the latent Gaussian 145 Shapes: 146 - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` 147 - logvar (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` 148 """ 149 std = torch.exp(0.5 * logvar) 150 eps = torch.randn_like(std) 151 return eps * std + mu 152 153 def forward(self, input: Dict[str, Tensor], **kwargs) -> dict: 154 """ 155 Overview: 156 Encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`. 157 Argumens: 158 - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) \ 159 and `action` (:obj:`torch.Tensor`), representing the observation \ 160 and agent's action respectively. 161 Returns: 162 - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \ 163 (:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \ 164 ``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \ 165 ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). 166 Shapes: 167 - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. 168 - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, \ 169 where B is batch size and O is ``observation dim``. 170 - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. 171 - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. 172 - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` 173 """ 174 175 encode_output = self.encode(input) 176 z = self.reparameterize(encode_output['mu'], encode_output['log_var']) 177 decode_output = self.decode(z, encode_output['obs_encoding']) 178 return { 179 'recons_action': decode_output['reconstruction_action'], 180 'prediction_residual': decode_output['predition_residual'], 181 'input': input, 182 'mu': encode_output['mu'], 183 'log_var': encode_output['log_var'], 184 'z': z 185 } 186 187 def loss_function(self, args: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: 188 """ 189 Overview: 190 Computes the VAE loss function. 191 Arguments: 192 - args (:obj:`Dict[str, Tensor]`): Dict containing keywords ``recons_action``, ``prediction_residual`` \ 193 ``original_action``, ``mu``, ``log_var`` and ``true_residual``. 194 - kwargs (:obj:`Dict`): Dict containing keywords ``kld_weight`` and ``predict_weight``. 195 Returns: 196 - outputs (:obj:`Dict[str, Tensor]`): Dict containing different ``loss`` results, including ``loss``, \ 197 ``reconstruction_loss``, ``kld_loss``, ``predict_loss``. 198 Shapes: 199 - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size \ 200 and A is ``action dim``. 201 - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size \ 202 and O is ``observation dim``. 203 - original_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. 204 - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. 205 - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. 206 - true_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. 207 """ 208 recons_action = args['recons_action'] 209 prediction_residual = args['prediction_residual'] 210 original_action = args['original_action'] 211 mu = args['mu'] 212 log_var = args['log_var'] 213 true_residual = args['true_residual'] 214 215 kld_weight = kwargs['kld_weight'] 216 predict_weight = kwargs['predict_weight'] 217 218 recons_loss = F.mse_loss(recons_action, original_action) 219 kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) 220 predict_loss = F.mse_loss(prediction_residual, true_residual) 221 222 loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss 223 return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss}