ding.model.template.wqmix¶
ding.model.template.wqmix
¶
MixerStar
¶
Bases: Module
Overview
Mixer network for Q_star in WQMIX(https://arxiv.org/abs/2006.10800), which mix up the independent q_value of each agent to a total q_value and is diffrent from the QMIX's mixer network, here the mixing network is a feedforward network with 3 hidden layers of 256 dim. This Q_star mixing network is not constrained to be monotonic by using non-negative weights and having the state and agent_q be inputs, as opposed to having hypernetworks take the state as input and generate the weights in QMIX.
Interface:
__init__, forward.
__init__(agent_num, state_dim, mixing_embed_dim)
¶
Overview
Initialize the mixer network of Q_star in WQMIX.
Arguments:
- agent_num (:obj:int): The number of agent, e.g., 8.
- state_dim(:obj:int): The dimension of global observation state, e.g., 16.
- mixing_embed_dim (:obj:int): The dimension of mixing state emdedding, e.g., 128.
forward(agent_qs, states)
¶
Overview
Forward computation graph of the mixer network for Q_star in WQMIX. This mixer network for is a feed-forward network that takes the state and the appropriate actions' utilities as input.
Arguments:
- agent_qs (:obj:torch.FloatTensor): The independent q_value of each agent.
- states (:obj:torch.FloatTensor): The emdedding vector of global state.
Returns:
- q_tot (:obj:torch.FloatTensor): The total mixed q_value.
Shapes:
- agent_qs (:obj:torch.FloatTensor): :math:(T,B, N), where T is timestep, B is batch size, A is agent_num, N is obs_shape.
- states (:obj:torch.FloatTensor): :math:(T, B, M), where M is global_obs_shape.
- q_tot (:obj:torch.FloatTensor): :math:(T, B, ).
WQMix
¶
Bases: Module
Overview
WQMIX (https://arxiv.org/abs/2006.10800) network, There are two components: 1) Q_tot, which is same as QMIX network and composed of agent Q network and mixer network. 2) An unrestricted joint action Q_star, which is composed of agent Q network and mixer_star network. The QMIX paper mentions that all agents share local Q network parameters, so only one Q network is initialized in Q_tot or Q_star.
Interface:
__init__, forward.
__init__(agent_num, obs_shape, global_obs_shape, action_shape, hidden_size_list, lstm_type='gru', dueling=False)
¶
Overview
Initialize WQMIX neural network according to arguments, i.e. agent Q network and mixer, Q_star network and mixer_star.
Arguments:
- agent_num (:obj:int): The number of agent, such as 8.
- obs_shape (:obj:int): The dimension of each agent's observation state, such as 8.
- global_obs_shape (:obj:int): The dimension of global observation state, such as 8.
- action_shape (:obj:int): The dimension of action shape, such as 6.
- hidden_size_list (:obj:list): The list of hidden size for q_network, the last element must match mixer's mixing_embed_dim.
- lstm_type (:obj:str): The type of RNN module in q_network, now support ['normal', 'pytorch', 'gru'], default to gru.
- dueling (:obj:bool): Whether choose DuelingHead (True) or DiscreteHead (False), default to False.
forward(data, single_step=True, q_star=False)
¶
Overview
Forward computation graph of qmix network. Input dict including time series observation and related data to predict total q_value and each agent q_value. Determine whether to calculate Q_tot or Q_star based on the q_star parameter.
Arguments:
- data (:obj:dict): Input data dict with keys ['obs', 'prev_state', 'action'].
- agent_state (:obj:torch.Tensor): Time series local observation data of each agents.
- global_state (:obj:torch.Tensor): Time series global observation data.
- prev_state (:obj:list): Previous rnn state for q_network or _q_network_star.
- action (:obj:torch.Tensor or None): If action is None, use argmax q_value index as action to calculate agent_q_act.
- single_step (:obj:bool): Whether single_step forward, if so, add timestep dim before forward and remove it after forward.
- Q_star (:obj:bool): Whether Q_star network forward. If True, using the Q_star network, where the agent networks have the same architecture as Q network but do not share parameters and the mixing network is a feedforward network with 3 hidden layers of 256 dim; if False, using the Q network, same as the Q network in Qmix paper.
Returns:
- ret (:obj:dict): Output data dict with keys [total_q, logit, next_state].
- total_q (:obj:torch.Tensor): Total q_value, which is the result of mixer network.
- agent_q (:obj:torch.Tensor): Each agent q_value.
- next_state (:obj:list): Next rnn state.
Shapes:
- agent_state (:obj:torch.Tensor): :math:(T, B, A, N), where T is timestep, B is batch_size A is agent_num, N is obs_shape.
- global_state (:obj:torch.Tensor): :math:(T, B, M), where M is global_obs_shape.
- prev_state (:obj:list): math:(T, B, A), a list of length B, and each element is a list of length A.
- action (:obj:torch.Tensor): :math:(T, B, A).
- total_q (:obj:torch.Tensor): :math:(T, B).
- agent_q (:obj:torch.Tensor): :math:(T, B, A, P), where P is action_shape.
- next_state (:obj:list): math:(T, B, A), a list of length B, and each element is a list of length A.
Full Source Code
../ding/model/template/wqmix.py