Skip to content

ding.rl_utils.coma

ding.rl_utils.coma

coma_error(data, gamma, lambda_)

Overview

Implementation of COMA

Arguments: - data (:obj:namedtuple): coma input data with fieids shown in coma_data Returns: - coma_loss (:obj:namedtuple): the coma loss item, all of them are the differentiable 0-dim tensor Shapes: - logit (:obj:torch.FloatTensor): :math:(T, B, A, N), where B is batch size A is the agent num, and N is action dim - action (:obj:torch.LongTensor): :math:(T, B, A) - q_value (:obj:torch.FloatTensor): :math:(T, B, A, N) - target_q_value (:obj:torch.FloatTensor): :math:(T, B, A, N) - reward (:obj:torch.FloatTensor): :math:(T, B) - weight (:obj:torch.FloatTensor or :obj:None): :math:(T ,B, A) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - value_loss (:obj:torch.FloatTensor): :math:() - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> agent_num = 3 >>> data = coma_data( >>> logit=torch.randn(2, 3, agent_num, action_dim), >>> action=torch.randint(0, action_dim, (2, 3, agent_num)), >>> q_value=torch.randn(2, 3, agent_num, action_dim), >>> target_q_value=torch.randn(2, 3, agent_num, action_dim), >>> reward=torch.randn(2, 3), >>> weight=torch.ones(2, 3, agent_num), >>> ) >>> loss = coma_error(data, 0.99, 0.99)

Full Source Code

../ding/rl_utils/coma.py

1from collections import namedtuple 2import torch 3import torch.nn.functional as F 4from ding.rl_utils.td import generalized_lambda_returns 5 6coma_data = namedtuple('coma_data', ['logit', 'action', 'q_value', 'target_q_value', 'reward', 'weight']) 7coma_loss = namedtuple('coma_loss', ['policy_loss', 'q_value_loss', 'entropy_loss']) 8 9 10def coma_error(data: namedtuple, gamma: float, lambda_: float) -> namedtuple: 11 """ 12 Overview: 13 Implementation of COMA 14 Arguments: 15 - data (:obj:`namedtuple`): coma input data with fieids shown in ``coma_data`` 16 Returns: 17 - coma_loss (:obj:`namedtuple`): the coma loss item, all of them are the differentiable 0-dim tensor 18 Shapes: 19 - logit (:obj:`torch.FloatTensor`): :math:`(T, B, A, N)`, where B is batch size A is the agent num, and N is \ 20 action dim 21 - action (:obj:`torch.LongTensor`): :math:`(T, B, A)` 22 - q_value (:obj:`torch.FloatTensor`): :math:`(T, B, A, N)` 23 - target_q_value (:obj:`torch.FloatTensor`): :math:`(T, B, A, N)` 24 - reward (:obj:`torch.FloatTensor`): :math:`(T, B)` 25 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(T ,B, A)` 26 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 27 - value_loss (:obj:`torch.FloatTensor`): :math:`()` 28 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 29 Examples: 30 >>> action_dim = 4 31 >>> agent_num = 3 32 >>> data = coma_data( 33 >>> logit=torch.randn(2, 3, agent_num, action_dim), 34 >>> action=torch.randint(0, action_dim, (2, 3, agent_num)), 35 >>> q_value=torch.randn(2, 3, agent_num, action_dim), 36 >>> target_q_value=torch.randn(2, 3, agent_num, action_dim), 37 >>> reward=torch.randn(2, 3), 38 >>> weight=torch.ones(2, 3, agent_num), 39 >>> ) 40 >>> loss = coma_error(data, 0.99, 0.99) 41 """ 42 logit, action, q_value, target_q_value, reward, weight = data 43 if weight is None: 44 weight = torch.ones_like(action) 45 q_taken = torch.gather(q_value, -1, index=action.unsqueeze(-1)).squeeze(-1) 46 target_q_taken = torch.gather(target_q_value, -1, index=action.unsqueeze(-1)).squeeze(-1) 47 T, B, A = target_q_taken.shape 48 reward = reward.unsqueeze(-1).expand_as(target_q_taken).reshape(T, -1) 49 target_q_taken = target_q_taken.reshape(T, -1) 50 return_ = generalized_lambda_returns(target_q_taken, reward[:-1], gamma, lambda_) 51 return_ = return_.reshape(T - 1, B, A) 52 q_value_loss = (F.mse_loss(return_, q_taken[:-1], reduction='none') * weight[:-1]).mean() 53 54 dist = torch.distributions.categorical.Categorical(logits=logit) 55 logp = dist.log_prob(action) 56 baseline = (torch.softmax(logit, dim=-1) * q_value).sum(-1).detach() 57 adv = (q_taken - baseline).detach() 58 entropy_loss = (dist.entropy() * weight).mean() 59 policy_loss = -(logp * adv * weight).mean() 60 return coma_loss(policy_loss, q_value_loss, entropy_loss)