ding.model.template.mavac¶
ding.model.template.mavac
¶
MAVAC
¶
Bases: Module
Overview
The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC) for multi-agent, such as MAPPO(https://arxiv.org/abs/2103.01955). This model now supports discrete and continuous action space. The MAVAC is composed of four parts: actor_encoder, critic_encoder, actor_head and critic_head. Encoders are used to extract the feature from various observation. Heads are used to predict corresponding value or action logit.
Interfaces:
__init__, forward, compute_actor, compute_critic, compute_actor_critic.
__init__(agent_obs_shape, global_obs_shape, action_shape, agent_num, actor_head_hidden_size=256, actor_head_layer_num=2, critic_head_hidden_size=512, critic_head_layer_num=1, action_space='discrete', activation=nn.ReLU(), norm_type=None, sigma_type='independent', bound_type=None, encoder=None)
¶
Overview
Init the MAVAC Model according to arguments.
Arguments:
- agent_obs_shape (:obj:Union[int, SequenceType]): Observation's space for single agent, such as 8 or [4, 84, 84].
- global_obs_shape (:obj:Union[int, SequenceType]): Global observation's space, such as 8 or [4, 84, 84].
- action_shape (:obj:Union[int, SequenceType]): Action space shape for single agent, such as 6 or [2, 3, 3].
- agent_num (:obj:int): This parameter is temporarily reserved. This parameter may be required for subsequent changes to the model
- actor_head_hidden_size (:obj:Optional[int]): The hidden_size of actor_head network, defaults to 256, it must match the last element of agent_obs_shape.
- actor_head_layer_num (:obj:int): The num of layers used in the actor_head network to compute action.
- critic_head_hidden_size (:obj:Optional[int]): The hidden_size of critic_head network, defaults to 512, it must match the last element of global_obs_shape.
- critic_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for critic's nn.
- action_space (:obj:Union[int, SequenceType]): The type of different action spaces, including ['discrete', 'continuous'], then will instantiate corresponding head, including DiscreteHead and ReparameterizationHead.
- activation (:obj:Optional[nn.Module]): The type of activation function to use in MLP the after layer_fn, if None then default set to nn.ReLU().
- norm_type (:obj:Optional[str]): The type of normalization in networks, see ding.torch_utils.fc_block for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN'].
- sigma_type (:obj:Optional[str]): The type of sigma in continuous action space, see ding.torch_utils.network.dreamer.ReparameterizationHead for more details, in MAPPO, it defaults to independent, which means state-independent sigma parameters.
- bound_type (:obj:Optional[str]): The type of action bound methods in continuous action space, defaults to None, which means no bound.
- encoder (:obj:Optional[Tuple[torch.nn.Module, torch.nn.Module]]): The encoder module list, defaults to None, you can define your own actor and critic encoder module and pass it into MAVAC to deal with different observation space.
forward(inputs, mode)
¶
Overview
MAVAC forward computation graph, input observation tensor to predict state value or action logit. mode includes compute_actor, compute_critic, compute_actor_critic.
Different mode will forward with different network modules to get different outputs and save computation.
Arguments:
- inputs (:obj:Dict): The input dict including observation and related info, whose key-values vary from different mode.
- mode (:obj:str): The forward mode, all the modes are defined in the beginning of this class.
Returns:
- outputs (:obj:Dict): The output dict of MAVAC's forward computation graph, whose key-values vary from different mode.
Examples (Actor): >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) >>> inputs = { 'agent_state': torch.randn(10, 8, 64), 'global_state': torch.randn(10, 8, 128), 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) } >>> actor_outputs = model(inputs,'compute_actor') >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14])
Examples (Critic): >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) >>> inputs = { 'agent_state': torch.randn(10, 8, 64), 'global_state': torch.randn(10, 8, 128), 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) } >>> critic_outputs = model(inputs,'compute_critic') >>> assert actor_outputs['value'].shape == torch.Size([10, 8])
Examples (Actor-Critic): >>> model = MAVAC(64, 64) >>> inputs = { 'agent_state': torch.randn(10, 8, 64), 'global_state': torch.randn(10, 8, 128), 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) } >>> outputs = model(inputs,'compute_actor_critic') >>> assert outputs['value'].shape == torch.Size([10, 8, 14]) >>> assert outputs['logit'].shape == torch.Size([10, 8])
compute_actor(x)
¶
Overview
MAVAC forward computation graph for actor part, predicting action logit with agent observation tensor in x.
Arguments:
- x (:obj:Dict): Input data dict with keys ['agent_state', 'action_mask'(optional)].
- agent_state: (:obj:torch.Tensor): Each agent local state(obs).
- action_mask(optional): (:obj:torch.Tensor): When action_space is discrete, action_mask needs to be provided to mask illegal actions.
Returns:
- outputs (:obj:Dict): The output dict of the forward computation graph for actor, including logit.
ReturnsKeys:
- logit (:obj:torch.Tensor): The predicted action logit tensor, for discrete action space, it will be the same dimension real-value ranged tensor of possible action choices, and for continuous action space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the same as the number of continuous actions.
Shapes:
- logit (:obj:torch.FloatTensor): :math:(B, M, N), where B is batch size and N is action_shape and M is agent_num.
Examples:
>>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
>>> inputs = {
'agent_state': torch.randn(10, 8, 64),
'global_state': torch.randn(10, 8, 128),
'action_mask': torch.randint(0, 2, size=(10, 8, 14))
}
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14])
compute_critic(x)
¶
Overview
MAVAC forward computation graph for critic part. Predict state value with global observation tensor in x.
Arguments:
- x (:obj:Dict): Input data dict with keys ['global_state'].
- global_state: (:obj:torch.Tensor): Global state(obs).
Returns:
- outputs (:obj:Dict): The output dict of MAVAC's forward computation graph for critic, including value.
ReturnsKeys:
- value (:obj:torch.Tensor): The predicted state value tensor.
Shapes:
- value (:obj:torch.FloatTensor): :math:(B, M), where B is batch size and M is agent_num.
Examples:
>>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
>>> inputs = {
'agent_state': torch.randn(10, 8, 64),
'global_state': torch.randn(10, 8, 128),
'action_mask': torch.randint(0, 2, size=(10, 8, 14))
}
>>> critic_outputs = model(inputs,'compute_critic')
>>> assert critic_outputs['value'].shape == torch.Size([10, 8])
compute_actor_critic(x)
¶
Overview
MAVAC forward computation graph for both actor and critic part, input observation to predict action logit and state value.
Arguments:
- x (:obj:Dict): The input dict contains agent_state, global_state and other related info.
Returns:
- outputs (:obj:Dict): The output dict of MAVAC's forward computation graph for both actor and critic, including logit and value.
ReturnsKeys:
- logit (:obj:torch.Tensor): Logit encoding tensor, with same size as input x.
- value (:obj:torch.Tensor): Q value tensor with same size as batch size.
Shapes:
- logit (:obj:torch.FloatTensor): :math:(B, M, N), where B is batch size and N is action_shape and M is agent_num.
- value (:obj:torch.FloatTensor): :math:(B, M), where B is batch sizeand M is agent_num.
Examples:
>>> model = MAVAC(64, 64)
>>> inputs = {
'agent_state': torch.randn(10, 8, 64),
'global_state': torch.randn(10, 8, 128),
'action_mask': torch.randint(0, 2, size=(10, 8, 14))
}
>>> outputs = model(inputs,'compute_actor_critic')
>>> assert outputs['value'].shape == torch.Size([10, 8])
>>> assert outputs['logit'].shape == torch.Size([10, 8, 14])
Full Source Code
../ding/model/template/mavac.py