Skip to content

ding.model.template.maqac

ding.model.template.maqac

DiscreteMAQAC

Bases: Module

Overview

The neural network and computation graph of algorithms related to discrete action Multi-Agent Q-value Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and critic is a MLP network. The actor network is used to predict the action probability distribution, and the critic network is used to predict the Q value of the state-action pair.

Interfaces: __init__, forward, compute_actor, compute_critic

__init__(agent_obs_shape, global_obs_shape, action_shape, twin_critic=False, actor_head_hidden_size=64, actor_head_layer_num=1, critic_head_hidden_size=64, critic_head_layer_num=1, activation=nn.ReLU(), norm_type=None)

Overview

Initialize the DiscreteMAQAC Model according to arguments.

Arguments: - agent_obs_shape (:obj:Union[int, SequenceType]): Agent's observation's space. - global_obs_shape (:obj:Union[int, SequenceType]): Global observation's space. - obs_shape (:obj:Union[int, SequenceType]): Observation's space. - action_shape (:obj:Union[int, SequenceType]): Action's space. - twin_critic (:obj:bool): Whether include twin critic. - actor_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to actor-nn's Head. - actor_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for actor's nn. - critic_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to critic-nn's Head. - critic_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for critic's nn. - activation (:obj:Optional[nn.Module]): The type of activation function to use in MLP the after layer_fn, if None then default set to nn.ReLU() - norm_type (:obj:Optional[str]): The type of normalization to use, see ding.torch_utils.fc_block for more details.

forward(inputs, mode)

Overview

Use observation tensor to predict output, with compute_actor or compute_critic mode.

Arguments: - inputs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - obs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - agent_state (:obj:torch.Tensor): The agent's observation tensor data, with shape :math:(B, A, N0), where B is batch size and A is agent num. N0 corresponds to agent_obs_shape. - global_state (:obj:torch.Tensor): The global observation tensor data, with shape :math:(B, A, N1), where B is batch size and A is agent num. N1 corresponds to global_obs_shape. - action_mask (:obj:torch.Tensor): The action mask tensor data, with shape :math:(B, A, N2), where B is batch size and A is agent num. N2 corresponds to action_shape.

- mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.

Returns: - output (:obj:Dict[str, torch.Tensor]): The output dict of DiscreteMAQAC forward computation graph, whose key-values vary in different forward modes. Examples: >>> B = 32 >>> agent_obs_shape = 216 >>> global_obs_shape = 264 >>> agent_num = 8 >>> action_shape = 14 >>> data = { >>> 'obs': { >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) >>> } >>> } >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) >>> logit = model(data, mode='compute_actor')['logit'] >>> value = model(data, mode='compute_critic')['q_value']

compute_actor(inputs)

Overview

Use observation tensor to predict action logits.

Arguments: - inputs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - obs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - agent_state (:obj:torch.Tensor): The agent's observation tensor data, with shape :math:(B, A, N0), where B is batch size and A is agent num. N0 corresponds to agent_obs_shape. - global_state (:obj:torch.Tensor): The global observation tensor data, with shape :math:(B, A, N1), where B is batch size and A is agent num. N1 corresponds to global_obs_shape. - action_mask (:obj:torch.Tensor): The action mask tensor data, with shape :math:(B, A, N2), where B is batch size and A is agent num. N2 corresponds to action_shape. Returns: - output (:obj:Dict[str, torch.Tensor]): The output dict of DiscreteMAQAC forward computation graph, whose key-values vary in different forward modes. - logit (:obj:torch.Tensor): Action's output logit (real value range), whose shape is :math:(B, A, N2), where N2 corresponds to action_shape. - action_mask (:obj:torch.Tensor): Action mask tensor with same size as action_shape. Examples: >>> B = 32 >>> agent_obs_shape = 216 >>> global_obs_shape = 264 >>> agent_num = 8 >>> action_shape = 14 >>> data = { >>> 'obs': { >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) >>> } >>> } >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) >>> logit = model.compute_actor(data)['logit']

compute_critic(inputs)

Overview

use observation tensor to predict Q value.

Arguments: - inputs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - obs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - agent_state (:obj:torch.Tensor): The agent's observation tensor data, with shape :math:(B, A, N0), where B is batch size and A is agent num. N0 corresponds to agent_obs_shape. - global_state (:obj:torch.Tensor): The global observation tensor data, with shape :math:(B, A, N1), where B is batch size and A is agent num. N1 corresponds to global_obs_shape. - action_mask (:obj:torch.Tensor): The action mask tensor data, with shape :math:(B, A, N2), where B is batch size and A is agent num. N2 corresponds to action_shape. Returns: - output (:obj:Dict[str, torch.Tensor]): The output dict of DiscreteMAQAC forward computation graph, whose key-values vary in different values of twin_critic. - q_value (:obj:list): If twin_critic=True, q_value should be 2 elements, each is the shape of :math:(B, A, N2), where B is batch size and A is agent num. N2 corresponds to action_shape. Otherwise, q_value should be torch.Tensor. Examples: >>> B = 32 >>> agent_obs_shape = 216 >>> global_obs_shape = 264 >>> agent_num = 8 >>> action_shape = 14 >>> data = { >>> 'obs': { >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) >>> } >>> } >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) >>> value = model.compute_critic(data)['q_value']

ContinuousMAQAC

Bases: Module

Overview

The neural network and computation graph of algorithms related to continuous action Multi-Agent Q-value Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and critic is a MLP network. The actor network is used to predict the action probability distribution, and the critic network is used to predict the Q value of the state-action pair.

Interfaces: __init__, forward, compute_actor, compute_critic

__init__(agent_obs_shape, global_obs_shape, action_shape, action_space, twin_critic=False, actor_head_hidden_size=64, actor_head_layer_num=1, critic_head_hidden_size=64, critic_head_layer_num=1, activation=nn.ReLU(), norm_type=None)

Overview

Initialize the QAC Model according to arguments.

Arguments: - obs_shape (:obj:Union[int, SequenceType]): Observation's space. - action_shape (:obj:Union[int, SequenceType, EasyDict]): Action's space, such as 4, (3, ) - action_space (:obj:str): Whether choose regression or reparameterization. - twin_critic (:obj:bool): Whether include twin critic. - actor_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to actor-nn's Head. - actor_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for actor's nn. - critic_head_hidden_size (:obj:Optional[int]): The hidden_size to pass to critic-nn's Head. - critic_head_layer_num (:obj:int): The num of layers used in the network to compute Q value output for critic's nn. - activation (:obj:Optional[nn.Module]): The type of activation function to use in MLP the after layer_fn, if None then default set to nn.ReLU() - norm_type (:obj:Optional[str]): The type of normalization to use, see ding.torch_utils.fc_block for more details.

forward(inputs, mode)

Overview

Use observation and action tensor to predict output in compute_actor or compute_critic mode.

Arguments: - inputs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - obs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - agent_state (:obj:torch.Tensor): The agent's observation tensor data, with shape :math:(B, A, N0), where B is batch size and A is agent num. N0 corresponds to agent_obs_shape. - global_state (:obj:torch.Tensor): The global observation tensor data, with shape :math:(B, A, N1), where B is batch size and A is agent num. N1 corresponds to global_obs_shape. - action_mask (:obj:torch.Tensor): The action mask tensor data, with shape :math:(B, A, N2), where B is batch size and A is agent num. N2 corresponds to action_shape.

    - ``action`` (:obj:`torch.Tensor`): The action tensor data,                     with shape :math:`(B, A, N3)`, where B is batch size and A is agent num.                     N3 corresponds to ``action_shape``.
- mode (:obj:`str`): Name of the forward mode.

Returns: - outputs (:obj:Dict): Outputs of network forward, whose key-values will be different for different mode, twin_critic, action_space. Examples: >>> B = 32 >>> agent_obs_shape = 216 >>> global_obs_shape = 264 >>> agent_num = 8 >>> action_shape = 14 >>> act_space = 'reparameterization' # regression >>> data = { >>> 'obs': { >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) >>> }, >>> 'action': torch.randn(B, agent_num, squeeze(action_shape)) >>> } >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False) >>> if action_space == 'regression': >>> action = model(data['obs'], mode='compute_actor')['action'] >>> elif action_space == 'reparameterization': >>> (mu, sigma) = model(data['obs'], mode='compute_actor')['logit'] >>> value = model(data, mode='compute_critic')['q_value']

compute_actor(inputs)

Overview

Use observation tensor to predict action logits.

Arguments: - inputs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - agent_state (:obj:torch.Tensor): The agent's observation tensor data, with shape :math:(B, A, N0), where B is batch size and A is agent num. N0 corresponds to agent_obs_shape.

Returns:

Type Description
Dict
  • outputs (:obj:Dict): Outputs of network forward.

ReturnKeys (action_space == 'regression'): - action (:obj:torch.Tensor): Action tensor with same size as action_shape. ReturnKeys (action_space == 'reparameterization'): - logit (:obj:list): 2 elements, each is the shape of :math:(B, A, N3), where B is batch size and A is agent num. N3 corresponds to action_shape. Examples: >>> B = 32 >>> agent_obs_shape = 216 >>> global_obs_shape = 264 >>> agent_num = 8 >>> action_shape = 14 >>> act_space = 'reparameterization' # 'regression' >>> data = { >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), >>> } >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False) >>> if action_space == 'regression': >>> action = model.compute_actor(data)['action'] >>> elif action_space == 'reparameterization': >>> (mu, sigma) = model.compute_actor(data)['logit']

compute_critic(inputs)

Overview

Use observation tensor and action tensor to predict Q value.

Arguments: - inputs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - obs (:obj:Dict[str, torch.Tensor]): The input dict tensor data, has keys: - agent_state (:obj:torch.Tensor): The agent's observation tensor data, with shape :math:(B, A, N0), where B is batch size and A is agent num. N0 corresponds to agent_obs_shape. - global_state (:obj:torch.Tensor): The global observation tensor data, with shape :math:(B, A, N1), where B is batch size and A is agent num. N1 corresponds to global_obs_shape. - action_mask (:obj:torch.Tensor): The action mask tensor data, with shape :math:(B, A, N2), where B is batch size and A is agent num. N2 corresponds to action_shape.

    - ``action`` (:obj:`torch.Tensor`): The action tensor data,                     with shape :math:`(B, A, N3)`, where B is batch size and A is agent num.                     N3 corresponds to ``action_shape``.

Returns:

Type Description
Dict
  • outputs (:obj:Dict): Outputs of network forward.

ReturnKeys (twin_critic=True): - q_value (:obj:list): 2 elements, each is the shape of :math:(B, A), where B is batch size and A is agent num. ReturnKeys (twin_critic=False): - q_value (:obj:torch.Tensor): :math:(B, A), where B is batch size and A is agent num. Examples: >>> B = 32 >>> agent_obs_shape = 216 >>> global_obs_shape = 264 >>> agent_num = 8 >>> action_shape = 14 >>> act_space = 'reparameterization' # 'regression' >>> data = { >>> 'obs': { >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) >>> }, >>> 'action': torch.randn(B, agent_num, squeeze(action_shape)) >>> } >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False) >>> value = model.compute_critic(data)['q_value']

Full Source Code

../ding/model/template/maqac.py

1from typing import Union, Dict, Optional 2from easydict import EasyDict 3import torch 4import torch.nn as nn 5 6from ding.utils import SequenceType, squeeze, MODEL_REGISTRY 7from ..common import RegressionHead, ReparameterizationHead, DiscreteHead, MultiHead, \ 8 FCEncoder, ConvEncoder 9 10 11@MODEL_REGISTRY.register('discrete_maqac') 12class DiscreteMAQAC(nn.Module): 13 """ 14 Overview: 15 The neural network and computation graph of algorithms related to discrete action Multi-Agent Q-value \ 16 Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \ 17 critic is a MLP network. The actor network is used to predict the action probability distribution, and the \ 18 critic network is used to predict the Q value of the state-action pair. 19 Interfaces: 20 ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` 21 """ 22 mode = ['compute_actor', 'compute_critic'] 23 24 def __init__( 25 self, 26 agent_obs_shape: Union[int, SequenceType], 27 global_obs_shape: Union[int, SequenceType], 28 action_shape: Union[int, SequenceType], 29 twin_critic: bool = False, 30 actor_head_hidden_size: int = 64, 31 actor_head_layer_num: int = 1, 32 critic_head_hidden_size: int = 64, 33 critic_head_layer_num: int = 1, 34 activation: Optional[nn.Module] = nn.ReLU(), 35 norm_type: Optional[str] = None, 36 ) -> None: 37 """ 38 Overview: 39 Initialize the DiscreteMAQAC Model according to arguments. 40 Arguments: 41 - agent_obs_shape (:obj:`Union[int, SequenceType]`): Agent's observation's space. 42 - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space. 43 - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. 44 - action_shape (:obj:`Union[int, SequenceType]`): Action's space. 45 - twin_critic (:obj:`bool`): Whether include twin critic. 46 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. 47 - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ 48 for actor's nn. 49 - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. 50 - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ 51 for critic's nn. 52 - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \ 53 ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` 54 - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \ 55 for more details. 56 """ 57 super(DiscreteMAQAC, self).__init__() 58 agent_obs_shape: int = squeeze(agent_obs_shape) 59 action_shape: int = squeeze(action_shape) 60 self.actor = nn.Sequential( 61 nn.Linear(agent_obs_shape, actor_head_hidden_size), activation, 62 DiscreteHead( 63 actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type 64 ) 65 ) 66 67 self.twin_critic = twin_critic 68 if self.twin_critic: 69 self.critic = nn.ModuleList() 70 for _ in range(2): 71 self.critic.append( 72 nn.Sequential( 73 nn.Linear(global_obs_shape, critic_head_hidden_size), activation, 74 DiscreteHead( 75 critic_head_hidden_size, 76 action_shape, 77 critic_head_layer_num, 78 activation=activation, 79 norm_type=norm_type 80 ) 81 ) 82 ) 83 else: 84 self.critic = nn.Sequential( 85 nn.Linear(global_obs_shape, critic_head_hidden_size), activation, 86 DiscreteHead( 87 critic_head_hidden_size, 88 action_shape, 89 critic_head_layer_num, 90 activation=activation, 91 norm_type=norm_type 92 ) 93 ) 94 95 def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: 96 """ 97 Overview: 98 Use observation tensor to predict output, with ``compute_actor`` or ``compute_critic`` mode. 99 Arguments: 100 - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 101 - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 102 - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ 103 with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ 104 N0 corresponds to ``agent_obs_shape``. 105 - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ 106 with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ 107 N1 corresponds to ``global_obs_shape``. 108 - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ 109 with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ 110 N2 corresponds to ``action_shape``. 111 112 - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. 113 Returns: 114 - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \ 115 whose key-values vary in different forward modes. 116 Examples: 117 >>> B = 32 118 >>> agent_obs_shape = 216 119 >>> global_obs_shape = 264 120 >>> agent_num = 8 121 >>> action_shape = 14 122 >>> data = { 123 >>> 'obs': { 124 >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), 125 >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), 126 >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) 127 >>> } 128 >>> } 129 >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) 130 >>> logit = model(data, mode='compute_actor')['logit'] 131 >>> value = model(data, mode='compute_critic')['q_value'] 132 """ 133 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 134 return getattr(self, mode)(inputs) 135 136 def compute_actor(self, inputs: Dict) -> Dict: 137 """ 138 Overview: 139 Use observation tensor to predict action logits. 140 Arguments: 141 - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 142 - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 143 - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ 144 with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ 145 N0 corresponds to ``agent_obs_shape``. 146 - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ 147 with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ 148 N1 corresponds to ``global_obs_shape``. 149 - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ 150 with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ 151 N2 corresponds to ``action_shape``. 152 Returns: 153 - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \ 154 whose key-values vary in different forward modes. 155 - logit (:obj:`torch.Tensor`): Action's output logit (real value range), whose shape is \ 156 :math:`(B, A, N2)`, where N2 corresponds to ``action_shape``. 157 - action_mask (:obj:`torch.Tensor`): Action mask tensor with same size as ``action_shape``. 158 Examples: 159 >>> B = 32 160 >>> agent_obs_shape = 216 161 >>> global_obs_shape = 264 162 >>> agent_num = 8 163 >>> action_shape = 14 164 >>> data = { 165 >>> 'obs': { 166 >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), 167 >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), 168 >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) 169 >>> } 170 >>> } 171 >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) 172 >>> logit = model.compute_actor(data)['logit'] 173 """ 174 action_mask = inputs['obs']['action_mask'] 175 x = self.actor(inputs['obs']['agent_state']) 176 return {'logit': x['logit'], 'action_mask': action_mask} 177 178 def compute_critic(self, inputs: Dict) -> Dict: 179 """ 180 Overview: 181 use observation tensor to predict Q value. 182 Arguments: 183 - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 184 - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 185 - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ 186 with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ 187 N0 corresponds to ``agent_obs_shape``. 188 - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ 189 with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ 190 N1 corresponds to ``global_obs_shape``. 191 - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ 192 with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ 193 N2 corresponds to ``action_shape``. 194 Returns: 195 - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \ 196 whose key-values vary in different values of ``twin_critic``. 197 - q_value (:obj:`list`): If ``twin_critic=True``, q_value should be 2 elements, each is the shape of \ 198 :math:`(B, A, N2)`, where B is batch size and A is agent num. N2 corresponds to ``action_shape``. \ 199 Otherwise, q_value should be ``torch.Tensor``. 200 Examples: 201 >>> B = 32 202 >>> agent_obs_shape = 216 203 >>> global_obs_shape = 264 204 >>> agent_num = 8 205 >>> action_shape = 14 206 >>> data = { 207 >>> 'obs': { 208 >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), 209 >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), 210 >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) 211 >>> } 212 >>> } 213 >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) 214 >>> value = model.compute_critic(data)['q_value'] 215 """ 216 217 if self.twin_critic: 218 x = [m(inputs['obs']['global_state'])['logit'] for m in self.critic] 219 else: 220 x = self.critic(inputs['obs']['global_state'])['logit'] 221 return {'q_value': x} 222 223 224@MODEL_REGISTRY.register('continuous_maqac') 225class ContinuousMAQAC(nn.Module): 226 """ 227 Overview: 228 The neural network and computation graph of algorithms related to continuous action Multi-Agent Q-value \ 229 Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \ 230 critic is a MLP network. The actor network is used to predict the action probability distribution, and the \ 231 critic network is used to predict the Q value of the state-action pair. 232 Interfaces: 233 ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` 234 """ 235 mode = ['compute_actor', 'compute_critic'] 236 237 def __init__( 238 self, 239 agent_obs_shape: Union[int, SequenceType], 240 global_obs_shape: Union[int, SequenceType], 241 action_shape: Union[int, SequenceType, EasyDict], 242 action_space: str, 243 twin_critic: bool = False, 244 actor_head_hidden_size: int = 64, 245 actor_head_layer_num: int = 1, 246 critic_head_hidden_size: int = 64, 247 critic_head_layer_num: int = 1, 248 activation: Optional[nn.Module] = nn.ReLU(), 249 norm_type: Optional[str] = None, 250 ) -> None: 251 """ 252 Overview: 253 Initialize the QAC Model according to arguments. 254 Arguments: 255 - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. 256 - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, ) 257 - action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization``. 258 - twin_critic (:obj:`bool`): Whether include twin critic. 259 - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. 260 - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ 261 for actor's nn. 262 - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. 263 - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ 264 for critic's nn. 265 - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \ 266 ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` 267 - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \ 268 for more details. 269 """ 270 super(ContinuousMAQAC, self).__init__() 271 obs_shape: int = squeeze(agent_obs_shape) 272 global_obs_shape: int = squeeze(global_obs_shape) 273 action_shape = squeeze(action_shape) 274 self.action_shape = action_shape 275 self.action_space = action_space 276 assert self.action_space in ['regression', 'reparameterization'], self.action_space 277 if self.action_space == 'regression': # DDPG, TD3 278 self.actor = nn.Sequential( 279 nn.Linear(obs_shape, actor_head_hidden_size), activation, 280 RegressionHead( 281 actor_head_hidden_size, 282 action_shape, 283 actor_head_layer_num, 284 final_tanh=True, 285 activation=activation, 286 norm_type=norm_type 287 ) 288 ) 289 else: # SAC 290 self.actor = nn.Sequential( 291 nn.Linear(obs_shape, actor_head_hidden_size), activation, 292 ReparameterizationHead( 293 actor_head_hidden_size, 294 action_shape, 295 actor_head_layer_num, 296 sigma_type='conditioned', 297 activation=activation, 298 norm_type=norm_type 299 ) 300 ) 301 self.twin_critic = twin_critic 302 critic_input_size = global_obs_shape + action_shape 303 if self.twin_critic: 304 self.critic = nn.ModuleList() 305 for _ in range(2): 306 self.critic.append( 307 nn.Sequential( 308 nn.Linear(critic_input_size, critic_head_hidden_size), activation, 309 RegressionHead( 310 critic_head_hidden_size, 311 1, 312 critic_head_layer_num, 313 final_tanh=False, 314 activation=activation, 315 norm_type=norm_type 316 ) 317 ) 318 ) 319 else: 320 self.critic = nn.Sequential( 321 nn.Linear(critic_input_size, critic_head_hidden_size), activation, 322 RegressionHead( 323 critic_head_hidden_size, 324 1, 325 critic_head_layer_num, 326 final_tanh=False, 327 activation=activation, 328 norm_type=norm_type 329 ) 330 ) 331 332 def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: 333 """ 334 Overview: 335 Use observation and action tensor to predict output in ``compute_actor`` or ``compute_critic`` mode. 336 Arguments: 337 - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 338 - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 339 - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ 340 with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ 341 N0 corresponds to ``agent_obs_shape``. 342 - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ 343 with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ 344 N1 corresponds to ``global_obs_shape``. 345 - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ 346 with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ 347 N2 corresponds to ``action_shape``. 348 349 - ``action`` (:obj:`torch.Tensor`): The action tensor data, \ 350 with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \ 351 N3 corresponds to ``action_shape``. 352 - mode (:obj:`str`): Name of the forward mode. 353 Returns: 354 - outputs (:obj:`Dict`): Outputs of network forward, whose key-values will be different for different \ 355 ``mode``, ``twin_critic``, ``action_space``. 356 Examples: 357 >>> B = 32 358 >>> agent_obs_shape = 216 359 >>> global_obs_shape = 264 360 >>> agent_num = 8 361 >>> action_shape = 14 362 >>> act_space = 'reparameterization' # regression 363 >>> data = { 364 >>> 'obs': { 365 >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), 366 >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), 367 >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) 368 >>> }, 369 >>> 'action': torch.randn(B, agent_num, squeeze(action_shape)) 370 >>> } 371 >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False) 372 >>> if action_space == 'regression': 373 >>> action = model(data['obs'], mode='compute_actor')['action'] 374 >>> elif action_space == 'reparameterization': 375 >>> (mu, sigma) = model(data['obs'], mode='compute_actor')['logit'] 376 >>> value = model(data, mode='compute_critic')['q_value'] 377 """ 378 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 379 return getattr(self, mode)(inputs) 380 381 def compute_actor(self, inputs: Dict) -> Dict: 382 """ 383 Overview: 384 Use observation tensor to predict action logits. 385 Arguments: 386 - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 387 - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ 388 with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ 389 N0 corresponds to ``agent_obs_shape``. 390 391 Returns: 392 - outputs (:obj:`Dict`): Outputs of network forward. 393 ReturnKeys (``action_space == 'regression'``): 394 - action (:obj:`torch.Tensor`): Action tensor with same size as ``action_shape``. 395 ReturnKeys (``action_space == 'reparameterization'``): 396 - logit (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N3)`, where B is batch size and \ 397 A is agent num. N3 corresponds to ``action_shape``. 398 Examples: 399 >>> B = 32 400 >>> agent_obs_shape = 216 401 >>> global_obs_shape = 264 402 >>> agent_num = 8 403 >>> action_shape = 14 404 >>> act_space = 'reparameterization' # 'regression' 405 >>> data = { 406 >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), 407 >>> } 408 >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False) 409 >>> if action_space == 'regression': 410 >>> action = model.compute_actor(data)['action'] 411 >>> elif action_space == 'reparameterization': 412 >>> (mu, sigma) = model.compute_actor(data)['logit'] 413 """ 414 inputs = inputs['agent_state'] 415 if self.action_space == 'regression': 416 x = self.actor(inputs) 417 return {'action': x['pred']} 418 else: 419 x = self.actor(inputs) 420 return {'logit': [x['mu'], x['sigma']]} 421 422 def compute_critic(self, inputs: Dict) -> Dict: 423 """ 424 Overview: 425 Use observation tensor and action tensor to predict Q value. 426 Arguments: 427 - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 428 - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: 429 - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ 430 with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ 431 N0 corresponds to ``agent_obs_shape``. 432 - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ 433 with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ 434 N1 corresponds to ``global_obs_shape``. 435 - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ 436 with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ 437 N2 corresponds to ``action_shape``. 438 439 - ``action`` (:obj:`torch.Tensor`): The action tensor data, \ 440 with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \ 441 N3 corresponds to ``action_shape``. 442 443 Returns: 444 - outputs (:obj:`Dict`): Outputs of network forward. 445 ReturnKeys (``twin_critic=True``): 446 - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A)`, where B is batch size and \ 447 A is agent num. 448 ReturnKeys (``twin_critic=False``): 449 - q_value (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is agent num. 450 Examples: 451 >>> B = 32 452 >>> agent_obs_shape = 216 453 >>> global_obs_shape = 264 454 >>> agent_num = 8 455 >>> action_shape = 14 456 >>> act_space = 'reparameterization' # 'regression' 457 >>> data = { 458 >>> 'obs': { 459 >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), 460 >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), 461 >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) 462 >>> }, 463 >>> 'action': torch.randn(B, agent_num, squeeze(action_shape)) 464 >>> } 465 >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False) 466 >>> value = model.compute_critic(data)['q_value'] 467 """ 468 469 obs, action = inputs['obs']['global_state'], inputs['action'] 470 if len(action.shape) == 1: # (B, ) -> (B, 1) 471 action = action.unsqueeze(1) 472 x = torch.cat([obs, action], dim=-1) 473 if self.twin_critic: 474 x = [m(x)['pred'] for m in self.critic] 475 else: 476 x = self.critic(x)['pred'] 477 return {'q_value': x}