Skip to content

ding.model.template.qvac

ding.model.template.qvac

ContinuousQVAC

Bases: Module

Overview

The neural network and computation graph of algorithms related to Actor-Critic that have both Q-value and V-value critic, such as IQL. This model now supports continuous and hybrid action space. The ContinuousQVAC is composed of four parts: actor_encoder, critic_encoder, actor_head and critic_head. Encoders are used to extract the feature. Heads are used to predict corresponding value or action logit. In high-dimensional observation space like 2D image, we often use a shared encoder for both actor_encoder and critic_encoder. In low-dimensional observation space like 1D vector, we often use different encoders.

Interfaces: __init__, forward, compute_actor, compute_critic

__init__(obs_shape, action_shape, action_space, twin_critic=False, actor_head_hidden_size=64, actor_head_layer_num=1, critic_head_hidden_size=64, critic_head_layer_num=1, activation=nn.SiLU(), norm_type=None, encoder_hidden_size_list=None, share_encoder=False)

Overview

Initailize the ContinuousQVAC Model according to input arguments.

Arguments: - obs_shape (:obj:Union[int, SequenceType]): Observation's shape, such as 128, (156, ). - action_shape (:obj:Union[int, SequenceType, EasyDict]): Action's shape, such as 4, (3, ), EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). - action_space (:obj:str): The type of action space, including [regression, reparameterization, hybrid], regression is used for DDPG/TD3, reparameterization is used for SAC and hybrid for PADDPG. - twin_critic (:obj:bool): Whether to use twin critic, one of tricks in TD3. - actor_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to actor head. - actor_head_layer_num (:obj:int): The num of layers used in the actor network to compute action. - critic_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to critic head. - critic_head_layer_num (:obj:int): The num of layers used in the critic network to compute Q-value. - activation (:obj:Optional[nn.Module]): The type of activation function to use in MLP after each FC layer, if None then default set to nn.ReLU(). - norm_type (:obj:Optional[str]): The type of normalization to after network layer (FC, Conv), see ding.torch_utils.network for more details. - encoder_hidden_size_list (:obj:SequenceType): Collection of hidden_size to pass to Encoder, the last element must match head_hidden_size, this argument is only used in image observation. - share_encoder (:obj:Optional[bool]): Whether to share encoder between actor and critic.

forward(inputs, mode)

Overview

QVAC forward computation graph, input observation tensor to predict Q-value or action logit. Different mode will forward with different network modules to get different outputs and save computation.

Arguments: - inputs (:obj:Union[torch.Tensor, Dict[str, torch.Tensor]]): The input data for forward computation graph, for compute_actor, it is the observation tensor, for compute_critic, it is the dict data including obs and action tensor. - mode (:obj:str): The forward mode, all the modes are defined in the beginning of this class. Returns: - output (:obj:Dict[str, torch.Tensor]): The output dict of QVAC forward computation graph, whose key-values vary in different forward modes. Examples (Actor): >>> # Regression mode >>> model = ContinuousQVAC(64, 6, 'regression') >>> obs = torch.randn(4, 64) >>> actor_outputs = model(obs,'compute_actor') >>> assert actor_outputs['action'].shape == torch.Size([4, 6]) >>> # Reparameterization Mode >>> model = ContinuousQVAC(64, 6, 'reparameterization') >>> obs = torch.randn(4, 64) >>> actor_outputs = model(obs,'compute_actor') >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma

Examples (Critic): >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} >>> model = ContinuousQVAC(obs_shape=(8, ),action_shape=1, action_space='regression') >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, ) # q value

compute_actor(obs)

Overview

QVAC forward computation graph for actor part, input observation tensor to predict action or action logit.

Arguments: - x (:obj:torch.Tensor): The input observation tensor data. Returns: - outputs (:obj:Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]): Actor output dict varying from action_space: regression, reparameterization, hybrid. ReturnsKeys (regression): - action (:obj:torch.Tensor): Continuous action with same size as action_shape, usually in DDPG/TD3. ReturnsKeys (reparameterization): - logit (:obj:Dict[str, torch.Tensor]): The predictd reparameterization action logit, usually in SAC. It is a list containing two tensors: mu and sigma. The former is the mean of the gaussian distribution, the latter is the standard deviation of the gaussian distribution. ReturnsKeys (hybrid): - logit (:obj:torch.Tensor): The predicted discrete action type logit, it will be the same dimension as action_type_shape, i.e., all the possible discrete action types. - action_args (:obj:torch.Tensor): Continuous action arguments with same size as action_args_shape. Shapes: - obs (:obj:torch.Tensor): :math:(B, N0), B is batch size and N0 corresponds to obs_shape. - action (:obj:torch.Tensor): :math:(B, N1), B is batch size and N1 corresponds to action_shape. - logit.mu (:obj:torch.Tensor): :math:(B, N1), B is batch size and N1 corresponds to action_shape. - logit.sigma (:obj:torch.Tensor): :math:(B, N1), B is batch size. - logit (:obj:torch.Tensor): :math:(B, N2), B is batch size and N2 corresponds to action_shape.action_type_shape. - action_args (:obj:torch.Tensor): :math:(B, N3), B is batch size and N3 corresponds to action_shape.action_args_shape. Examples: >>> # Regression mode >>> model = ContinuousQVAC(64, 6, 'regression') >>> obs = torch.randn(4, 64) >>> actor_outputs = model(obs,'compute_actor') >>> assert actor_outputs['action'].shape == torch.Size([4, 6]) >>> # Reparameterization Mode >>> model = ContinuousQVAC(64, 6, 'reparameterization') >>> obs = torch.randn(4, 64) >>> actor_outputs = model(obs,'compute_actor') >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma

compute_critic(inputs)

Overview

QVAC forward computation graph for critic part, input observation and action tensor to predict Q-value.

Arguments: - inputs (:obj:Dict[str, torch.Tensor]): The dict of input data, including obs and action tensor, also contains logit and action_args tensor in hybrid action_space. ArgumentsKeys: - obs: (:obj:torch.Tensor): Observation tensor data, now supports a batch of 1-dim vector data. - action (:obj:Union[torch.Tensor, Dict]): Continuous action with same size as action_shape. - logit (:obj:torch.Tensor): Discrete action logit, only in hybrid action_space. - action_args (:obj:torch.Tensor): Continuous action arguments, only in hybrid action_space. Returns: - outputs (:obj:Dict[str, torch.Tensor]): The output of QVAC's forward computation graph for critic, including q_value. ReturnKeys: - q_value (:obj:torch.Tensor): Q value tensor with same size as batch size. Shapes: - obs (:obj:torch.Tensor): :math:(B, N1), where B is batch size and N1 is obs_shape. - logit (:obj:torch.Tensor): :math:(B, N2), B is batch size and N2 corresponds to action_shape.action_type_shape. - action_args (:obj:torch.Tensor): :math:(B, N3), B is batch size and N3 corresponds to action_shape.action_args_shape. - action (:obj:torch.Tensor): :math:(B, N4), where B is batch size and N4 is action_shape. - q_value (:obj:torch.Tensor): :math:(B, ), where B is batch size.

Examples:

>>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
>>> model = ContinuousQVAC(obs_shape=(8, ),action_shape=1, action_space='regression')
>>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, )  # q value

Full Source Code

../ding/model/template/qvac.py

1from typing import Union, Dict, Optional 2from easydict import EasyDict 3import numpy as np 4import torch 5import torch.nn as nn 6 7from ding.utils import SequenceType, squeeze, MODEL_REGISTRY 8from ..common import RegressionHead, ReparameterizationHead, DiscreteHead, MultiHead, \ 9 FCEncoder, ConvEncoder 10 11 12@MODEL_REGISTRY.register('continuous_qvac') 13class ContinuousQVAC(nn.Module): 14 """ 15 Overview: 16 The neural network and computation graph of algorithms related to Actor-Critic that have both Q-value and \ 17 V-value critic, such as IQL. This model now supports continuous and hybrid action space. The ContinuousQVAC is \ 18 composed of four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders \ 19 are used to extract the feature. Heads are used to predict corresponding value or action logit. 20 In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \ 21 and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders. 22 Interfaces: 23 ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` 24 """ 25 mode = ['compute_actor', 'compute_critic'] 26 27 def __init__( 28 self, 29 obs_shape: Union[int, SequenceType], 30 action_shape: Union[int, SequenceType, EasyDict], 31 action_space: str, 32 twin_critic: bool = False, 33 actor_head_hidden_size: int = 64, 34 actor_head_layer_num: int = 1, 35 critic_head_hidden_size: int = 64, 36 critic_head_layer_num: int = 1, 37 activation: Optional[nn.Module] = nn.SiLU(), 38 norm_type: Optional[str] = None, 39 encoder_hidden_size_list: Optional[SequenceType] = None, 40 share_encoder: Optional[bool] = False, 41 ) -> None: 42 """ 43 Overview: 44 Initailize the ContinuousQVAC Model according to input arguments. 45 Arguments: 46 - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). 47 - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \ 48 EasyDict({'action_type_shape': 3, 'action_args_shape': 4}). 49 - action_space (:obj:`str`): The type of action space, including [``regression``, ``reparameterization``, \ 50 ``hybrid``], ``regression`` is used for DDPG/TD3, ``reparameterization`` is used for SAC and \ 51 ``hybrid`` for PADDPG. 52 - twin_critic (:obj:`bool`): Whether to use twin critic, one of tricks in TD3. 53 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head. 54 - actor_head_layer_num (:obj:`int`): The num of layers used in the actor network to compute action. 55 - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head. 56 - critic_head_layer_num (:obj:`int`): The num of layers used in the critic network to compute Q-value. 57 - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ 58 after each FC layer, if ``None`` then default set to ``nn.ReLU()``. 59 - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ 60 see ``ding.torch_utils.network`` for more details. 61 - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ 62 the last element must match ``head_hidden_size``, this argument is only used in image observation. 63 - share_encoder (:obj:`Optional[bool]`): Whether to share encoder between actor and critic. 64 """ 65 super(ContinuousQVAC, self).__init__() 66 obs_shape: int = squeeze(obs_shape) 67 action_shape = squeeze(action_shape) 68 self.action_shape = action_shape 69 self.action_space = action_space 70 assert self.action_space in ['regression', 'reparameterization', 'hybrid'], self.action_space 71 72 # encoder 73 self.share_encoder = share_encoder 74 if np.isscalar(obs_shape) or len(obs_shape) == 1: 75 assert not self.share_encoder, "Vector observation doesn't need share encoder." 76 assert encoder_hidden_size_list is None, "Vector obs encoder only uses one layer nn.Linear" 77 # Because there is already a layer nn.Linear in the head, so we use nn.Identity here to keep 78 # compatible with the image observation and avoid adding an extra layer nn.Linear. 79 self.actor_encoder = nn.Identity() 80 self.critic_encoder = nn.Identity() 81 encoder_output_size = obs_shape 82 elif len(obs_shape) == 3: 83 84 def setup_conv_encoder(): 85 kernel_size = [3 for _ in range(len(encoder_hidden_size_list))] 86 stride = [2] + [1 for _ in range(len(encoder_hidden_size_list) - 1)] 87 return ConvEncoder( 88 obs_shape, 89 encoder_hidden_size_list, 90 activation=activation, 91 norm_type=norm_type, 92 kernel_size=kernel_size, 93 stride=stride 94 ) 95 96 if self.share_encoder: 97 encoder = setup_conv_encoder() 98 self.actor_encoder = self.critic_encoder = encoder 99 else: 100 self.actor_encoder = setup_conv_encoder() 101 self.critic_encoder = setup_conv_encoder() 102 encoder_output_size = self.actor_encoder.output_size 103 else: 104 raise RuntimeError("not support observation shape: {}".format(obs_shape)) 105 # head 106 if self.action_space == 'regression': # DDPG, TD3 107 self.actor_head = nn.Sequential( 108 nn.Linear(encoder_output_size, actor_head_hidden_size), activation, 109 RegressionHead( 110 actor_head_hidden_size, 111 action_shape, 112 actor_head_layer_num, 113 final_tanh=True, 114 activation=activation, 115 norm_type=norm_type 116 ) 117 ) 118 elif self.action_space == 'reparameterization': # SAC 119 self.actor_head = nn.Sequential( 120 nn.Linear(encoder_output_size, actor_head_hidden_size), activation, 121 ReparameterizationHead( 122 actor_head_hidden_size, 123 action_shape, 124 actor_head_layer_num, 125 sigma_type='conditioned', 126 activation=activation, 127 norm_type=norm_type 128 ) 129 ) 130 elif self.action_space == 'hybrid': # PADDPG 131 # hybrid action space: action_type(discrete) + action_args(continuous), 132 # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} 133 action_shape.action_args_shape = squeeze(action_shape.action_args_shape) 134 action_shape.action_type_shape = squeeze(action_shape.action_type_shape) 135 actor_action_args = nn.Sequential( 136 nn.Linear(encoder_output_size, actor_head_hidden_size), activation, 137 RegressionHead( 138 actor_head_hidden_size, 139 action_shape.action_args_shape, 140 actor_head_layer_num, 141 final_tanh=True, 142 activation=activation, 143 norm_type=norm_type 144 ) 145 ) 146 actor_action_type = nn.Sequential( 147 nn.Linear(encoder_output_size, actor_head_hidden_size), activation, 148 DiscreteHead( 149 actor_head_hidden_size, 150 action_shape.action_type_shape, 151 actor_head_layer_num, 152 activation=activation, 153 norm_type=norm_type, 154 ) 155 ) 156 self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) 157 158 self.twin_critic = twin_critic 159 if self.action_space == 'hybrid': 160 critic_q_input_size = encoder_output_size + action_shape.action_type_shape + action_shape.action_args_shape 161 critic_v_input_size = encoder_output_size 162 else: 163 critic_q_input_size = encoder_output_size + action_shape 164 critic_v_input_size = encoder_output_size 165 if self.twin_critic: 166 self.critic_q_head = nn.ModuleList() 167 self.critic_v_head = nn.ModuleList() 168 for _ in range(2): 169 self.critic_q_head.append( 170 nn.Sequential( 171 nn.Linear(critic_q_input_size, critic_head_hidden_size), activation, 172 RegressionHead( 173 critic_head_hidden_size, 174 1, 175 critic_head_layer_num, 176 final_tanh=False, 177 activation=activation, 178 norm_type=norm_type 179 ) 180 ) 181 ) 182 self.critic_v_head = nn.Sequential( 183 nn.Linear(critic_v_input_size, critic_head_hidden_size), activation, 184 RegressionHead( 185 critic_head_hidden_size, 186 1, 187 critic_head_layer_num, 188 final_tanh=False, 189 activation=activation, 190 norm_type=norm_type 191 ) 192 ) 193 else: 194 self.critic_q_head = nn.Sequential( 195 nn.Linear(critic_q_input_size, critic_head_hidden_size), activation, 196 RegressionHead( 197 critic_head_hidden_size, 198 1, 199 critic_head_layer_num, 200 final_tanh=False, 201 activation=activation, 202 norm_type=norm_type 203 ) 204 ) 205 self.critic_v_head = nn.Sequential( 206 nn.Linear(critic_v_input_size, critic_head_hidden_size), activation, 207 RegressionHead( 208 critic_head_hidden_size, 209 1, 210 critic_head_layer_num, 211 final_tanh=False, 212 activation=activation, 213 norm_type=norm_type 214 ) 215 ) 216 217 # Convenient for calling some apis (e.g. self.critic.parameters()), 218 # but may cause misunderstanding when `print(self)` 219 self.actor = nn.ModuleList([self.actor_encoder, self.actor_head]) 220 self.critic = nn.ModuleList([self.critic_encoder, self.critic_q_head, self.critic_v_head]) 221 222 def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]: 223 """ 224 Overview: 225 QVAC forward computation graph, input observation tensor to predict Q-value or action logit. Different \ 226 ``mode`` will forward with different network modules to get different outputs and save computation. 227 Arguments: 228 - inputs (:obj:`Union[torch.Tensor, Dict[str, torch.Tensor]]`): The input data for forward computation \ 229 graph, for ``compute_actor``, it is the observation tensor, for ``compute_critic``, it is the \ 230 dict data including obs and action tensor. 231 - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. 232 Returns: 233 - output (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC forward computation graph, whose \ 234 key-values vary in different forward modes. 235 Examples (Actor): 236 >>> # Regression mode 237 >>> model = ContinuousQVAC(64, 6, 'regression') 238 >>> obs = torch.randn(4, 64) 239 >>> actor_outputs = model(obs,'compute_actor') 240 >>> assert actor_outputs['action'].shape == torch.Size([4, 6]) 241 >>> # Reparameterization Mode 242 >>> model = ContinuousQVAC(64, 6, 'reparameterization') 243 >>> obs = torch.randn(4, 64) 244 >>> actor_outputs = model(obs,'compute_actor') 245 >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu 246 >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma 247 248 Examples (Critic): 249 >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} 250 >>> model = ContinuousQVAC(obs_shape=(8, ),action_shape=1, action_space='regression') 251 >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, ) # q value 252 """ 253 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 254 return getattr(self, mode)(inputs) 255 256 def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: 257 """ 258 Overview: 259 QVAC forward computation graph for actor part, input observation tensor to predict action or action logit. 260 Arguments: 261 - x (:obj:`torch.Tensor`): The input observation tensor data. 262 Returns: 263 - outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output dict varying \ 264 from action_space: ``regression``, ``reparameterization``, ``hybrid``. 265 ReturnsKeys (regression): 266 - action (:obj:`torch.Tensor`): Continuous action with same size as ``action_shape``, usually in DDPG/TD3. 267 ReturnsKeys (reparameterization): 268 - logit (:obj:`Dict[str, torch.Tensor]`): The predictd reparameterization action logit, usually in SAC. \ 269 It is a list containing two tensors: ``mu`` and ``sigma``. The former is the mean of the gaussian \ 270 distribution, the latter is the standard deviation of the gaussian distribution. 271 ReturnsKeys (hybrid): 272 - logit (:obj:`torch.Tensor`): The predicted discrete action type logit, it will be the same dimension \ 273 as ``action_type_shape``, i.e., all the possible discrete action types. 274 - action_args (:obj:`torch.Tensor`): Continuous action arguments with same size as ``action_args_shape``. 275 Shapes: 276 - obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``. 277 - action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. 278 - logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``. 279 - logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size. 280 - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ 281 ``action_shape.action_type_shape``. 282 - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \ 283 ``action_shape.action_args_shape``. 284 Examples: 285 >>> # Regression mode 286 >>> model = ContinuousQVAC(64, 6, 'regression') 287 >>> obs = torch.randn(4, 64) 288 >>> actor_outputs = model(obs,'compute_actor') 289 >>> assert actor_outputs['action'].shape == torch.Size([4, 6]) 290 >>> # Reparameterization Mode 291 >>> model = ContinuousQVAC(64, 6, 'reparameterization') 292 >>> obs = torch.randn(4, 64) 293 >>> actor_outputs = model(obs,'compute_actor') 294 >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu 295 >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma 296 """ 297 obs = self.actor_encoder(obs) 298 if self.action_space == 'regression': 299 x = self.actor_head(obs) 300 return {'action': x['pred']} 301 elif self.action_space == 'reparameterization': 302 x = self.actor_head(obs) 303 return {'logit': [x['mu'], x['sigma']]} 304 elif self.action_space == 'hybrid': 305 logit = self.actor_head[0](obs) 306 action_args = self.actor_head[1](obs) 307 return {'logit': logit['logit'], 'action_args': action_args['pred']} 308 309 def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 310 """ 311 Overview: 312 QVAC forward computation graph for critic part, input observation and action tensor to predict Q-value. 313 Arguments: 314 - inputs (:obj:`Dict[str, torch.Tensor]`): The dict of input data, including ``obs`` and ``action`` \ 315 tensor, also contains ``logit`` and ``action_args`` tensor in hybrid action_space. 316 ArgumentsKeys: 317 - obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data. 318 - action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``. 319 - logit (:obj:`torch.Tensor`): Discrete action logit, only in hybrid action_space. 320 - action_args (:obj:`torch.Tensor`): Continuous action arguments, only in hybrid action_space. 321 Returns: 322 - outputs (:obj:`Dict[str, torch.Tensor]`): The output of QVAC's forward computation graph for critic, \ 323 including ``q_value``. 324 ReturnKeys: 325 - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. 326 Shapes: 327 - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``. 328 - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \ 329 ``action_shape.action_type_shape``. 330 - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \ 331 ``action_shape.action_args_shape``. 332 - action (:obj:`torch.Tensor`): :math:`(B, N4)`, where B is batch size and N4 is ``action_shape``. 333 - q_value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size. 334 335 Examples: 336 >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)} 337 >>> model = ContinuousQVAC(obs_shape=(8, ),action_shape=1, action_space='regression') 338 >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, ) # q value 339 """ 340 341 obs, action = inputs['obs'], inputs['action'] 342 obs = self.critic_encoder(obs) 343 assert len(obs.shape) == 2 344 if self.action_space == 'hybrid': 345 action_type_logit = inputs['logit'] 346 action_type_logit = torch.softmax(action_type_logit, dim=-1) 347 action_args = action['action_args'] 348 if len(action_args.shape) == 1: 349 action_args = action_args.unsqueeze(1) 350 x = torch.cat([obs, action_type_logit, action_args], dim=1) 351 else: 352 if len(action.shape) == 1: # (B, ) -> (B, 1) 353 action = action.unsqueeze(1) 354 x = torch.cat([obs, action], dim=1) 355 if self.twin_critic: 356 x = [m(x)['pred'] for m in self.critic_q_head] 357 y = self.critic_v_head(obs)['pred'] 358 else: 359 x = self.critic_q_head(x)['pred'] 360 y = self.critic_v_head(obs)['pred'] 361 return {'q_value': x, 'v_value': y}