ding.rl_utils.acer¶
ding.rl_utils.acer
¶
acer_policy_error(q_values, q_retraces, v_pred, target_logit, actions, ratio, c_clip_ratio=10.0)
¶
Overview
Get ACER policy loss.
Arguments:
- q_values (:obj:torch.Tensor): Q values
- q_retraces (:obj:torch.Tensor): Q values (be calculated by retrace method)
- v_pred (:obj:torch.Tensor): V values
- target_pi (:obj:torch.Tensor): The new policy's probability
- actions (:obj:torch.Tensor): The actions in replay buffer
- ratio (:obj:torch.Tensor): ratio of new polcy with behavior policy
- c_clip_ratio (:obj:float): clip value for ratio
Returns:
- actor_loss (:obj:torch.Tensor): policy loss from q_retrace
- bc_loss (:obj:torch.Tensor): correct policy loss
Shapes:
- q_values (:obj:torch.FloatTensor): :math:(T, B, N), where B is batch size and N is action dim
- q_retraces (:obj:torch.FloatTensor): :math:(T, B, 1)
- v_pred (:obj:torch.FloatTensor): :math:(T, B, 1)
- target_pi (:obj:torch.FloatTensor): :math:(T, B, N)
- actions (:obj:torch.LongTensor): :math:(T, B)
- ratio (:obj:torch.FloatTensor): :math:(T, B, N)
- actor_loss (:obj:torch.FloatTensor): :math:(T, B, 1)
- bc_loss (:obj:torch.FloatTensor): :math:(T, B, 1)
Examples:
>>> q_values=torch.randn(2, 3, 4),
>>> q_retraces=torch.randn(2, 3, 1),
>>> v_pred=torch.randn(2, 3, 1),
>>> target_pi=torch.randn(2, 3, 4),
>>> actions=torch.randint(0, 4, (2, 3)),
>>> ratio=torch.randn(2, 3, 4),
>>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio)
acer_value_error(q_values, q_retraces, actions)
¶
Overview
Get ACER critic loss.
Arguments:
- q_values (:obj:torch.Tensor): Q values
- q_retraces (:obj:torch.Tensor): Q values (be calculated by retrace method)
- actions (:obj:torch.Tensor): The actions in replay buffer
- ratio (:obj:torch.Tensor): ratio of new polcy with behavior policy
Returns:
- critic_loss (:obj:torch.Tensor): critic loss
Shapes:
- q_values (:obj:torch.FloatTensor): :math:(T, B, N), where B is batch size and N is action dim
- q_retraces (:obj:torch.FloatTensor): :math:(T, B, 1)
- actions (:obj:torch.LongTensor): :math:(T, B)
- critic_loss (:obj:torch.FloatTensor): :math:(T, B, 1)
Examples:
>>> q_values=torch.randn(2, 3, 4)
>>> q_retraces=torch.randn(2, 3, 1)
>>> actions=torch.randint(0, 4, (2, 3))
>>> loss = acer_value_error(q_values, q_retraces, actions)
acer_trust_region_update(actor_gradients, target_logit, avg_logit, trust_region_value)
¶
Overview
calcuate gradient with trust region constrain
Arguments:
- actor_gradients (:obj:list(torch.Tensor)): gradients value's for different part
- target_pi (:obj:torch.Tensor): The new policy's probability
- avg_pi (:obj:torch.Tensor): The average policy's probability
- trust_region_value (:obj:float): the range of trust region
Returns:
- update_gradients (:obj:list(torch.Tensor)): gradients with trust region constraint
Shapes:
- target_pi (:obj:torch.FloatTensor): :math:(T, B, N)
- avg_pi (:obj:torch.FloatTensor): :math:(T, B, N)
- update_gradients (:obj:list(torch.FloatTensor)): :math:(T, B, N)
Examples:
>>> actor_gradients=[torch.randn(2, 3, 4)]
>>> target_pi=torch.randn(2, 3, 4)
>>> avg_pi=torch.randn(2, 3, 4)
>>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1)
Full Source Code
../ding/rl_utils/acer.py