Skip to content

ding.model.template.havac

ding.model.template.havac

RNNLayer

Bases: Module

forward(x, prev_state, inference=False)

Forward pass of the RNN layer. If inference is True, sequence length of input is set to 1. If res_link is True, a residual link is added to the output.

HAVAC

Bases: Module

Overview

The HAVAC model of each agent for HAPPO.

Interfaces: __init__, forward

__init__(agent_obs_shape, global_obs_shape, action_shape, agent_num, use_lstm=False, lstm_type='gru', encoder_hidden_size_list=[128, 128, 64], actor_head_hidden_size=64, actor_head_layer_num=2, critic_head_hidden_size=64, critic_head_layer_num=1, action_space='discrete', activation=nn.ReLU(), norm_type=None, sigma_type='independent', bound_type=None, res_link=False)

Overview

Init the VAC Model for HAPPO according to arguments.

Arguments: - agent_obs_shape (:obj:Union[int, SequenceType]): Observation's space for single agent. - global_obs_shape (:obj:Union[int, SequenceType]): Observation's space for global agent - action_shape (:obj:Union[int, SequenceType]): Action's space. - agent_num (:obj:int): Number of agents. - lstm_type (:obj:str): use lstm or gru, default to gru - encoder_hidden_size_list (:obj:SequenceType): Collection of hidden_size to pass to Encoder - actor_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to actor-nn's Head. - actor_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for actor's nn. - critic_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to critic-nn's Head. - critic_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for critic's nn. - activation (:obj:Optional[nn.Module]): The type of activation function to use in MLP the after layer_fn, if None then default set to nn.ReLU() - norm_type (:obj:Optional[str]): The type of normalization to use, see ding.torch_utils.fc_block for more details- res_link (:obj:bool`): use the residual link or not, default to False

HAVACAgent

Bases: Module

Overview

The HAVAC model of each agent for HAPPO.

Interfaces: __init__, forward, compute_actor, compute_critic, compute_actor_critic

__init__(agent_obs_shape, global_obs_shape, action_shape, use_lstm=False, lstm_type='gru', encoder_hidden_size_list=[128, 128, 64], actor_head_hidden_size=64, actor_head_layer_num=2, critic_head_hidden_size=64, critic_head_layer_num=1, action_space='discrete', activation=nn.ReLU(), norm_type=None, sigma_type='happo', bound_type=None, res_link=False)

Overview

Init the VAC Model for HAPPO according to arguments.

Arguments: - agent_obs_shape (:obj:Union[int, SequenceType]): Observation's space for single agent. - global_obs_shape (:obj:Union[int, SequenceType]): Observation's space for global agent - action_shape (:obj:Union[int, SequenceType]): Action's space. - lstm_type (:obj:str): use lstm or gru, default to gru - encoder_hidden_size_list (:obj:SequenceType): Collection of hidden_size to pass to Encoder - actor_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to actor-nn's Head. - actor_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for actor's nn. - critic_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to critic-nn's Head. - critic_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for critic's nn. - activation (:obj:Optional[nn.Module]): The type of activation function to use in MLP the after layer_fn, if None then default set to nn.ReLU() - norm_type (:obj:Optional[str]): The type of normalization to use, see ding.torch_utils.fc_block for more details- res_link (:obj:bool`): use the residual link or not, default to False

forward(inputs, mode)

Overview

Use encoded embedding tensor to predict output. Parameter updates with VAC's MLPs forward setup.

Arguments: Forward with 'compute_actor' or 'compute_critic': - inputs (:obj:torch.Tensor): The encoded embedding tensor, determined with given hidden_size, i.e. (B, N=hidden_size). Whether actor_head_hidden_size or critic_head_hidden_size depend on mode. Returns: - outputs (:obj:Dict): Run with encoder and head.

    Forward with ``'compute_actor'``, Necessary Keys:
        - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``.

    Forward with ``'compute_critic'``, Necessary Keys:
        - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.

Shapes: - inputs (:obj:torch.Tensor): :math:(B, N), where B is batch size and N corresponding hidden_size - logit (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action_shape - value (:obj:torch.FloatTensor): :math:(B, ), where B is batch size.

Actor Examples

model = VAC(64,128) inputs = torch.randn(4, 64) actor_outputs = model(inputs,'compute_actor') assert actor_outputs['logit'].shape == torch.Size([4, 128])

Critic Examples

model = VAC(64,64) inputs = torch.randn(4, 64) critic_outputs = model(inputs,'compute_critic') critic_outputs['value'] tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=)

Actor-Critic Examples

model = VAC(64,64) inputs = torch.randn(4, 64) outputs = model(inputs,'compute_actor_critic') outputs['value'] tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) assert outputs['logit'].shape == torch.Size([4, 64])

compute_actor(inputs, inference=False)

Overview

Execute parameter updates with 'compute_actor' mode Use encoded embedding tensor to predict output.

Arguments: - inputs (:obj:torch.Tensor): input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), 'actor_prev_state'] Returns: - outputs (:obj:Dict): Run with encoder RNN(optional) and head.

ReturnsKeys
  • logit (:obj:torch.Tensor): Logit encoding tensor.
  • actor_next_state:
  • hidden_state

Shapes: - logit (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action_shape - actor_next_state: (B,) - hidden_state:

Examples:

>>> model = HAVAC(
        agent_obs_shape=obs_dim,
        global_obs_shape=global_obs_dim,
        action_shape=action_dim,
        use_lstm = True,
        )
>>> inputs = {
        'obs': {
            'agent_state': torch.randn(T, bs, obs_dim),
            'global_state': torch.randn(T, bs, global_obs_dim),
            'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
        },
        'actor_prev_state': [None for _ in range(bs)],
    }
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['logit'].shape == (T, bs, action_dim)

compute_critic(inputs, inference=False)

Overview

Execute parameter updates with 'compute_critic' mode Use encoded embedding tensor to predict output.

Arguments: - inputs (:obj:Dict): input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), 'critic_prev_state'(when you are using rnn)] Returns: - outputs (:obj:Dict): Run with encoder [rnn] and head.

    Necessary Keys:
        - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
        - logits

Shapes: - value (:obj:torch.FloatTensor): :math:(B, ), where B is batch size. - logits

Examples:

>>> model = HAVAC(
        agent_obs_shape=obs_dim,
        global_obs_shape=global_obs_dim,
        action_shape=action_dim,
        use_lstm = True,
        )
>>> inputs = {
        'obs': {
            'agent_state': torch.randn(T, bs, obs_dim),
            'global_state': torch.randn(T, bs, global_obs_dim),
            'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
        },
        'critic_prev_state': [None for _ in range(bs)],
    }
>>> critic_outputs = model(inputs,'compute_critic')
>>> assert critic_outputs['value'].shape == (T, bs))

compute_actor_critic(inputs, inference=False)

Overview

Execute parameter updates with 'compute_actor_critic' mode Use encoded embedding tensor to predict output.

Arguments: - inputs (:dict): input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), 'actor_prev_state', 'critic_prev_state'(when you are using rnn)]

Returns:

Type Description
Dict
  • outputs (:obj:Dict): Run with encoder and head.
ReturnsKeys
  • logit (:obj:torch.Tensor): Logit encoding tensor, with same size as input x.
  • value (:obj:torch.Tensor): Q value tensor with same size as batch size.

Shapes: - logit (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action_shape - value (:obj:torch.FloatTensor): :math:(B, ), where B is batch size.

Examples:

>>> model = VAC(64,64)
>>> inputs = torch.randn(4, 64)
>>> outputs = model(inputs,'compute_actor_critic')
>>> outputs['value']
tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>)
>>> assert outputs['logit'].shape == torch.Size([4, 64])

.. note:: compute_actor_critic interface aims to save computation when shares encoder. Returning the combination dictionry.

Full Source Code

../ding/model/template/havac.py

1from typing import Union, Dict, Optional 2import torch 3import torch.nn as nn 4 5from ding.torch_utils import get_lstm 6from ding.utils import SequenceType, squeeze, MODEL_REGISTRY 7from ding.model.template.q_learning import parallel_wrapper 8from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, \ 9 FCEncoder, ConvEncoder 10 11 12class RNNLayer(nn.Module): 13 14 def __init__(self, lstm_type, input_size, hidden_size, res_link: bool = False): 15 super(RNNLayer, self).__init__() 16 self.rnn = get_lstm(lstm_type, input_size=input_size, hidden_size=hidden_size) 17 self.res_link = res_link 18 19 def forward(self, x, prev_state, inference: bool = False): 20 """ 21 Forward pass of the RNN layer. 22 If inference is True, sequence length of input is set to 1. 23 If res_link is True, a residual link is added to the output. 24 """ 25 # x: obs_embedding 26 if self.res_link: 27 a = x 28 if inference: 29 x = x.unsqueeze(0) # for rnn input, put the seq_len of x as 1 instead of none. 30 # prev_state: DataType: List[Tuple[torch.Tensor]]; Initially, it is a list of None 31 x, next_state = self.rnn(x, prev_state) 32 x = x.squeeze(0) # to delete the seq_len dim to match head network input 33 if self.res_link: 34 x = x + a 35 return {'output': x, 'next_state': next_state} 36 else: 37 # lstm_embedding stores all hidden_state 38 lstm_embedding = [] 39 hidden_state_list = [] 40 for t in range(x.shape[0]): # T timesteps 41 # use x[t:t+1] but not x[t] can keep original dimension 42 output, prev_state = self.rnn(x[t:t + 1], prev_state) # output: (1,B, head_hidden_size) 43 lstm_embedding.append(output) 44 hidden_state = [p['h'] for p in prev_state] 45 # only keep ht, {list: x.shape[0]{Tensor:(1, batch_size, head_hidden_size)}} 46 hidden_state_list.append(torch.cat(hidden_state, dim=1)) 47 x = torch.cat(lstm_embedding, 0) # (T, B, head_hidden_size) 48 if self.res_link: 49 x = x + a 50 all_hidden_state = torch.cat(hidden_state_list, dim=0) 51 return {'output': x, 'next_state': prev_state, 'hidden_state': all_hidden_state} 52 53 54@MODEL_REGISTRY.register('havac') 55class HAVAC(nn.Module): 56 """ 57 Overview: 58 The HAVAC model of each agent for HAPPO. 59 Interfaces: 60 ``__init__``, ``forward`` 61 """ 62 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 63 64 def __init__( 65 self, 66 agent_obs_shape: Union[int, SequenceType], 67 global_obs_shape: Union[int, SequenceType], 68 action_shape: Union[int, SequenceType], 69 agent_num: int, 70 use_lstm: bool = False, 71 lstm_type: str = 'gru', 72 encoder_hidden_size_list: SequenceType = [128, 128, 64], 73 actor_head_hidden_size: int = 64, 74 actor_head_layer_num: int = 2, 75 critic_head_hidden_size: int = 64, 76 critic_head_layer_num: int = 1, 77 action_space: str = 'discrete', 78 activation: Optional[nn.Module] = nn.ReLU(), 79 norm_type: Optional[str] = None, 80 sigma_type: Optional[str] = 'independent', 81 bound_type: Optional[str] = None, 82 res_link: bool = False, 83 ) -> None: 84 r""" 85 Overview: 86 Init the VAC Model for HAPPO according to arguments. 87 Arguments: 88 - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent. 89 - global_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for global agent 90 - action_shape (:obj:`Union[int, SequenceType]`): Action's space. 91 - agent_num (:obj:`int`): Number of agents. 92 - lstm_type (:obj:`str`): use lstm or gru, default to gru 93 - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` 94 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. 95 - actor_head_layer_num (:obj:`int`): 96 The num of layers used in the network to compute Q value output for actor's nn. 97 - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. 98 - critic_head_layer_num (:obj:`int`): 99 The num of layers used in the network to compute Q value output for critic's nn. 100 - activation (:obj:`Optional[nn.Module]`): 101 The type of activation function to use in ``MLP`` the after ``layer_fn``, 102 if ``None`` then default set to ``nn.ReLU()`` 103 - norm_type (:obj:`Optional[str]`): 104 The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` 105 - res_link (:obj:`bool`): use the residual link or not, default to False 106 """ 107 super(HAVAC, self).__init__() 108 self.agent_num = agent_num 109 self.agent_models = nn.ModuleList( 110 [ 111 HAVACAgent( 112 agent_obs_shape=agent_obs_shape, 113 global_obs_shape=global_obs_shape, 114 action_shape=action_shape, 115 use_lstm=use_lstm, 116 action_space=action_space, 117 ) for _ in range(agent_num) 118 ] 119 ) 120 121 def forward(self, agent_idx, input_data, mode): 122 selected_agent_model = self.agent_models[agent_idx] 123 output = selected_agent_model(input_data, mode) 124 return output 125 126 127class HAVACAgent(nn.Module): 128 """ 129 Overview: 130 The HAVAC model of each agent for HAPPO. 131 Interfaces: 132 ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic`` 133 """ 134 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 135 136 def __init__( 137 self, 138 agent_obs_shape: Union[int, SequenceType], 139 global_obs_shape: Union[int, SequenceType], 140 action_shape: Union[int, SequenceType], 141 use_lstm: bool = False, 142 lstm_type: str = 'gru', 143 encoder_hidden_size_list: SequenceType = [128, 128, 64], 144 actor_head_hidden_size: int = 64, 145 actor_head_layer_num: int = 2, 146 critic_head_hidden_size: int = 64, 147 critic_head_layer_num: int = 1, 148 action_space: str = 'discrete', 149 activation: Optional[nn.Module] = nn.ReLU(), 150 norm_type: Optional[str] = None, 151 sigma_type: Optional[str] = 'happo', 152 bound_type: Optional[str] = None, 153 res_link: bool = False, 154 ) -> None: 155 r""" 156 Overview: 157 Init the VAC Model for HAPPO according to arguments. 158 Arguments: 159 - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent. 160 - global_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for global agent 161 - action_shape (:obj:`Union[int, SequenceType]`): Action's space. 162 - lstm_type (:obj:`str`): use lstm or gru, default to gru 163 - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` 164 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. 165 - actor_head_layer_num (:obj:`int`): 166 The num of layers used in the network to compute Q value output for actor's nn. 167 - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. 168 - critic_head_layer_num (:obj:`int`): 169 The num of layers used in the network to compute Q value output for critic's nn. 170 - activation (:obj:`Optional[nn.Module]`): 171 The type of activation function to use in ``MLP`` the after ``layer_fn``, 172 if ``None`` then default set to ``nn.ReLU()`` 173 - norm_type (:obj:`Optional[str]`): 174 The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` 175 - res_link (:obj:`bool`): use the residual link or not, default to False 176 """ 177 super(HAVACAgent, self).__init__() 178 agent_obs_shape: int = squeeze(agent_obs_shape) 179 global_obs_shape: int = squeeze(global_obs_shape) 180 action_shape: int = squeeze(action_shape) 181 self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape 182 self.action_space = action_space 183 # Encoder Type 184 if isinstance(agent_obs_shape, int) or len(agent_obs_shape) == 1: 185 actor_encoder_cls = FCEncoder 186 elif len(agent_obs_shape) == 3: 187 actor_encoder_cls = ConvEncoder 188 else: 189 raise RuntimeError( 190 "not support obs_shape for pre-defined encoder: {}, please customize your own VAC". 191 format(agent_obs_shape) 192 ) 193 if isinstance(global_obs_shape, int) or len(global_obs_shape) == 1: 194 critic_encoder_cls = FCEncoder 195 elif len(global_obs_shape) == 3: 196 critic_encoder_cls = ConvEncoder 197 else: 198 raise RuntimeError( 199 "not support obs_shape for pre-defined encoder: {}, please customize your own VAC". 200 format(global_obs_shape) 201 ) 202 203 # We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder. 204 # In SMAC task it can obviously improve the performance. 205 # Users can change the model according to their own needs. 206 self.actor_encoder = actor_encoder_cls( 207 obs_shape=agent_obs_shape, 208 hidden_size_list=encoder_hidden_size_list, 209 activation=activation, 210 norm_type=norm_type 211 ) 212 self.critic_encoder = critic_encoder_cls( 213 obs_shape=global_obs_shape, 214 hidden_size_list=encoder_hidden_size_list, 215 activation=activation, 216 norm_type=norm_type 217 ) 218 # RNN part 219 self.use_lstm = use_lstm 220 if self.use_lstm: 221 self.actor_rnn = RNNLayer( 222 lstm_type, 223 input_size=encoder_hidden_size_list[-1], 224 hidden_size=actor_head_hidden_size, 225 res_link=res_link 226 ) 227 self.critic_rnn = RNNLayer( 228 lstm_type, 229 input_size=encoder_hidden_size_list[-1], 230 hidden_size=critic_head_hidden_size, 231 res_link=res_link 232 ) 233 # Head Type 234 self.critic_head = RegressionHead( 235 critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type 236 ) 237 assert self.action_space in ['discrete', 'continuous'], self.action_space 238 if self.action_space == 'discrete': 239 self.actor_head = DiscreteHead( 240 actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type 241 ) 242 elif self.action_space == 'continuous': 243 self.actor_head = ReparameterizationHead( 244 actor_head_hidden_size, 245 action_shape, 246 actor_head_layer_num, 247 sigma_type=sigma_type, 248 activation=activation, 249 norm_type=norm_type, 250 bound_type=bound_type 251 ) 252 # must use list, not nn.ModuleList 253 self.actor = [self.actor_encoder, self.actor_rnn, self.actor_head] if self.use_lstm \ 254 else [self.actor_encoder, self.actor_head] 255 self.critic = [self.critic_encoder, self.critic_rnn, self.critic_head] if self.use_lstm \ 256 else [self.critic_encoder, self.critic_head] 257 # for convenience of call some apis(such as: self.critic.parameters()), but may cause 258 # misunderstanding when print(self) 259 self.actor = nn.ModuleList(self.actor) 260 self.critic = nn.ModuleList(self.critic) 261 262 def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: 263 r""" 264 Overview: 265 Use encoded embedding tensor to predict output. 266 Parameter updates with VAC's MLPs forward setup. 267 Arguments: 268 Forward with ``'compute_actor'`` or ``'compute_critic'``: 269 - inputs (:obj:`torch.Tensor`): 270 The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. 271 Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. 272 Returns: 273 - outputs (:obj:`Dict`): 274 Run with encoder and head. 275 276 Forward with ``'compute_actor'``, Necessary Keys: 277 - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. 278 279 Forward with ``'compute_critic'``, Necessary Keys: 280 - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. 281 Shapes: 282 - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N corresponding ``hidden_size`` 283 - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` 284 - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. 285 286 Actor Examples: 287 >>> model = VAC(64,128) 288 >>> inputs = torch.randn(4, 64) 289 >>> actor_outputs = model(inputs,'compute_actor') 290 >>> assert actor_outputs['logit'].shape == torch.Size([4, 128]) 291 292 Critic Examples: 293 >>> model = VAC(64,64) 294 >>> inputs = torch.randn(4, 64) 295 >>> critic_outputs = model(inputs,'compute_critic') 296 >>> critic_outputs['value'] 297 tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>) 298 299 Actor-Critic Examples: 300 >>> model = VAC(64,64) 301 >>> inputs = torch.randn(4, 64) 302 >>> outputs = model(inputs,'compute_actor_critic') 303 >>> outputs['value'] 304 tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>) 305 >>> assert outputs['logit'].shape == torch.Size([4, 64]) 306 307 """ 308 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 309 return getattr(self, mode)(inputs) 310 311 def compute_actor(self, inputs: Dict, inference: bool = False) -> Dict: 312 r""" 313 Overview: 314 Execute parameter updates with ``'compute_actor'`` mode 315 Use encoded embedding tensor to predict output. 316 Arguments: 317 - inputs (:obj:`torch.Tensor`): 318 input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), 319 'actor_prev_state'] 320 Returns: 321 - outputs (:obj:`Dict`): 322 Run with encoder RNN(optional) and head. 323 324 ReturnsKeys: 325 - logit (:obj:`torch.Tensor`): Logit encoding tensor. 326 - actor_next_state: 327 - hidden_state 328 Shapes: 329 - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` 330 - actor_next_state: (B,) 331 - hidden_state: 332 333 Examples: 334 >>> model = HAVAC( 335 agent_obs_shape=obs_dim, 336 global_obs_shape=global_obs_dim, 337 action_shape=action_dim, 338 use_lstm = True, 339 ) 340 >>> inputs = { 341 'obs': { 342 'agent_state': torch.randn(T, bs, obs_dim), 343 'global_state': torch.randn(T, bs, global_obs_dim), 344 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) 345 }, 346 'actor_prev_state': [None for _ in range(bs)], 347 } 348 >>> actor_outputs = model(inputs,'compute_actor') 349 >>> assert actor_outputs['logit'].shape == (T, bs, action_dim) 350 """ 351 x = inputs['obs']['agent_state'] 352 output = {} 353 if self.use_lstm: 354 rnn_actor_prev_state = inputs['actor_prev_state'] 355 if inference: 356 x = self.actor_encoder(x) 357 rnn_output = self.actor_rnn(x, rnn_actor_prev_state, inference) 358 x = rnn_output['output'] 359 x = self.actor_head(x) 360 output['next_state'] = rnn_output['next_state'] 361 # output: 'logit'/'next_state' 362 else: 363 assert len(x.shape) in [3, 5], x.shape 364 x = parallel_wrapper(self.actor_encoder)(x) # (T, B, N) 365 rnn_output = self.actor_rnn(x, rnn_actor_prev_state, inference) 366 x = rnn_output['output'] 367 x = parallel_wrapper(self.actor_head)(x) 368 output['actor_next_state'] = rnn_output['next_state'] 369 output['actor_hidden_state'] = rnn_output['hidden_state'] 370 # output: 'logit'/'actor_next_state'/'hidden_state' 371 else: 372 x = self.actor_encoder(x) 373 x = self.actor_head(x) 374 # output: 'logit' 375 376 if self.action_space == 'discrete': 377 action_mask = inputs['obs']['action_mask'] 378 logit = x['logit'] 379 logit[action_mask == 0.0] = -99999999 380 elif self.action_space == 'continuous': 381 logit = x 382 output['logit'] = logit 383 return output 384 385 def compute_critic(self, inputs: Dict, inference: bool = False) -> Dict: 386 r""" 387 Overview: 388 Execute parameter updates with ``'compute_critic'`` mode 389 Use encoded embedding tensor to predict output. 390 Arguments: 391 - inputs (:obj:`Dict`): 392 input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), 393 'critic_prev_state'(when you are using rnn)] 394 Returns: 395 - outputs (:obj:`Dict`): 396 Run with encoder [rnn] and head. 397 398 Necessary Keys: 399 - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. 400 - logits 401 Shapes: 402 - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. 403 - logits 404 405 Examples: 406 >>> model = HAVAC( 407 agent_obs_shape=obs_dim, 408 global_obs_shape=global_obs_dim, 409 action_shape=action_dim, 410 use_lstm = True, 411 ) 412 >>> inputs = { 413 'obs': { 414 'agent_state': torch.randn(T, bs, obs_dim), 415 'global_state': torch.randn(T, bs, global_obs_dim), 416 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim)) 417 }, 418 'critic_prev_state': [None for _ in range(bs)], 419 } 420 >>> critic_outputs = model(inputs,'compute_critic') 421 >>> assert critic_outputs['value'].shape == (T, bs)) 422 """ 423 global_obs = inputs['obs']['global_state'] 424 output = {} 425 if self.use_lstm: 426 rnn_critic_prev_state = inputs['critic_prev_state'] 427 if inference: 428 x = self.critic_encoder(global_obs) 429 rnn_output = self.critic_rnn(x, rnn_critic_prev_state, inference) 430 x = rnn_output['output'] 431 x = self.critic_head(x) 432 output['next_state'] = rnn_output['next_state'] 433 # output: 'value'/'next_state' 434 else: 435 assert len(global_obs.shape) in [3, 5], global_obs.shape 436 x = parallel_wrapper(self.critic_encoder)(global_obs) # (T, B, N) 437 rnn_output = self.critic_rnn(x, rnn_critic_prev_state, inference) 438 x = rnn_output['output'] 439 x = parallel_wrapper(self.critic_head)(x) 440 output['critic_next_state'] = rnn_output['next_state'] 441 output['critic_hidden_state'] = rnn_output['hidden_state'] 442 # output: 'value'/'critic_next_state'/'hidden_state' 443 else: 444 x = self.critic_encoder(global_obs) 445 x = self.critic_head(x) 446 # output: 'value' 447 output['value'] = x['pred'] 448 return output 449 450 def compute_actor_critic(self, inputs: Dict, inference: bool = False) -> Dict: 451 r""" 452 Overview: 453 Execute parameter updates with ``'compute_actor_critic'`` mode 454 Use encoded embedding tensor to predict output. 455 Arguments: 456 - inputs (:dict): input data dict with keys 457 ['obs'(with keys ['agent_state', 'global_state', 'action_mask']), 458 'actor_prev_state', 'critic_prev_state'(when you are using rnn)] 459 460 Returns: 461 - outputs (:obj:`Dict`): 462 Run with encoder and head. 463 464 ReturnsKeys: 465 - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. 466 - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. 467 Shapes: 468 - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` 469 - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. 470 471 Examples: 472 >>> model = VAC(64,64) 473 >>> inputs = torch.randn(4, 64) 474 >>> outputs = model(inputs,'compute_actor_critic') 475 >>> outputs['value'] 476 tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=<SqueezeBackward1>) 477 >>> assert outputs['logit'].shape == torch.Size([4, 64]) 478 479 480 .. note:: 481 ``compute_actor_critic`` interface aims to save computation when shares encoder. 482 Returning the combination dictionry. 483 484 """ 485 actor_output = self.compute_actor(inputs, inference) 486 critic_output = self.compute_critic(inputs, inference) 487 if self.use_lstm: 488 return { 489 'logit': actor_output['logit'], 490 'value': critic_output['value'], 491 'actor_next_state': actor_output['actor_next_state'], 492 'actor_hidden_state': actor_output['actor_hidden_state'], 493 'critic_next_state': critic_output['critic_next_state'], 494 'critic_hidden_state': critic_output['critic_hidden_state'], 495 } 496 else: 497 return { 498 'logit': actor_output['logit'], 499 'value': critic_output['value'], 500 }