Skip to content

ding.model.template.qmix

ding.model.template.qmix

Mixer

Bases: Module

Overview

Mixer network in QMIX, which mix up the independent q_value of each agent to a total q_value. The weights (but not the biases) of the Mixer network are restricted to be non-negative and produced by separate hypernetworks. Each hypernetwork takes the globle state s as input and generates the weights of one layer of the Mixer network.

Interface: __init__, forward.

__init__(agent_num, state_dim, mixing_embed_dim, hypernet_embed=64, activation=nn.ReLU())

Overview

Initialize mixer network proposed in QMIX according to arguments. Each hypernetwork consists of linear layers, followed by an absolute activation function, to ensure that the Mixer network weights are non-negative.

Arguments: - agent_num (:obj:int): The number of agent, such as 8. - state_dim(:obj:int): The dimension of global observation state, such as 16. - mixing_embed_dim (:obj:int): The dimension of mixing state emdedding, such as 128. - hypernet_embed (:obj:int): The dimension of hypernet emdedding, default to 64. - activation (:obj:nn.Module): Activation function in network, defaults to nn.ReLU().

forward(agent_qs, states)

Overview

Forward computation graph of pymarl mixer network. Mix up the input independent q_value of each agent to a total q_value with weights generated by hypernetwork according to global states.

Arguments: - agent_qs (:obj:torch.FloatTensor): The independent q_value of each agent. - states (:obj:torch.FloatTensor): The emdedding vector of global state. Returns: - q_tot (:obj:torch.FloatTensor): The total mixed q_value. Shapes: - agent_qs (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is agent_num. - states (:obj:torch.FloatTensor): :math:(B, M), where M is embedding_size. - q_tot (:obj:torch.FloatTensor): :math:(B, ).

QMix

Bases: Module

Overview

The neural network and computation graph of algorithms related to QMIX(https://arxiv.org/abs/1803.11485). The QMIX is composed of two parts: agent Q network and mixer(optional). The QMIX paper mentions that all agents share local Q network parameters, so only one Q network is initialized here. Then use summation or Mixer network to process the local Q according to the mixer settings to obtain the global Q.

Interface: __init__, forward.

__init__(agent_num, obs_shape, global_obs_shape, action_shape, hidden_size_list, mixer=True, lstm_type='gru', activation=nn.ReLU(), dueling=False)

Overview

Initialize QMIX neural network according to arguments, i.e. agent Q network and mixer.

Arguments: - agent_num (:obj:int): The number of agent, such as 8. - obs_shape (:obj:int): The dimension of each agent's observation state, such as 8 or [4, 84, 84]. - global_obs_shape (:obj:int): The dimension of global observation state, such as 8 or [4, 84, 84]. - action_shape (:obj:int): The dimension of action shape, such as 6 or [2, 3, 3]. - hidden_size_list (:obj:list): The list of hidden size for q_network, the last element must match mixer's mixing_embed_dim. - mixer (:obj:bool): Use mixer net or not, default to True. If it is false, the final local Q is added to obtain the global Q. - lstm_type (:obj:str): The type of RNN module in q_network, now support ['normal', 'pytorch', 'gru'], default to gru. - activation (:obj:nn.Module): The type of activation function to use in MLP the after layer_fn, if None then default set to nn.ReLU(). - dueling (:obj:bool): Whether choose DuelingHead (True) or DiscreteHead (False), default to False.

forward(data, single_step=True)

Overview

QMIX forward computation graph, input dict including time series observation and related data to predict total q_value and each agent q_value.

Arguments: - data (:obj:dict): Input data dict with keys ['obs', 'prev_state', 'action']. - agent_state (:obj:torch.Tensor): Time series local observation data of each agents. - global_state (:obj:torch.Tensor): Time series global observation data. - prev_state (:obj:list): Previous rnn state for q_network. - action (:obj:torch.Tensor or None): The actions of each agent given outside the function. If action is None, use argmax q_value index as action to calculate agent_q_act. - single_step (:obj:bool): Whether single_step forward, if so, add timestep dim before forward and remove it after forward. Returns: - ret (:obj:dict): Output data dict with keys [total_q, logit, next_state]. ReturnsKeys: - total_q (:obj:torch.Tensor): Total q_value, which is the result of mixer network. - agent_q (:obj:torch.Tensor): Each agent q_value. - next_state (:obj:list): Next rnn state for q_network. Shapes: - agent_state (:obj:torch.Tensor): :math:(T, B, A, N), where T is timestep, B is batch_size A is agent_num, N is obs_shape. - global_state (:obj:torch.Tensor): :math:(T, B, M), where M is global_obs_shape. - prev_state (:obj:list): math:(B, A), a list of length B, and each element is a list of length A. - action (:obj:torch.Tensor): :math:(T, B, A). - total_q (:obj:torch.Tensor): :math:(T, B). - agent_q (:obj:torch.Tensor): :math:(T, B, A, P), where P is action_shape. - next_state (:obj:list): math:(B, A), a list of length B, and each element is a list of length A.

Full Source Code

../ding/model/template/qmix.py

1from functools import reduce 2from typing import List, Union 3 4import torch 5import torch.nn as nn 6import torch.nn.functional as F 7from ding.torch_utils import MLP, fc_block 8from ding.utils import MODEL_REGISTRY, list_split 9 10from ..common import ConvEncoder 11from .q_learning import DRQN 12 13 14class Mixer(nn.Module): 15 """ 16 Overview: 17 Mixer network in QMIX, which mix up the independent q_value of each agent to a total q_value. \ 18 The weights (but not the biases) of the Mixer network are restricted to be non-negative and \ 19 produced by separate hypernetworks. Each hypernetwork takes the globle state s as input and generates \ 20 the weights of one layer of the Mixer network. 21 Interface: 22 ``__init__``, ``forward``. 23 """ 24 25 def __init__( 26 self, 27 agent_num: int, 28 state_dim: int, 29 mixing_embed_dim: int, 30 hypernet_embed: int = 64, 31 activation: nn.Module = nn.ReLU() 32 ): 33 """ 34 Overview: 35 Initialize mixer network proposed in QMIX according to arguments. Each hypernetwork consists of \ 36 linear layers, followed by an absolute activation function, to ensure that the Mixer network weights are \ 37 non-negative. 38 Arguments: 39 - agent_num (:obj:`int`): The number of agent, such as 8. 40 - state_dim(:obj:`int`): The dimension of global observation state, such as 16. 41 - mixing_embed_dim (:obj:`int`): The dimension of mixing state emdedding, such as 128. 42 - hypernet_embed (:obj:`int`): The dimension of hypernet emdedding, default to 64. 43 - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU(). 44 """ 45 super(Mixer, self).__init__() 46 47 self.n_agents = agent_num 48 self.state_dim = state_dim 49 self.embed_dim = mixing_embed_dim 50 self.act = activation 51 self.hyper_w_1 = nn.Sequential( 52 nn.Linear(self.state_dim, hypernet_embed), self.act, 53 nn.Linear(hypernet_embed, self.embed_dim * self.n_agents) 54 ) 55 self.hyper_w_final = nn.Sequential( 56 nn.Linear(self.state_dim, hypernet_embed), self.act, nn.Linear(hypernet_embed, self.embed_dim) 57 ) 58 59 # state dependent bias for hidden layer 60 self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) 61 62 # V(s) instead of a bias for the last layers 63 self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), self.act, nn.Linear(self.embed_dim, 1)) 64 65 def forward(self, agent_qs, states): 66 """ 67 Overview: 68 Forward computation graph of pymarl mixer network. Mix up the input independent q_value of each agent \ 69 to a total q_value with weights generated by hypernetwork according to global ``states``. 70 Arguments: 71 - agent_qs (:obj:`torch.FloatTensor`): The independent q_value of each agent. 72 - states (:obj:`torch.FloatTensor`): The emdedding vector of global state. 73 Returns: 74 - q_tot (:obj:`torch.FloatTensor`): The total mixed q_value. 75 Shapes: 76 - agent_qs (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is agent_num. 77 - states (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is embedding_size. 78 - q_tot (:obj:`torch.FloatTensor`): :math:`(B, )`. 79 """ 80 bs = agent_qs.shape[:-1] 81 states = states.reshape(-1, self.state_dim) 82 agent_qs = agent_qs.view(-1, 1, self.n_agents) 83 # First layer 84 w1 = torch.abs(self.hyper_w_1(states)) 85 b1 = self.hyper_b_1(states) 86 w1 = w1.view(-1, self.n_agents, self.embed_dim) 87 b1 = b1.view(-1, 1, self.embed_dim) 88 hidden = F.elu(torch.bmm(agent_qs, w1) + b1) 89 # Second layer 90 w_final = torch.abs(self.hyper_w_final(states)) 91 w_final = w_final.view(-1, self.embed_dim, 1) 92 # State-dependent bias 93 v = self.V(states).view(-1, 1, 1) 94 # Compute final output 95 y = torch.bmm(hidden, w_final) + v 96 # Reshape and return 97 q_tot = y.view(*bs) 98 return q_tot 99 100 101@MODEL_REGISTRY.register('qmix') 102class QMix(nn.Module): 103 """ 104 Overview: 105 The neural network and computation graph of algorithms related to QMIX(https://arxiv.org/abs/1803.11485). \ 106 The QMIX is composed of two parts: agent Q network and mixer(optional). The QMIX paper mentions that all \ 107 agents share local Q network parameters, so only one Q network is initialized here. Then use summation or \ 108 Mixer network to process the local Q according to the ``mixer`` settings to obtain the global Q. 109 Interface: 110 ``__init__``, ``forward``. 111 """ 112 113 def __init__( 114 self, 115 agent_num: int, 116 obs_shape: int, 117 global_obs_shape: Union[int, List[int]], 118 action_shape: int, 119 hidden_size_list: list, 120 mixer: bool = True, 121 lstm_type: str = 'gru', 122 activation: nn.Module = nn.ReLU(), 123 dueling: bool = False 124 ) -> None: 125 """ 126 Overview: 127 Initialize QMIX neural network according to arguments, i.e. agent Q network and mixer. 128 Arguments: 129 - agent_num (:obj:`int`): The number of agent, such as 8. 130 - obs_shape (:obj:`int`): The dimension of each agent's observation state, such as 8 or [4, 84, 84]. 131 - global_obs_shape (:obj:`int`): The dimension of global observation state, such as 8 or [4, 84, 84]. 132 - action_shape (:obj:`int`): The dimension of action shape, such as 6 or [2, 3, 3]. 133 - hidden_size_list (:obj:`list`): The list of hidden size for ``q_network``, \ 134 the last element must match mixer's ``mixing_embed_dim``. 135 - mixer (:obj:`bool`): Use mixer net or not, default to True. If it is false, \ 136 the final local Q is added to obtain the global Q. 137 - lstm_type (:obj:`str`): The type of RNN module in ``q_network``, now support \ 138 ['normal', 'pytorch', 'gru'], default to gru. 139 - activation (:obj:`nn.Module`): The type of activation function to use in ``MLP`` the after \ 140 ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``. 141 - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \ 142 default to False. 143 """ 144 super(QMix, self).__init__() 145 self._act = activation 146 self._q_network = DRQN( 147 obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling, activation=activation 148 ) 149 embedding_size = hidden_size_list[-1] 150 self.mixer = mixer 151 if self.mixer: 152 global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape) 153 154 if global_obs_shape_type == "flat": 155 self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) 156 self._global_state_encoder = nn.Identity() 157 elif global_obs_shape_type == "image": 158 self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation) 159 self._global_state_encoder = ConvEncoder( 160 global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN' 161 ) 162 else: 163 raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}") 164 165 def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str: 166 """ 167 Overview: 168 Determine the type of global observation shape. 169 Arguments: 170 - global_obs_shape (:obj:`Union[int, List[int]]`): The global observation state. 171 Returns: 172 - obs_shape_type (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation. 173 """ 174 if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1): 175 return "flat" 176 elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3: 177 return "image" 178 else: 179 raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}") 180 181 def forward(self, data: dict, single_step: bool = True) -> dict: 182 """ 183 Overview: 184 QMIX forward computation graph, input dict including time series observation and related data to predict \ 185 total q_value and each agent q_value. 186 Arguments: 187 - data (:obj:`dict`): Input data dict with keys ['obs', 'prev_state', 'action']. 188 - agent_state (:obj:`torch.Tensor`): Time series local observation data of each agents. 189 - global_state (:obj:`torch.Tensor`): Time series global observation data. 190 - prev_state (:obj:`list`): Previous rnn state for ``q_network``. 191 - action (:obj:`torch.Tensor` or None): The actions of each agent given outside the function. \ 192 If action is None, use argmax q_value index as action to calculate ``agent_q_act``. 193 - single_step (:obj:`bool`): Whether single_step forward, if so, add timestep dim before forward and\ 194 remove it after forward. 195 Returns: 196 - ret (:obj:`dict`): Output data dict with keys [``total_q``, ``logit``, ``next_state``]. 197 ReturnsKeys: 198 - total_q (:obj:`torch.Tensor`): Total q_value, which is the result of mixer network. 199 - agent_q (:obj:`torch.Tensor`): Each agent q_value. 200 - next_state (:obj:`list`): Next rnn state for ``q_network``. 201 Shapes: 202 - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\ 203 A is agent_num, N is obs_shape. 204 - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape. 205 - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A. 206 - action (:obj:`torch.Tensor`): :math:`(T, B, A)`. 207 - total_q (:obj:`torch.Tensor`): :math:`(T, B)`. 208 - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape. 209 - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A. 210 """ 211 agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[ 212 'prev_state'] 213 action = data.get('action', None) 214 # If single_step is True, add a new dimension at the front of agent_state 215 # This is necessary to maintain the expected input shape for the model, 216 # which requires a time step dimension even when processing a single step. 217 if single_step: 218 agent_state = agent_state.unsqueeze(0) 219 # If single_step is True and global_state has 2 dimensions, add a new dimension at the front of global_state 220 # This ensures that global_state has the same number of dimensions as agent_state, 221 # allowing for consistent processing in the forward computation. 222 if single_step and len(global_state.shape) == 2: 223 global_state = global_state.unsqueeze(0) 224 T, B, A = agent_state.shape[:3] 225 assert len(prev_state) == B and all( 226 [len(p) == A for p in prev_state] 227 ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0])) 228 prev_state = reduce(lambda x, y: x + y, prev_state) 229 agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) 230 output = self._q_network({'obs': agent_state, 'prev_state': prev_state}) 231 agent_q, next_state = output['logit'], output['next_state'] 232 next_state, _ = list_split(next_state, step=A) 233 agent_q = agent_q.reshape(T, B, A, -1) 234 if action is None: 235 # for target forward process 236 if len(data['obs']['action_mask'].shape) == 3: 237 action_mask = data['obs']['action_mask'].unsqueeze(0) 238 else: 239 action_mask = data['obs']['action_mask'] 240 agent_q[action_mask == 0.0] = -9999999 241 action = agent_q.argmax(dim=-1) 242 agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) 243 agent_q_act = agent_q_act.squeeze(-1) # T, B, A 244 if self.mixer: 245 global_state_embedding = self._process_global_state(global_state) 246 total_q = self._mixer(agent_q_act, global_state_embedding) 247 else: 248 total_q = agent_q_act.sum(dim=-1) 249 250 if single_step: 251 total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0) 252 253 return { 254 'total_q': total_q, 255 'logit': agent_q, 256 'next_state': next_state, 257 'action_mask': data['obs']['action_mask'] 258 } 259 260 def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor: 261 """ 262 Overview: 263 Process the global state to obtain an embedding. 264 Arguments: 265 - global_state (:obj:`torch.Tensor`): The global state tensor. 266 267 Returns: 268 - global_state_embedding (:obj:`torch.Tensor`): The processed global state embedding. 269 """ 270 # If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W] 271 if global_state.dim() == 5: 272 # Reshape and apply the global state encoder 273 batch_time_shape = global_state.shape[:2] # [batch_size, time_steps] 274 reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims 275 encoded_state = self._global_state_encoder(reshaped_state) 276 return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim] 277 else: 278 # For lower-dimensional states, apply the encoder directly 279 return self._global_state_encoder(global_state)