Skip to content

ding.model.template.edac

ding.model.template.edac

EDAC

Bases: Module

Overview

The Q-value Actor-Critic network with the ensemble mechanism, which is used in EDAC.

Interfaces: __init__, forward, compute_actor, compute_critic

__init__(obs_shape, action_shape, ensemble_num=2, actor_head_hidden_size=64, actor_head_layer_num=1, critic_head_hidden_size=64, critic_head_layer_num=1, activation=nn.ReLU(), norm_type=None, **kwargs)

Overview

Initailize the EDAC Model according to input arguments.

Arguments: - obs_shape (:obj:Union[int, SequenceType]): Observation's shape, such as 128, (156, ). - action_shape (:obj:Union[int, SequenceType, EasyDict]): Action's shape, such as 4, (3, ), EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). - ensemble_num (:obj:int): Q-net number. - actor_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to actor head. - actor_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for actor head. - critic_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to critic head. - critic_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for critic head. - activation (:obj:Optional[nn.Module]): The type of activation function to use in MLP after each FC layer, if None then default set to nn.ReLU(). - norm_type (:obj:Optional[str]): The type of normalization to after network layer (FC, Conv), see ding.torch_utils.network for more details.

forward(inputs, mode)

Overview

The unique execution (forward) method of EDAC method, and one can indicate different modes to implement different computation graph, including compute_actor and compute_critic in EDAC.

Mode compute_actor: Arguments: - inputs (:obj:torch.Tensor): Observation data, defaults to tensor. Returns: - output (:obj:Dict): Output dict data, including differnet key-values among distinct action_space. Mode compute_critic: Arguments: - inputs (:obj:Dict): Input dict data, including obs and action tensor. Returns: - output (:obj:Dict): Output dict data, including q_value tensor.

.. note:: For specific examples, one can refer to API doc of compute_actor and compute_critic respectively.

compute_actor(obs)

Overview

The forward computation graph of compute_actor mode, uses observation tensor to produce actor output, such as action, logit and so on.

Arguments: - obs (:obj:torch.Tensor): Observation tensor data, now supports a batch of 1-dim vector data, i.e. (B, obs_shape). Returns: - outputs (:obj:Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]): Actor output varying from action_space: reparameterization. ReturnsKeys (either): - logit (:obj:Dict[str, torch.Tensor]): Reparameterization logit, usually in SAC. - mu (:obj:torch.Tensor): Mean of parameterization gaussion distribution. - sigma (:obj:torch.Tensor): Standard variation of parameterization gaussion distribution. Shapes: - obs (:obj:torch.Tensor): :math:(B, N0), B is batch size and N0 corresponds to obs_shape. - action (:obj:torch.Tensor): :math:(B, N1), B is batch size and N1 corresponds to action_shape. - logit.mu (:obj:torch.Tensor): :math:(B, N1), B is batch size and N1 corresponds to action_shape. - logit.sigma (:obj:torch.Tensor): :math:(B, N1), B is batch size. - logit (:obj:torch.Tensor): :math:(B, N2), B is batch size and N2 corresponds to action_shape.action_type_shape. - action_args (:obj:torch.Tensor): :math:(B, N3), B is batch size and N3 corresponds to action_shape.action_args_shape. Examples: >>> model = EDAC(64, 64,) >>> obs = torch.randn(4, 64) >>> actor_outputs = model(obs,'compute_actor') >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu >>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma

compute_critic(inputs)

Overview

The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic output, such as q_value.

Arguments: - inputs (:obj:Dict[str, torch.Tensor]): Dict strcture of input data, including obs and action tensor Returns: - outputs (:obj:Dict[str, torch.Tensor]): Critic output, such as q_value. ArgumentsKeys: - obs: (:obj:torch.Tensor): Observation tensor data, now supports a batch of 1-dim vector data. - action (:obj:Union[torch.Tensor, Dict]): Continuous action with same size as action_shape. ReturnKeys: - q_value (:obj:torch.Tensor): Q value tensor with same size as batch size. Shapes: - obs (:obj:torch.Tensor): :math:(B, N1) or '(Ensemble_num, B, N1)', where B is batch size and N1 is obs_shape. - action (:obj:torch.Tensor): :math:(B, N2) or '(Ensemble_num, B, N2)', where B is batch size and N4 is action_shape. - q_value (:obj:torch.Tensor): :math:(Ensemble_num, B), where B is batch size. Examples: >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} >>> model = EDAC(obs_shape=(8, ),action_shape=1) >>> model(inputs, mode='compute_critic')['q_value'] # q value ... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=)

Full Source Code

../ding/model/template/edac.py

1from typing import Union, Optional, Dict 2from easydict import EasyDict 3 4import torch 5import torch.nn as nn 6from ding.model.common import ReparameterizationHead, EnsembleHead 7from ding.utils import SequenceType, squeeze 8 9from ding.utils import MODEL_REGISTRY 10 11 12@MODEL_REGISTRY.register('edac') 13class EDAC(nn.Module): 14 """ 15 Overview: 16 The Q-value Actor-Critic network with the ensemble mechanism, which is used in EDAC. 17 Interfaces: 18 ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` 19 """ 20 mode = ['compute_actor', 'compute_critic'] 21 22 def __init__( 23 self, 24 obs_shape: Union[int, SequenceType], 25 action_shape: Union[int, SequenceType, EasyDict], 26 ensemble_num: int = 2, 27 actor_head_hidden_size: int = 64, 28 actor_head_layer_num: int = 1, 29 critic_head_hidden_size: int = 64, 30 critic_head_layer_num: int = 1, 31 activation: Optional[nn.Module] = nn.ReLU(), 32 norm_type: Optional[str] = None, 33 **kwargs 34 ) -> None: 35 """ 36 Overview: 37 Initailize the EDAC Model according to input arguments. 38 Arguments: 39 - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). 40 - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \ 41 EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). 42 - ensemble_num (:obj:`int`): Q-net number. 43 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. 44 - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ 45 for actor head. 46 - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head. 47 - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ 48 for critic head. 49 - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ 50 after each FC layer, if ``None`` then default set to ``nn.ReLU()``. 51 - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ 52 see ``ding.torch_utils.network`` for more details. 53 """ 54 super(EDAC, self).__init__() 55 obs_shape: int = squeeze(obs_shape) 56 action_shape = squeeze(action_shape) 57 self.action_shape = action_shape 58 self.ensemble_num = ensemble_num 59 self.actor = nn.Sequential( 60 nn.Linear(obs_shape, actor_head_hidden_size), activation, 61 ReparameterizationHead( 62 actor_head_hidden_size, 63 action_shape, 64 actor_head_layer_num, 65 sigma_type='conditioned', 66 activation=activation, 67 norm_type=norm_type 68 ) 69 ) 70 71 critic_input_size = obs_shape + action_shape 72 self.critic = EnsembleHead( 73 critic_input_size, 74 1, 75 critic_head_hidden_size, 76 critic_head_layer_num, 77 self.ensemble_num, 78 activation=activation, 79 norm_type=norm_type 80 ) 81 82 def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]: 83 """ 84 Overview: 85 The unique execution (forward) method of EDAC method, and one can indicate different modes to implement \ 86 different computation graph, including ``compute_actor`` and ``compute_critic`` in EDAC. 87 Mode compute_actor: 88 Arguments: 89 - inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor. 90 Returns: 91 - output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space. 92 Mode compute_critic: 93 Arguments: 94 - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. 95 Returns: 96 - output (:obj:`Dict`): Output dict data, including q_value tensor. 97 98 .. note:: 99 For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively. 100 """ 101 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 102 return getattr(self, mode)(inputs) 103 104 def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: 105 """ 106 Overview: 107 The forward computation graph of compute_actor mode, uses observation tensor to produce actor output, 108 such as ``action``, ``logit`` and so on. 109 Arguments: 110 - obs (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data, \ 111 i.e. ``(B, obs_shape)``. 112 Returns: 113 - outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output varying \ 114 from action_space: ``reparameterization``. 115 ReturnsKeys (either): 116 - logit (:obj:`Dict[str, torch.Tensor]`): Reparameterization logit, usually in SAC. 117 - mu (:obj:`torch.Tensor`): Mean of parameterization gaussion distribution. 118 - sigma (:obj:`torch.Tensor`): Standard variation of parameterization gaussion distribution. 119 Shapes: 120 - obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``. 121 - action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. 122 - logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. 123 - logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size. 124 - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ 125 ``action_shape.action_type_shape``. 126 - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \ 127 ``action_shape.action_args_shape``. 128 Examples: 129 >>> model = EDAC(64, 64,) 130 >>> obs = torch.randn(4, 64) 131 >>> actor_outputs = model(obs,'compute_actor') 132 >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu 133 >>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma 134 """ 135 x = self.actor(obs) 136 return {'logit': [x['mu'], x['sigma']]} 137 138 def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 139 """ 140 Overview: 141 The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic 142 output, such as ``q_value``. 143 Arguments: 144 - inputs (:obj:`Dict[str, torch.Tensor]`): Dict strcture of input data, including ``obs`` and \ 145 ``action`` tensor 146 Returns: 147 - outputs (:obj:`Dict[str, torch.Tensor]`): Critic output, such as ``q_value``. 148 ArgumentsKeys: 149 - obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data. 150 - action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``. 151 ReturnKeys: 152 - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. 153 Shapes: 154 - obs (:obj:`torch.Tensor`): :math:`(B, N1)` or '(Ensemble_num, B, N1)', where B is batch size and N1 is \ 155 ``obs_shape``. 156 - action (:obj:`torch.Tensor`): :math:`(B, N2)` or '(Ensemble_num, B, N2)', where B is batch size and N4 \ 157 is ``action_shape``. 158 - q_value (:obj:`torch.Tensor`): :math:`(Ensemble_num, B)`, where B is batch size. 159 Examples: 160 >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} 161 >>> model = EDAC(obs_shape=(8, ),action_shape=1) 162 >>> model(inputs, mode='compute_critic')['q_value'] # q value 163 ... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>) 164 """ 165 166 obs, action = inputs['obs'], inputs['action'] 167 if len(action.shape) == 1: # (B, ) -> (B, 1) 168 action = action.unsqueeze(1) 169 x = torch.cat([obs, action], dim=-1) 170 if len(obs.shape) < 3: 171 # [batch_size,dim] -> [batch_size,Ensemble_num * dim,1] 172 x = x.repeat(1, self.ensemble_num).unsqueeze(-1) 173 else: 174 # [Ensemble_num,batch_size,dim] -> [batch_size,Ensemble_num,dim] -> [batch_size,Ensemble_num * dim, 1] 175 x = x.transpose(0, 1) 176 batch_size = obs.shape[1] 177 x = x.reshape(batch_size, -1, 1) 178 # [Ensemble_num,batch_size,1] 179 x = self.critic(x)['pred'] 180 # [batch_size,1*Ensemble_num] -> [Ensemble_num,batch_size] 181 x = x.permute(1, 0) 182 return {'q_value': x}