1from functools import reduce 2from typing import List, Union 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7from ding.torch_utils import MLP, fc_block 8from ding.utils import MODEL_REGISTRY, list_split 9 10from ..common import ConvEncoder 11from .q_learning import DRQN 12 13 14class Mixer(nn.Module): 15 """ 16 Overview: 17 Mixer network in QMIX, which mix up the independent q_value of each agent to a total q_value. \ 18 The weights (but not the biases) of the Mixer network are restricted to be non-negative and \ 19 produced by separate hypernetworks. Each hypernetwork takes the globle state s as input and generates \ 20 the weights of one layer of the Mixer network. 21 Interface: 22 ``__init__``, ``forward``. 23 """ 24 25 def __init__( 26 self, 27 agent_num: int, 28 state_dim: int, 29 mixing_embed_dim: int, 30 hypernet_embed: int = 64, 31 activation: nn.Module = nn.ReLU() 32 ): 33 """ 34 Overview: 35 Initialize mixer network proposed in QMIX according to arguments. Each hypernetwork consists of \ 36 linear layers, followed by an absolute activation function, to ensure that the Mixer network weights are \ 37 non-negative. 38 Arguments: 39 - agent_num (:obj:`int`): The number of agent, such as 8. 40 - state_dim(:obj:`int`): The dimension of global observation state, such as 16. 41 - mixing_embed_dim (:obj:`int`): The dimension of mixing state emdedding, such as 128. 42 - hypernet_embed (:obj:`int`): The dimension of hypernet emdedding, default to 64. 43 - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU(). 44 """ 45 super(Mixer, self).__init__() 46 47 self.n_agents = agent_num 48 self.state_dim = state_dim 49 self.embed_dim = mixing_embed_dim 50 self.act = activation 51 self.hyper_w_1 = nn.Sequential( 52 nn.Linear(self.state_dim, hypernet_embed), self.act, 53 nn.Linear(hypernet_embed, self.embed_dim * self.n_agents) 54 ) 55 self.hyper_w_final = nn.Sequential( 56 nn.Linear(self.state_dim, hypernet_embed), self.act, nn.Linear(hypernet_embed, self.embed_dim) 57 ) 58 59 # state dependent bias for hidden layer 60 self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) 61 62 # V(s) instead of a bias for the last layers 63 self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), self.act, nn.Linear(self.embed_dim, 1)) 64 65 def forward(self, agent_qs, states): 66 """ 67 Overview: 68 Forward computation graph of pymarl mixer network. Mix up the input independent q_value of each agent \ 69 to a total q_value with weights generated by hypernetwork according to global ``states``. 70 Arguments: 71 - agent_qs (:obj:`torch.FloatTensor`): The independent q_value of each agent. 72 - states (:obj:`torch.FloatTensor`): The emdedding vector of global state. 73 Returns: 74 - q_tot (:obj:`torch.FloatTensor`): The total mixed q_value. 75 Shapes: 76 - agent_qs (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is agent_num. 77 - states (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is embedding_size. 78 - q_tot (:obj:`torch.FloatTensor`): :math:`(B, )`. 79 """ 80 bs = agent_qs.shape[:-1] 81 states = states.reshape(-1, self.state_dim) 82 agent_qs = agent_qs.view(-1, 1, self.n_agents) 83 # First layer 84 w1 = torch.abs(self.hyper_w_1(states)) 85 b1 = self.hyper_b_1(states) 86 w1 = w1.view(-1, self.n_agents, self.embed_dim) 87 b1 = b1.view(-1, 1, self.embed_dim) 88 hidden = F.elu(torch.bmm(agent_qs, w1) + b1) 89 # Second layer 90 w_final = torch.abs(self.hyper_w_final(states)) 91 w_final = w_final.view(-1, self.embed_dim, 1) 92 # State-dependent bias 93 v = self.V(states).view(-1, 1, 1) 94 # Compute final output 95 y = torch.bmm(hidden, w_final) + v 96 # Reshape and return 97 q_tot = y.view(*bs) 98 return q_tot 99 100 101@MODEL_REGISTRY.register('qmix') 102class QMix(nn.Module): 103 """ 104 Overview: 105 The neural network and computation graph of algorithms related to QMIX(https://arxiv.org/abs/1803.11485). \ 106 The QMIX is composed of two parts: agent Q network and mixer(optional). The QMIX paper mentions that all \ 107 agents share local Q network parameters, so only one Q network is initialized here. Then use summation or \ 108 Mixer network to process the local Q according to the ``mixer`` settings to obtain the global Q. 109 Interface: 110 ``__init__``, ``forward``. 111 """ 112 113 def __init__( 114 self, 115 agent_num: int, 116 obs_shape: int, 117 global_obs_shape: Union[int, List[int]], 118 action_shape: int, 119 hidden_size_list: list, 120 mixer: bool = True, 121 lstm_type: str = 'gru', 122 activation: nn.Module = nn.ReLU(), 123 dueling: bool = False 124 ) -> None: 125 """ 126 Overview: 127 Initialize QMIX neural network according to arguments, i.e. agent Q network and mixer. 128 Arguments: 129 - agent_num (:obj:`int`): The number of agent, such as 8. 130 - obs_shape (:obj:`int`): The dimension of each agent's observation state, such as 8 or [4, 84, 84]. 131 - global_obs_shape (:obj:`int`): The dimension of global observation state, such as 8 or [4, 84, 84]. 132 - action_shape (:obj:`int`): The dimension of action shape, such as 6 or [2, 3, 3]. 133 - hidden_size_list (:obj:`list`): The list of hidden size for ``q_network``, \ 134 the last element must match mixer's ``mixing_embed_dim``. 135 - mixer (:obj:`bool`): Use mixer net or not, default to True. If it is false, \ 136 the final local Q is added to obtain the global Q. 137 - lstm_type (:obj:`str`): The type of RNN module in ``q_network``, now support \ 138 ['normal', 'pytorch', 'gru'], default to gru. 139 - activation (:obj:`nn.Module`): The type of activation function to use in ``MLP`` the after \ 140 ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``. 141 - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \ 142 default to False. 143 """ 144 super(QMix, self).__init__() 145 self._act = activation 146 self._q_network = DRQN( 147 obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling, activation=activation 148 ) 149 embedding_size = hidden_size_list[-1] 150 self.mixer = mixer 151 if self.mixer: 152 global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape) 153 154 if global_obs_shape_type == "flat": 155 self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) 156 self._global_state_encoder = nn.Identity() 157 elif global_obs_shape_type == "image": 158 self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation) 159 self._global_state_encoder = ConvEncoder( 160 global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN' 161 ) 162 else: 163 raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}") 164 165 def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str: 166 """ 167 Overview: 168 Determine the type of global observation shape. 169 Arguments: 170 - global_obs_shape (:obj:`Union[int, List[int]]`): The global observation state. 171 Returns: 172 - obs_shape_type (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation. 173 """ 174 if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1): 175 return "flat" 176 elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3: 177 return "image" 178 else: 179 raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}") 180 181 def forward(self, data: dict, single_step: bool = True) -> dict: 182 """ 183 Overview: 184 QMIX forward computation graph, input dict including time series observation and related data to predict \ 185 total q_value and each agent q_value. 186 Arguments: 187 - data (:obj:`dict`): Input data dict with keys ['obs', 'prev_state', 'action']. 188 - agent_state (:obj:`torch.Tensor`): Time series local observation data of each agents. 189 - global_state (:obj:`torch.Tensor`): Time series global observation data. 190 - prev_state (:obj:`list`): Previous rnn state for ``q_network``. 191 - action (:obj:`torch.Tensor` or None): The actions of each agent given outside the function. \ 192 If action is None, use argmax q_value index as action to calculate ``agent_q_act``. 193 - single_step (:obj:`bool`): Whether single_step forward, if so, add timestep dim before forward and\ 194 remove it after forward. 195 Returns: 196 - ret (:obj:`dict`): Output data dict with keys [``total_q``, ``logit``, ``next_state``]. 197 ReturnsKeys: 198 - total_q (:obj:`torch.Tensor`): Total q_value, which is the result of mixer network. 199 - agent_q (:obj:`torch.Tensor`): Each agent q_value. 200 - next_state (:obj:`list`): Next rnn state for ``q_network``. 201 Shapes: 202 - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\ 203 A is agent_num, N is obs_shape. 204 - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape. 205 - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A. 206 - action (:obj:`torch.Tensor`): :math:`(T, B, A)`. 207 - total_q (:obj:`torch.Tensor`): :math:`(T, B)`. 208 - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape. 209 - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A. 210 """ 211 agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[ 212 'prev_state'] 213 action = data.get('action', None) 214 # If single_step is True, add a new dimension at the front of agent_state 215 # This is necessary to maintain the expected input shape for the model, 216 # which requires a time step dimension even when processing a single step. 217 if single_step: 218 agent_state = agent_state.unsqueeze(0) 219 # If single_step is True and global_state has 2 dimensions, add a new dimension at the front of global_state 220 # This ensures that global_state has the same number of dimensions as agent_state, 221 # allowing for consistent processing in the forward computation. 222 if single_step and len(global_state.shape) == 2: 223 global_state = global_state.unsqueeze(0) 224 T, B, A = agent_state.shape[:3] 225 assert len(prev_state) == B and all( 226 [len(p) == A for p in prev_state] 227 ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0])) 228 prev_state = reduce(lambda x, y: x + y, prev_state) 229 agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) 230 output = self._q_network({'obs': agent_state, 'prev_state': prev_state}) 231 agent_q, next_state = output['logit'], output['next_state'] 232 next_state, _ = list_split(next_state, step=A) 233 agent_q = agent_q.reshape(T, B, A, -1) 234 if action is None: 235 # for target forward process 236 if len(data['obs']['action_mask'].shape) == 3: 237 action_mask = data['obs']['action_mask'].unsqueeze(0) 238 else: 239 action_mask = data['obs']['action_mask'] 240 agent_q[action_mask == 0.0] = -9999999 241 action = agent_q.argmax(dim=-1) 242 agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) 243 agent_q_act = agent_q_act.squeeze(-1) # T, B, A 244 if self.mixer: 245 global_state_embedding = self._process_global_state(global_state) 246 total_q = self._mixer(agent_q_act, global_state_embedding) 247 else: 248 total_q = agent_q_act.sum(dim=-1) 249 250 if single_step: 251 total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0) 252 253 return { 254 'total_q': total_q, 255 'logit': agent_q, 256 'next_state': next_state, 257 'action_mask': data['obs']['action_mask'] 258 } 259 260 def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor: 261 """ 262 Overview: 263 Process the global state to obtain an embedding. 264 Arguments: 265 - global_state (:obj:`torch.Tensor`): The global state tensor. 266 267 Returns: 268 - global_state_embedding (:obj:`torch.Tensor`): The processed global state embedding. 269 """ 270 # If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W] 271 if global_state.dim() == 5: 272 # Reshape and apply the global state encoder 273 batch_time_shape = global_state.shape[:2] # [batch_size, time_steps] 274 reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims 275 encoded_state = self._global_state_encoder(reshaped_state) 276 return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim] 277 else: 278 # For lower-dimensional states, apply the encoder directly 279 return self._global_state_encoder(global_state)