Skip to content

ding.model.template.wqmix

ding.model.template.wqmix

MixerStar

Bases: Module

Overview

Mixer network for Q_star in WQMIX(https://arxiv.org/abs/2006.10800), which mix up the independent q_value of each agent to a total q_value and is diffrent from the QMIX's mixer network, here the mixing network is a feedforward network with 3 hidden layers of 256 dim. This Q_star mixing network is not constrained to be monotonic by using non-negative weights and having the state and agent_q be inputs, as opposed to having hypernetworks take the state as input and generate the weights in QMIX.

Interface: __init__, forward.

__init__(agent_num, state_dim, mixing_embed_dim)

Overview

Initialize the mixer network of Q_star in WQMIX.

Arguments: - agent_num (:obj:int): The number of agent, e.g., 8. - state_dim(:obj:int): The dimension of global observation state, e.g., 16. - mixing_embed_dim (:obj:int): The dimension of mixing state emdedding, e.g., 128.

forward(agent_qs, states)

Overview

Forward computation graph of the mixer network for Q_star in WQMIX. This mixer network for is a feed-forward network that takes the state and the appropriate actions' utilities as input.

Arguments: - agent_qs (:obj:torch.FloatTensor): The independent q_value of each agent. - states (:obj:torch.FloatTensor): The emdedding vector of global state. Returns: - q_tot (:obj:torch.FloatTensor): The total mixed q_value. Shapes: - agent_qs (:obj:torch.FloatTensor): :math:(T,B, N), where T is timestep, B is batch size, A is agent_num, N is obs_shape. - states (:obj:torch.FloatTensor): :math:(T, B, M), where M is global_obs_shape. - q_tot (:obj:torch.FloatTensor): :math:(T, B, ).

WQMix

Bases: Module

Overview

WQMIX (https://arxiv.org/abs/2006.10800) network, There are two components: 1) Q_tot, which is same as QMIX network and composed of agent Q network and mixer network. 2) An unrestricted joint action Q_star, which is composed of agent Q network and mixer_star network. The QMIX paper mentions that all agents share local Q network parameters, so only one Q network is initialized in Q_tot or Q_star.

Interface: __init__, forward.

__init__(agent_num, obs_shape, global_obs_shape, action_shape, hidden_size_list, lstm_type='gru', dueling=False)

Overview

Initialize WQMIX neural network according to arguments, i.e. agent Q network and mixer, Q_star network and mixer_star.

Arguments: - agent_num (:obj:int): The number of agent, such as 8. - obs_shape (:obj:int): The dimension of each agent's observation state, such as 8. - global_obs_shape (:obj:int): The dimension of global observation state, such as 8. - action_shape (:obj:int): The dimension of action shape, such as 6. - hidden_size_list (:obj:list): The list of hidden size for q_network, the last element must match mixer's mixing_embed_dim. - lstm_type (:obj:str): The type of RNN module in q_network, now support ['normal', 'pytorch', 'gru'], default to gru. - dueling (:obj:bool): Whether choose DuelingHead (True) or DiscreteHead (False), default to False.

forward(data, single_step=True, q_star=False)

Overview

Forward computation graph of qmix network. Input dict including time series observation and related data to predict total q_value and each agent q_value. Determine whether to calculate Q_tot or Q_star based on the q_star parameter.

Arguments: - data (:obj:dict): Input data dict with keys ['obs', 'prev_state', 'action']. - agent_state (:obj:torch.Tensor): Time series local observation data of each agents. - global_state (:obj:torch.Tensor): Time series global observation data. - prev_state (:obj:list): Previous rnn state for q_network or _q_network_star. - action (:obj:torch.Tensor or None): If action is None, use argmax q_value index as action to calculate agent_q_act. - single_step (:obj:bool): Whether single_step forward, if so, add timestep dim before forward and remove it after forward. - Q_star (:obj:bool): Whether Q_star network forward. If True, using the Q_star network, where the agent networks have the same architecture as Q network but do not share parameters and the mixing network is a feedforward network with 3 hidden layers of 256 dim; if False, using the Q network, same as the Q network in Qmix paper. Returns: - ret (:obj:dict): Output data dict with keys [total_q, logit, next_state]. - total_q (:obj:torch.Tensor): Total q_value, which is the result of mixer network. - agent_q (:obj:torch.Tensor): Each agent q_value. - next_state (:obj:list): Next rnn state. Shapes: - agent_state (:obj:torch.Tensor): :math:(T, B, A, N), where T is timestep, B is batch_size A is agent_num, N is obs_shape. - global_state (:obj:torch.Tensor): :math:(T, B, M), where M is global_obs_shape. - prev_state (:obj:list): math:(T, B, A), a list of length B, and each element is a list of length A. - action (:obj:torch.Tensor): :math:(T, B, A). - total_q (:obj:torch.Tensor): :math:(T, B). - agent_q (:obj:torch.Tensor): :math:(T, B, A, P), where P is action_shape. - next_state (:obj:list): math:(T, B, A), a list of length B, and each element is a list of length A.

Full Source Code

../ding/model/template/wqmix.py

1from typing import Union, List 2import torch 3import torch.nn as nn 4import torch.nn.functional as F 5from functools import reduce 6from ding.utils import list_split, MODEL_REGISTRY 7from ding.torch_utils.network.nn_module import fc_block, MLP 8from ding.torch_utils.network.transformer import ScaledDotProductAttention 9from .q_learning import DRQN 10from ding.model.template.qmix import Mixer 11 12 13class MixerStar(nn.Module): 14 """ 15 Overview: 16 Mixer network for Q_star in WQMIX(https://arxiv.org/abs/2006.10800), which mix up the independent q_value of \ 17 each agent to a total q_value and is diffrent from the QMIX's mixer network, \ 18 here the mixing network is a feedforward network with 3 hidden layers of 256 dim. \ 19 This Q_star mixing network is not constrained to be monotonic by using non-negative weights and \ 20 having the state and agent_q be inputs, as opposed to having hypernetworks take the state as input \ 21 and generate the weights in QMIX. 22 Interface: 23 ``__init__``, ``forward``. 24 """ 25 26 def __init__(self, agent_num: int, state_dim: int, mixing_embed_dim: int) -> None: 27 """ 28 Overview: 29 Initialize the mixer network of Q_star in WQMIX. 30 Arguments: 31 - agent_num (:obj:`int`): The number of agent, e.g., 8. 32 - state_dim(:obj:`int`): The dimension of global observation state, e.g., 16. 33 - mixing_embed_dim (:obj:`int`): The dimension of mixing state emdedding, e.g., 128. 34 """ 35 super(MixerStar, self).__init__() 36 self.agent_num = agent_num 37 self.state_dim = state_dim 38 self.embed_dim = mixing_embed_dim 39 self.input_dim = self.agent_num + self.state_dim # shape N+A 40 non_lin = nn.ReLU() 41 self.net = nn.Sequential( 42 nn.Linear(self.input_dim, self.embed_dim), non_lin, nn.Linear(self.embed_dim, self.embed_dim), non_lin, 43 nn.Linear(self.embed_dim, self.embed_dim), non_lin, nn.Linear(self.embed_dim, 1) 44 ) 45 46 # V(s) instead of a bias for the last layers 47 self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), non_lin, nn.Linear(self.embed_dim, 1)) 48 49 def forward(self, agent_qs: torch.FloatTensor, states: torch.FloatTensor) -> torch.FloatTensor: 50 """ 51 Overview: 52 Forward computation graph of the mixer network for Q_star in WQMIX. This mixer network for \ 53 is a feed-forward network that takes the state and the appropriate actions' utilities as input. 54 Arguments: 55 - agent_qs (:obj:`torch.FloatTensor`): The independent q_value of each agent. 56 - states (:obj:`torch.FloatTensor`): The emdedding vector of global state. 57 Returns: 58 - q_tot (:obj:`torch.FloatTensor`): The total mixed q_value. 59 Shapes: 60 - agent_qs (:obj:`torch.FloatTensor`): :math:`(T,B, N)`, where T is timestep, \ 61 B is batch size, A is agent_num, N is obs_shape. 62 - states (:obj:`torch.FloatTensor`): :math:`(T, B, M)`, where M is global_obs_shape. 63 - q_tot (:obj:`torch.FloatTensor`): :math:`(T, B, )`. 64 """ 65 # in below annotations about the shape of the variables, T is timestep, 66 # B is batch_size A is agent_num, N is obs_shape, for example, 67 # in 3s5z, we can set T=10, B=32, A=8, N=216 68 bs = agent_qs.shape[:-1] # (T*B, A) 69 states = states.reshape(-1, self.state_dim) # T*B, N), 70 agent_qs = agent_qs.reshape(-1, self.agent_num) # (T, B, A) -> (T*B, A) 71 inputs = torch.cat([states, agent_qs], dim=1) # (T*B, N) (T*B, A)-> (T*B, N+A) 72 advs = self.net(inputs) # (T*B, 1) 73 vs = self.V(states) # (T*B, 1) 74 y = advs + vs 75 q_tot = y.view(*bs) # (T*B, 1) -> (T, B) 76 77 return q_tot 78 79 80@MODEL_REGISTRY.register('wqmix') 81class WQMix(nn.Module): 82 """ 83 Overview: 84 WQMIX (https://arxiv.org/abs/2006.10800) network, There are two components: \ 85 1) Q_tot, which is same as QMIX network and composed of agent Q network and mixer network. \ 86 2) An unrestricted joint action Q_star, which is composed of agent Q network and mixer_star network. \ 87 The QMIX paper mentions that all agents share local Q network parameters, so only one Q network is initialized \ 88 in Q_tot or Q_star. 89 Interface: 90 ``__init__``, ``forward``. 91 """ 92 93 def __init__( 94 self, 95 agent_num: int, 96 obs_shape: int, 97 global_obs_shape: int, 98 action_shape: int, 99 hidden_size_list: list, 100 lstm_type: str = 'gru', 101 dueling: bool = False 102 ) -> None: 103 """ 104 Overview: 105 Initialize WQMIX neural network according to arguments, i.e. agent Q network and mixer, \ 106 Q_star network and mixer_star. 107 Arguments: 108 - agent_num (:obj:`int`): The number of agent, such as 8. 109 - obs_shape (:obj:`int`): The dimension of each agent's observation state, such as 8. 110 - global_obs_shape (:obj:`int`): The dimension of global observation state, such as 8. 111 - action_shape (:obj:`int`): The dimension of action shape, such as 6. 112 - hidden_size_list (:obj:`list`): The list of hidden size for ``q_network``, \ 113 the last element must match mixer's ``mixing_embed_dim``. 114 - lstm_type (:obj:`str`): The type of RNN module in ``q_network``, now support \ 115 ['normal', 'pytorch', 'gru'], default to gru. 116 - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \ 117 default to False. 118 """ 119 super(WQMix, self).__init__() 120 self._act = nn.ReLU() 121 self._q_network = DRQN(obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling) 122 self._q_network_star = DRQN(obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling) 123 embedding_size = hidden_size_list[-1] 124 self._mixer = Mixer(agent_num, global_obs_shape, mixing_embed_dim=embedding_size) 125 self._mixer_star = MixerStar( 126 agent_num, global_obs_shape, mixing_embed_dim=256 127 ) # the mixing network of Q_star is a feedforward network with 3 hidden layers of 256 dim 128 self._global_state_encoder = nn.Identity() # nn.Sequential() 129 130 def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> dict: 131 """ 132 Overview: 133 Forward computation graph of qmix network. Input dict including time series observation and \ 134 related data to predict total q_value and each agent q_value. Determine whether to calculate \ 135 Q_tot or Q_star based on the ``q_star`` parameter. 136 Arguments: 137 - data (:obj:`dict`): Input data dict with keys ['obs', 'prev_state', 'action']. 138 - agent_state (:obj:`torch.Tensor`): Time series local observation data of each agents. 139 - global_state (:obj:`torch.Tensor`): Time series global observation data. 140 - prev_state (:obj:`list`): Previous rnn state for ``q_network`` or ``_q_network_star``. 141 - action (:obj:`torch.Tensor` or None): If action is None, use argmax q_value index as action to\ 142 calculate ``agent_q_act``. 143 - single_step (:obj:`bool`): Whether single_step forward, if so, add timestep dim before forward and\ 144 remove it after forward. 145 - Q_star (:obj:`bool`): Whether Q_star network forward. If True, using the Q_star network, where the\ 146 agent networks have the same architecture as Q network but do not share parameters and the mixing\ 147 network is a feedforward network with 3 hidden layers of 256 dim; if False, using the Q network,\ 148 same as the Q network in Qmix paper. 149 Returns: 150 - ret (:obj:`dict`): Output data dict with keys [``total_q``, ``logit``, ``next_state``]. 151 - total_q (:obj:`torch.Tensor`): Total q_value, which is the result of mixer network. 152 - agent_q (:obj:`torch.Tensor`): Each agent q_value. 153 - next_state (:obj:`list`): Next rnn state. 154 Shapes: 155 - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\ 156 A is agent_num, N is obs_shape. 157 - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape. 158 - prev_state (:obj:`list`): math:`(T, B, A)`, a list of length B, and each element is a list of length A. 159 - action (:obj:`torch.Tensor`): :math:`(T, B, A)`. 160 - total_q (:obj:`torch.Tensor`): :math:`(T, B)`. 161 - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape. 162 - next_state (:obj:`list`): math:`(T, B, A)`, a list of length B, and each element is a list of length A. 163 """ 164 if q_star: # forward using Q_star network 165 agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[ 166 'prev_state'] 167 action = data.get('action', None) 168 if single_step: 169 agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0) 170 T, B, A = agent_state.shape[:3] 171 assert len(prev_state) == B and all( 172 [len(p) == A for p in prev_state] 173 ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0])) 174 prev_state = reduce(lambda x, y: x + y, prev_state) 175 agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) 176 output = self._q_network_star( 177 { 178 'obs': agent_state, 179 'prev_state': prev_state, 180 } 181 ) # here is the forward pass of the agent networks of Q_star 182 agent_q, next_state = output['logit'], output['next_state'] 183 next_state, _ = list_split(next_state, step=A) 184 agent_q = agent_q.reshape(T, B, A, -1) 185 if action is None: 186 # For target forward process 187 if len(data['obs']['action_mask'].shape) == 3: 188 action_mask = data['obs']['action_mask'].unsqueeze(0) 189 else: 190 action_mask = data['obs']['action_mask'] 191 agent_q[action_mask == 0.0] = -9999999 192 action = agent_q.argmax(dim=-1) 193 agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) 194 agent_q_act = agent_q_act.squeeze(-1) # T, B, A 195 196 global_state_embedding = self._global_state_encoder(global_state) 197 total_q = self._mixer_star( 198 agent_q_act, global_state_embedding 199 ) # here is the forward pass of the mixer networks of Q_star 200 201 if single_step: 202 total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0) 203 return { 204 'total_q': total_q, 205 'logit': agent_q, 206 'next_state': next_state, 207 'action_mask': data['obs']['action_mask'] 208 } 209 else: # forward using Q network 210 agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[ 211 'prev_state'] 212 action = data.get('action', None) 213 if single_step: 214 agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0) 215 T, B, A = agent_state.shape[:3] 216 assert len(prev_state) == B and all( 217 [len(p) == A for p in prev_state] 218 ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0])) 219 prev_state = reduce(lambda x, y: x + y, prev_state) 220 agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) 221 output = self._q_network( 222 { 223 'obs': agent_state, 224 'prev_state': prev_state, 225 } 226 ) # here is the forward pass of the agent networks of Q 227 agent_q, next_state = output['logit'], output['next_state'] 228 next_state, _ = list_split(next_state, step=A) 229 agent_q = agent_q.reshape(T, B, A, -1) 230 if action is None: 231 # For target forward process 232 if len(data['obs']['action_mask'].shape) == 3: 233 action_mask = data['obs']['action_mask'].unsqueeze(0) 234 else: 235 action_mask = data['obs']['action_mask'] 236 agent_q[action_mask == 0.0] = -9999999 237 action = agent_q.argmax(dim=-1) 238 agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) 239 agent_q_act = agent_q_act.squeeze(-1) # T, B, A 240 241 global_state_embedding = self._global_state_encoder(global_state) 242 total_q = self._mixer( 243 agent_q_act, global_state_embedding 244 ) # here is the forward pass of the mixer networks of Q 245 246 if single_step: 247 total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0) 248 return { 249 'total_q': total_q, 250 'logit': agent_q, 251 'next_state': next_state, 252 'action_mask': data['obs']['action_mask'] 253 }