Skip to content

ding.model.template.coma

ding.model.template.coma

COMAActorNetwork

Bases: Module

Overview

Decentralized actor network in COMA algorithm.

Interface: __init__, forward

__init__(obs_shape, action_shape, hidden_size_list=[128, 128, 64])

Overview

Initialize COMA actor network

Arguments: - obs_shape (:obj:int): the dimension of each agent's observation state - action_shape (:obj:int): the dimension of action shape - hidden_size_list (:obj:list): the list of hidden size, default to [128, 128, 64]

forward(inputs)

Overview

The forward computation graph of COMA actor network

Arguments: - inputs (:obj:dict): input data dict with keys ['obs', 'prev_state'] - agent_state (:obj:torch.Tensor): each agent local state(obs) - action_mask (:obj:torch.Tensor): the masked action - prev_state (:obj:torch.Tensor): the previous hidden state Returns: - output (:obj:dict): output data dict with keys ['logit', 'next_state', 'action_mask'] ArgumentsKeys: - necessary: obs { agent_state, action_mask }, prev_state ReturnsKeys: - necessary: logit, next_state, action_mask Examples: >>> T, B, A, N = 4, 8, 3, 32 >>> embedding_dim = 64 >>> action_dim = 6 >>> data = torch.randn(T, B, A, N) >>> model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim]) >>> prev_state = [[None for _ in range(A)] for _ in range(B)] >>> for t in range(T): >>> inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state} >>> outputs = model(inputs) >>> logit, prev_state = outputs['logit'], outputs['next_state']

COMACriticNetwork

Bases: Module

Overview

Centralized critic network in COMA algorithm.

Interface: __init__, forward

__init__(input_size, action_shape, hidden_size=128)

Overview

initialize COMA critic network

Arguments: - input_size (:obj:int): the size of input global observation - action_shape (:obj:int): the dimension of action shape - hidden_size_list (:obj:list): the list of hidden size, default to 128 Returns: - output (:obj:dict): output data dict with keys ['q_value'] Shapes: - obs (:obj:dict): agent_state: :math:(T, B, A, N, D), action_mask: :math:(T, B, A, N, A) - prev_state (:obj:list): :math:[[[h, c] for _ in range(A)] for _ in range(B)] - logit (:obj:torch.Tensor): :math:(T, B, A, N, A) - next_state (:obj:list): :math:[[[h, c] for _ in range(A)] for _ in range(B)] - action_mask (:obj:torch.Tensor): :math:(T, B, A, N, A)

forward(data)

Overview

forward computation graph of qmix network

Arguments: - data (:obj:dict): input data dict with keys ['obs', 'prev_state', 'action'] - agent_state (:obj:torch.Tensor): each agent local state(obs) - global_state (:obj:torch.Tensor): global state(obs) - action (:obj:torch.Tensor): the masked action ArgumentsKeys: - necessary: obs { agent_state, global_state }, action, prev_state ReturnsKeys: - necessary: q_value Examples: >>> agent_num, bs, T = 4, 3, 8 >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 >>> coma_model = COMACriticNetwork( >>> obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim) >>> data = { >>> 'obs': { >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), >>> 'global_state': torch.randn(T, bs, global_obs_dim), >>> }, >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), >>> } >>> output = coma_model(data)

COMA

Bases: Module

Overview

The network of COMA algorithm, which is QAC-type actor-critic.

Interface: __init__, forward Properties: - mode (:obj:list): The list of forward mode, including compute_actor and compute_critic

__init__(agent_num, obs_shape, action_shape, actor_hidden_size_list)

Overview

initialize COMA network

Arguments: - agent_num (:obj:int): the number of agent - obs_shape (:obj:Dict): the observation information, including agent_state and global_state - action_shape (:obj:Union[int, SequenceType]): the dimension of action shape - actor_hidden_size_list (:obj:SequenceType): the list of hidden size

forward(inputs, mode)

Overview

forward computation graph of COMA network

Arguments: - inputs (:obj:dict): input data dict with keys ['obs', 'prev_state', 'action'] - agent_state (:obj:torch.Tensor): each agent local state(obs) - global_state (:obj:torch.Tensor): global state(obs) - action (:obj:torch.Tensor): the masked action ArgumentsKeys: - necessary: obs { agent_state, global_state, action_mask }, action, prev_state ReturnsKeys: - necessary: - compute_critic: q_value - compute_actor: logit, next_state, action_mask Shapes: - obs (:obj:dict): agent_state: :math:(T, B, A, N, D), action_mask: :math:(T, B, A, N, A) - prev_state (:obj:list): :math:[[[h, c] for _ in range(A)] for _ in range(B)] - logit (:obj:torch.Tensor): :math:(T, B, A, N, A) - next_state (:obj:list): :math:[[[h, c] for _ in range(A)] for _ in range(B)] - action_mask (:obj:torch.Tensor): :math:(T, B, A, N, A) - q_value (:obj:torch.Tensor): :math:(T, B, A, N, A) Examples: >>> agent_num, bs, T = 4, 3, 8 >>> agent_num, bs, T = 4, 3, 8 >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 >>> coma_model = COMA( >>> agent_num=agent_num, >>> obs_shape=dict(agent_state=(obs_dim, ), global_state=(global_obs_dim, )), >>> action_shape=action_dim, >>> actor_hidden_size_list=[128, 64], >>> ) >>> prev_state = [[None for _ in range(agent_num)] for _ in range(bs)] >>> data = { >>> 'obs': { >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), >>> 'action_mask': None, >>> }, >>> 'prev_state': prev_state, >>> } >>> output = coma_model(data, mode='compute_actor') >>> data= { >>> 'obs': { >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), >>> 'global_state': torch.randn(T, bs, global_obs_dim), >>> }, >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), >>> } >>> output = coma_model(data, mode='compute_critic')

Full Source Code

../ding/model/template/coma.py

1from typing import Dict, Union 2import torch 3import torch.nn as nn 4 5from functools import reduce 6from ding.torch_utils import one_hot, MLP 7from ding.utils import squeeze, list_split, MODEL_REGISTRY, SequenceType 8from .q_learning import DRQN 9 10 11class COMAActorNetwork(nn.Module): 12 """ 13 Overview: 14 Decentralized actor network in COMA algorithm. 15 Interface: 16 ``__init__``, ``forward`` 17 """ 18 19 def __init__( 20 self, 21 obs_shape: int, 22 action_shape: int, 23 hidden_size_list: SequenceType = [128, 128, 64], 24 ): 25 """ 26 Overview: 27 Initialize COMA actor network 28 Arguments: 29 - obs_shape (:obj:`int`): the dimension of each agent's observation state 30 - action_shape (:obj:`int`): the dimension of action shape 31 - hidden_size_list (:obj:`list`): the list of hidden size, default to [128, 128, 64] 32 """ 33 super(COMAActorNetwork, self).__init__() 34 self.main = DRQN(obs_shape, action_shape, hidden_size_list) 35 36 def forward(self, inputs: Dict) -> Dict: 37 """ 38 Overview: 39 The forward computation graph of COMA actor network 40 Arguments: 41 - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state'] 42 - agent_state (:obj:`torch.Tensor`): each agent local state(obs) 43 - action_mask (:obj:`torch.Tensor`): the masked action 44 - prev_state (:obj:`torch.Tensor`): the previous hidden state 45 Returns: 46 - output (:obj:`dict`): output data dict with keys ['logit', 'next_state', 'action_mask'] 47 ArgumentsKeys: 48 - necessary: ``obs`` { ``agent_state``, ``action_mask`` }, ``prev_state`` 49 ReturnsKeys: 50 - necessary: ``logit``, ``next_state``, ``action_mask`` 51 Examples: 52 >>> T, B, A, N = 4, 8, 3, 32 53 >>> embedding_dim = 64 54 >>> action_dim = 6 55 >>> data = torch.randn(T, B, A, N) 56 >>> model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim]) 57 >>> prev_state = [[None for _ in range(A)] for _ in range(B)] 58 >>> for t in range(T): 59 >>> inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state} 60 >>> outputs = model(inputs) 61 >>> logit, prev_state = outputs['logit'], outputs['next_state'] 62 """ 63 agent_state = inputs['obs']['agent_state'] 64 prev_state = inputs['prev_state'] 65 if len(agent_state.shape) == 3: # B, A, N 66 agent_state = agent_state.unsqueeze(0) 67 unsqueeze_flag = True 68 else: 69 unsqueeze_flag = False 70 T, B, A = agent_state.shape[:3] 71 agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) 72 prev_state = reduce(lambda x, y: x + y, prev_state) 73 output = self.main({'obs': agent_state, 'prev_state': prev_state}) 74 logit, next_state = output['logit'], output['next_state'] 75 next_state, _ = list_split(next_state, step=A) 76 logit = logit.reshape(T, B, A, -1) 77 if unsqueeze_flag: 78 logit = logit.squeeze(0) 79 return {'logit': logit, 'next_state': next_state, 'action_mask': inputs['obs']['action_mask']} 80 81 82class COMACriticNetwork(nn.Module): 83 """ 84 Overview: 85 Centralized critic network in COMA algorithm. 86 Interface: 87 ``__init__``, ``forward`` 88 """ 89 90 def __init__( 91 self, 92 input_size: int, 93 action_shape: int, 94 hidden_size: int = 128, 95 ): 96 """ 97 Overview: 98 initialize COMA critic network 99 Arguments: 100 - input_size (:obj:`int`): the size of input global observation 101 - action_shape (:obj:`int`): the dimension of action shape 102 - hidden_size_list (:obj:`list`): the list of hidden size, default to 128 103 Returns: 104 - output (:obj:`dict`): output data dict with keys ['q_value'] 105 Shapes: 106 - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` 107 - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` 108 - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` 109 - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` 110 - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` 111 """ 112 super(COMACriticNetwork, self).__init__() 113 self.action_shape = action_shape 114 self.act = nn.ReLU() 115 self.mlp = nn.Sequential( 116 MLP(input_size, hidden_size, hidden_size, 2, activation=self.act), nn.Linear(hidden_size, action_shape) 117 ) 118 119 def forward(self, data: Dict) -> Dict: 120 """ 121 Overview: 122 forward computation graph of qmix network 123 Arguments: 124 - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] 125 - agent_state (:obj:`torch.Tensor`): each agent local state(obs) 126 - global_state (:obj:`torch.Tensor`): global state(obs) 127 - action (:obj:`torch.Tensor`): the masked action 128 ArgumentsKeys: 129 - necessary: ``obs`` { ``agent_state``, ``global_state`` }, ``action``, ``prev_state`` 130 ReturnsKeys: 131 - necessary: ``q_value`` 132 Examples: 133 >>> agent_num, bs, T = 4, 3, 8 134 >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 135 >>> coma_model = COMACriticNetwork( 136 >>> obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim) 137 >>> data = { 138 >>> 'obs': { 139 >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), 140 >>> 'global_state': torch.randn(T, bs, global_obs_dim), 141 >>> }, 142 >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), 143 >>> } 144 >>> output = coma_model(data) 145 """ 146 x = self._preprocess_data(data) 147 q = self.mlp(x) 148 return {'q_value': q} 149 150 def _preprocess_data(self, data: Dict) -> torch.Tensor: 151 """ 152 Overview: 153 preprocess data to make it can be used by MLP net 154 Arguments: 155 - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] 156 - agent_state (:obj:`torch.Tensor`): each agent local state(obs) 157 - global_state (:obj:`torch.Tensor`): global state(obs) 158 - action (:obj:`torch.Tensor`): the masked action 159 ArgumentsKeys: 160 - necessary: ``obs`` { ``agent_state``, ``global_state``} , ``action``, ``prev_state`` 161 Return: 162 - x (:obj:`torch.Tensor`): the data can be used by MLP net, including \ 163 ``global_state``, ``agent_state``, ``last_action``, ``action``, ``agent_id`` 164 """ 165 t_size, batch_size, agent_num = data['obs']['agent_state'].shape[:3] 166 agent_state_ori, global_state = data['obs']['agent_state'], data['obs']['global_state'] 167 168 # splite obs, last_action and agent_id 169 agent_state = agent_state_ori[..., :-self.action_shape - agent_num] 170 last_action = agent_state_ori[..., -self.action_shape - agent_num:-agent_num] 171 last_action = last_action.reshape(t_size, batch_size, 1, -1).repeat(1, 1, agent_num, 1) 172 agent_id = agent_state_ori[..., -agent_num:] 173 174 action = one_hot(data['action'], self.action_shape) # T, B, A,N 175 action = action.reshape(t_size, batch_size, -1, agent_num * self.action_shape).repeat(1, 1, agent_num, 1) 176 action_mask = (1 - torch.eye(agent_num).to(action.device)) 177 action_mask = action_mask.view(-1, 1).repeat(1, self.action_shape).view(agent_num, -1) # A, A*N 178 action = (action_mask.unsqueeze(0).unsqueeze(0)) * action # T, B, A, A*N 179 global_state = global_state.unsqueeze(2).repeat(1, 1, agent_num, 1) 180 181 x = torch.cat([global_state, agent_state, last_action, action, agent_id], -1) 182 return x 183 184 185@MODEL_REGISTRY.register('coma') 186class COMA(nn.Module): 187 """ 188 Overview: 189 The network of COMA algorithm, which is QAC-type actor-critic. 190 Interface: 191 ``__init__``, ``forward`` 192 Properties: 193 - mode (:obj:`list`): The list of forward mode, including ``compute_actor`` and ``compute_critic`` 194 """ 195 196 mode = ['compute_actor', 'compute_critic'] 197 198 def __init__( 199 self, agent_num: int, obs_shape: Dict, action_shape: Union[int, SequenceType], 200 actor_hidden_size_list: SequenceType 201 ) -> None: 202 """ 203 Overview: 204 initialize COMA network 205 Arguments: 206 - agent_num (:obj:`int`): the number of agent 207 - obs_shape (:obj:`Dict`): the observation information, including agent_state and \ 208 global_state 209 - action_shape (:obj:`Union[int, SequenceType]`): the dimension of action shape 210 - actor_hidden_size_list (:obj:`SequenceType`): the list of hidden size 211 """ 212 super(COMA, self).__init__() 213 action_shape = squeeze(action_shape) 214 actor_input_size = squeeze(obs_shape['agent_state']) 215 critic_input_size = squeeze(obs_shape['agent_state']) + squeeze(obs_shape['global_state']) + \ 216 agent_num * action_shape + (agent_num - 1) * action_shape 217 critic_hidden_size = actor_hidden_size_list[-1] 218 self.actor = COMAActorNetwork(actor_input_size, action_shape, actor_hidden_size_list) 219 self.critic = COMACriticNetwork(critic_input_size, action_shape, critic_hidden_size) 220 221 def forward(self, inputs: Dict, mode: str) -> Dict: 222 """ 223 Overview: 224 forward computation graph of COMA network 225 Arguments: 226 - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] 227 - agent_state (:obj:`torch.Tensor`): each agent local state(obs) 228 - global_state (:obj:`torch.Tensor`): global state(obs) 229 - action (:obj:`torch.Tensor`): the masked action 230 ArgumentsKeys: 231 - necessary: ``obs`` { ``agent_state``, ``global_state``, ``action_mask`` }, ``action``, ``prev_state`` 232 ReturnsKeys: 233 - necessary: 234 - compute_critic: ``q_value`` 235 - compute_actor: ``logit``, ``next_state``, ``action_mask`` 236 Shapes: 237 - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` 238 - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` 239 - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` 240 - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` 241 - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` 242 - q_value (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` 243 Examples: 244 >>> agent_num, bs, T = 4, 3, 8 245 >>> agent_num, bs, T = 4, 3, 8 246 >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 247 >>> coma_model = COMA( 248 >>> agent_num=agent_num, 249 >>> obs_shape=dict(agent_state=(obs_dim, ), global_state=(global_obs_dim, )), 250 >>> action_shape=action_dim, 251 >>> actor_hidden_size_list=[128, 64], 252 >>> ) 253 >>> prev_state = [[None for _ in range(agent_num)] for _ in range(bs)] 254 >>> data = { 255 >>> 'obs': { 256 >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), 257 >>> 'action_mask': None, 258 >>> }, 259 >>> 'prev_state': prev_state, 260 >>> } 261 >>> output = coma_model(data, mode='compute_actor') 262 >>> data= { 263 >>> 'obs': { 264 >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), 265 >>> 'global_state': torch.randn(T, bs, global_obs_dim), 266 >>> }, 267 >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), 268 >>> } 269 >>> output = coma_model(data, mode='compute_critic') 270 """ 271 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 272 if mode == 'compute_actor': 273 return self.actor(inputs) 274 elif mode == 'compute_critic': 275 return self.critic(inputs)