Skip to content

ding.model.template.qtran

ding.model.template.qtran

QTran

Bases: Module

Overview

QTRAN network

Interface: init, forward

__init__(agent_num, obs_shape, global_obs_shape, action_shape, hidden_size_list, embedding_size, lstm_type='gru', dueling=False)

Overview

initialize QTRAN network

Arguments: - agent_num (:obj:int): the number of agent - obs_shape (:obj:int): the dimension of each agent's observation state - global_obs_shape (:obj:int): the dimension of global observation state - action_shape (:obj:int): the dimension of action shape - hidden_size_list (:obj:list): the list of hidden size - embedding_size (:obj:int): the dimension of embedding - lstm_type (:obj:str): use lstm or gru, default to gru - dueling (:obj:bool): use dueling head or not, default to False.

forward(data, single_step=True)

Overview

forward computation graph of qtran network

Arguments: - data (:obj:dict): input data dict with keys ['obs', 'prev_state', 'action'] - agent_state (:obj:torch.Tensor): each agent local state(obs) - global_state (:obj:torch.Tensor): global state(obs) - prev_state (:obj:list): previous rnn state - action (:obj:torch.Tensor or None): if action is None, use argmax q_value index as action to calculate agent_q_act - single_step (:obj:bool): whether single_step forward, if so, add timestep dim before forward and remove it after forward Return: - ret (:obj:dict): output data dict with keys ['total_q', 'logit', 'next_state'] - total_q (:obj:torch.Tensor): total q_value, which is the result of mixer network - agent_q (:obj:torch.Tensor): each agent q_value - next_state (:obj:list): next rnn state Shapes: - agent_state (:obj:torch.Tensor): :math:(T, B, A, N), where T is timestep, B is batch_size A is agent_num, N is obs_shape - global_state (:obj:torch.Tensor): :math:(T, B, M), where M is global_obs_shape - prev_state (:obj:list): math:(B, A), a list of length B, and each element is a list of length A - action (:obj:torch.Tensor): :math:(T, B, A) - total_q (:obj:torch.Tensor): :math:(T, B) - agent_q (:obj:torch.Tensor): :math:(T, B, A, P), where P is action_shape - next_state (:obj:list): math:(B, A), a list of length B, and each element is a list of length A

Full Source Code

../ding/model/template/qtran.py

1from typing import Union, List 2import torch 3import numpy as np 4import torch.nn as nn 5import torch.nn.functional as F 6from functools import reduce 7from ding.utils import list_split, squeeze, MODEL_REGISTRY 8from ding.torch_utils.network.nn_module import fc_block, MLP 9from ding.torch_utils.network.transformer import ScaledDotProductAttention 10from ding.torch_utils import to_tensor, tensor_to_list 11from .q_learning import DRQN 12 13 14@MODEL_REGISTRY.register('qtran') 15class QTran(nn.Module): 16 """ 17 Overview: 18 QTRAN network 19 Interface: 20 __init__, forward 21 """ 22 23 def __init__( 24 self, 25 agent_num: int, 26 obs_shape: int, 27 global_obs_shape: int, 28 action_shape: int, 29 hidden_size_list: list, 30 embedding_size: int, 31 lstm_type: str = 'gru', 32 dueling: bool = False 33 ) -> None: 34 """ 35 Overview: 36 initialize QTRAN network 37 Arguments: 38 - agent_num (:obj:`int`): the number of agent 39 - obs_shape (:obj:`int`): the dimension of each agent's observation state 40 - global_obs_shape (:obj:`int`): the dimension of global observation state 41 - action_shape (:obj:`int`): the dimension of action shape 42 - hidden_size_list (:obj:`list`): the list of hidden size 43 - embedding_size (:obj:`int`): the dimension of embedding 44 - lstm_type (:obj:`str`): use lstm or gru, default to gru 45 - dueling (:obj:`bool`): use dueling head or not, default to False. 46 """ 47 super(QTran, self).__init__() 48 self._act = nn.ReLU() 49 self._q_network = DRQN(obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling) 50 q_input_size = global_obs_shape + hidden_size_list[-1] + action_shape 51 self.Q = nn.Sequential( 52 nn.Linear(q_input_size, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size), nn.ReLU(), 53 nn.Linear(embedding_size, 1) 54 ) 55 56 # V(s) 57 self.V = nn.Sequential( 58 nn.Linear(global_obs_shape, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size), 59 nn.ReLU(), nn.Linear(embedding_size, 1) 60 ) 61 ae_input = hidden_size_list[-1] + action_shape 62 self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), nn.ReLU(), nn.Linear(ae_input, ae_input)) 63 64 def forward(self, data: dict, single_step: bool = True) -> dict: 65 """ 66 Overview: 67 forward computation graph of qtran network 68 Arguments: 69 - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] 70 - agent_state (:obj:`torch.Tensor`): each agent local state(obs) 71 - global_state (:obj:`torch.Tensor`): global state(obs) 72 - prev_state (:obj:`list`): previous rnn state 73 - action (:obj:`torch.Tensor` or None): if action is None, use argmax q_value index as action to\ 74 calculate ``agent_q_act`` 75 - single_step (:obj:`bool`): whether single_step forward, if so, add timestep dim before forward and\ 76 remove it after forward 77 Return: 78 - ret (:obj:`dict`): output data dict with keys ['total_q', 'logit', 'next_state'] 79 - total_q (:obj:`torch.Tensor`): total q_value, which is the result of mixer network 80 - agent_q (:obj:`torch.Tensor`): each agent q_value 81 - next_state (:obj:`list`): next rnn state 82 Shapes: 83 - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\ 84 A is agent_num, N is obs_shape 85 - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape 86 - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A 87 - action (:obj:`torch.Tensor`): :math:`(T, B, A)` 88 - total_q (:obj:`torch.Tensor`): :math:`(T, B)` 89 - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape 90 - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A 91 """ 92 agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[ 93 'prev_state'] 94 action = data.get('action', None) 95 if single_step: 96 agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0) 97 T, B, A = agent_state.shape[:3] 98 assert len(prev_state) == B and all( 99 [len(p) == A for p in prev_state] 100 ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0])) 101 prev_state = reduce(lambda x, y: x + y, prev_state) 102 agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) 103 output = self._q_network({'obs': agent_state, 'prev_state': prev_state}) 104 agent_q, next_state = output['logit'], output['next_state'] 105 next_state, _ = list_split(next_state, step=A) 106 agent_q = agent_q.reshape(T, B, A, -1) 107 if action is None: 108 # For target forward process 109 if len(data['obs']['action_mask'].shape) == 3: 110 action_mask = data['obs']['action_mask'].unsqueeze(0) 111 else: 112 action_mask = data['obs']['action_mask'] 113 agent_q[action_mask == 0.0] = -9999999 114 action = agent_q.argmax(dim=-1) 115 agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) 116 agent_q_act = agent_q_act.squeeze(-1) # T, B, A 117 118 hidden_states = output['hidden_state'].reshape(T * B, A, -1) 119 action = action.reshape(T * B, A).unsqueeze(-1) 120 action_onehot = torch.zeros(size=(T * B, A, agent_q.shape[-1]), device=action.device) 121 action_onehot = action_onehot.scatter(2, action, 1) 122 agent_state_action_input = torch.cat([hidden_states, action_onehot], dim=2) 123 agent_state_action_encoding = self.action_encoding(agent_state_action_input.reshape(T * B * A, 124 -1)).reshape(T * B, A, -1) 125 agent_state_action_encoding = agent_state_action_encoding.sum(dim=1) # Sum across agents 126 127 inputs = torch.cat([global_state.reshape(T * B, -1), agent_state_action_encoding], dim=1) 128 q_outputs = self.Q(inputs) 129 q_outputs = q_outputs.reshape(T, B) 130 v_outputs = self.V(global_state.reshape(T * B, -1)) 131 v_outputs = v_outputs.reshape(T, B) 132 if single_step: 133 q_outputs, agent_q, agent_q_act, v_outputs = q_outputs.squeeze(0), agent_q.squeeze(0), agent_q_act.squeeze( 134 0 135 ), v_outputs.squeeze(0) 136 return { 137 'total_q': q_outputs, 138 'logit': agent_q, 139 'agent_q_act': agent_q_act, 140 'vs': v_outputs, 141 'next_state': next_state, 142 'action_mask': data['obs']['action_mask'] 143 }