Skip to content

ding.world_model.idm

ding.world_model.idm

InverseDynamicsModel

Bases: Module

InverseDynamicsModel: infering missing action information from state transition. input and output: given pair of observation, return action (s0,s1 --> a0 if n=2)

__init__(obs_shape, action_shape, encoder_hidden_size_list=[60, 80, 100, 40], action_space='regression', activation=nn.LeakyReLU(), norm_type=None)

Overview

Init the Inverse Dynamics (encoder + head) Model according to input arguments.

Arguments: - obs_shape (:obj:Union[int, SequenceType]): Observation space shape, such as 8 or [4, 84, 84]. - action_shape (:obj:Union[int, SequenceType]): Action space shape, such as 6 or [2, 3, 3]. - encoder_hidden_size_list (:obj:SequenceType): Collection of hidden_size to pass to Encoder, \ the last element must match head_hidden_size. - action_space (:obj:String): Action space, such as 'regression', 'reparameterization', 'discrete'. - activation (:obj:Optional[nn.Module]): The type of activation function in networks \ if None then default set it to nn.LeakyReLU() refer to https://arxiv.org/abs/1805.01954 - norm_type (:obj:Optional[str]): The type of normalization in networks, see \ ding.torch_utils.fc_block for more details.

train(training_set, n_epoch, learning_rate, weight_decay)

Overview

Train idm model, given pair of states return action (s_t,s_t+1,a_t)

Parameters:

Name Type Description Default
- training_set (

obj:dict):states transition

required
- n_epoch (

obj:int): number of epoches

required
- learning_rate (

obj:float): learning rate for optimizer

required
- weight_decay (

obj:float): weight decay for optimizer

required

Full Source Code

../ding/world_model/idm.py

1import torch 2import torch.nn as nn 3from typing import Union, Optional, Dict 4import numpy as np 5 6from ding.model.common.head import DiscreteHead, RegressionHead, ReparameterizationHead 7from ding.utils import SequenceType, squeeze 8from ding.model.common.encoder import FCEncoder, ConvEncoder 9from torch.distributions import Independent, Normal 10 11 12class InverseDynamicsModel(nn.Module): 13 """ 14 InverseDynamicsModel: infering missing action information from state transition. 15 input and output: given pair of observation, return action (s0,s1 --> a0 if n=2) 16 """ 17 18 def __init__( 19 self, 20 obs_shape: Union[int, SequenceType], 21 action_shape: Union[int, SequenceType], 22 encoder_hidden_size_list: SequenceType = [60, 80, 100, 40], 23 action_space: str = "regression", 24 activation: Optional[nn.Module] = nn.LeakyReLU(), 25 norm_type: Optional[str] = None 26 ) -> None: 27 r""" 28 Overview: 29 Init the Inverse Dynamics (encoder + head) Model according to input arguments. 30 Arguments: 31 - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. 32 - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. 33 - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ 34 the last element must match ``head_hidden_size``. 35 - action_space (:obj:`String`): Action space, such as 'regression', 'reparameterization', 'discrete'. 36 - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ 37 if ``None`` then default set it to ``nn.LeakyReLU()`` refer to https://arxiv.org/abs/1805.01954 38 - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ 39 ``ding.torch_utils.fc_block`` for more details. 40 """ 41 super(InverseDynamicsModel, self).__init__() 42 # For compatibility: 1, (1, ), [4, 32, 32] 43 obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape) 44 # FC encoder: obs and obs[next] ,so input shape is obs_shape*2 45 if isinstance(obs_shape, int) or len(obs_shape) == 1: 46 self.encoder = FCEncoder( 47 obs_shape * 2, encoder_hidden_size_list, activation=activation, norm_type=norm_type 48 ) 49 elif len(obs_shape) == 3: 50 # FC encoder: obs and obs[next] ,so first channel need multiply 2 51 obs_shape = (obs_shape[0] * 2, *obs_shape[1:]) 52 self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) 53 else: 54 raise RuntimeError( 55 "not support obs_shape for pre-defined encoder: {}, please customize your own Model".format(obs_shape) 56 ) 57 self.action_space = action_space 58 assert self.action_space in ['regression', 'reparameterization', 59 'discrete'], "not supported action_space: {}".format(self.action_space) 60 if self.action_space == "regression": 61 self.header = RegressionHead( 62 encoder_hidden_size_list[-1], 63 action_shape, 64 final_tanh=False, 65 activation=activation, 66 norm_type=norm_type 67 ) 68 elif self.action_space == "reparameterization": 69 self.header = ReparameterizationHead( 70 encoder_hidden_size_list[-1], 71 action_shape, 72 sigma_type='conditioned', 73 activation=activation, 74 norm_type=norm_type 75 ) 76 elif self.action_space == "discrete": 77 self.header = DiscreteHead( 78 encoder_hidden_size_list[-1], action_shape, activation=activation, norm_type=norm_type 79 ) 80 81 def forward(self, x: torch.Tensor) -> Dict: 82 if self.action_space == "regression": 83 x = self.encoder(x) 84 x = self.header(x) 85 return {'action': x['pred']} 86 elif self.action_space == "reparameterization": 87 x = self.encoder(x) 88 x = self.header(x) 89 mu, sigma = x['mu'], x['sigma'] 90 dist = Independent(Normal(mu, sigma), 1) 91 pred = dist.rsample() 92 action = torch.tanh(pred) 93 return {'logit': [mu, sigma], 'action': action} 94 elif self.action_space == "discrete": 95 x = self.encoder(x) 96 x = self.header(x) 97 return x 98 99 def predict_action(self, x: torch.Tensor) -> Dict: 100 if self.action_space == "discrete": 101 res = nn.Softmax(dim=-1) 102 action = torch.argmax(res(self.forward(x)['logit']), -1) 103 return {'action': action} 104 else: 105 return self.forward(x) 106 107 def train(self, training_set: dict, n_epoch: int, learning_rate: float, weight_decay: float): 108 r""" 109 Overview: 110 Train idm model, given pair of states return action (s_t,s_t+1,a_t) 111 112 Arguments: 113 - training_set (:obj:`dict`):states transition 114 - n_epoch (:obj:`int`): number of epoches 115 - learning_rate (:obj:`float`): learning rate for optimizer 116 - weight_decay (:obj:`float`): weight decay for optimizer 117 """ 118 if self.action_space == "discrete": 119 criterion = nn.CrossEntropyLoss() 120 else: 121 # criterion = nn.MSELoss() 122 criterion = nn.L1Loss() 123 optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, weight_decay=weight_decay) 124 loss_list = [] 125 for itr in range(n_epoch): 126 data = training_set['obs'] 127 y = training_set['action'] 128 if self.action_space == "discrete": 129 y_pred = self.forward(data)['logit'] 130 else: 131 y_pred = self.forward(data)['action'] 132 loss = criterion(y_pred, y) 133 optimizer.zero_grad() 134 loss.backward() 135 optimizer.step() 136 loss_list.append(loss.item()) 137 loss = np.mean(loss_list) 138 return loss