ding.rl_utils.a2c¶
ding.rl_utils.a2c
¶
a2c_error(data)
¶
Overview
Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for discrete action space
Arguments:
- data (:obj:namedtuple): a2c input data with fieids shown in a2c_data
Returns:
- a2c_loss (:obj:namedtuple): the a2c loss item, all of them are the differentiable 0-dim tensor
Shapes:
- logit (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim
- action (:obj:torch.LongTensor): :math:(B, )
- value (:obj:torch.FloatTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- return (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- value_loss (:obj:torch.FloatTensor): :math:()
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> data = a2c_data(
>>> logit=torch.randn(2, 3),
>>> action=torch.randint(0, 3, (2, )),
>>> value=torch.randn(2, ),
>>> adv=torch.randn(2, ),
>>> return_=torch.randn(2, ),
>>> weight=torch.ones(2, ),
>>> )
>>> loss = a2c_error(data)
a2c_error_continuous(data)
¶
Overview
Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for continuous action space
Arguments:
- data (:obj:namedtuple): a2c input data with fieids shown in a2c_data
Returns:
- a2c_loss (:obj:namedtuple): the a2c loss item, all of them are the differentiable 0-dim tensor
Shapes:
- logit (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim
- action (:obj:torch.LongTensor): :math:(B, N)
- value (:obj:torch.FloatTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- return (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- value_loss (:obj:torch.FloatTensor): :math:()
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> data = a2c_data(
>>> logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)},
>>> action=torch.randn(2, 3),
>>> value=torch.randn(2, ),
>>> adv=torch.randn(2, ),
>>> return_=torch.randn(2, ),
>>> weight=torch.ones(2, ),
>>> )
>>> loss = a2c_error_continuous(data)
Full Source Code
../ding/rl_utils/a2c.py