ding.model.template.vae¶
ding.model.template.vae
¶
Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE
VanillaVAE
¶
Bases: Module
Overview
Implementation of Vanilla variational autoencoder for action reconstruction.
Interfaces:
__init__, encode, decode, decode_with_obs, reparameterize, forward, loss_function .
encode(input)
¶
Overview
Encodes the input by passing through the encoder network and returns the latent codes.
Arguments:
- input (:obj:Dict): Dict containing keywords obs (:obj:torch.Tensor) and action (:obj:torch.Tensor), representing the observation and agent's action respectively.
Returns:
- outputs (:obj:Dict): Dict containing keywords mu (:obj:torch.Tensor), log_var (:obj:torch.Tensor) and obs_encoding (:obj:torch.Tensor) representing latent codes.
Shapes:
- obs (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is observation dim.
- action (:obj:torch.Tensor): :math:(B, A), where B is batch size and A is action dim.
- mu (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size.
- log_var (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size.
- obs_encoding (:obj:torch.Tensor): :math:(B, H), where B is batch size and H is hidden dim.
decode(z, obs_encoding)
¶
Overview
Maps the given latent action and obs_encoding onto the original action space.
Arguments:
- z (:obj:torch.Tensor): the sampled latent action
- obs_encoding (:obj:torch.Tensor): observation encoding
Returns:
- outputs (:obj:Dict): DQN forward outputs, such as q_value.
ReturnsKeys:
- reconstruction_action (:obj:torch.Tensor): reconstruction_action.
- predition_residual (:obj:torch.Tensor): predition_residual.
Shapes:
- z (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent_size
- obs_encoding (:obj:torch.Tensor): :math:(B, H), where B is batch size and H is hidden dim
decode_with_obs(z, obs)
¶
Overview
Maps the given latent action and obs onto the original action space. Using the method self.encode_obs_head(obs) to get the obs_encoding.
Arguments:
- z (:obj:torch.Tensor): the sampled latent action
- obs (:obj:torch.Tensor): observation
Returns:
- outputs (:obj:Dict): DQN forward outputs, such as q_value.
ReturnsKeys:
- reconstruction_action (:obj:torch.Tensor): the action reconstructed by VAE .
- predition_residual (:obj:torch.Tensor): the observation predicted by VAE.
Shapes:
- z (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent_size
- obs (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is obs_shape
reparameterize(mu, logvar)
¶
Overview
Reparameterization trick to sample from N(mu, var) from N(0,1).
Arguments:
- mu (:obj:torch.Tensor): Mean of the latent Gaussian
- logvar (:obj:torch.Tensor): Standard deviation of the latent Gaussian
Shapes:
- mu (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latnet_size
- logvar (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latnet_size
forward(input, **kwargs)
¶
Overview
Encode the input, reparameterize mu and log_var, decode obs_encoding.
Argumens:
- input (:obj:Dict): Dict containing keywords obs (:obj:torch.Tensor) and action (:obj:torch.Tensor), representing the observation and agent's action respectively.
Returns:
- outputs (:obj:Dict): Dict containing keywords recons_action (:obj:torch.Tensor), prediction_residual (:obj:torch.Tensor), input (:obj:torch.Tensor), mu (:obj:torch.Tensor), log_var (:obj:torch.Tensor) and z (:obj:torch.Tensor).
Shapes:
- recons_action (:obj:torch.Tensor): :math:(B, A), where B is batch size and A is action dim.
- prediction_residual (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is observation dim.
- mu (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size.
- log_var (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size.
- z (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent_size
loss_function(args, **kwargs)
¶
Overview
Computes the VAE loss function.
Arguments:
- args (:obj:Dict[str, Tensor]): Dict containing keywords recons_action, prediction_residual original_action, mu, log_var and true_residual.
- kwargs (:obj:Dict): Dict containing keywords kld_weight and predict_weight.
Returns:
- outputs (:obj:Dict[str, Tensor]): Dict containing different loss results, including loss, reconstruction_loss, kld_loss, predict_loss.
Shapes:
- recons_action (:obj:torch.Tensor): :math:(B, A), where B is batch size and A is action dim.
- prediction_residual (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is observation dim.
- original_action (:obj:torch.Tensor): :math:(B, A), where B is batch size and A is action dim.
- mu (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size.
- log_var (:obj:torch.Tensor): :math:(B, L), where B is batch size and L is latent size.
- true_residual (:obj:torch.Tensor): :math:(B, O), where B is batch size and O is observation dim.
Full Source Code
../ding/model/template/vae.py