ding.model.template.coma¶
ding.model.template.coma
¶
COMAActorNetwork
¶
Bases: Module
Overview
Decentralized actor network in COMA algorithm.
Interface:
__init__, forward
__init__(obs_shape, action_shape, hidden_size_list=[128, 128, 64])
¶
Overview
Initialize COMA actor network
Arguments:
- obs_shape (:obj:int): the dimension of each agent's observation state
- action_shape (:obj:int): the dimension of action shape
- hidden_size_list (:obj:list): the list of hidden size, default to [128, 128, 64]
forward(inputs)
¶
Overview
The forward computation graph of COMA actor network
Arguments:
- inputs (:obj:dict): input data dict with keys ['obs', 'prev_state']
- agent_state (:obj:torch.Tensor): each agent local state(obs)
- action_mask (:obj:torch.Tensor): the masked action
- prev_state (:obj:torch.Tensor): the previous hidden state
Returns:
- output (:obj:dict): output data dict with keys ['logit', 'next_state', 'action_mask']
ArgumentsKeys:
- necessary: obs { agent_state, action_mask }, prev_state
ReturnsKeys:
- necessary: logit, next_state, action_mask
Examples:
>>> T, B, A, N = 4, 8, 3, 32
>>> embedding_dim = 64
>>> action_dim = 6
>>> data = torch.randn(T, B, A, N)
>>> model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim])
>>> prev_state = [[None for _ in range(A)] for _ in range(B)]
>>> for t in range(T):
>>> inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state}
>>> outputs = model(inputs)
>>> logit, prev_state = outputs['logit'], outputs['next_state']
COMACriticNetwork
¶
Bases: Module
Overview
Centralized critic network in COMA algorithm.
Interface:
__init__, forward
__init__(input_size, action_shape, hidden_size=128)
¶
Overview
initialize COMA critic network
Arguments:
- input_size (:obj:int): the size of input global observation
- action_shape (:obj:int): the dimension of action shape
- hidden_size_list (:obj:list): the list of hidden size, default to 128
Returns:
- output (:obj:dict): output data dict with keys ['q_value']
Shapes:
- obs (:obj:dict): agent_state: :math:(T, B, A, N, D), action_mask: :math:(T, B, A, N, A)
- prev_state (:obj:list): :math:[[[h, c] for _ in range(A)] for _ in range(B)]
- logit (:obj:torch.Tensor): :math:(T, B, A, N, A)
- next_state (:obj:list): :math:[[[h, c] for _ in range(A)] for _ in range(B)]
- action_mask (:obj:torch.Tensor): :math:(T, B, A, N, A)
forward(data)
¶
Overview
forward computation graph of qmix network
Arguments:
- data (:obj:dict): input data dict with keys ['obs', 'prev_state', 'action']
- agent_state (:obj:torch.Tensor): each agent local state(obs)
- global_state (:obj:torch.Tensor): global state(obs)
- action (:obj:torch.Tensor): the masked action
ArgumentsKeys:
- necessary: obs { agent_state, global_state }, action, prev_state
ReturnsKeys:
- necessary: q_value
Examples:
>>> agent_num, bs, T = 4, 3, 8
>>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
>>> coma_model = COMACriticNetwork(
>>> obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim)
>>> data = {
>>> 'obs': {
>>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
>>> 'global_state': torch.randn(T, bs, global_obs_dim),
>>> },
>>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)),
>>> }
>>> output = coma_model(data)
COMA
¶
Bases: Module
Overview
The network of COMA algorithm, which is QAC-type actor-critic.
Interface:
__init__, forward
Properties:
- mode (:obj:list): The list of forward mode, including compute_actor and compute_critic
__init__(agent_num, obs_shape, action_shape, actor_hidden_size_list)
¶
Overview
initialize COMA network
Arguments:
- agent_num (:obj:int): the number of agent
- obs_shape (:obj:Dict): the observation information, including agent_state and global_state
- action_shape (:obj:Union[int, SequenceType]): the dimension of action shape
- actor_hidden_size_list (:obj:SequenceType): the list of hidden size
forward(inputs, mode)
¶
Overview
forward computation graph of COMA network
Arguments:
- inputs (:obj:dict): input data dict with keys ['obs', 'prev_state', 'action']
- agent_state (:obj:torch.Tensor): each agent local state(obs)
- global_state (:obj:torch.Tensor): global state(obs)
- action (:obj:torch.Tensor): the masked action
ArgumentsKeys:
- necessary: obs { agent_state, global_state, action_mask }, action, prev_state
ReturnsKeys:
- necessary:
- compute_critic: q_value
- compute_actor: logit, next_state, action_mask
Shapes:
- obs (:obj:dict): agent_state: :math:(T, B, A, N, D), action_mask: :math:(T, B, A, N, A)
- prev_state (:obj:list): :math:[[[h, c] for _ in range(A)] for _ in range(B)]
- logit (:obj:torch.Tensor): :math:(T, B, A, N, A)
- next_state (:obj:list): :math:[[[h, c] for _ in range(A)] for _ in range(B)]
- action_mask (:obj:torch.Tensor): :math:(T, B, A, N, A)
- q_value (:obj:torch.Tensor): :math:(T, B, A, N, A)
Examples:
>>> agent_num, bs, T = 4, 3, 8
>>> agent_num, bs, T = 4, 3, 8
>>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
>>> coma_model = COMA(
>>> agent_num=agent_num,
>>> obs_shape=dict(agent_state=(obs_dim, ), global_state=(global_obs_dim, )),
>>> action_shape=action_dim,
>>> actor_hidden_size_list=[128, 64],
>>> )
>>> prev_state = [[None for _ in range(agent_num)] for _ in range(bs)]
>>> data = {
>>> 'obs': {
>>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
>>> 'action_mask': None,
>>> },
>>> 'prev_state': prev_state,
>>> }
>>> output = coma_model(data, mode='compute_actor')
>>> data= {
>>> 'obs': {
>>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
>>> 'global_state': torch.randn(T, bs, global_obs_dim),
>>> },
>>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)),
>>> }
>>> output = coma_model(data, mode='compute_critic')
Full Source Code
../ding/model/template/coma.py