1from typing import Union, Dict, Optional 2from easydict import EasyDict 3import torch 4import torch.nn as nn 5from copy import deepcopy 6from ding.utils import SequenceType, squeeze, MODEL_REGISTRY 7from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ 8 FCEncoder, ConvEncoder, IMPALAConvEncoder 9from ding.torch_utils.network.dreamer import ActionHead, DenseHead 10from ding.torch_utils.network.gtrxl import GTrXL 11 12 13@MODEL_REGISTRY.register('vac') 14class VAC(nn.Module): 15 """ 16 Overview: 17 The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC), such as \ 18 A2C/PPO/IMPALA. This model now supports discrete, continuous and hybrid action space. The VAC is composed of \ 19 four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \ 20 extract the feature from various observation. Heads are used to predict corresponding value or action logit. \ 21 In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \ 22 and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders. 23 Interfaces: 24 ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. 25 """ 26 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 27 28 def __init__( 29 self, 30 obs_shape: Union[int, SequenceType], 31 action_shape: Union[int, SequenceType, EasyDict], 32 action_space: str = 'discrete', 33 share_encoder: bool = True, 34 encoder_hidden_size_list: SequenceType = [128, 128, 64], 35 actor_head_hidden_size: int = 64, 36 actor_head_layer_num: int = 1, 37 critic_head_hidden_size: int = 64, 38 critic_head_layer_num: int = 1, 39 activation: Optional[nn.Module] = nn.ReLU(), 40 norm_type: Optional[str] = None, 41 sigma_type: Optional[str] = 'independent', 42 fixed_sigma_value: Optional[int] = 0.3, 43 bound_type: Optional[str] = None, 44 encoder: Optional[torch.nn.Module] = None, 45 impala_cnn_encoder: bool = False, 46 ) -> None: 47 """ 48 Overview: 49 Initialize the VAC model according to corresponding input arguments. 50 Arguments: 51 - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. 52 - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. 53 - action_space (:obj:`str`): The type of different action spaces, including ['discrete', 'continuous', \ 54 'hybrid'], then will instantiate corresponding head, including ``DiscreteHead``, \ 55 ``ReparameterizationHead``, and hybrid heads. 56 - share_encoder (:obj:`bool`): Whether to share observation encoders between actor and decoder. 57 - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ 58 the last element is used as the input size of ``actor_head`` and ``critic_head``. 59 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \ 60 to 64, it is the hidden size of the last layer of the ``actor_head`` network. 61 - actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action. 62 - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \ 63 to 64, it is the hidden size of the last layer of the ``critic_head`` network. 64 - critic_head_layer_num (:obj:`int`): The num of layers used in the ``critic_head`` network. 65 - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ 66 if ``None`` then default set it to ``nn.ReLU()``. 67 - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ 68 ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN'] 69 - sigma_type (:obj:`Optional[str]`): The type of sigma in continuous action space, see \ 70 ``ding.torch_utils.network.dreamer.ReparameterizationHead`` for more details, in A2C/PPO, it defaults \ 71 to ``independent``, which means state-independent sigma parameters. 72 - fixed_sigma_value (:obj:`Optional[int]`): If ``sigma_type`` is ``fixed``, then use this value as sigma. 73 - bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \ 74 to ``None``, which means no bound. 75 - encoder (:obj:`Optional[torch.nn.Module]`): The encoder module, defaults to ``None``, you can define \ 76 your own encoder module and pass it into VAC to deal with different observation space. 77 - impala_cnn_encoder (:obj:`bool`): Whether to use IMPALA CNN encoder, defaults to ``False``. 78 """ 79 super(VAC, self).__init__() 80 obs_shape: int = squeeze(obs_shape) 81 action_shape = squeeze(action_shape) 82 self.obs_shape, self.action_shape = obs_shape, action_shape 83 self.impala_cnn_encoder = impala_cnn_encoder 84 self.share_encoder = share_encoder 85 86 # Encoder Type 87 def new_encoder(outsize, activation): 88 if impala_cnn_encoder: 89 return IMPALAConvEncoder(obs_shape=obs_shape, channels=encoder_hidden_size_list, outsize=outsize) 90 else: 91 if isinstance(obs_shape, int) or len(obs_shape) == 1: 92 return FCEncoder( 93 obs_shape=obs_shape, 94 hidden_size_list=encoder_hidden_size_list, 95 activation=activation, 96 norm_type=norm_type 97 ) 98 elif len(obs_shape) == 3: 99 return ConvEncoder( 100 obs_shape=obs_shape, 101 hidden_size_list=encoder_hidden_size_list, 102 activation=activation, 103 norm_type=norm_type 104 ) 105 else: 106 raise RuntimeError( 107 "not support obs_shape for pre-defined encoder: {}, please customize your own encoder". 108 format(obs_shape) 109 ) 110 111 if self.share_encoder: 112 if encoder: 113 if isinstance(encoder, torch.nn.Module): 114 self.encoder = encoder 115 else: 116 raise ValueError("illegal encoder instance.") 117 else: 118 self.encoder = new_encoder(encoder_hidden_size_list[-1], activation) 119 else: 120 if encoder: 121 if isinstance(encoder, torch.nn.Module): 122 self.actor_encoder = encoder 123 self.critic_encoder = deepcopy(encoder) 124 else: 125 raise ValueError("illegal encoder instance.") 126 else: 127 self.actor_encoder = new_encoder(encoder_hidden_size_list[-1], activation) 128 self.critic_encoder = new_encoder(encoder_hidden_size_list[-1], activation) 129 130 # Head Type 131 self.critic_head = RegressionHead( 132 encoder_hidden_size_list[-1], 133 1, 134 critic_head_layer_num, 135 activation=activation, 136 norm_type=norm_type, 137 hidden_size=critic_head_hidden_size 138 ) 139 self.action_space = action_space 140 assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space 141 if self.action_space == 'continuous': 142 self.multi_head = False 143 self.actor_head = ReparameterizationHead( 144 encoder_hidden_size_list[-1], 145 action_shape, 146 actor_head_layer_num, 147 sigma_type=sigma_type, 148 activation=activation, 149 norm_type=norm_type, 150 bound_type=bound_type, 151 hidden_size=actor_head_hidden_size, 152 ) 153 elif self.action_space == 'discrete': 154 actor_head_cls = DiscreteHead 155 multi_head = not isinstance(action_shape, int) 156 self.multi_head = multi_head 157 if multi_head: 158 self.actor_head = MultiHead( 159 actor_head_cls, 160 actor_head_hidden_size, 161 action_shape, 162 layer_num=actor_head_layer_num, 163 activation=activation, 164 norm_type=norm_type 165 ) 166 else: 167 self.actor_head = actor_head_cls( 168 actor_head_hidden_size, 169 action_shape, 170 actor_head_layer_num, 171 activation=activation, 172 norm_type=norm_type 173 ) 174 elif self.action_space == 'hybrid': # HPPO 175 # hybrid action space: action_type(discrete) + action_args(continuous), 176 # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} 177 action_shape.action_args_shape = squeeze(action_shape.action_args_shape) 178 action_shape.action_type_shape = squeeze(action_shape.action_type_shape) 179 actor_action_args = ReparameterizationHead( 180 encoder_hidden_size_list[-1], 181 action_shape.action_args_shape, 182 actor_head_layer_num, 183 sigma_type=sigma_type, 184 fixed_sigma_value=fixed_sigma_value, 185 activation=activation, 186 norm_type=norm_type, 187 bound_type=bound_type, 188 hidden_size=actor_head_hidden_size, 189 ) 190 actor_action_type = DiscreteHead( 191 actor_head_hidden_size, 192 action_shape.action_type_shape, 193 actor_head_layer_num, 194 activation=activation, 195 norm_type=norm_type, 196 ) 197 self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) 198 199 if self.share_encoder: 200 self.actor = [self.encoder, self.actor_head] 201 self.critic = [self.encoder, self.critic_head] 202 else: 203 self.actor = [self.actor_encoder, self.actor_head] 204 self.critic = [self.critic_encoder, self.critic_head] 205 # Convenient for calling some apis (e.g. self.critic.parameters()), 206 # but may cause misunderstanding when `print(self)` 207 self.actor = nn.ModuleList(self.actor) 208 self.critic = nn.ModuleList(self.critic) 209 210 def forward(self, x: torch.Tensor, mode: str) -> Dict: 211 """ 212 Overview: 213 VAC forward computation graph, input observation tensor to predict state value or action logit. Different \ 214 ``mode`` will forward with different network modules to get different outputs and save computation. 215 Arguments: 216 - x (:obj:`torch.Tensor`): The input observation tensor data. 217 - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. 218 Returns: 219 - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph, whose key-values vary from \ 220 different ``mode``. 221 222 Examples (Actor): 223 >>> model = VAC(64, 128) 224 >>> inputs = torch.randn(4, 64) 225 >>> actor_outputs = model(inputs,'compute_actor') 226 >>> assert actor_outputs['logit'].shape == torch.Size([4, 128]) 227 228 Examples (Critic): 229 >>> model = VAC(64, 64) 230 >>> inputs = torch.randn(4, 64) 231 >>> critic_outputs = model(inputs,'compute_critic') 232 >>> assert actor_outputs['logit'].shape == torch.Size([4, 64]) 233 234 Examples (Actor-Critic): 235 >>> model = VAC(64, 64) 236 >>> inputs = torch.randn(4, 64) 237 >>> outputs = model(inputs,'compute_actor_critic') 238 >>> assert critic_outputs['value'].shape == torch.Size([4]) 239 >>> assert outputs['logit'].shape == torch.Size([4, 64]) 240 241 """ 242 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 243 return getattr(self, mode)(x) 244 245 def compute_actor(self, x: Union[torch.Tensor, Dict]) -> Dict: 246 """ 247 Overview: 248 VAC forward computation graph for actor part, input observation tensor to predict action logit. 249 Arguments: 250 - x (:obj:`Union[torch.Tensor, Dict]`): The input observation tensor data. If a dictionary is provided, \ 251 it should contain keys 'observation' and optionally 'action_mask'. 252 Returns: 253 - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for actor, including ``logit`` \ 254 and optionally ``action_mask`` if the input is a dictionary. 255 ReturnsKeys: 256 - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ 257 the same dimension real-value ranged tensor of possible action choices, and for continuous action \ 258 space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ 259 same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \ 260 and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``. 261 - action_mask (:obj:`Optional[torch.Tensor]`): The action mask tensor, included if the input is a \ 262 dictionary containing 'action_mask'. 263 Shapes: 264 - logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` 265 266 Examples: 267 >>> model = VAC(64, 64) 268 >>> inputs = torch.randn(4, 64) 269 >>> actor_outputs = model(inputs,'compute_actor') 270 >>> assert actor_outputs['logit'].shape == torch.Size([4, 64]) 271 """ 272 if isinstance(x, dict): 273 action_mask = x['action_mask'] 274 x = self.encoder(x['observation']) if self.share_encoder else self.actor_encoder(x['observation']) 275 else: 276 action_mask = None 277 x = self.encoder(x) if self.share_encoder else self.actor_encoder(x) 278 279 if self.action_space == 'discrete': 280 result = {'logit': self.actor_head(x)['logit']} 281 if action_mask is not None: 282 result['action_mask'] = action_mask 283 return result 284 elif self.action_space == 'continuous': 285 x = self.actor_head(x) # mu, sigma 286 return {'logit': x} 287 elif self.action_space == 'hybrid': 288 action_type = self.actor_head[0](x) 289 action_args = self.actor_head[1](x) 290 return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}} 291 292 def compute_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: 293 """ 294 Overview: 295 VAC forward computation graph for critic part, input observation tensor to predict state value. 296 Arguments: 297 - x (:obj:`Union[torch.Tensor, Dict]`): The input observation tensor data. If a dictionary is provided, \ 298 it should contain the key 'observation'. 299 Returns: 300 - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for critic, including ``value``. 301 ReturnsKeys: 302 - value (:obj:`torch.Tensor`): The predicted state value tensor. 303 Shapes: 304 - value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size, (B, 1) is squeezed to (B, ). 305 306 Examples: 307 >>> model = VAC(64, 64) 308 >>> inputs = torch.randn(4, 64) 309 >>> critic_outputs = model(inputs,'compute_critic') 310 >>> assert critic_outputs['value'].shape == torch.Size([4]) 311 """ 312 if isinstance(x, dict): 313 x = self.encoder(x['observation']) if self.share_encoder else self.critic_encoder(x['observation']) 314 else: 315 x = self.encoder(x) if self.share_encoder else self.critic_encoder(x) 316 x = self.critic_head(x) 317 return {'value': x['pred']} 318 319 def compute_actor_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: 320 """ 321 Overview: 322 VAC forward computation graph for both actor and critic part, input observation tensor to predict action \ 323 logit and state value. 324 Arguments: 325 - x (:obj:`Union[torch.Tensor, Dict]`): The input observation tensor data. If a dictionary is provided, \ 326 it should contain keys 'observation' and optionally 'action_mask'. 327 Returns: 328 - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for both actor and critic, \ 329 including ``logit``, ``value``, and optionally ``action_mask`` if the input is a dictionary. 330 ReturnsKeys: 331 - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ 332 the same dimension real-value ranged tensor of possible action choices, and for continuous action \ 333 space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ 334 same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \ 335 and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``. 336 - value (:obj:`torch.Tensor`): The predicted state value tensor. 337 - action_mask (:obj:`torch.Tensor`, optional): The action mask tensor, included if the input is a \ 338 dictionary containing 'action_mask'. 339 Shapes: 340 - logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` 341 - value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size, (B, 1) is squeezed to (B, ). 342 343 Examples: 344 >>> model = VAC(64, 64) 345 >>> inputs = torch.randn(4, 64) 346 >>> outputs = model(inputs,'compute_actor_critic') 347 >>> assert critic_outputs['value'].shape == torch.Size([4]) 348 >>> assert outputs['logit'].shape == torch.Size([4, 64]) 349 350 351 .. note:: 352 ``compute_actor_critic`` interface aims to save computation when shares encoder and return the combination \ 353 dict output. 354 """ 355 if isinstance(x, dict): 356 action_mask = x['action_mask'] 357 if self.share_encoder: 358 actor_embedding = critic_embedding = self.encoder(x['observation']) 359 else: 360 actor_embedding = self.actor_encoder(x['observation']) 361 critic_embedding = self.critic_encoder(x['observation']) 362 else: 363 action_mask = None 364 if self.share_encoder: 365 actor_embedding = critic_embedding = self.encoder(x) 366 else: 367 actor_embedding = self.actor_encoder(x) 368 critic_embedding = self.critic_encoder(x) 369 370 value = self.critic_head(critic_embedding)['pred'] 371 372 if self.action_space == 'discrete': 373 logit = self.actor_head(actor_embedding)['logit'] 374 result = {'logit': logit, 'value': value} 375 if action_mask is not None: 376 result['action_mask'] = action_mask 377 return result 378 elif self.action_space == 'continuous': 379 x = self.actor_head(actor_embedding) 380 return {'logit': x, 'value': value} 381 elif self.action_space == 'hybrid': 382 action_type = self.actor_head[0](actor_embedding) 383 action_args = self.actor_head[1](actor_embedding) 384 return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}, 'value': value} 385 386 387@MODEL_REGISTRY.register('gtrxl_vac') 388class GTrXLVAC(nn.Module): 389 """ 390 Overview: 391 VAC-style actor-critic model with a GTrXL core. 392 This model is intended for policies (e.g., VMPO/PPO variants) that use the VAC interfaces: 393 ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. 394 395 Notes: 396 - By default, this model is used with ``memory_len=0`` in on-policy pipelines, where sequence state is 397 not tracked per environment in the policy. 398 - It still runs observation features through GTrXL layers at every forward call. 399 """ 400 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 401 402 def __init__( 403 self, 404 obs_shape: Union[int, SequenceType], 405 action_shape: Union[int, SequenceType, EasyDict], 406 action_space: str = 'discrete', 407 encoder_hidden_size_list: SequenceType = [128, 512, 1024], 408 hidden_size: int = 1024, 409 actor_head_hidden_size: int = 1024, 410 actor_head_layer_num: int = 1, 411 critic_head_hidden_size: int = 1024, 412 critic_head_layer_num: int = 1, 413 att_head_dim: int = 16, 414 att_head_num: int = 8, 415 att_mlp_num: int = 2, 416 att_layer_num: int = 3, 417 memory_len: int = 0, 418 dropout: float = 0., 419 gru_gating: bool = True, 420 gru_bias: float = 2., 421 activation: Optional[nn.Module] = nn.ReLU(), 422 norm_type: Optional[str] = None, 423 sigma_type: Optional[str] = 'independent', 424 fixed_sigma_value: Optional[int] = 0.3, 425 bound_type: Optional[str] = None, 426 encoder: Optional[torch.nn.Module] = None, 427 ) -> None: 428 super(GTrXLVAC, self).__init__() 429 obs_shape = squeeze(obs_shape) 430 action_shape = squeeze(action_shape) 431 self.obs_shape, self.action_shape = obs_shape, action_shape 432 self.action_space = action_space 433 assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space 434 435 # Observation encoder (vector -> FC, image -> Conv), then projection to GTrXL embedding size. 436 if encoder is not None: 437 if not isinstance(encoder, torch.nn.Module): 438 raise ValueError("illegal encoder instance.") 439 self.encoder = encoder 440 encoder_out_dim = hidden_size 441 else: 442 if isinstance(obs_shape, int) or len(obs_shape) == 1: 443 self.encoder = FCEncoder( 444 obs_shape=obs_shape, 445 hidden_size_list=encoder_hidden_size_list, 446 activation=activation, 447 norm_type=norm_type 448 ) 449 elif len(obs_shape) == 3: 450 self.encoder = ConvEncoder( 451 obs_shape=obs_shape, 452 hidden_size_list=encoder_hidden_size_list, 453 activation=activation, 454 norm_type=norm_type 455 ) 456 else: 457 raise RuntimeError( 458 "not support obs_shape for pre-defined encoder: {}, please customize your own encoder".format( 459 obs_shape 460 ) 461 ) 462 encoder_out_dim = encoder_hidden_size_list[-1] 463 464 self.encoder_proj = nn.Identity() if encoder_out_dim == hidden_size else nn.Linear(encoder_out_dim, hidden_size) 465 466 # GTrXL over encoded features. 467 self.core = GTrXL( 468 input_dim=hidden_size, 469 head_dim=att_head_dim, 470 embedding_dim=hidden_size, 471 head_num=att_head_num, 472 mlp_num=att_mlp_num, 473 layer_num=att_layer_num, 474 memory_len=memory_len, 475 dropout_ratio=dropout, 476 activation=activation, 477 gru_gating=gru_gating, 478 gru_bias=gru_bias, 479 use_embedding_layer=False, 480 ) 481 482 # Separate projections for actor/critic heads. 483 self.actor_proj = nn.Identity() if actor_head_hidden_size == hidden_size else nn.Linear( 484 hidden_size, actor_head_hidden_size 485 ) 486 self.critic_proj = nn.Identity() if critic_head_hidden_size == hidden_size else nn.Linear( 487 hidden_size, critic_head_hidden_size 488 ) 489 490 self.critic_head = RegressionHead( 491 critic_head_hidden_size, 492 1, 493 critic_head_layer_num, 494 activation=activation, 495 norm_type=norm_type, 496 hidden_size=critic_head_hidden_size 497 ) 498 499 if self.action_space == 'continuous': 500 self.multi_head = False 501 self.actor_head = ReparameterizationHead( 502 actor_head_hidden_size, 503 action_shape, 504 actor_head_layer_num, 505 sigma_type=sigma_type, 506 fixed_sigma_value=fixed_sigma_value, 507 activation=activation, 508 norm_type=norm_type, 509 bound_type=bound_type, 510 hidden_size=actor_head_hidden_size, 511 ) 512 elif self.action_space == 'discrete': 513 self.multi_head = not isinstance(action_shape, int) 514 if self.multi_head: 515 self.actor_head = MultiHead( 516 DiscreteHead, 517 actor_head_hidden_size, 518 action_shape, 519 layer_num=actor_head_layer_num, 520 activation=activation, 521 norm_type=norm_type 522 ) 523 else: 524 self.actor_head = DiscreteHead( 525 actor_head_hidden_size, 526 action_shape, 527 actor_head_layer_num, 528 activation=activation, 529 norm_type=norm_type 530 ) 531 else: # hybrid 532 action_shape.action_args_shape = squeeze(action_shape.action_args_shape) 533 action_shape.action_type_shape = squeeze(action_shape.action_type_shape) 534 actor_action_args = ReparameterizationHead( 535 actor_head_hidden_size, 536 action_shape.action_args_shape, 537 actor_head_layer_num, 538 sigma_type=sigma_type, 539 fixed_sigma_value=fixed_sigma_value, 540 activation=activation, 541 norm_type=norm_type, 542 bound_type=bound_type, 543 hidden_size=actor_head_hidden_size, 544 ) 545 actor_action_type = DiscreteHead( 546 actor_head_hidden_size, 547 action_shape.action_type_shape, 548 actor_head_layer_num, 549 activation=activation, 550 norm_type=norm_type, 551 ) 552 self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) 553 554 def reset(self, *args, **kwargs) -> None: 555 # Keep compatibility with model wrappers that call model.reset(). 556 state = kwargs.get('state', None) 557 batch_size = kwargs.get('batch_size', None) 558 if state is not None: 559 self.core.reset_memory(state=state) 560 elif batch_size is not None: 561 self.core.reset_memory(batch_size=batch_size) 562 else: 563 # Defer memory initialization to the next forward with actual batch size. 564 self.core.memory = None 565 566 def _encode_core(self, x: torch.Tensor) -> torch.Tensor: 567 """ 568 Encode observations, run GTrXL, and return feature tensor. 569 Returns shape: 570 - (B, D) for batched observations 571 - (T, B, D) for sequence observations 572 """ 573 if isinstance(self.obs_shape, int) or len(self.obs_shape) == 1: 574 obs_dims = 1 575 else: 576 obs_dims = 3 577 578 leading_shape = x.shape[:-obs_dims] 579 x_flat = x.reshape(-1, *x.shape[-obs_dims:]) 580 enc = self.encoder(x_flat) 581 enc = self.encoder_proj(enc) 582 enc = enc.reshape(*leading_shape, -1) 583 584 if enc.dim() == 2: 585 seq_in = enc.unsqueeze(0) # (1, B, D) 586 core_out = self.core(seq_in)['logit'].squeeze(0) # (B, D) 587 elif enc.dim() == 3: 588 core_out = self.core(enc)['logit'] # (T, B, D) 589 else: 590 raise RuntimeError(f"Unsupported encoded tensor rank {enc.dim()} for GTrXLVAC.") 591 return core_out 592 593 def forward(self, x: torch.Tensor, mode: str) -> Dict: 594 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 595 return getattr(self, mode)(x) 596 597 def compute_actor(self, x: Union[torch.Tensor, Dict]) -> Dict: 598 if isinstance(x, dict): 599 action_mask = x['action_mask'] 600 obs = x['observation'] 601 else: 602 action_mask = None 603 obs = x 604 605 actor_embedding = self.actor_proj(self._encode_core(obs)) 606 if self.action_space == 'discrete': 607 result = {'logit': self.actor_head(actor_embedding)['logit']} 608 if action_mask is not None: 609 result['action_mask'] = action_mask 610 return result 611 elif self.action_space == 'continuous': 612 return {'logit': self.actor_head(actor_embedding)} 613 else: 614 action_type = self.actor_head[0](actor_embedding) 615 action_args = self.actor_head[1](actor_embedding) 616 return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}} 617 618 def compute_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: 619 obs = x['observation'] if isinstance(x, dict) else x 620 critic_embedding = self.critic_proj(self._encode_core(obs)) 621 value = self.critic_head(critic_embedding)['pred'] 622 return {'value': value} 623 624 def compute_actor_critic(self, x: Union[torch.Tensor, Dict]) -> Dict: 625 if isinstance(x, dict): 626 action_mask = x['action_mask'] 627 obs = x['observation'] 628 else: 629 action_mask = None 630 obs = x 631 632 core_embedding = self._encode_core(obs) 633 actor_embedding = self.actor_proj(core_embedding) 634 critic_embedding = self.critic_proj(core_embedding) 635 value = self.critic_head(critic_embedding)['pred'] 636 637 if self.action_space == 'discrete': 638 logit = self.actor_head(actor_embedding)['logit'] 639 result = {'logit': logit, 'value': value} 640 if action_mask is not None: 641 result['action_mask'] = action_mask 642 return result 643 elif self.action_space == 'continuous': 644 return {'logit': self.actor_head(actor_embedding), 'value': value} 645 else: 646 action_type = self.actor_head[0](actor_embedding) 647 action_args = self.actor_head[1](actor_embedding) 648 return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}, 'value': value} 649 650 651@MODEL_REGISTRY.register('dreamervac') 652class DREAMERVAC(nn.Module): 653 """ 654 Overview: 655 The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC). 656 This model now supports discrete, continuous action space. 657 Interfaces: 658 ``__init__``, ``forward``. 659 """ 660 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 661 662 def __init__( 663 self, 664 action_shape: Union[int, SequenceType, EasyDict], 665 dyn_stoch=32, 666 dyn_deter=512, 667 dyn_discrete=32, 668 actor_layers=2, 669 value_layers=2, 670 units=512, 671 act='SiLU', 672 norm='LayerNorm', 673 actor_dist='normal', 674 actor_init_std=1.0, 675 actor_min_std=0.1, 676 actor_max_std=1.0, 677 actor_temp=0.1, 678 action_unimix_ratio=0.01, 679 ) -> None: 680 """ 681 Overview: 682 Initialize the ``DREAMERVAC`` model according to arguments. 683 Arguments: 684 - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. 685 - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. 686 """ 687 super(DREAMERVAC, self).__init__() 688 action_shape = squeeze(action_shape) 689 self.action_shape = action_shape 690 691 if dyn_discrete: 692 feat_size = dyn_stoch * dyn_discrete + dyn_deter 693 else: 694 feat_size = dyn_stoch + dyn_deter 695 self.actor = ActionHead( 696 feat_size, # pytorch version 697 action_shape, 698 actor_layers, 699 units, 700 act, 701 norm, 702 actor_dist, 703 actor_init_std, 704 actor_min_std, 705 actor_max_std, 706 actor_temp, 707 outscale=1.0, 708 unimix_ratio=action_unimix_ratio, 709 ) 710 self.critic = DenseHead( 711 feat_size, # pytorch version 712 (255, ), 713 value_layers, 714 units, 715 'SiLU', # act 716 'LN', # norm 717 'twohot_symlog', 718 outscale=0.0, 719 device='cuda' if torch.cuda.is_available() else 'cpu', 720 )