Skip to content

ding.model.template.collaq

ding.model.template.collaq

CollaQMultiHeadAttention

Bases: Module

Overview

The head of collaq attention module.

Interface: __init__, forward

__init__(n_head, d_model_q, d_model_v, d_k, d_v, d_out, dropout=0.0, activation=nn.ReLU())

Overview

initialize the head of collaq attention module

Arguments: - n_head (:obj:int): the num of head - d_model_q (:obj:int): the size of input q - d_model_v (:obj:int): the size of input v - d_k (:obj:int): the size of k, used by Scaled Dot Product Attention - d_v (:obj:int): the size of v, used by Scaled Dot Product Attention - d_out (:obj:int): the size of output q - dropout (:obj:float): Dropout ratio, defaults to 0. - activation (:obj:nn.Module): Activation in FFN after attention.

forward(q, k, v, mask=None)

Overview

forward computation graph of collaQ multi head attention net.

Arguments: - q (:obj:torch.nn.Sequential): the transformer information q - k (:obj:torch.nn.Sequential): the transformer information k - v (:obj:torch.nn.Sequential): the transformer information v Returns: - q (:obj:torch.nn.Sequential): the transformer output q - residual (:obj:torch.nn.Sequential): the transformer output residual Shapes: - q (:obj:torch.nn.Sequential): :math:(B, L, N) where B is batch_size, L is sequence length, N is the size of input q - k (:obj:torch.nn.Sequential): :math:(B, L, N) where B is batch_size, L is sequence length, N is the size of input k - v (:obj:torch.nn.Sequential): :math:(B, L, N) where B is batch_size, L is sequence length, N is the size of input v - q (:obj:torch.nn.Sequential): :math:(B, L, N) where B is batch_size, L is sequence length, N is the size of output q - residual (:obj:torch.nn.Sequential): :math:(B, L, N) where B is batch_size, L is sequence length, N is the size of output residual Examples: >>> net = CollaQMultiHeadAttention(1, 2, 3, 4, 5, 6) >>> q = torch.randn(1, 2, 2) >>> k = torch.randn(1, 3, 3) >>> v = torch.randn(1, 3, 3) >>> q, residual = net(q, k, v)

CollaQSMACAttentionModule

Bases: Module

Overview

Collaq attention module. Used to get agent's attention observation. It includes agent's observation and agent's part of the observation information of the agent's concerned allies

Interface: __init__, _cut_obs, forward

__init__(q_dim, v_dim, self_feature_range, ally_feature_range, attention_size, activation=nn.ReLU())

Overview

initialize collaq attention module

Arguments: - q_dim (:obj:int): the dimension of transformer output q - v_dim (:obj:int): the dimension of transformer output v - self_features (:obj:torch.Tensor): output self agent's attention observation - ally_features (:obj:torch.Tensor): output ally agent's attention observation - attention_size (:obj:int): the size of attention net layer - activation (:obj:nn.Module): Activation in FFN after attention.

forward(inputs)

Overview

forward computation to get agent's attention observation information

Arguments: - obs (:obj:torch.Tensor): input each agent's observation Returns: - obs (:obj:torch.Tensor): output agent's attention observation Shapes: - obs (:obj:torch.Tensor): :math:(T, B, A, N) where T is timestep, B is batch_size, A is agent_num, N is obs_shape

CollaQ

Bases: Module

Overview

The network of CollaQ (Collaborative Q-learning) algorithm. It includes two parts: q_network and q_alone_network. The q_network is used to get the q_value of the agent's observation and the agent's part of the observation information of the agent's concerned allies. The q_alone_network is used to get the q_value of the agent's observation and the agent's observation information without the agent's concerned allies. Multi-Agent Collaboration via Reward Attribution Decomposition https://arxiv.org/abs/2010.08531

Interface: __init__, forward, _setup_global_encoder

__init__(agent_num, obs_shape, alone_obs_shape, global_obs_shape, action_shape, hidden_size_list, attention=False, self_feature_range=None, ally_feature_range=None, attention_size=32, mixer=True, lstm_type='gru', activation=nn.ReLU(), dueling=False)

Overview

Initialize Collaq network.

Arguments: - agent_num (:obj:int): the number of agent - obs_shape (:obj:int): the dimension of each agent's observation state - alone_obs_shape (:obj:int): the dimension of each agent's observation state without other agents - global_obs_shape (:obj:int): the dimension of global observation state - action_shape (:obj:int): the dimension of action shape - hidden_size_list (:obj:list): the list of hidden size - attention (:obj:bool): use attention module or not, default to False - self_feature_range (:obj:Union[List[int], None]): the agent's feature range - ally_feature_range (:obj:Union[List[int], None]): the agent ally's feature range - attention_size (:obj:int): the size of attention net layer - mixer (:obj:bool): use mixer net or not, default to True - lstm_type (:obj:str): use lstm or gru, default to gru - activation (:obj:nn.Module): Activation function in network, defaults to nn.ReLU(). - dueling (:obj:bool): use dueling head or not, default to False.

forward(data, single_step=True)

Overview

The forward method calculates the q_value of each agent and the total q_value of all agents. The q_value of each agent is calculated by the q_network, and the total q_value is calculated by the mixer.

Arguments: - data (:obj:dict): input data dict with keys ['obs', 'prev_state', 'action'] - agent_state (:obj:torch.Tensor): each agent local state(obs) - agent_alone_state (:obj:torch.Tensor): each agent's local state alone, in smac setting is without ally feature(obs_along) - global_state (:obj:torch.Tensor): global state(obs) - prev_state (:obj:list): previous rnn state, should include 3 parts: one hidden state of q_network, and two hidden state if q_alone_network for obs and obs_alone inputs - action (:obj:torch.Tensor or None): if action is None, use argmax q_value index as action to calculate agent_q_act - single_step (:obj:bool): whether single_step forward, if so, add timestep dim before forward and remove it after forward Return: - ret (:obj:dict): output data dict with keys ['total_q', 'logit', 'next_state'] - total_q (:obj:torch.Tensor): total q_value, which is the result of mixer network - agent_q (:obj:torch.Tensor): each agent q_value - next_state (:obj:list): next rnn state Shapes: - agent_state (:obj:torch.Tensor): :math:(T, B, A, N), where T is timestep, B is batch_size A is agent_num, N is obs_shape - global_state (:obj:torch.Tensor): :math:(T, B, M), where M is global_obs_shape - prev_state (:obj:list): math:(B, A), a list of length B, and each element is a list of length A - action (:obj:torch.Tensor): :math:(T, B, A) - total_q (:obj:torch.Tensor): :math:(T, B) - agent_q (:obj:torch.Tensor): :math:(T, B, A, P), where P is action_shape - next_state (:obj:list): math:(B, A), a list of length B, and each element is a list of length A Examples: >>> collaQ_model = CollaQ( >>> agent_num=4, >>> obs_shape=32, >>> alone_obs_shape=24, >>> global_obs_shape=32 * 4, >>> action_shape=9, >>> hidden_size_list=[128, 64], >>> self_feature_range=[8, 10], >>> ally_feature_range=[10, 16], >>> attention_size=64, >>> mixer=True, >>> activation=torch.nn.Tanh() >>> ) >>> data={ >>> 'obs': { >>> 'agent_state': torch.randn(8, 4, 4, 32), >>> 'agent_alone_state': torch.randn(8, 4, 4, 24), >>> 'agent_alone_padding_state': torch.randn(8, 4, 4, 32), >>> 'global_state': torch.randn(8, 4, 32 * 4), >>> 'action_mask': torch.randint(0, 2, size=(8, 4, 4, 9)) >>> }, >>> 'prev_state': [[[None for _ in range(4)] for _ in range(3)] for _ in range(4)], >>> 'action': torch.randint(0, 9, size=(8, 4, 4)) >>> } >>> output = collaQ_model(data, single_step=False)

Full Source Code

../ding/model/template/collaq.py

1from typing import Union, List 2import torch 3import torch.nn as nn 4import torch.nn.functional as F 5from functools import reduce 6from ding.utils import list_split, MODEL_REGISTRY 7from ding.torch_utils import fc_block, MLP, ScaledDotProductAttention 8from .q_learning import DRQN 9from .qmix import Mixer 10 11 12class CollaQMultiHeadAttention(nn.Module): 13 """ 14 Overview: 15 The head of collaq attention module. 16 Interface: 17 ``__init__``, ``forward`` 18 """ 19 20 def __init__( 21 self, 22 n_head: int, 23 d_model_q: int, 24 d_model_v: int, 25 d_k: int, 26 d_v: int, 27 d_out: int, 28 dropout: float = 0., 29 activation: nn.Module = nn.ReLU() 30 ): 31 """ 32 Overview: 33 initialize the head of collaq attention module 34 Arguments: 35 - n_head (:obj:`int`): the num of head 36 - d_model_q (:obj:`int`): the size of input q 37 - d_model_v (:obj:`int`): the size of input v 38 - d_k (:obj:`int`): the size of k, used by Scaled Dot Product Attention 39 - d_v (:obj:`int`): the size of v, used by Scaled Dot Product Attention 40 - d_out (:obj:`int`): the size of output q 41 - dropout (:obj:`float`): Dropout ratio, defaults to 0. 42 - activation (:obj:`nn.Module`): Activation in FFN after attention. 43 """ 44 super(CollaQMultiHeadAttention, self).__init__() 45 46 self.act = activation 47 48 self.n_head = n_head 49 self.d_k = d_k 50 self.d_v = d_v 51 52 self.w_qs = nn.Linear(d_model_q, n_head * d_k) 53 self.w_ks = nn.Linear(d_model_v, n_head * d_k) 54 self.w_vs = nn.Linear(d_model_v, n_head * d_v) 55 56 self.fc1 = fc_block(n_head * d_v, n_head * d_v, activation=self.act) 57 self.fc2 = fc_block(n_head * d_v, d_out) 58 59 self.attention = ScaledDotProductAttention(d_k=d_k) 60 self.layer_norm_q = nn.LayerNorm(n_head * d_k, eps=1e-6) 61 self.layer_norm_k = nn.LayerNorm(n_head * d_k, eps=1e-6) 62 self.layer_norm_v = nn.LayerNorm(n_head * d_v, eps=1e-6) 63 64 def forward(self, q, k, v, mask=None): 65 """ 66 Overview: 67 forward computation graph of collaQ multi head attention net. 68 Arguments: 69 - q (:obj:`torch.nn.Sequential`): the transformer information q 70 - k (:obj:`torch.nn.Sequential`): the transformer information k 71 - v (:obj:`torch.nn.Sequential`): the transformer information v 72 Returns: 73 - q (:obj:`torch.nn.Sequential`): the transformer output q 74 - residual (:obj:`torch.nn.Sequential`): the transformer output residual 75 Shapes: 76 - q (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ 77 N is the size of input q 78 - k (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ 79 N is the size of input k 80 - v (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ 81 N is the size of input v 82 - q (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ 83 N is the size of output q 84 - residual (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ 85 N is the size of output residual 86 Examples: 87 >>> net = CollaQMultiHeadAttention(1, 2, 3, 4, 5, 6) 88 >>> q = torch.randn(1, 2, 2) 89 >>> k = torch.randn(1, 3, 3) 90 >>> v = torch.randn(1, 3, 3) 91 >>> q, residual = net(q, k, v) 92 """ 93 d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 94 batch_size, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 95 96 # Pass through the pre-attention projection: batch_size x len_q x (n_head * d_v) 97 # Separate different heads: batch_size x len_q x n_head x d_v 98 q = self.w_qs(q).view(batch_size, len_q, n_head, d_k) 99 k = self.w_ks(k).view(batch_size, len_k, n_head, d_k) 100 v = self.w_vs(v).view(batch_size, len_v, n_head, d_v) 101 residual = q 102 103 # Transpose for attention dot product: batch_size x n_head x len_q x d_v 104 q, k, v = self.layer_norm_q(q).transpose(1, 2), self.layer_norm_k(k).transpose( 105 1, 2 106 ), self.layer_norm_v(v).transpose(1, 2) 107 # Unsqueeze the mask tensor for head axis broadcasting 108 if mask is not None: 109 mask = mask.unsqueeze(1) 110 q = self.attention(q, k, v, mask=mask) 111 112 # Transpose to move the head dimension back: batch_size x len_q x n_head x d_v 113 # Combine the last two dimensions to concatenate all the heads together: batch_size x len_q x (n*dv) 114 q = q.transpose(1, 2).contiguous().view(batch_size, len_q, -1) 115 q = self.fc2(self.fc1(q)) 116 return q, residual 117 118 119class CollaQSMACAttentionModule(nn.Module): 120 """ 121 Overview: 122 Collaq attention module. Used to get agent's attention observation. It includes agent's observation\ 123 and agent's part of the observation information of the agent's concerned allies 124 Interface: 125 ``__init__``, ``_cut_obs``, ``forward`` 126 """ 127 128 def __init__( 129 self, 130 q_dim: int, 131 v_dim: int, 132 self_feature_range: List[int], 133 ally_feature_range: List[int], 134 attention_size: int, 135 activation: nn.Module = nn.ReLU() 136 ): 137 """ 138 Overview: 139 initialize collaq attention module 140 Arguments: 141 - q_dim (:obj:`int`): the dimension of transformer output q 142 - v_dim (:obj:`int`): the dimension of transformer output v 143 - self_features (:obj:`torch.Tensor`): output self agent's attention observation 144 - ally_features (:obj:`torch.Tensor`): output ally agent's attention observation 145 - attention_size (:obj:`int`): the size of attention net layer 146 - activation (:obj:`nn.Module`): Activation in FFN after attention. 147 """ 148 super(CollaQSMACAttentionModule, self).__init__() 149 self.self_feature_range = self_feature_range 150 self.ally_feature_range = ally_feature_range 151 self.attention_layer = CollaQMultiHeadAttention( 152 1, q_dim, v_dim, attention_size, attention_size, attention_size, activation=activation 153 ) 154 155 def _cut_obs(self, obs: torch.Tensor): 156 """ 157 Overview: 158 cut the observed information into self's observation and allay's observation 159 Arguments: 160 - obs (:obj:`torch.Tensor`): input each agent's observation 161 Returns: 162 - self_features (:obj:`torch.Tensor`): output self agent's attention observation 163 - ally_features (:obj:`torch.Tensor`): output ally agent's attention observation 164 Shapes: 165 - obs (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \ 166 A is agent_num, N is obs_shape 167 - self_features (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \ 168 A is agent_num, N is self_feature_range[1] - self_feature_range[0] 169 - ally_features (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \ 170 A is agent_num, N is ally_feature_range[1] - ally_feature_range[0] 171 """ 172 # obs shape = (T, B, A, obs_shape) 173 self_features = obs[:, :, :, self.self_feature_range[0]:self.self_feature_range[1]] 174 ally_features = obs[:, :, :, self.ally_feature_range[0]:self.ally_feature_range[1]] 175 return self_features, ally_features 176 177 def forward(self, inputs: torch.Tensor): 178 """ 179 Overview: 180 forward computation to get agent's attention observation information 181 Arguments: 182 - obs (:obj:`torch.Tensor`): input each agent's observation 183 Returns: 184 - obs (:obj:`torch.Tensor`): output agent's attention observation 185 Shapes: 186 - obs (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \ 187 A is agent_num, N is obs_shape 188 """ 189 # obs shape = (T, B ,A, obs_shape) 190 obs = inputs 191 self_features, ally_features = self._cut_obs(obs) 192 T, B, A, _ = self_features.shape 193 self_features = self_features.reshape(T * B * A, 1, -1) 194 ally_features = ally_features.reshape(T * B * A, A - 1, -1) 195 self_features, ally_features = self.attention_layer(self_features, ally_features, ally_features) 196 self_features = self_features.reshape(T, B, A, -1) 197 ally_features = ally_features.reshape(T, B, A, -1) 198 # note: we assume self_feature is near the ally_feature here so we can do this concat 199 obs = torch.cat( 200 [ 201 obs[:, :, :, :self.self_feature_range[0]], self_features, ally_features, 202 obs[:, :, :, self.ally_feature_range[1]:] 203 ], 204 dim=-1 205 ) 206 return obs 207 208 209@MODEL_REGISTRY.register('collaq') 210class CollaQ(nn.Module): 211 """ 212 Overview: 213 The network of CollaQ (Collaborative Q-learning) algorithm. 214 It includes two parts: q_network and q_alone_network. 215 The q_network is used to get the q_value of the agent's observation and \ 216 the agent's part of the observation information of the agent's concerned allies. 217 The q_alone_network is used to get the q_value of the agent's observation and \ 218 the agent's observation information without the agent's concerned allies. 219 Multi-Agent Collaboration via Reward Attribution Decomposition 220 https://arxiv.org/abs/2010.08531 221 Interface: 222 ``__init__``, ``forward``, ``_setup_global_encoder`` 223 """ 224 225 def __init__( 226 self, 227 agent_num: int, 228 obs_shape: int, 229 alone_obs_shape: int, 230 global_obs_shape: int, 231 action_shape: int, 232 hidden_size_list: list, 233 attention: bool = False, 234 self_feature_range: Union[List[int], None] = None, 235 ally_feature_range: Union[List[int], None] = None, 236 attention_size: int = 32, 237 mixer: bool = True, 238 lstm_type: str = 'gru', 239 activation: nn.Module = nn.ReLU(), 240 dueling: bool = False, 241 ) -> None: 242 """ 243 Overview: 244 Initialize Collaq network. 245 Arguments: 246 - agent_num (:obj:`int`): the number of agent 247 - obs_shape (:obj:`int`): the dimension of each agent's observation state 248 - alone_obs_shape (:obj:`int`): the dimension of each agent's observation state without\ 249 other agents 250 - global_obs_shape (:obj:`int`): the dimension of global observation state 251 - action_shape (:obj:`int`): the dimension of action shape 252 - hidden_size_list (:obj:`list`): the list of hidden size 253 - attention (:obj:`bool`): use attention module or not, default to False 254 - self_feature_range (:obj:`Union[List[int], None]`): the agent's feature range 255 - ally_feature_range (:obj:`Union[List[int], None]`): the agent ally's feature range 256 - attention_size (:obj:`int`): the size of attention net layer 257 - mixer (:obj:`bool`): use mixer net or not, default to True 258 - lstm_type (:obj:`str`): use lstm or gru, default to gru 259 - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU(). 260 - dueling (:obj:`bool`): use dueling head or not, default to False. 261 """ 262 super(CollaQ, self).__init__() 263 self.attention = attention 264 self.attention_size = attention_size 265 self._act = activation 266 self.mixer = mixer 267 if not self.attention: 268 self._q_network = DRQN( 269 obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling, activation=activation 270 ) 271 else: 272 # TODO set the attention layer here beautifully 273 self._self_attention = CollaQSMACAttentionModule( 274 self_feature_range[1] - self_feature_range[0], 275 (ally_feature_range[1] - ally_feature_range[0]) // (agent_num - 1), 276 self_feature_range, 277 ally_feature_range, 278 attention_size, 279 activation=activation 280 ) 281 # TODO get the obs_dim_after_attention here beautifully 282 obs_shape_after_attention = self._self_attention( 283 # torch.randn( 284 # 1, 1, (ally_feature_range[1] - ally_feature_range[0]) // 285 # ((self_feature_range[1] - self_feature_range[0])*2) + 1, obs_dim 286 # ) 287 torch.randn(1, 1, agent_num, obs_shape) 288 ).shape[-1] 289 self._q_network = DRQN( 290 obs_shape_after_attention, 291 action_shape, 292 hidden_size_list, 293 lstm_type=lstm_type, 294 dueling=dueling, 295 activation=activation 296 ) 297 self._q_alone_network = DRQN( 298 alone_obs_shape, 299 action_shape, 300 hidden_size_list, 301 lstm_type=lstm_type, 302 dueling=dueling, 303 activation=activation 304 ) 305 embedding_size = hidden_size_list[-1] 306 if self.mixer: 307 self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) 308 self._global_state_encoder = nn.Identity() 309 310 def forward(self, data: dict, single_step: bool = True) -> dict: 311 """ 312 Overview: 313 The forward method calculates the q_value of each agent and the total q_value of all agents. 314 The q_value of each agent is calculated by the q_network, and the total q_value is calculated by the mixer. 315 Arguments: 316 - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] 317 - agent_state (:obj:`torch.Tensor`): each agent local state(obs) 318 - agent_alone_state (:obj:`torch.Tensor`): each agent's local state alone, \ 319 in smac setting is without ally feature(obs_along) 320 - global_state (:obj:`torch.Tensor`): global state(obs) 321 - prev_state (:obj:`list`): previous rnn state, should include 3 parts: \ 322 one hidden state of q_network, and two hidden state if q_alone_network for obs and obs_alone inputs 323 - action (:obj:`torch.Tensor` or None): if action is None, use argmax q_value index as action to\ 324 calculate ``agent_q_act`` 325 - single_step (:obj:`bool`): whether single_step forward, if so, add timestep dim before forward and\ 326 remove it after forward 327 Return: 328 - ret (:obj:`dict`): output data dict with keys ['total_q', 'logit', 'next_state'] 329 - total_q (:obj:`torch.Tensor`): total q_value, which is the result of mixer network 330 - agent_q (:obj:`torch.Tensor`): each agent q_value 331 - next_state (:obj:`list`): next rnn state 332 Shapes: 333 - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\ 334 A is agent_num, N is obs_shape 335 - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape 336 - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A 337 - action (:obj:`torch.Tensor`): :math:`(T, B, A)` 338 - total_q (:obj:`torch.Tensor`): :math:`(T, B)` 339 - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape 340 - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A 341 Examples: 342 >>> collaQ_model = CollaQ( 343 >>> agent_num=4, 344 >>> obs_shape=32, 345 >>> alone_obs_shape=24, 346 >>> global_obs_shape=32 * 4, 347 >>> action_shape=9, 348 >>> hidden_size_list=[128, 64], 349 >>> self_feature_range=[8, 10], 350 >>> ally_feature_range=[10, 16], 351 >>> attention_size=64, 352 >>> mixer=True, 353 >>> activation=torch.nn.Tanh() 354 >>> ) 355 >>> data={ 356 >>> 'obs': { 357 >>> 'agent_state': torch.randn(8, 4, 4, 32), 358 >>> 'agent_alone_state': torch.randn(8, 4, 4, 24), 359 >>> 'agent_alone_padding_state': torch.randn(8, 4, 4, 32), 360 >>> 'global_state': torch.randn(8, 4, 32 * 4), 361 >>> 'action_mask': torch.randint(0, 2, size=(8, 4, 4, 9)) 362 >>> }, 363 >>> 'prev_state': [[[None for _ in range(4)] for _ in range(3)] for _ in range(4)], 364 >>> 'action': torch.randint(0, 9, size=(8, 4, 4)) 365 >>> } 366 >>> output = collaQ_model(data, single_step=False) 367 """ 368 agent_state, agent_alone_state = data['obs']['agent_state'], data['obs']['agent_alone_state'] 369 agent_alone_padding_state = data['obs']['agent_alone_padding_state'] 370 global_state, prev_state = data['obs']['global_state'], data['prev_state'] 371 # TODO find a better way to implement agent_along_padding_state 372 373 action = data.get('action', None) 374 if single_step: 375 agent_state, agent_alone_state, agent_alone_padding_state, global_state = agent_state.unsqueeze( 376 0 377 ), agent_alone_state.unsqueeze(0), agent_alone_padding_state.unsqueeze(0), global_state.unsqueeze(0) 378 T, B, A = agent_state.shape[:3] 379 380 if self.attention: 381 agent_state = self._self_attention(agent_state) 382 agent_alone_padding_state = self._self_attention(agent_alone_padding_state) 383 384 # prev state should be of size (B, 3, A) hidden_size) 385 """ 386 Note: to achieve such work, we should change the init_fn of hidden_state plugin in collaQ policy 387 """ 388 assert len(prev_state) == B and all([len(p) == 3 for p in prev_state]) and all( 389 [len(q) == A] for p in prev_state for q in p 390 ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0])) 391 392 alone_prev_state = [[None for _ in range(A)] for _ in range(B)] 393 colla_prev_state = [[None for _ in range(A)] for _ in range(B)] 394 colla_alone_prev_state = [[None for _ in range(A)] for _ in range(B)] 395 396 for i in range(B): 397 for j in range(3): 398 for k in range(A): 399 if j == 0: 400 alone_prev_state[i][k] = prev_state[i][j][k] 401 elif j == 1: 402 colla_prev_state[i][k] = prev_state[i][j][k] 403 elif j == 2: 404 colla_alone_prev_state[i][k] = prev_state[i][j][k] 405 406 alone_prev_state = reduce(lambda x, y: x + y, alone_prev_state) 407 colla_prev_state = reduce(lambda x, y: x + y, colla_prev_state) 408 colla_alone_prev_state = reduce(lambda x, y: x + y, colla_alone_prev_state) 409 410 agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) 411 agent_alone_state = agent_alone_state.reshape(T, -1, *agent_alone_state.shape[3:]) 412 agent_alone_padding_state = agent_alone_padding_state.reshape(T, -1, *agent_alone_padding_state.shape[3:]) 413 414 colla_output = self._q_network({ 415 'obs': agent_state, 416 'prev_state': colla_prev_state, 417 }) 418 colla_alone_output = self._q_network( 419 { 420 'obs': agent_alone_padding_state, 421 'prev_state': colla_alone_prev_state, 422 } 423 ) 424 alone_output = self._q_alone_network({ 425 'obs': agent_alone_state, 426 'prev_state': alone_prev_state, 427 }) 428 429 agent_alone_q, alone_next_state = alone_output['logit'], alone_output['next_state'] 430 agent_colla_alone_q, colla_alone_next_state = colla_alone_output['logit'], colla_alone_output['next_state'] 431 agent_colla_q, colla_next_state = colla_output['logit'], colla_output['next_state'] 432 433 colla_next_state, _ = list_split(colla_next_state, step=A) 434 alone_next_state, _ = list_split(alone_next_state, step=A) 435 colla_alone_next_state, _ = list_split(colla_alone_next_state, step=A) 436 437 next_state = list( 438 map(lambda x: [x[0], x[1], x[2]], zip(alone_next_state, colla_next_state, colla_alone_next_state)) 439 ) 440 441 agent_alone_q = agent_alone_q.reshape(T, B, A, -1) 442 agent_colla_alone_q = agent_colla_alone_q.reshape(T, B, A, -1) 443 agent_colla_q = agent_colla_q.reshape(T, B, A, -1) 444 445 total_q_before_mix = agent_alone_q + agent_colla_q - agent_colla_alone_q 446 # total_q_before_mix = agent_colla_q 447 # total_q_before_mix = agent_alone_q 448 agent_q = total_q_before_mix 449 450 if action is None: 451 # For target forward process 452 if len(data['obs']['action_mask'].shape) == 3: 453 action_mask = data['obs']['action_mask'].unsqueeze(0) 454 else: 455 action_mask = data['obs']['action_mask'] 456 agent_q[action_mask == 0.0] = -9999999 457 action = agent_q.argmax(dim=-1) 458 agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) 459 agent_q_act = agent_q_act.squeeze(-1) # T, B, A 460 if self.mixer: 461 global_state_embedding = self._global_state_encoder(global_state) 462 total_q = self._mixer(agent_q_act, global_state_embedding) 463 else: 464 total_q = agent_q_act.sum(-1) 465 if single_step: 466 total_q, agent_q, agent_colla_alone_q = total_q.squeeze(0), agent_q.squeeze(0), agent_colla_alone_q.squeeze( 467 0 468 ) 469 return { 470 'total_q': total_q, 471 'logit': agent_q, 472 'agent_colla_alone_q': agent_colla_alone_q * data['obs']['action_mask'], 473 'next_state': next_state, 474 'action_mask': data['obs']['action_mask'] 475 } 476 477 def _setup_global_encoder(self, global_obs_shape: int, embedding_size: int) -> torch.nn.Module: 478 """ 479 Overview: 480 Used to encoder global observation. 481 Arguments: 482 - global_obs_shape (:obj:`int`): the dimension of global observation state 483 - embedding_size (:obj:`int`): the dimension of state emdedding 484 Returns: 485 - outputs (:obj:`torch.nn.Module`): Global observation encoding network 486 """ 487 return MLP(global_obs_shape, embedding_size, embedding_size, 2, activation=self._act)