ding.rl_utils.coma¶
ding.rl_utils.coma
¶
coma_error(data, gamma, lambda_)
¶
Overview
Implementation of COMA
Arguments:
- data (:obj:namedtuple): coma input data with fieids shown in coma_data
Returns:
- coma_loss (:obj:namedtuple): the coma loss item, all of them are the differentiable 0-dim tensor
Shapes:
- logit (:obj:torch.FloatTensor): :math:(T, B, A, N), where B is batch size A is the agent num, and N is action dim
- action (:obj:torch.LongTensor): :math:(T, B, A)
- q_value (:obj:torch.FloatTensor): :math:(T, B, A, N)
- target_q_value (:obj:torch.FloatTensor): :math:(T, B, A, N)
- reward (:obj:torch.FloatTensor): :math:(T, B)
- weight (:obj:torch.FloatTensor or :obj:None): :math:(T ,B, A)
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- value_loss (:obj:torch.FloatTensor): :math:()
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> action_dim = 4
>>> agent_num = 3
>>> data = coma_data(
>>> logit=torch.randn(2, 3, agent_num, action_dim),
>>> action=torch.randint(0, action_dim, (2, 3, agent_num)),
>>> q_value=torch.randn(2, 3, agent_num, action_dim),
>>> target_q_value=torch.randn(2, 3, agent_num, action_dim),
>>> reward=torch.randn(2, 3),
>>> weight=torch.ones(2, 3, agent_num),
>>> )
>>> loss = coma_error(data, 0.99, 0.99)
Full Source Code
../ding/rl_utils/coma.py