Skip to content

ding.model.template.mavac

ding.model.template.mavac

MAVAC

Bases: Module

Overview

The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC) for multi-agent, such as MAPPO(https://arxiv.org/abs/2103.01955). This model now supports discrete and continuous action space. The MAVAC is composed of four parts: actor_encoder, critic_encoder, actor_head and critic_head. Encoders are used to extract the feature from various observation. Heads are used to predict corresponding value or action logit.

Interfaces: __init__, forward, compute_actor, compute_critic, compute_actor_critic.

__init__(agent_obs_shape, global_obs_shape, action_shape, agent_num, actor_head_hidden_size=256, actor_head_layer_num=2, critic_head_hidden_size=512, critic_head_layer_num=1, action_space='discrete', activation=nn.ReLU(), norm_type=None, sigma_type='independent', bound_type=None, encoder=None)

Overview

Init the MAVAC Model according to arguments.

Arguments: - agent_obs_shape (:obj:Union[int, SequenceType]): Observation's space for single agent, such as 8 or [4, 84, 84]. - global_obs_shape (:obj:Union[int, SequenceType]): Global observation's space, such as 8 or [4, 84, 84]. - action_shape (:obj:Union[int, SequenceType]): Action space shape for single agent, such as 6 or [2, 3, 3]. - agent_num (:obj:int): This parameter is temporarily reserved. This parameter may be required for subsequent changes to the model - actor_head_hidden_size (:obj:Optional[int]): The hidden_size of actor_head network, defaults to 256, it must match the last element of agent_obs_shape. - actor_head_layer_num (:obj:int): The num of layers used in the actor_head network to compute action. - critic_head_hidden_size (:obj:Optional[int]): The hidden_size of critic_head network, defaults to 512, it must match the last element of global_obs_shape. - critic_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for critic's nn. - action_space (:obj:Union[int, SequenceType]): The type of different action spaces, including ['discrete', 'continuous'], then will instantiate corresponding head, including DiscreteHead and ReparameterizationHead. - activation (:obj:Optional[nn.Module]): The type of activation function to use in MLP the after layer_fn, if None then default set to nn.ReLU(). - norm_type (:obj:Optional[str]): The type of normalization in networks, see ding.torch_utils.fc_block for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']. - sigma_type (:obj:Optional[str]): The type of sigma in continuous action space, see ding.torch_utils.network.dreamer.ReparameterizationHead for more details, in MAPPO, it defaults to independent, which means state-independent sigma parameters. - bound_type (:obj:Optional[str]): The type of action bound methods in continuous action space, defaults to None, which means no bound. - encoder (:obj:Optional[Tuple[torch.nn.Module, torch.nn.Module]]): The encoder module list, defaults to None, you can define your own actor and critic encoder module and pass it into MAVAC to deal with different observation space.

forward(inputs, mode)

Overview

MAVAC forward computation graph, input observation tensor to predict state value or action logit. mode includes compute_actor, compute_critic, compute_actor_critic. Different mode will forward with different network modules to get different outputs and save computation.

Arguments: - inputs (:obj:Dict): The input dict including observation and related info, whose key-values vary from different mode. - mode (:obj:str): The forward mode, all the modes are defined in the beginning of this class. Returns: - outputs (:obj:Dict): The output dict of MAVAC's forward computation graph, whose key-values vary from different mode.

Examples (Actor): >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) >>> inputs = { 'agent_state': torch.randn(10, 8, 64), 'global_state': torch.randn(10, 8, 128), 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) } >>> actor_outputs = model(inputs,'compute_actor') >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14])

Examples (Critic): >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) >>> inputs = { 'agent_state': torch.randn(10, 8, 64), 'global_state': torch.randn(10, 8, 128), 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) } >>> critic_outputs = model(inputs,'compute_critic') >>> assert actor_outputs['value'].shape == torch.Size([10, 8])

Examples (Actor-Critic): >>> model = MAVAC(64, 64) >>> inputs = { 'agent_state': torch.randn(10, 8, 64), 'global_state': torch.randn(10, 8, 128), 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) } >>> outputs = model(inputs,'compute_actor_critic') >>> assert outputs['value'].shape == torch.Size([10, 8, 14]) >>> assert outputs['logit'].shape == torch.Size([10, 8])

compute_actor(x)

Overview

MAVAC forward computation graph for actor part, predicting action logit with agent observation tensor in x.

Arguments: - x (:obj:Dict): Input data dict with keys ['agent_state', 'action_mask'(optional)]. - agent_state: (:obj:torch.Tensor): Each agent local state(obs). - action_mask(optional): (:obj:torch.Tensor): When action_space is discrete, action_mask needs to be provided to mask illegal actions. Returns: - outputs (:obj:Dict): The output dict of the forward computation graph for actor, including logit. ReturnsKeys: - logit (:obj:torch.Tensor): The predicted action logit tensor, for discrete action space, it will be the same dimension real-value ranged tensor of possible action choices, and for continuous action space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the same as the number of continuous actions. Shapes: - logit (:obj:torch.FloatTensor): :math:(B, M, N), where B is batch size and N is action_shape and M is agent_num.

Examples:

>>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
>>> inputs = {
        'agent_state': torch.randn(10, 8, 64),
        'global_state': torch.randn(10, 8, 128),
        'action_mask': torch.randint(0, 2, size=(10, 8, 14))
    }
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14])

compute_critic(x)

Overview

MAVAC forward computation graph for critic part. Predict state value with global observation tensor in x.

Arguments: - x (:obj:Dict): Input data dict with keys ['global_state']. - global_state: (:obj:torch.Tensor): Global state(obs). Returns: - outputs (:obj:Dict): The output dict of MAVAC's forward computation graph for critic, including value. ReturnsKeys: - value (:obj:torch.Tensor): The predicted state value tensor. Shapes: - value (:obj:torch.FloatTensor): :math:(B, M), where B is batch size and M is agent_num.

Examples:

>>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
>>> inputs = {
        'agent_state': torch.randn(10, 8, 64),
        'global_state': torch.randn(10, 8, 128),
        'action_mask': torch.randint(0, 2, size=(10, 8, 14))
    }
>>> critic_outputs = model(inputs,'compute_critic')
>>> assert critic_outputs['value'].shape == torch.Size([10, 8])

compute_actor_critic(x)

Overview

MAVAC forward computation graph for both actor and critic part, input observation to predict action logit and state value.

Arguments: - x (:obj:Dict): The input dict contains agent_state, global_state and other related info. Returns: - outputs (:obj:Dict): The output dict of MAVAC's forward computation graph for both actor and critic, including logit and value. ReturnsKeys: - logit (:obj:torch.Tensor): Logit encoding tensor, with same size as input x. - value (:obj:torch.Tensor): Q value tensor with same size as batch size. Shapes: - logit (:obj:torch.FloatTensor): :math:(B, M, N), where B is batch size and N is action_shape and M is agent_num. - value (:obj:torch.FloatTensor): :math:(B, M), where B is batch sizeand M is agent_num.

Examples:

>>> model = MAVAC(64, 64)
>>> inputs = {
        'agent_state': torch.randn(10, 8, 64),
        'global_state': torch.randn(10, 8, 128),
        'action_mask': torch.randint(0, 2, size=(10, 8, 14))
    }
>>> outputs = model(inputs,'compute_actor_critic')
>>> assert outputs['value'].shape == torch.Size([10, 8])
>>> assert outputs['logit'].shape == torch.Size([10, 8, 14])

Full Source Code

../ding/model/template/mavac.py

1from typing import Union, Dict, Tuple, Optional 2import torch 3import torch.nn as nn 4 5from ding.utils import SequenceType, squeeze, MODEL_REGISTRY 6from ..common import ReparameterizationHead, RegressionHead, DiscreteHead 7 8 9@MODEL_REGISTRY.register('mavac') 10class MAVAC(nn.Module): 11 """ 12 Overview: 13 The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC) for \ 14 multi-agent, such as MAPPO(https://arxiv.org/abs/2103.01955). This model now supports discrete and \ 15 continuous action space. The MAVAC is composed of four parts: ``actor_encoder``, ``critic_encoder``, \ 16 ``actor_head`` and ``critic_head``. Encoders are used to extract the feature from various observation. \ 17 Heads are used to predict corresponding value or action logit. 18 Interfaces: 19 ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. 20 """ 21 mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] 22 23 def __init__( 24 self, 25 agent_obs_shape: Union[int, SequenceType], 26 global_obs_shape: Union[int, SequenceType], 27 action_shape: Union[int, SequenceType], 28 agent_num: int, 29 actor_head_hidden_size: int = 256, 30 actor_head_layer_num: int = 2, 31 critic_head_hidden_size: int = 512, 32 critic_head_layer_num: int = 1, 33 action_space: str = 'discrete', 34 activation: Optional[nn.Module] = nn.ReLU(), 35 norm_type: Optional[str] = None, 36 sigma_type: Optional[str] = 'independent', 37 bound_type: Optional[str] = None, 38 encoder: Optional[Tuple[torch.nn.Module, torch.nn.Module]] = None, 39 ) -> None: 40 """ 41 Overview: 42 Init the MAVAC Model according to arguments. 43 Arguments: 44 - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent, \ 45 such as 8 or [4, 84, 84]. 46 - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space, such as 8 or [4, 84, 84]. 47 - action_shape (:obj:`Union[int, SequenceType]`): Action space shape for single agent, such as 6 \ 48 or [2, 3, 3]. 49 - agent_num (:obj:`int`): This parameter is temporarily reserved. This parameter may be required for \ 50 subsequent changes to the model 51 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \ 52 to 256, it must match the last element of ``agent_obs_shape``. 53 - actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action. 54 - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \ 55 to 512, it must match the last element of ``global_obs_shape``. 56 - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output for \ 57 critic's nn. 58 - action_space (:obj:`Union[int, SequenceType]`): The type of different action spaces, including \ 59 ['discrete', 'continuous'], then will instantiate corresponding head, including ``DiscreteHead`` \ 60 and ``ReparameterizationHead``. 61 - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \ 62 ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``. 63 - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ 64 ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']. 65 - sigma_type (:obj:`Optional[str]`): The type of sigma in continuous action space, see \ 66 ``ding.torch_utils.network.dreamer.ReparameterizationHead`` for more details, in MAPPO, it defaults \ 67 to ``independent``, which means state-independent sigma parameters. 68 - bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \ 69 to ``None``, which means no bound. 70 - encoder (:obj:`Optional[Tuple[torch.nn.Module, torch.nn.Module]]`): The encoder module list, defaults \ 71 to ``None``, you can define your own actor and critic encoder module and pass it into MAVAC to \ 72 deal with different observation space. 73 """ 74 super(MAVAC, self).__init__() 75 agent_obs_shape: int = squeeze(agent_obs_shape) 76 global_obs_shape: int = squeeze(global_obs_shape) 77 action_shape: int = squeeze(action_shape) 78 self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape 79 self.action_space = action_space 80 # Encoder Type 81 if encoder: 82 self.actor_encoder, self.critic_encoder = encoder 83 else: 84 # We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder. 85 # In SMAC task it can obviously improve the performance. 86 # Users can change the model according to their own needs. 87 self.actor_encoder = nn.Sequential( 88 nn.Linear(agent_obs_shape, actor_head_hidden_size), 89 activation, 90 ) 91 self.critic_encoder = nn.Sequential( 92 nn.Linear(global_obs_shape, critic_head_hidden_size), 93 activation, 94 ) 95 # Head Type 96 self.critic_head = RegressionHead( 97 critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type 98 ) 99 assert self.action_space in ['discrete', 'continuous'], self.action_space 100 if self.action_space == 'discrete': 101 self.actor_head = DiscreteHead( 102 actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type 103 ) 104 elif self.action_space == 'continuous': 105 self.actor_head = ReparameterizationHead( 106 actor_head_hidden_size, 107 action_shape, 108 actor_head_layer_num, 109 sigma_type=sigma_type, 110 activation=activation, 111 norm_type=norm_type, 112 bound_type=bound_type 113 ) 114 # must use list, not nn.ModuleList 115 self.actor = [self.actor_encoder, self.actor_head] 116 self.critic = [self.critic_encoder, self.critic_head] 117 # for convenience of call some apis(such as: self.critic.parameters()), but may cause 118 # misunderstanding when print(self) 119 self.actor = nn.ModuleList(self.actor) 120 self.critic = nn.ModuleList(self.critic) 121 122 def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: 123 """ 124 Overview: 125 MAVAC forward computation graph, input observation tensor to predict state value or action logit. \ 126 ``mode`` includes ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. 127 Different ``mode`` will forward with different network modules to get different outputs and save \ 128 computation. 129 Arguments: 130 - inputs (:obj:`Dict`): The input dict including observation and related info, \ 131 whose key-values vary from different ``mode``. 132 - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. 133 Returns: 134 - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph, whose key-values vary from \ 135 different ``mode``. 136 137 Examples (Actor): 138 >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) 139 >>> inputs = { 140 'agent_state': torch.randn(10, 8, 64), 141 'global_state': torch.randn(10, 8, 128), 142 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) 143 } 144 >>> actor_outputs = model(inputs,'compute_actor') 145 >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14]) 146 147 Examples (Critic): 148 >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) 149 >>> inputs = { 150 'agent_state': torch.randn(10, 8, 64), 151 'global_state': torch.randn(10, 8, 128), 152 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) 153 } 154 >>> critic_outputs = model(inputs,'compute_critic') 155 >>> assert actor_outputs['value'].shape == torch.Size([10, 8]) 156 157 Examples (Actor-Critic): 158 >>> model = MAVAC(64, 64) 159 >>> inputs = { 160 'agent_state': torch.randn(10, 8, 64), 161 'global_state': torch.randn(10, 8, 128), 162 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) 163 } 164 >>> outputs = model(inputs,'compute_actor_critic') 165 >>> assert outputs['value'].shape == torch.Size([10, 8, 14]) 166 >>> assert outputs['logit'].shape == torch.Size([10, 8]) 167 168 """ 169 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 170 return getattr(self, mode)(inputs) 171 172 def compute_actor(self, x: Dict) -> Dict: 173 """ 174 Overview: 175 MAVAC forward computation graph for actor part, \ 176 predicting action logit with agent observation tensor in ``x``. 177 Arguments: 178 - x (:obj:`Dict`): Input data dict with keys ['agent_state', 'action_mask'(optional)]. 179 - agent_state: (:obj:`torch.Tensor`): Each agent local state(obs). 180 - action_mask(optional): (:obj:`torch.Tensor`): When ``action_space`` is discrete, action_mask needs \ 181 to be provided to mask illegal actions. 182 Returns: 183 - outputs (:obj:`Dict`): The output dict of the forward computation graph for actor, including ``logit``. 184 ReturnsKeys: 185 - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ 186 the same dimension real-value ranged tensor of possible action choices, and for continuous action \ 187 space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ 188 same as the number of continuous actions. 189 Shapes: 190 - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \ 191 and M is ``agent_num``. 192 193 Examples: 194 >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) 195 >>> inputs = { 196 'agent_state': torch.randn(10, 8, 64), 197 'global_state': torch.randn(10, 8, 128), 198 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) 199 } 200 >>> actor_outputs = model(inputs,'compute_actor') 201 >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14]) 202 203 """ 204 if self.action_space == 'discrete': 205 action_mask = x['action_mask'] 206 x = x['agent_state'] 207 x = self.actor_encoder(x) 208 x = self.actor_head(x) 209 logit = x['logit'] 210 logit[action_mask == 0.0] = -99999999 211 elif self.action_space == 'continuous': 212 x = x['agent_state'] 213 x = self.actor_encoder(x) 214 x = self.actor_head(x) 215 logit = x 216 return {'logit': logit} 217 218 def compute_critic(self, x: Dict) -> Dict: 219 """ 220 Overview: 221 MAVAC forward computation graph for critic part. \ 222 Predict state value with global observation tensor in ``x``. 223 Arguments: 224 - x (:obj:`Dict`): Input data dict with keys ['global_state']. 225 - global_state: (:obj:`torch.Tensor`): Global state(obs). 226 Returns: 227 - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for critic, \ 228 including ``value``. 229 ReturnsKeys: 230 - value (:obj:`torch.Tensor`): The predicted state value tensor. 231 Shapes: 232 - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``agent_num``. 233 234 Examples: 235 >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) 236 >>> inputs = { 237 'agent_state': torch.randn(10, 8, 64), 238 'global_state': torch.randn(10, 8, 128), 239 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) 240 } 241 >>> critic_outputs = model(inputs,'compute_critic') 242 >>> assert critic_outputs['value'].shape == torch.Size([10, 8]) 243 """ 244 245 x = self.critic_encoder(x['global_state']) 246 x = self.critic_head(x) 247 return {'value': x['pred']} 248 249 def compute_actor_critic(self, x: Dict) -> Dict: 250 """ 251 Overview: 252 MAVAC forward computation graph for both actor and critic part, input observation to predict action \ 253 logit and state value. 254 Arguments: 255 - x (:obj:`Dict`): The input dict contains ``agent_state``, ``global_state`` and other related info. 256 Returns: 257 - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for both actor and critic, \ 258 including ``logit`` and ``value``. 259 ReturnsKeys: 260 - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. 261 - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. 262 Shapes: 263 - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \ 264 and M is ``agent_num``. 265 - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch sizeand M is ``agent_num``. 266 267 Examples: 268 >>> model = MAVAC(64, 64) 269 >>> inputs = { 270 'agent_state': torch.randn(10, 8, 64), 271 'global_state': torch.randn(10, 8, 128), 272 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) 273 } 274 >>> outputs = model(inputs,'compute_actor_critic') 275 >>> assert outputs['value'].shape == torch.Size([10, 8]) 276 >>> assert outputs['logit'].shape == torch.Size([10, 8, 14]) 277 """ 278 x_actor = self.actor_encoder(x['agent_state']) 279 x_critic = self.critic_encoder(x['global_state']) 280 281 if self.action_space == 'discrete': 282 action_mask = x['action_mask'] 283 x = self.actor_head(x_actor) 284 logit = x['logit'] 285 logit[action_mask == 0.0] = -99999999 286 elif self.action_space == 'continuous': 287 x = self.actor_head(x_actor) 288 logit = x 289 value = self.critic_head(x_critic)['pred'] 290 return {'logit': logit, 'value': value}