Skip to content

ding.rl_utils.a2c

ding.rl_utils.a2c

a2c_error(data)

Overview

Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for discrete action space

Arguments: - data (:obj:namedtuple): a2c input data with fieids shown in a2c_data Returns: - a2c_loss (:obj:namedtuple): the a2c loss item, all of them are the differentiable 0-dim tensor Shapes: - logit (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim - action (:obj:torch.LongTensor): :math:(B, ) - value (:obj:torch.FloatTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - return (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - value_loss (:obj:torch.FloatTensor): :math:() - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> data = a2c_data( >>> logit=torch.randn(2, 3), >>> action=torch.randint(0, 3, (2, )), >>> value=torch.randn(2, ), >>> adv=torch.randn(2, ), >>> return_=torch.randn(2, ), >>> weight=torch.ones(2, ), >>> ) >>> loss = a2c_error(data)

a2c_error_continuous(data)

Overview

Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for continuous action space

Arguments: - data (:obj:namedtuple): a2c input data with fieids shown in a2c_data Returns: - a2c_loss (:obj:namedtuple): the a2c loss item, all of them are the differentiable 0-dim tensor Shapes: - logit (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim - action (:obj:torch.LongTensor): :math:(B, N) - value (:obj:torch.FloatTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - return (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - value_loss (:obj:torch.FloatTensor): :math:() - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> data = a2c_data( >>> logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)}, >>> action=torch.randn(2, 3), >>> value=torch.randn(2, ), >>> adv=torch.randn(2, ), >>> return_=torch.randn(2, ), >>> weight=torch.ones(2, ), >>> ) >>> loss = a2c_error_continuous(data)

Full Source Code

../ding/rl_utils/a2c.py

1from collections import namedtuple 2import torch 3import torch.nn.functional as F 4from torch.distributions import Independent, Normal 5 6a2c_data = namedtuple('a2c_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight']) 7a2c_loss = namedtuple('a2c_loss', ['policy_loss', 'value_loss', 'entropy_loss']) 8 9 10def a2c_error(data: namedtuple) -> namedtuple: 11 """ 12 Overview: 13 Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for discrete action space 14 Arguments: 15 - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data`` 16 Returns: 17 - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor 18 Shapes: 19 - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim 20 - action (:obj:`torch.LongTensor`): :math:`(B, )` 21 - value (:obj:`torch.FloatTensor`): :math:`(B, )` 22 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 23 - return (:obj:`torch.FloatTensor`): :math:`(B, )` 24 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 25 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 26 - value_loss (:obj:`torch.FloatTensor`): :math:`()` 27 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 28 Examples: 29 >>> data = a2c_data( 30 >>> logit=torch.randn(2, 3), 31 >>> action=torch.randint(0, 3, (2, )), 32 >>> value=torch.randn(2, ), 33 >>> adv=torch.randn(2, ), 34 >>> return_=torch.randn(2, ), 35 >>> weight=torch.ones(2, ), 36 >>> ) 37 >>> loss = a2c_error(data) 38 """ 39 logit, action, value, adv, return_, weight = data 40 if weight is None: 41 weight = torch.ones_like(value) 42 dist = torch.distributions.categorical.Categorical(logits=logit) 43 logp = dist.log_prob(action) 44 entropy_loss = (dist.entropy() * weight).mean() 45 policy_loss = -(logp * adv * weight).mean() 46 value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean() 47 return a2c_loss(policy_loss, value_loss, entropy_loss) 48 49 50def a2c_error_continuous(data: namedtuple) -> namedtuple: 51 """ 52 Overview: 53 Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for continuous action space 54 Arguments: 55 - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data`` 56 Returns: 57 - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor 58 Shapes: 59 - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim 60 - action (:obj:`torch.LongTensor`): :math:`(B, N)` 61 - value (:obj:`torch.FloatTensor`): :math:`(B, )` 62 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 63 - return (:obj:`torch.FloatTensor`): :math:`(B, )` 64 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 65 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 66 - value_loss (:obj:`torch.FloatTensor`): :math:`()` 67 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 68 Examples: 69 >>> data = a2c_data( 70 >>> logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)}, 71 >>> action=torch.randn(2, 3), 72 >>> value=torch.randn(2, ), 73 >>> adv=torch.randn(2, ), 74 >>> return_=torch.randn(2, ), 75 >>> weight=torch.ones(2, ), 76 >>> ) 77 >>> loss = a2c_error_continuous(data) 78 """ 79 logit, action, value, adv, return_, weight = data 80 if weight is None: 81 weight = torch.ones_like(value) 82 83 dist = Independent(Normal(logit['mu'], logit['sigma']), 1) 84 logp = dist.log_prob(action) 85 entropy_loss = (dist.entropy() * weight).mean() 86 policy_loss = -(logp * adv * weight).mean() 87 value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean() 88 return a2c_loss(policy_loss, value_loss, entropy_loss)