Skip to content

ding.model.template.vac

ding.model.template.vac

VAC

Bases: Module

Overview

The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC), such as A2C/PPO/IMPALA. This model now supports discrete, continuous and hybrid action space. The VAC is composed of four parts: actor_encoder, critic_encoder, actor_head and critic_head. Encoders are used to extract the feature from various observation. Heads are used to predict corresponding value or action logit. In high-dimensional observation space like 2D image, we often use a shared encoder for both actor_encoder and critic_encoder. In low-dimensional observation space like 1D vector, we often use different encoders.

Interfaces: __init__, forward, compute_actor, compute_critic, compute_actor_critic.

__init__(obs_shape, action_shape, action_space='discrete', share_encoder=True, encoder_hidden_size_list=[128, 128, 64], actor_head_hidden_size=64, actor_head_layer_num=1, critic_head_hidden_size=64, critic_head_layer_num=1, activation=nn.ReLU(), norm_type=None, sigma_type='independent', fixed_sigma_value=0.3, bound_type=None, encoder=None, impala_cnn_encoder=False)

Overview

Initialize the VAC model according to corresponding input arguments.

Arguments: - obs_shape (:obj:Union[int, SequenceType]): Observation space shape, such as 8 or [4, 84, 84]. - action_shape (:obj:Union[int, SequenceType]): Action space shape, such as 6 or [2, 3, 3]. - action_space (:obj:str): The type of different action spaces, including ['discrete', 'continuous', 'hybrid'], then will instantiate corresponding head, including DiscreteHead, ReparameterizationHead, and hybrid heads. - share_encoder (:obj:bool): Whether to share observation encoders between actor and decoder. - encoder_hidden_size_list (:obj:SequenceType): Collection of hidden_size to pass to Encoder, the last element is used as the input size of actor_head and critic_head. - actor_head_hidden_size (:obj:Optional[int]): The hidden_size of actor_head network, defaults to 64, it is the hidden size of the last layer of the actor_head network. - actor_head_layer_num (:obj:int): The num of layers used in the actor_head network to compute action. - critic_head_hidden_size (:obj:Optional[int]): The hidden_size of critic_head network, defaults to 64, it is the hidden size of the last layer of the critic_head network. - critic_head_layer_num (:obj:int): The num of layers used in the critic_head network. - activation (:obj:Optional[nn.Module]): The type of activation function in networks if None then default set it to nn.ReLU(). - norm_type (:obj:Optional[str]): The type of normalization in networks, see ding.torch_utils.fc_block for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN'] - sigma_type (:obj:Optional[str]): The type of sigma in continuous action space, see ding.torch_utils.network.dreamer.ReparameterizationHead for more details, in A2C/PPO, it defaults to independent, which means state-independent sigma parameters. - fixed_sigma_value (:obj:Optional[int]): If sigma_type is fixed, then use this value as sigma. - bound_type (:obj:Optional[str]): The type of action bound methods in continuous action space, defaults to None, which means no bound. - encoder (:obj:Optional[torch.nn.Module]): The encoder module, defaults to None, you can define your own encoder module and pass it into VAC to deal with different observation space. - impala_cnn_encoder (:obj:bool): Whether to use IMPALA CNN encoder, defaults to False.

forward(x, mode)

Overview

VAC forward computation graph, input observation tensor to predict state value or action logit. Different mode will forward with different network modules to get different outputs and save computation.

Arguments: - x (:obj:torch.Tensor): The input observation tensor data. - mode (:obj:str): The forward mode, all the modes are defined in the beginning of this class. Returns: - outputs (:obj:Dict): The output dict of VAC's forward computation graph, whose key-values vary from different mode.

Examples (Actor): >>> model = VAC(64, 128) >>> inputs = torch.randn(4, 64) >>> actor_outputs = model(inputs,'compute_actor') >>> assert actor_outputs['logit'].shape == torch.Size([4, 128])

Examples (Critic): >>> model = VAC(64, 64) >>> inputs = torch.randn(4, 64) >>> critic_outputs = model(inputs,'compute_critic') >>> assert actor_outputs['logit'].shape == torch.Size([4, 64])

Examples (Actor-Critic): >>> model = VAC(64, 64) >>> inputs = torch.randn(4, 64) >>> outputs = model(inputs,'compute_actor_critic') >>> assert critic_outputs['value'].shape == torch.Size([4]) >>> assert outputs['logit'].shape == torch.Size([4, 64])

compute_actor(x)

Overview

VAC forward computation graph for actor part, input observation tensor to predict action logit.

Arguments: - x (:obj:Union[torch.Tensor, Dict]): The input observation tensor data. If a dictionary is provided, it should contain keys 'observation' and optionally 'action_mask'. Returns: - outputs (:obj:Dict): The output dict of VAC's forward computation graph for actor, including logit and optionally action_mask if the input is a dictionary. ReturnsKeys: - logit (:obj:torch.Tensor): The predicted action logit tensor, for discrete action space, it will be the same dimension real-value ranged tensor of possible action choices, and for continuous action space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the same as the number of continuous actions. Hybrid action space is a kind of combination of discrete and continuous action space, so the logit will be a dict with action_type and action_args. - action_mask (:obj:Optional[torch.Tensor]): The action mask tensor, included if the input is a dictionary containing 'action_mask'. Shapes: - logit (:obj:torch.Tensor): :math:(B, N), where B is batch size and N is action_shape

Examples:

>>> model = VAC(64, 64)
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['logit'].shape == torch.Size([4, 64])

compute_critic(x)

Overview

VAC forward computation graph for critic part, input observation tensor to predict state value.

Arguments: - x (:obj:Union[torch.Tensor, Dict]): The input observation tensor data. If a dictionary is provided, it should contain the key 'observation'. Returns: - outputs (:obj:Dict): The output dict of VAC's forward computation graph for critic, including value. ReturnsKeys: - value (:obj:torch.Tensor): The predicted state value tensor. Shapes: - value (:obj:torch.Tensor): :math:(B, ), where B is batch size, (B, 1) is squeezed to (B, ).

Examples:

>>> model = VAC(64, 64)
>>> inputs = torch.randn(4, 64)
>>> critic_outputs = model(inputs,'compute_critic')
>>> assert critic_outputs['value'].shape == torch.Size([4])

compute_actor_critic(x)

Overview

VAC forward computation graph for both actor and critic part, input observation tensor to predict action logit and state value.

Arguments: - x (:obj:Union[torch.Tensor, Dict]): The input observation tensor data. If a dictionary is provided, it should contain keys 'observation' and optionally 'action_mask'. Returns: - outputs (:obj:Dict): The output dict of VAC's forward computation graph for both actor and critic, including logit, value, and optionally action_mask if the input is a dictionary. ReturnsKeys: - logit (:obj:torch.Tensor): The predicted action logit tensor, for discrete action space, it will be the same dimension real-value ranged tensor of possible action choices, and for continuous action space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the same as the number of continuous actions. Hybrid action space is a kind of combination of discrete and continuous action space, so the logit will be a dict with action_type and action_args. - value (:obj:torch.Tensor): The predicted state value tensor. - action_mask (:obj:torch.Tensor, optional): The action mask tensor, included if the input is a dictionary containing 'action_mask'. Shapes: - logit (:obj:torch.Tensor): :math:(B, N), where B is batch size and N is action_shape - value (:obj:torch.Tensor): :math:(B, ), where B is batch size, (B, 1) is squeezed to (B, ).

Examples:

>>> model = VAC(64, 64)
>>> inputs = torch.randn(4, 64)
>>> outputs = model(inputs,'compute_actor_critic')
>>> assert critic_outputs['value'].shape == torch.Size([4])
>>> assert outputs['logit'].shape == torch.Size([4, 64])

.. note:: compute_actor_critic interface aims to save computation when shares encoder and return the combination dict output.

GTrXLVAC

Bases: Module

Overview

VAC-style actor-critic model with a GTrXL core. This model is intended for policies (e.g., VMPO/PPO variants) that use the VAC interfaces: compute_actor, compute_critic, compute_actor_critic.

Notes
  • By default, this model is used with memory_len=0 in on-policy pipelines, where sequence state is not tracked per environment in the policy.
  • It still runs observation features through GTrXL layers at every forward call.

DREAMERVAC

Bases: Module

Overview

The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC). This model now supports discrete, continuous action space.

Interfaces: __init__, forward.

__init__(action_shape, dyn_stoch=32, dyn_deter=512, dyn_discrete=32, actor_layers=2, value_layers=2, units=512, act='SiLU', norm='LayerNorm', actor_dist='normal', actor_init_std=1.0, actor_min_std=0.1, actor_max_std=1.0, actor_temp=0.1, action_unimix_ratio=0.01)

Overview

Initialize the DREAMERVAC model according to arguments.

Arguments: - obs_shape (:obj:Union[int, SequenceType]): Observation space shape, such as 8 or [4, 84, 84]. - action_shape (:obj:Union[int, SequenceType]): Action space shape, such as 6 or [2, 3, 3].

Full Source Code

../ding/model/template/vac.py

1from typing import Union, Dict, Optional 2from easydict import EasyDict 3import torch 4import torch.nn as nn 5from copy import deepcopy 6from ding.utils import SequenceType, squeeze, MODEL_REGISTRY 7from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ 8 FCEncoder, ConvEncoder, IMPALAConvEncoder 9from ding.torch_utils.network.dreamer import ActionHead, DenseHead 10from ding.torch_utils.network.gtrxl import GTrXL 11 12 13@MODEL_REGISTRY.register('vac') 14class VAC(nn.Module): 15 """ 16 Overview: 17 The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC), such as \ 18 A2C/PPO/IMPALA. This model now supports discrete, continuous and hybrid action space. The VAC is composed of \ 19 four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \ 20 extract the feature from various observation. Heads are used to predict corresponding value or action logit. \ 21 In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \ 22 and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders. 23 Interfaces: 24 ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. 25 """ 26 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 27 28 def __init__( 29 self, 30 obs_shape: Union[int, SequenceType], 31 action_shape: Union[int, SequenceType, EasyDict], 32 action_space: str = 'discrete', 33 share_encoder: bool = True, 34 encoder_hidden_size_list: SequenceType = [128, 128, 64], 35 actor_head_hidden_size: int = 64, 36 actor_head_layer_num: int = 1, 37 critic_head_hidden_size: int = 64, 38 critic_head_layer_num: int = 1, 39 activation: Optional[nn.Module] = nn.ReLU(), 40 norm_type: Optional[str] = None, 41 sigma_type: Optional[str] = 'independent', 42 fixed_sigma_value: Optional[int] = 0.3, 43 bound_type: Optional[str] = None, 44 encoder: Optional[torch.nn.Module] = None, 45 impala_cnn_encoder: bool = False, 46 ) -> None: 47 """ 48 Overview: 49 Initialize the VAC model according to corresponding input arguments. 50 Arguments: 51 - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. 52 - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. 53 - action_space (:obj:`str`): The type of different action spaces, including ['discrete', 'continuous', \ 54 'hybrid'], then will instantiate corresponding head, including ``DiscreteHead``, \ 55 ``ReparameterizationHead``, and hybrid heads. 56 - share_encoder (:obj:`bool`): Whether to share observation encoders between actor and decoder. 57 - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ 58 the last element is used as the input size of ``actor_head`` and ``critic_head``. 59 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \ 60 to 64, it is the hidden size of the last layer of the ``actor_head`` network. 61 - actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action. 62 - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \ 63 to 64, it is the hidden size of the last layer of the ``critic_head`` network. 64 - critic_head_layer_num (:obj:`int`): The num of layers used in the ``critic_head`` network. 65 - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ 66 if ``None`` then default set it to ``nn.ReLU()``. 67 - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ 68 ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN'] 69 - sigma_type (:obj:`Optional[str]`): The type of sigma in continuous action space, see \ 70 ``ding.torch_utils.network.dreamer.ReparameterizationHead`` for more details, in A2C/PPO, it defaults \ 71 to ``independent``, which means state-independent sigma parameters. 72 - fixed_sigma_value (:obj:`Optional[int]`): If ``sigma_type`` is ``fixed``, then use this value as sigma. 73 - bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \ 74 to ``None``, which means no bound. 75 - encoder (:obj:`Optional[torch.nn.Module]`): The encoder module, defaults to ``None``, you can define \ 76 your own encoder module and pass it into VAC to deal with different observation space. 77 - impala_cnn_encoder (:obj:`bool`): Whether to use IMPALA CNN encoder, defaults to ``False``. 78 """ 79 super(VAC, self).__init__() 80 obs_shape: int = squeeze(obs_shape) 81 action_shape = squeeze(action_shape) 82 self.obs_shape, self.action_shape = obs_shape, action_shape 83 self.impala_cnn_encoder = impala_cnn_encoder 84 self.share_encoder = share_encoder 85 86 # Encoder Type 87 def new_encoder(outsize, activation): 88 if impala_cnn_encoder: 89 return IMPALAConvEncoder(obs_shape=obs_shape, channels=encoder_hidden_size_list, outsize=outsize) 90 else: 91 if isinstance(obs_shape, int) or len(obs_shape) == 1: 92 return FCEncoder( 93 obs_shape=obs_shape, 94 hidden_size_list=encoder_hidden_size_list, 95 activation=activation, 96 norm_type=norm_type 97 ) 98 elif len(obs_shape) == 3: 99 return ConvEncoder( 100 obs_shape=obs_shape, 101 hidden_size_list=encoder_hidden_size_list, 102 activation=activation, 103 norm_type=norm_type 104 ) 105 else: 106 raise RuntimeError( 107 "not support obs_shape for pre-defined encoder: {}, please customize your own encoder". 108 format(obs_shape) 109 ) 110 111 if self.share_encoder: 112 if encoder: 113 if isinstance(encoder, torch.nn.Module): 114 self.encoder = encoder 115 else: 116 raise ValueError("illegal encoder instance.") 117 else: 118 self.encoder = new_encoder(encoder_hidden_size_list[-1], activation) 119 else: 120 if encoder: 121 if isinstance(encoder, torch.nn.Module): 122 self.actor_encoder = encoder 123 self.critic_encoder = deepcopy(encoder) 124 else: 125 raise ValueError("illegal encoder instance.") 126 else: 127 self.actor_encoder = new_encoder(encoder_hidden_size_list[-1], activation) 128 self.critic_encoder = new_encoder(encoder_hidden_size_list[-1], activation) 129 130 # Head Type 131 self.critic_head = RegressionHead( 132 encoder_hidden_size_list[-1], 133 1, 134 critic_head_layer_num, 135 activation=activation, 136 norm_type=norm_type, 137 hidden_size=critic_head_hidden_size 138 ) 139 self.action_space = action_space 140 assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space 141 if self.action_space == 'continuous': 142 self.multi_head = False 143 self.actor_head = ReparameterizationHead( 144 encoder_hidden_size_list[-1], 145 action_shape, 146 actor_head_layer_num, 147 sigma_type=sigma_type, 148 activation=activation, 149 norm_type=norm_type, 150 bound_type=bound_type, 151 hidden_size=actor_head_hidden_size, 152 ) 153 elif self.action_space == 'discrete': 154 actor_head_cls = DiscreteHead 155 multi_head = not isinstance(action_shape, int) 156 self.multi_head = multi_head 157 if multi_head: 158 self.actor_head = MultiHead( 159 actor_head_cls, 160 actor_head_hidden_size, 161 action_shape, 162 layer_num=actor_head_layer_num, 163 activation=activation, 164 norm_type=norm_type 165 ) 166 else: 167 self.actor_head = actor_head_cls( 168 actor_head_hidden_size, 169 action_shape, 170 actor_head_layer_num, 171 activation=activation, 172 norm_type=norm_type 173 ) 174 elif self.action_space == 'hybrid': # HPPO 175 # hybrid action space: action_type(discrete) + action_args(continuous), 176 # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} 177 action_shape.action_args_shape = squeeze(action_shape.action_args_shape) 178 action_shape.action_type_shape = squeeze(action_shape.action_type_shape) 179 actor_action_args = ReparameterizationHead( 180 encoder_hidden_size_list[-1], 181 action_shape.action_args_shape, 182 actor_head_layer_num, 183 sigma_type=sigma_type, 184 fixed_sigma_value=fixed_sigma_value, 185 activation=activation, 186 norm_type=norm_type, 187 bound_type=bound_type, 188 hidden_size=actor_head_hidden_size, 189 ) 190 actor_action_type = DiscreteHead( 191 actor_head_hidden_size, 192 action_shape.action_type_shape, 193 actor_head_layer_num, 194 activation=activation, 195 norm_type=norm_type, 196 ) 197 self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) 198 199 if self.share_encoder: 200 self.actor = [self.encoder, self.actor_head] 201 self.critic = [self.encoder, self.critic_head] 202 else: 203 self.actor = [self.actor_encoder, self.actor_head] 204 self.critic = [self.critic_encoder, self.critic_head] 205 # Convenient for calling some apis (e.g. self.critic.parameters()), 206 # but may cause misunderstanding when `print(self)` 207 self.actor = nn.ModuleList(self.actor) 208 self.critic = nn.ModuleList(self.critic) 209 210 def forward(self, x: torch.Tensor, mode: str) -> Dict: 211 """ 212 Overview: 213 VAC forward computation graph, input observation tensor to predict state value or action logit. Different \ 214 ``mode`` will forward with different network modules to get different outputs and save computation. 215 Arguments: 216 - x (:obj:`torch.Tensor`): The input observation tensor data. 217 - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. 218 Returns: 219 - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph, whose key-values vary from \ 220 different ``mode``. 221 222 Examples (Actor): 223 >>> model = VAC(64, 128) 224 >>> inputs = torch.randn(4, 64) 225 >>> actor_outputs = model(inputs,'compute_actor') 226 >>> assert actor_outputs['logit'].shape == torch.Size([4, 128]) 227 228 Examples (Critic): 229 >>> model = VAC(64, 64) 230 >>> inputs = torch.randn(4, 64) 231 >>> critic_outputs = model(inputs,'compute_critic') 232 >>> assert actor_outputs['logit'].shape == torch.Size([4, 64]) 233 234 Examples (Actor-Critic): 235 >>> model = VAC(64, 64) 236 >>> inputs = torch.randn(4, 64) 237 >>> outputs = model(inputs,'compute_actor_critic') 238 >>> assert critic_outputs['value'].shape == torch.Size([4]) 239 >>> assert outputs['logit'].shape == torch.Size([4, 64]) 240 241 """ 242 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 243 return getattr(self, mode)(x) 244 245 def compute_actor(self, x: Union[torch.Tensor, Dict]) -> Dict: 246 """ 247 Overview: 248 VAC forward computation graph for actor part, input observation tensor to predict action logit. 249 Arguments: 250 - x (:obj:`Union[torch.Tensor, Dict]`): The input observation tensor data. If a dictionary is provided, \ 251 it should contain keys 'observation' and optionally 'action_mask'. 252 Returns: 253 - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for actor, including ``logit`` \ 254 and optionally ``action_mask`` if the input is a dictionary. 255 ReturnsKeys: 256 - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ 257 the same dimension real-value ranged tensor of possible action choices, and for continuous action \ 258 space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ 259 same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \ 260 and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``. 261 - action_mask (:obj:`Optional[torch.Tensor]`): The action mask tensor, included if the input is a \ 262 dictionary containing 'action_mask'. 263 Shapes: 264 - logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` 265 266 Examples: 267 >>> model = VAC(64, 64) 268 >>> inputs = torch.randn(4, 64) 269 >>> actor_outputs = model(inputs,'compute_actor') 270 >>> assert actor_outputs['logit'].shape == torch.Size([4, 64]) 271 """ 272 if isinstance(x, dict): 273 action_mask = x['action_mask'] 274 x = self.encoder(x['observation']) if self.share_encoder else self.actor_encoder(x['observation']) 275 else: 276 action_mask = None 277 x = self.encoder(x) if self.share_encoder else self.actor_encoder(x) 278 279 if self.action_space == 'discrete': 280 result = {'logit': self.actor_head(x)['logit']} 281 if action_mask is not None: 282 result['action_mask'] = action_mask 283 return result 284 elif self.action_space == 'continuous': 285 x = self.actor_head(x) # mu, sigma 286 return {'logit': x} 287 elif self.action_space == 'hybrid': 288 action_type = self.actor_head[0](x) 289 action_args = self.actor_head[1](x) 290 return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}} 291 292 def compute_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: 293 """ 294 Overview: 295 VAC forward computation graph for critic part, input observation tensor to predict state value. 296 Arguments: 297 - x (:obj:`Union[torch.Tensor, Dict]`): The input observation tensor data. If a dictionary is provided, \ 298 it should contain the key 'observation'. 299 Returns: 300 - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for critic, including ``value``. 301 ReturnsKeys: 302 - value (:obj:`torch.Tensor`): The predicted state value tensor. 303 Shapes: 304 - value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size, (B, 1) is squeezed to (B, ). 305 306 Examples: 307 >>> model = VAC(64, 64) 308 >>> inputs = torch.randn(4, 64) 309 >>> critic_outputs = model(inputs,'compute_critic') 310 >>> assert critic_outputs['value'].shape == torch.Size([4]) 311 """ 312 if isinstance(x, dict): 313 x = self.encoder(x['observation']) if self.share_encoder else self.critic_encoder(x['observation']) 314 else: 315 x = self.encoder(x) if self.share_encoder else self.critic_encoder(x) 316 x = self.critic_head(x) 317 return {'value': x['pred']} 318 319 def compute_actor_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: 320 """ 321 Overview: 322 VAC forward computation graph for both actor and critic part, input observation tensor to predict action \ 323 logit and state value. 324 Arguments: 325 - x (:obj:`Union[torch.Tensor, Dict]`): The input observation tensor data. If a dictionary is provided, \ 326 it should contain keys 'observation' and optionally 'action_mask'. 327 Returns: 328 - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for both actor and critic, \ 329 including ``logit``, ``value``, and optionally ``action_mask`` if the input is a dictionary. 330 ReturnsKeys: 331 - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ 332 the same dimension real-value ranged tensor of possible action choices, and for continuous action \ 333 space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ 334 same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \ 335 and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``. 336 - value (:obj:`torch.Tensor`): The predicted state value tensor. 337 - action_mask (:obj:`torch.Tensor`, optional): The action mask tensor, included if the input is a \ 338 dictionary containing 'action_mask'. 339 Shapes: 340 - logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` 341 - value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size, (B, 1) is squeezed to (B, ). 342 343 Examples: 344 >>> model = VAC(64, 64) 345 >>> inputs = torch.randn(4, 64) 346 >>> outputs = model(inputs,'compute_actor_critic') 347 >>> assert critic_outputs['value'].shape == torch.Size([4]) 348 >>> assert outputs['logit'].shape == torch.Size([4, 64]) 349 350 351 .. note:: 352 ``compute_actor_critic`` interface aims to save computation when shares encoder and return the combination \ 353 dict output. 354 """ 355 if isinstance(x, dict): 356 action_mask = x['action_mask'] 357 if self.share_encoder: 358 actor_embedding = critic_embedding = self.encoder(x['observation']) 359 else: 360 actor_embedding = self.actor_encoder(x['observation']) 361 critic_embedding = self.critic_encoder(x['observation']) 362 else: 363 action_mask = None 364 if self.share_encoder: 365 actor_embedding = critic_embedding = self.encoder(x) 366 else: 367 actor_embedding = self.actor_encoder(x) 368 critic_embedding = self.critic_encoder(x) 369 370 value = self.critic_head(critic_embedding)['pred'] 371 372 if self.action_space == 'discrete': 373 logit = self.actor_head(actor_embedding)['logit'] 374 result = {'logit': logit, 'value': value} 375 if action_mask is not None: 376 result['action_mask'] = action_mask 377 return result 378 elif self.action_space == 'continuous': 379 x = self.actor_head(actor_embedding) 380 return {'logit': x, 'value': value} 381 elif self.action_space == 'hybrid': 382 action_type = self.actor_head[0](actor_embedding) 383 action_args = self.actor_head[1](actor_embedding) 384 return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}, 'value': value} 385 386 387@MODEL_REGISTRY.register('gtrxl_vac') 388class GTrXLVAC(nn.Module): 389 """ 390 Overview: 391 VAC-style actor-critic model with a GTrXL core. 392 This model is intended for policies (e.g., VMPO/PPO variants) that use the VAC interfaces: 393 ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. 394 395 Notes: 396 - By default, this model is used with ``memory_len=0`` in on-policy pipelines, where sequence state is 397 not tracked per environment in the policy. 398 - It still runs observation features through GTrXL layers at every forward call. 399 """ 400 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 401 402 def __init__( 403 self, 404 obs_shape: Union[int, SequenceType], 405 action_shape: Union[int, SequenceType, EasyDict], 406 action_space: str = 'discrete', 407 encoder_hidden_size_list: SequenceType = [128, 512, 1024], 408 hidden_size: int = 1024, 409 actor_head_hidden_size: int = 1024, 410 actor_head_layer_num: int = 1, 411 critic_head_hidden_size: int = 1024, 412 critic_head_layer_num: int = 1, 413 att_head_dim: int = 16, 414 att_head_num: int = 8, 415 att_mlp_num: int = 2, 416 att_layer_num: int = 3, 417 memory_len: int = 0, 418 dropout: float = 0., 419 gru_gating: bool = True, 420 gru_bias: float = 2., 421 activation: Optional[nn.Module] = nn.ReLU(), 422 norm_type: Optional[str] = None, 423 sigma_type: Optional[str] = 'independent', 424 fixed_sigma_value: Optional[int] = 0.3, 425 bound_type: Optional[str] = None, 426 encoder: Optional[torch.nn.Module] = None, 427 ) -> None: 428 super(GTrXLVAC, self).__init__() 429 obs_shape = squeeze(obs_shape) 430 action_shape = squeeze(action_shape) 431 self.obs_shape, self.action_shape = obs_shape, action_shape 432 self.action_space = action_space 433 assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space 434 435 # Observation encoder (vector -> FC, image -> Conv), then projection to GTrXL embedding size. 436 if encoder is not None: 437 if not isinstance(encoder, torch.nn.Module): 438 raise ValueError("illegal encoder instance.") 439 self.encoder = encoder 440 encoder_out_dim = hidden_size 441 else: 442 if isinstance(obs_shape, int) or len(obs_shape) == 1: 443 self.encoder = FCEncoder( 444 obs_shape=obs_shape, 445 hidden_size_list=encoder_hidden_size_list, 446 activation=activation, 447 norm_type=norm_type 448 ) 449 elif len(obs_shape) == 3: 450 self.encoder = ConvEncoder( 451 obs_shape=obs_shape, 452 hidden_size_list=encoder_hidden_size_list, 453 activation=activation, 454 norm_type=norm_type 455 ) 456 else: 457 raise RuntimeError( 458 "not support obs_shape for pre-defined encoder: {}, please customize your own encoder".format( 459 obs_shape 460 ) 461 ) 462 encoder_out_dim = encoder_hidden_size_list[-1] 463 464 self.encoder_proj = nn.Identity() if encoder_out_dim == hidden_size else nn.Linear(encoder_out_dim, hidden_size) 465 466 # GTrXL over encoded features. 467 self.core = GTrXL( 468 input_dim=hidden_size, 469 head_dim=att_head_dim, 470 embedding_dim=hidden_size, 471 head_num=att_head_num, 472 mlp_num=att_mlp_num, 473 layer_num=att_layer_num, 474 memory_len=memory_len, 475 dropout_ratio=dropout, 476 activation=activation, 477 gru_gating=gru_gating, 478 gru_bias=gru_bias, 479 use_embedding_layer=False, 480 ) 481 482 # Separate projections for actor/critic heads. 483 self.actor_proj = nn.Identity() if actor_head_hidden_size == hidden_size else nn.Linear( 484 hidden_size, actor_head_hidden_size 485 ) 486 self.critic_proj = nn.Identity() if critic_head_hidden_size == hidden_size else nn.Linear( 487 hidden_size, critic_head_hidden_size 488 ) 489 490 self.critic_head = RegressionHead( 491 critic_head_hidden_size, 492 1, 493 critic_head_layer_num, 494 activation=activation, 495 norm_type=norm_type, 496 hidden_size=critic_head_hidden_size 497 ) 498 499 if self.action_space == 'continuous': 500 self.multi_head = False 501 self.actor_head = ReparameterizationHead( 502 actor_head_hidden_size, 503 action_shape, 504 actor_head_layer_num, 505 sigma_type=sigma_type, 506 fixed_sigma_value=fixed_sigma_value, 507 activation=activation, 508 norm_type=norm_type, 509 bound_type=bound_type, 510 hidden_size=actor_head_hidden_size, 511 ) 512 elif self.action_space == 'discrete': 513 self.multi_head = not isinstance(action_shape, int) 514 if self.multi_head: 515 self.actor_head = MultiHead( 516 DiscreteHead, 517 actor_head_hidden_size, 518 action_shape, 519 layer_num=actor_head_layer_num, 520 activation=activation, 521 norm_type=norm_type 522 ) 523 else: 524 self.actor_head = DiscreteHead( 525 actor_head_hidden_size, 526 action_shape, 527 actor_head_layer_num, 528 activation=activation, 529 norm_type=norm_type 530 ) 531 else: # hybrid 532 action_shape.action_args_shape = squeeze(action_shape.action_args_shape) 533 action_shape.action_type_shape = squeeze(action_shape.action_type_shape) 534 actor_action_args = ReparameterizationHead( 535 actor_head_hidden_size, 536 action_shape.action_args_shape, 537 actor_head_layer_num, 538 sigma_type=sigma_type, 539 fixed_sigma_value=fixed_sigma_value, 540 activation=activation, 541 norm_type=norm_type, 542 bound_type=bound_type, 543 hidden_size=actor_head_hidden_size, 544 ) 545 actor_action_type = DiscreteHead( 546 actor_head_hidden_size, 547 action_shape.action_type_shape, 548 actor_head_layer_num, 549 activation=activation, 550 norm_type=norm_type, 551 ) 552 self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) 553 554 def reset(self, *args, **kwargs) -> None: 555 # Keep compatibility with model wrappers that call model.reset(). 556 state = kwargs.get('state', None) 557 batch_size = kwargs.get('batch_size', None) 558 if state is not None: 559 self.core.reset_memory(state=state) 560 elif batch_size is not None: 561 self.core.reset_memory(batch_size=batch_size) 562 else: 563 # Defer memory initialization to the next forward with actual batch size. 564 self.core.memory = None 565 566 def _encode_core(self, x: torch.Tensor) -> torch.Tensor: 567 """ 568 Encode observations, run GTrXL, and return feature tensor. 569 Returns shape: 570 - (B, D) for batched observations 571 - (T, B, D) for sequence observations 572 """ 573 if isinstance(self.obs_shape, int) or len(self.obs_shape) == 1: 574 obs_dims = 1 575 else: 576 obs_dims = 3 577 578 leading_shape = x.shape[:-obs_dims] 579 x_flat = x.reshape(-1, *x.shape[-obs_dims:]) 580 enc = self.encoder(x_flat) 581 enc = self.encoder_proj(enc) 582 enc = enc.reshape(*leading_shape, -1) 583 584 if enc.dim() == 2: 585 seq_in = enc.unsqueeze(0) # (1, B, D) 586 core_out = self.core(seq_in)['logit'].squeeze(0) # (B, D) 587 elif enc.dim() == 3: 588 core_out = self.core(enc)['logit'] # (T, B, D) 589 else: 590 raise RuntimeError(f"Unsupported encoded tensor rank {enc.dim()} for GTrXLVAC.") 591 return core_out 592 593 def forward(self, x: torch.Tensor, mode: str) -> Dict: 594 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 595 return getattr(self, mode)(x) 596 597 def compute_actor(self, x: Union[torch.Tensor, Dict]) -> Dict: 598 if isinstance(x, dict): 599 action_mask = x['action_mask'] 600 obs = x['observation'] 601 else: 602 action_mask = None 603 obs = x 604 605 actor_embedding = self.actor_proj(self._encode_core(obs)) 606 if self.action_space == 'discrete': 607 result = {'logit': self.actor_head(actor_embedding)['logit']} 608 if action_mask is not None: 609 result['action_mask'] = action_mask 610 return result 611 elif self.action_space == 'continuous': 612 return {'logit': self.actor_head(actor_embedding)} 613 else: 614 action_type = self.actor_head[0](actor_embedding) 615 action_args = self.actor_head[1](actor_embedding) 616 return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}} 617 618 def compute_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: 619 obs = x['observation'] if isinstance(x, dict) else x 620 critic_embedding = self.critic_proj(self._encode_core(obs)) 621 value = self.critic_head(critic_embedding)['pred'] 622 return {'value': value} 623 624 def compute_actor_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: 625 if isinstance(x, dict): 626 action_mask = x['action_mask'] 627 obs = x['observation'] 628 else: 629 action_mask = None 630 obs = x 631 632 core_embedding = self._encode_core(obs) 633 actor_embedding = self.actor_proj(core_embedding) 634 critic_embedding = self.critic_proj(core_embedding) 635 value = self.critic_head(critic_embedding)['pred'] 636 637 if self.action_space == 'discrete': 638 logit = self.actor_head(actor_embedding)['logit'] 639 result = {'logit': logit, 'value': value} 640 if action_mask is not None: 641 result['action_mask'] = action_mask 642 return result 643 elif self.action_space == 'continuous': 644 return {'logit': self.actor_head(actor_embedding), 'value': value} 645 else: 646 action_type = self.actor_head[0](actor_embedding) 647 action_args = self.actor_head[1](actor_embedding) 648 return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}, 'value': value} 649 650 651@MODEL_REGISTRY.register('dreamervac') 652class DREAMERVAC(nn.Module): 653 """ 654 Overview: 655 The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC). 656 This model now supports discrete, continuous action space. 657 Interfaces: 658 ``__init__``, ``forward``. 659 """ 660 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 661 662 def __init__( 663 self, 664 action_shape: Union[int, SequenceType, EasyDict], 665 dyn_stoch=32, 666 dyn_deter=512, 667 dyn_discrete=32, 668 actor_layers=2, 669 value_layers=2, 670 units=512, 671 act='SiLU', 672 norm='LayerNorm', 673 actor_dist='normal', 674 actor_init_std=1.0, 675 actor_min_std=0.1, 676 actor_max_std=1.0, 677 actor_temp=0.1, 678 action_unimix_ratio=0.01, 679 ) -> None: 680 """ 681 Overview: 682 Initialize the ``DREAMERVAC`` model according to arguments. 683 Arguments: 684 - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. 685 - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. 686 """ 687 super(DREAMERVAC, self).__init__() 688 action_shape = squeeze(action_shape) 689 self.action_shape = action_shape 690 691 if dyn_discrete: 692 feat_size = dyn_stoch * dyn_discrete + dyn_deter 693 else: 694 feat_size = dyn_stoch + dyn_deter 695 self.actor = ActionHead( 696 feat_size, # pytorch version 697 action_shape, 698 actor_layers, 699 units, 700 act, 701 norm, 702 actor_dist, 703 actor_init_std, 704 actor_min_std, 705 actor_max_std, 706 actor_temp, 707 outscale=1.0, 708 unimix_ratio=action_unimix_ratio, 709 ) 710 self.critic = DenseHead( 711 feat_size, # pytorch version 712 (255, ), 713 value_layers, 714 units, 715 'SiLU', # act 716 'LN', # norm 717 'twohot_symlog', 718 outscale=0.0, 719 device='cuda' if torch.cuda.is_available() else 'cpu', 720 )