Skip to content

ding.model.template.atoc

ding.model.template.atoc

ATOCAttentionUnit

Bases: Module

Overview

The attention unit of the ATOC network. We now implement it as two-layer MLP, same as the original paper.

Interface: __init__, forward

.. note:: "ATOC paper: We use two-layer MLP to implement the attention unit but it is also can be realized by RNN."

__init__(thought_size, embedding_size)

Overview

Initialize the attention unit according to the size of input arguments.

Arguments: - thought_size (:obj:int): the size of input thought - embedding_size (:obj:int): the size of hidden layers

forward(data)

Overview

Take the thought of agents as input and generate the probability of these agent being initiator

Arguments: - x (:obj:Union[Dict, torch.Tensor): the input tensor or dict contain the thoughts tensor - ret (:obj:torch.Tensor): the output initiator probability Shapes: - data['thought']: :math:(M, B, N), M is the num of thoughts to integrate, B is batch_size and N is thought size Examples: >>> attention_unit = ATOCAttentionUnit(64, 64) >>> thought = torch.randn(2, 3, 64) >>> attention_unit(thought)

ATOCCommunicationNet

Bases: Module

Overview

This ATOC commnication net is a bi-direction LSTM, so it can integrate all the thoughts in the group.

Interface: __init__, forward

__init__(thought_size)

Overview

Initialize the communication network according to the size of input arguments.

Arguments: - thought_size (:obj:int): the size of input thought

.. note::

communication hidden size should be half of the actor_hidden_size because of the bi-direction lstm

forward(data)

Overview

The forward of ATOCCommunicationNet integrates thoughts in the group.

Arguments: - x (:obj:Union[Dict, torch.Tensor): the input tensor or dict contain the thoughts tensor - out (:obj:torch.Tensor): the integrated thoughts Shapes: - data['thoughts']: :math:(M, B, N), M is the num of thoughts to integrate, B is batch_size and N is thought size Examples: >>> comm_net = ATOCCommunicationNet(64) >>> thoughts = torch.randn(2, 3, 64) >>> comm_net(thoughts)

ATOCActorNet

Bases: Module

Overview

The actor network of ATOC.

Interface: __init__, forward

.. note::
    "ATOC paper: The neural networks use ReLU and batch normalization for some hidden layers."

__init__(obs_shape, thought_size, action_shape, n_agent, communication=True, agent_per_group=2, initiator_threshold=0.5, attention_embedding_size=64, actor_1_embedding_size=None, actor_2_embedding_size=None, activation=nn.ReLU(), norm_type=None)

Overview

Initialize the actor network of ATOC

Arguments: - obs_shape(:obj:Union[Tuple, int]): the observation size - thought_size (:obj:int): the size of thoughts - action_shape (:obj:int): the action size - n_agent (:obj:int): the num of agents - agent_per_group (:obj:int): the num of agent in each group - initiator_threshold (:obj:float): the threshold of becoming an initiator, default set to 0.5 - attention_embedding_size (obj:int): the embedding size of attention unit, default set to 64 - actor_1_embedding_size (:obj:Union[int, None]): the size of embedding size of actor network part1, if None, then default set to thought size - actor_2_embedding_size (:obj:Union[int, None]): the size of embedding size of actor network part2, if None, then default set to thought size

forward(obs)

Overview

Take the input obs, and calculate the corresponding action, group, initiator_prob, thoughts, etc...

Arguments: - obs (:obj:Dict): the input obs containing the observation Returns: - ret (:obj:Dict): the returned output, including action, group, initiator_prob, is_initiator, new_thoughts and old_thoughts ReturnsKeys: - necessary: action - optional: group, initiator_prob, is_initiator, new_thoughts, old_thoughts Shapes: - obs (:obj:torch.Tensor): :math:(B, A, N), where B is batch size, A is agent num, N is obs size - action (:obj:torch.Tensor): :math:(B, A, M), where M is action size - group (:obj:torch.Tensor): :math:(B, A, A) - initiator_prob (:obj:torch.Tensor): :math:(B, A) - is_initiator (:obj:torch.Tensor): :math:(B, A) - new_thoughts (:obj:torch.Tensor): :math:(B, A, M) - old_thoughts (:obj:torch.Tensor): :math:(B, A, M) Examples: >>> actor_net = ATOCActorNet(64, 64, 64, 3) >>> obs = torch.randn(2, 3, 64) >>> actor_net(obs)

ATOC

Bases: Module

Overview

The QAC network of ATOC, a kind of extension of DDPG for MARL. Learning Attentional Communication for Multi-Agent Cooperation https://arxiv.org/abs/1805.07733

Interface: __init__, forward, compute_critic, compute_actor, optimize_actor_attention

__init__(obs_shape, action_shape, thought_size, n_agent, communication=True, agent_per_group=2, actor_1_embedding_size=None, actor_2_embedding_size=None, critic_head_hidden_size=64, critic_head_layer_num=2, activation=nn.ReLU(), norm_type=None)

Overview

Initialize the ATOC QAC network

Arguments: - obs_shape(:obj:Union[Tuple, int]): the observation space shape - thought_size (:obj:int): the size of thoughts - action_shape (:obj:int): the action space shape - n_agent (:obj:int): the num of agents - agent_per_group (:obj:int): the num of agent in each group

compute_actor(obs, get_delta_q=False)

Overview

compute the action according to inputs, call the _compute_delta_q function to compute delta_q

Arguments: - obs (:obj:torch.Tensor): observation - get_delta_q (:obj:bool) : whether need to get delta_q Returns: - outputs (:obj:Dict): the output of actor network and delta_q ReturnsKeys: - necessary: action - optional: group, initiator_prob, is_initiator, new_thoughts, old_thoughts, delta_q Shapes: - obs (:obj:torch.Tensor): :math:(B, A, N), where B is batch size, A is agent num, N is obs size - action (:obj:torch.Tensor): :math:(B, A, M), where M is action size - group (:obj:torch.Tensor): :math:(B, A, A) - initiator_prob (:obj:torch.Tensor): :math:(B, A) - is_initiator (:obj:torch.Tensor): :math:(B, A) - new_thoughts (:obj:torch.Tensor): :math:(B, A, M) - old_thoughts (:obj:torch.Tensor): :math:(B, A, M) - delta_q (:obj:torch.Tensor): :math:(B, A) Examples: >>> net = ATOC(64, 64, 64, 3) >>> obs = torch.randn(2, 3, 64) >>> net.compute_actor(obs)

compute_critic(inputs)

Overview

compute the q_value according to inputs

Arguments: - inputs (:obj:Dict): the inputs contain the obs and action Returns: - outputs (:obj:Dict): the output of critic network ArgumentsKeys: - necessary: obs, action ReturnsKeys: - necessary: q_value Shapes: - obs (:obj:torch.Tensor): :math:(B, A, N), where B is batch size, A is agent num, N is obs size - action (:obj:torch.Tensor): :math:(B, A, M), where M is action size - q_value (:obj:torch.Tensor): :math:(B, A) Examples: >>> net = ATOC(64, 64, 64, 3) >>> obs = torch.randn(2, 3, 64) >>> action = torch.randn(2, 3, 64) >>> net.compute_critic({'obs': obs, 'action': action})

optimize_actor_attention(inputs)

Overview

return the actor attention loss

Arguments: - inputs (:obj:Dict): the inputs contain the delta_q, initiator_prob, and is_initiator Returns - loss (:obj:Dict): the loss of actor attention unit ArgumentsKeys: - necessary: delta_q, initiator_prob, is_initiator ReturnsKeys: - necessary: loss Shapes: - delta_q (:obj:torch.Tensor): :math:(B, A) - initiator_prob (:obj:torch.Tensor): :math:(B, A) - is_initiator (:obj:torch.Tensor): :math:(B, A) - loss (:obj:torch.Tensor): :math:(1) Examples: >>> net = ATOC(64, 64, 64, 3) >>> delta_q = torch.randn(2, 3) >>> initiator_prob = torch.randn(2, 3) >>> is_initiator = torch.randn(2, 3) >>> net.optimize_actor_attention( >>> {'delta_q': delta_q, >>> 'initiator_prob': initiator_prob, >>> 'is_initiator': is_initiator})

Full Source Code

../ding/model/template/atoc.py

1from typing import Union, Dict, Optional, Tuple 2 3import torch 4import torch.nn as nn 5 6from ding.utils import squeeze, MODEL_REGISTRY, SequenceType 7from ding.torch_utils import MLP 8from ding.model.common import RegressionHead 9 10 11class ATOCAttentionUnit(nn.Module): 12 """ 13 Overview: 14 The attention unit of the ATOC network. We now implement it as two-layer MLP, same as the original paper. 15 Interface: 16 ``__init__``, ``forward`` 17 18 .. note:: 19 "ATOC paper: We use two-layer MLP to implement the attention unit but it is also can be realized by RNN." 20 """ 21 22 def __init__(self, thought_size: int, embedding_size: int) -> None: 23 """ 24 Overview: 25 Initialize the attention unit according to the size of input arguments. 26 Arguments: 27 - thought_size (:obj:`int`): the size of input thought 28 - embedding_size (:obj:`int`): the size of hidden layers 29 """ 30 super(ATOCAttentionUnit, self).__init__() 31 self._thought_size = thought_size 32 self._hidden_size = embedding_size 33 self._output_size = 1 34 self._act1 = nn.ReLU() 35 self._fc1 = nn.Linear(self._thought_size, self._hidden_size, bias=True) 36 self._fc2 = nn.Linear(self._hidden_size, self._hidden_size, bias=True) 37 self._fc3 = nn.Linear(self._hidden_size, self._output_size, bias=True) 38 self._act2 = nn.Sigmoid() 39 40 def forward(self, data: Union[Dict, torch.Tensor]) -> torch.Tensor: 41 """ 42 Overview: 43 Take the thought of agents as input and generate the probability of these agent being initiator 44 Arguments: 45 - x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor 46 - ret (:obj:`torch.Tensor`): the output initiator probability 47 Shapes: 48 - data['thought']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\ 49 B is batch_size and N is thought size 50 Examples: 51 >>> attention_unit = ATOCAttentionUnit(64, 64) 52 >>> thought = torch.randn(2, 3, 64) 53 >>> attention_unit(thought) 54 """ 55 x = data 56 if isinstance(data, Dict): 57 x = data['thought'] 58 x = self._fc1(x) 59 x = self._act1(x) 60 x = self._fc2(x) 61 x = self._act1(x) 62 x = self._fc3(x) 63 x = self._act2(x) 64 return x.squeeze(-1) 65 66 67class ATOCCommunicationNet(nn.Module): 68 """ 69 Overview: 70 This ATOC commnication net is a bi-direction LSTM, so it can integrate all the thoughts in the group. 71 Interface: 72 ``__init__``, ``forward`` 73 """ 74 75 def __init__(self, thought_size: int) -> None: 76 """ 77 Overview: 78 Initialize the communication network according to the size of input arguments. 79 Arguments: 80 - thought_size (:obj:`int`): the size of input thought 81 82 .. note:: 83 84 communication hidden size should be half of the actor_hidden_size because of the bi-direction lstm 85 """ 86 super(ATOCCommunicationNet, self).__init__() 87 assert thought_size % 2 == 0 88 self._thought_size = thought_size 89 self._comm_hidden_size = thought_size // 2 90 self._bi_lstm = nn.LSTM(self._thought_size, self._comm_hidden_size, bidirectional=True) 91 92 def forward(self, data: Union[Dict, torch.Tensor]): 93 """ 94 Overview: 95 The forward of ATOCCommunicationNet integrates thoughts in the group. 96 Arguments: 97 - x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor 98 - out (:obj:`torch.Tensor`): the integrated thoughts 99 Shapes: 100 - data['thoughts']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\ 101 B is batch_size and N is thought size 102 Examples: 103 >>> comm_net = ATOCCommunicationNet(64) 104 >>> thoughts = torch.randn(2, 3, 64) 105 >>> comm_net(thoughts) 106 """ 107 self._bi_lstm.flatten_parameters() 108 x = data 109 if isinstance(data, Dict): 110 x = data['thoughts'] 111 out, _ = self._bi_lstm(x) 112 return out 113 114 115class ATOCActorNet(nn.Module): 116 """ 117 Overview: 118 The actor network of ATOC. 119 Interface: 120 ``__init__``, ``forward`` 121 122 .. note:: 123 "ATOC paper: The neural networks use ReLU and batch normalization for some hidden layers." 124 """ 125 126 def __init__( 127 self, 128 obs_shape: Union[Tuple, int], 129 thought_size: int, 130 action_shape: int, 131 n_agent: int, 132 communication: bool = True, 133 agent_per_group: int = 2, 134 initiator_threshold: float = 0.5, 135 attention_embedding_size: int = 64, 136 actor_1_embedding_size: Union[int, None] = None, 137 actor_2_embedding_size: Union[int, None] = None, 138 activation: Optional[nn.Module] = nn.ReLU(), 139 norm_type: Optional[str] = None, 140 ): 141 """ 142 Overview: 143 Initialize the actor network of ATOC 144 Arguments: 145 - obs_shape(:obj:`Union[Tuple, int]`): the observation size 146 - thought_size (:obj:`int`): the size of thoughts 147 - action_shape (:obj:`int`): the action size 148 - n_agent (:obj:`int`): the num of agents 149 - agent_per_group (:obj:`int`): the num of agent in each group 150 - initiator_threshold (:obj:`float`): the threshold of becoming an initiator, default set to 0.5 151 - attention_embedding_size (obj:`int`): the embedding size of attention unit, default set to 64 152 - actor_1_embedding_size (:obj:`Union[int, None]`): the size of embedding size of actor network part1, \ 153 if None, then default set to thought size 154 - actor_2_embedding_size (:obj:`Union[int, None]`): the size of embedding size of actor network part2, \ 155 if None, then default set to thought size 156 """ 157 super(ATOCActorNet, self).__init__() 158 # now only support obs_shape of shape (O_dim, ) 159 self._obs_shape = squeeze(obs_shape) 160 self._thought_size = thought_size 161 self._act_shape = action_shape 162 self._n_agent = n_agent 163 self._communication = communication 164 self._agent_per_group = agent_per_group 165 self._initiator_threshold = initiator_threshold 166 if not actor_1_embedding_size: 167 actor_1_embedding_size = self._thought_size 168 if not actor_2_embedding_size: 169 actor_2_embedding_size = self._thought_size 170 171 # Actor Net(I) 172 self.actor_1 = MLP( 173 self._obs_shape, 174 actor_1_embedding_size, 175 self._thought_size, 176 layer_num=2, 177 activation=activation, 178 norm_type=norm_type 179 ) 180 181 # Actor Net(II) 182 self.actor_2 = nn.Sequential( 183 nn.Linear(self._thought_size * 2, actor_2_embedding_size), activation, 184 RegressionHead( 185 actor_2_embedding_size, self._act_shape, 2, final_tanh=True, activation=activation, norm_type=norm_type 186 ) 187 ) 188 189 # Communication 190 if self._communication: 191 self.attention = ATOCAttentionUnit(self._thought_size, attention_embedding_size) 192 self.comm_net = ATOCCommunicationNet(self._thought_size) 193 194 def forward(self, obs: torch.Tensor) -> Dict: 195 """ 196 Overview: 197 Take the input obs, and calculate the corresponding action, group, initiator_prob, thoughts, etc... 198 Arguments: 199 - obs (:obj:`Dict`): the input obs containing the observation 200 Returns: 201 - ret (:obj:`Dict`): the returned output, including action, group, initiator_prob, is_initiator, \ 202 new_thoughts and old_thoughts 203 ReturnsKeys: 204 - necessary: ``action`` 205 - optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts`` 206 Shapes: 207 - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size 208 - action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size 209 - group (:obj:`torch.Tensor`): :math:`(B, A, A)` 210 - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` 211 - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` 212 - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` 213 - old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` 214 Examples: 215 >>> actor_net = ATOCActorNet(64, 64, 64, 3) 216 >>> obs = torch.randn(2, 3, 64) 217 >>> actor_net(obs) 218 """ 219 assert len(obs.shape) == 3 220 self._cur_batch_size = obs.shape[0] 221 B, A, N = obs.shape 222 assert A == self._n_agent 223 assert N == self._obs_shape 224 225 current_thoughts = self.actor_1(obs) # B, A, thought size 226 227 if self._communication: 228 old_thoughts = current_thoughts.clone().detach() 229 init_prob, is_initiator, group = self._get_initiate_group(old_thoughts) 230 231 new_thoughts = self._get_new_thoughts(current_thoughts, group, is_initiator) 232 else: 233 new_thoughts = current_thoughts 234 action = self.actor_2(torch.cat([current_thoughts, new_thoughts], dim=-1))['pred'] 235 236 if self._communication: 237 return { 238 'action': action, 239 'group': group, 240 'initiator_prob': init_prob, 241 'is_initiator': is_initiator, 242 'new_thoughts': new_thoughts, 243 'old_thoughts': old_thoughts, 244 } 245 else: 246 return {'action': action} 247 248 def _get_initiate_group(self, current_thoughts): 249 """ 250 Overview: 251 Calculate the initiator probability, group and is_initiator 252 Arguments: 253 - current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts 254 Returns: 255 - init_prob (:obj:`torch.Tensor`): tesnor of initiator probability 256 - is_initiator (:obj:`torch.Tensor`): tensor of is initiator 257 - group (:obj:`torch.Tensor`): tensor of group 258 Shapes: 259 - current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size 260 - init_prob (:obj:`torch.Tensor`): :math:`(B, A)` 261 - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` 262 - group (:obj:`torch.Tensor`): :math:`(B, A, A)` 263 Examples: 264 >>> actor_net = ATOCActorNet(64, 64, 64, 3) 265 >>> current_thoughts = torch.randn(2, 3, 64) 266 >>> actor_net._get_initiate_group(current_thoughts) 267 """ 268 if not self._communication: 269 raise NotImplementedError 270 init_prob = self.attention(current_thoughts) # B, A 271 is_initiator = (init_prob > self._initiator_threshold) 272 B, A = init_prob.shape[:2] 273 274 thoughts_pair_dot = current_thoughts.bmm(current_thoughts.transpose(1, 2)) 275 thoughts_square = thoughts_pair_dot.diagonal(0, 1, 2) 276 curr_thought_dists = thoughts_square.unsqueeze(1) - 2 * thoughts_pair_dot + thoughts_square.unsqueeze(2) 277 278 group = torch.zeros(B, A, A).to(init_prob.device) 279 280 # "considers the agents in its observable field" 281 # "initiator first chooses collaborators from agents who have not been selected, 282 # then from agents selected by other initiators, 283 # finally from other initiators" 284 # "all based on proximity" 285 286 # roughly choose m closest as group 287 for b in range(B): 288 for i in range(A): 289 if is_initiator[b][i]: 290 index_seq = curr_thought_dists[b][i].argsort() 291 index_seq = index_seq[:self._agent_per_group] 292 group[b][i][index_seq] = 1 293 return init_prob, is_initiator, group 294 295 def _get_new_thoughts(self, current_thoughts, group, is_initiator): 296 """ 297 Overview: 298 Calculate the new thoughts according to current thoughts, group and is_initiator 299 Arguments: 300 - current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts 301 - group (:obj:`torch.Tensor`): tensor of group 302 - is_initiator (:obj:`torch.Tensor`): tensor of is initiator 303 Returns: 304 - new_thoughts (:obj:`torch.Tensor`): tensor of new thoughts 305 Shapes: 306 - current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size 307 - group: (:obj:`torch.Tensor`): :math:`(B, A, A)` 308 - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` 309 - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` 310 Examples: 311 >>> actor_net = ATOCActorNet(64, 64, 64, 3) 312 >>> current_thoughts = torch.randn(2, 3, 64) 313 >>> group = torch.randn(2, 3, 3) 314 >>> is_initiator = torch.randn(2, 3) 315 >>> actor_net._get_new_thoughts(current_thoughts, group, is_initiator) 316 """ 317 if not self._communication: 318 raise NotImplementedError 319 B, A = current_thoughts.shape[:2] 320 new_thoughts = current_thoughts.detach().clone() 321 if len(torch.nonzero(is_initiator)) == 0: 322 return new_thoughts 323 324 # TODO(nyz) execute communication serially for shared agent in different group 325 thoughts_to_commute = [] 326 for b in range(B): 327 for i in range(A): 328 if is_initiator[b][i]: 329 tmp = [] 330 for j in range(A): 331 if group[b][i][j]: 332 tmp.append(new_thoughts[b][j]) 333 thoughts_to_commute.append(torch.stack(tmp, dim=0)) 334 thoughts_to_commute = torch.stack(thoughts_to_commute, dim=1) # agent_per_group, B_, N 335 integrated_thoughts = self.comm_net(thoughts_to_commute) 336 b_count = 0 337 for b in range(B): 338 for i in range(A): 339 if is_initiator[b][i]: 340 j_count = 0 341 for j in range(A): 342 if group[b][i][j]: 343 new_thoughts[b][j] = integrated_thoughts[j_count][b_count] 344 j_count += 1 345 b_count += 1 346 return new_thoughts 347 348 349@MODEL_REGISTRY.register('atoc') 350class ATOC(nn.Module): 351 """ 352 Overview: 353 The QAC network of ATOC, a kind of extension of DDPG for MARL. 354 Learning Attentional Communication for Multi-Agent Cooperation 355 https://arxiv.org/abs/1805.07733 356 Interface: 357 ``__init__``, ``forward``, ``compute_critic``, ``compute_actor``, ``optimize_actor_attention`` 358 """ 359 mode = ['compute_actor', 'compute_critic', 'optimize_actor_attention'] 360 361 def __init__( 362 self, 363 obs_shape: Union[int, SequenceType], 364 action_shape: Union[int, SequenceType], 365 thought_size: int, 366 n_agent: int, 367 communication: bool = True, 368 agent_per_group: int = 2, 369 actor_1_embedding_size: Union[int, None] = None, 370 actor_2_embedding_size: Union[int, None] = None, 371 critic_head_hidden_size: int = 64, 372 critic_head_layer_num: int = 2, 373 activation: Optional[nn.Module] = nn.ReLU(), 374 norm_type: Optional[str] = None, 375 ) -> None: 376 """ 377 Overview: 378 Initialize the ATOC QAC network 379 Arguments: 380 - obs_shape(:obj:`Union[Tuple, int]`): the observation space shape 381 - thought_size (:obj:`int`): the size of thoughts 382 - action_shape (:obj:`int`): the action space shape 383 - n_agent (:obj:`int`): the num of agents 384 - agent_per_group (:obj:`int`): the num of agent in each group 385 """ 386 super(ATOC, self).__init__() 387 self._communication = communication 388 389 self.actor = ATOCActorNet( 390 obs_shape, 391 thought_size, 392 action_shape, 393 n_agent, 394 communication, 395 agent_per_group, 396 actor_1_embedding_size=actor_1_embedding_size, 397 actor_2_embedding_size=actor_2_embedding_size 398 ) 399 self.critic = nn.Sequential( 400 nn.Linear(obs_shape + action_shape, critic_head_hidden_size), activation, 401 RegressionHead( 402 critic_head_hidden_size, 403 1, 404 critic_head_layer_num, 405 final_tanh=False, 406 activation=activation, 407 norm_type=norm_type, 408 ) 409 ) 410 411 def _compute_delta_q(self, obs: torch.Tensor, actor_outputs: Dict) -> torch.Tensor: 412 """ 413 Overview: 414 calculate the delta_q according to obs and actor_outputs 415 Arguments: 416 - obs (:obj:`torch.Tensor`): the observations 417 - actor_outputs (:obj:`dict`): the output of actors 418 - delta_q (:obj:`Dict`): the calculated delta_q 419 Returns: 420 - delta_q (:obj:`Dict`): the calculated delta_q 421 ArgumentsKeys: 422 - necessary: ``new_thoughts``, ``old_thoughts``, ``group``, ``is_initiator`` 423 Shapes: 424 - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size 425 - actor_outputs (:obj:`Dict`): the output of actor network, including ``action``, ``new_thoughts``, \ 426 ``old_thoughts``, ``group``, ``initiator_prob``, ``is_initiator`` 427 - action (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is action size 428 - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size 429 - old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size 430 - group (:obj:`torch.Tensor`): :math:`(B, A, A)` 431 - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` 432 - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` 433 - delta_q (:obj:`torch.Tensor`): :math:`(B, A)` 434 Examples: 435 >>> net = ATOC(64, 64, 64, 3) 436 >>> obs = torch.randn(2, 3, 64) 437 >>> actor_outputs = net.compute_actor(obs) 438 >>> net._compute_delta_q(obs, actor_outputs) 439 """ 440 if not self._communication: 441 raise NotImplementedError 442 assert len(obs.shape) == 3 443 new_thoughts, old_thoughts, group, is_initiator = actor_outputs['new_thoughts'], actor_outputs[ 444 'old_thoughts'], actor_outputs['group'], actor_outputs['is_initiator'] 445 B, A = new_thoughts.shape[:2] 446 curr_delta_q = torch.zeros(B, A).to(new_thoughts.device) 447 with torch.no_grad(): 448 for b in range(B): 449 for i in range(A): 450 if not is_initiator[b][i]: 451 continue 452 q_group = [] 453 actual_q_group = [] 454 for j in range(A): 455 if not group[b][i][j]: 456 continue 457 before_update_action_j = self.actor.actor_2( 458 torch.cat([old_thoughts[b][j], old_thoughts[b][j]], dim=-1) 459 ) 460 after_update_action_j = self.actor.actor_2( 461 torch.cat([old_thoughts[b][j], new_thoughts[b][j]], dim=-1) 462 ) 463 before_update_input = torch.cat([obs[b][j], before_update_action_j['pred']], dim=-1) 464 before_update_Q_j = self.critic(before_update_input)['pred'] 465 after_update_input = torch.cat([obs[b][j], after_update_action_j['pred']], dim=-1) 466 after_update_Q_j = self.critic(after_update_input)['pred'] 467 q_group.append(before_update_Q_j) 468 actual_q_group.append(after_update_Q_j) 469 q_group = torch.stack(q_group) 470 actual_q_group = torch.stack(actual_q_group) 471 curr_delta_q[b][i] = actual_q_group.mean() - q_group.mean() 472 return curr_delta_q 473 474 def compute_actor(self, obs: torch.Tensor, get_delta_q: bool = False) -> Dict[str, torch.Tensor]: 475 ''' 476 Overview: 477 compute the action according to inputs, call the _compute_delta_q function to compute delta_q 478 Arguments: 479 - obs (:obj:`torch.Tensor`): observation 480 - get_delta_q (:obj:`bool`) : whether need to get delta_q 481 Returns: 482 - outputs (:obj:`Dict`): the output of actor network and delta_q 483 ReturnsKeys: 484 - necessary: ``action`` 485 - optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts``, ``delta_q`` 486 Shapes: 487 - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size 488 - action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size 489 - group (:obj:`torch.Tensor`): :math:`(B, A, A)` 490 - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` 491 - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` 492 - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` 493 - old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` 494 - delta_q (:obj:`torch.Tensor`): :math:`(B, A)` 495 Examples: 496 >>> net = ATOC(64, 64, 64, 3) 497 >>> obs = torch.randn(2, 3, 64) 498 >>> net.compute_actor(obs) 499 ''' 500 outputs = self.actor(obs) 501 if get_delta_q and self._communication: 502 delta_q = self._compute_delta_q(obs, outputs) 503 outputs['delta_q'] = delta_q 504 return outputs 505 506 def compute_critic(self, inputs: Dict) -> Dict: 507 """ 508 Overview: 509 compute the q_value according to inputs 510 Arguments: 511 - inputs (:obj:`Dict`): the inputs contain the obs and action 512 Returns: 513 - outputs (:obj:`Dict`): the output of critic network 514 ArgumentsKeys: 515 - necessary: ``obs``, ``action`` 516 ReturnsKeys: 517 - necessary: ``q_value`` 518 Shapes: 519 - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size 520 - action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size 521 - q_value (:obj:`torch.Tensor`): :math:`(B, A)` 522 Examples: 523 >>> net = ATOC(64, 64, 64, 3) 524 >>> obs = torch.randn(2, 3, 64) 525 >>> action = torch.randn(2, 3, 64) 526 >>> net.compute_critic({'obs': obs, 'action': action}) 527 """ 528 obs, action = inputs['obs'], inputs['action'] 529 if len(action.shape) == 2: # (B, A) -> (B, A, 1) 530 action = action.unsqueeze(2) 531 x = torch.cat([obs, action], dim=-1) 532 x = self.critic(x)['pred'] 533 return {'q_value': x} 534 535 def optimize_actor_attention(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 536 """ 537 Overview: 538 return the actor attention loss 539 Arguments: 540 - inputs (:obj:`Dict`): the inputs contain the delta_q, initiator_prob, and is_initiator 541 Returns 542 - loss (:obj:`Dict`): the loss of actor attention unit 543 ArgumentsKeys: 544 - necessary: ``delta_q``, ``initiator_prob``, ``is_initiator`` 545 ReturnsKeys: 546 - necessary: ``loss`` 547 Shapes: 548 - delta_q (:obj:`torch.Tensor`): :math:`(B, A)` 549 - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` 550 - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` 551 - loss (:obj:`torch.Tensor`): :math:`(1)` 552 Examples: 553 >>> net = ATOC(64, 64, 64, 3) 554 >>> delta_q = torch.randn(2, 3) 555 >>> initiator_prob = torch.randn(2, 3) 556 >>> is_initiator = torch.randn(2, 3) 557 >>> net.optimize_actor_attention( 558 >>> {'delta_q': delta_q, 559 >>> 'initiator_prob': initiator_prob, 560 >>> 'is_initiator': is_initiator}) 561 """ 562 if not self._communication: 563 raise NotImplementedError 564 delta_q = inputs['delta_q'].reshape(-1) 565 init_prob = inputs['initiator_prob'].reshape(-1) 566 is_init = inputs['is_initiator'].reshape(-1) 567 delta_q = delta_q[is_init.nonzero()] 568 init_prob = init_prob[is_init.nonzero()] 569 init_prob = 0.9 * init_prob + 0.05 570 571 # judge to avoid nan 572 if init_prob.shape == (0, 1): 573 actor_attention_loss = torch.FloatTensor([-0.0]).to(delta_q.device) 574 actor_attention_loss.requires_grad = True 575 else: 576 actor_attention_loss = -delta_q * \ 577 torch.log(init_prob) - (1 - delta_q) * torch.log(1 - init_prob) 578 return {'loss': actor_attention_loss.mean()} 579 580 def forward(self, inputs: Union[torch.Tensor, Dict], mode: str, **kwargs) -> Dict: 581 assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) 582 return getattr(self, mode)(inputs, **kwargs)