Skip to content

ding.rl_utils.acer

ding.rl_utils.acer

acer_policy_error(q_values, q_retraces, v_pred, target_logit, actions, ratio, c_clip_ratio=10.0)

Overview

Get ACER policy loss.

Arguments: - q_values (:obj:torch.Tensor): Q values - q_retraces (:obj:torch.Tensor): Q values (be calculated by retrace method) - v_pred (:obj:torch.Tensor): V values - target_pi (:obj:torch.Tensor): The new policy's probability - actions (:obj:torch.Tensor): The actions in replay buffer - ratio (:obj:torch.Tensor): ratio of new polcy with behavior policy - c_clip_ratio (:obj:float): clip value for ratio Returns: - actor_loss (:obj:torch.Tensor): policy loss from q_retrace - bc_loss (:obj:torch.Tensor): correct policy loss Shapes: - q_values (:obj:torch.FloatTensor): :math:(T, B, N), where B is batch size and N is action dim - q_retraces (:obj:torch.FloatTensor): :math:(T, B, 1) - v_pred (:obj:torch.FloatTensor): :math:(T, B, 1) - target_pi (:obj:torch.FloatTensor): :math:(T, B, N) - actions (:obj:torch.LongTensor): :math:(T, B) - ratio (:obj:torch.FloatTensor): :math:(T, B, N) - actor_loss (:obj:torch.FloatTensor): :math:(T, B, 1) - bc_loss (:obj:torch.FloatTensor): :math:(T, B, 1) Examples: >>> q_values=torch.randn(2, 3, 4), >>> q_retraces=torch.randn(2, 3, 1), >>> v_pred=torch.randn(2, 3, 1), >>> target_pi=torch.randn(2, 3, 4), >>> actions=torch.randint(0, 4, (2, 3)), >>> ratio=torch.randn(2, 3, 4), >>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio)

acer_value_error(q_values, q_retraces, actions)

Overview

Get ACER critic loss.

Arguments: - q_values (:obj:torch.Tensor): Q values - q_retraces (:obj:torch.Tensor): Q values (be calculated by retrace method) - actions (:obj:torch.Tensor): The actions in replay buffer - ratio (:obj:torch.Tensor): ratio of new polcy with behavior policy Returns: - critic_loss (:obj:torch.Tensor): critic loss Shapes: - q_values (:obj:torch.FloatTensor): :math:(T, B, N), where B is batch size and N is action dim - q_retraces (:obj:torch.FloatTensor): :math:(T, B, 1) - actions (:obj:torch.LongTensor): :math:(T, B) - critic_loss (:obj:torch.FloatTensor): :math:(T, B, 1) Examples: >>> q_values=torch.randn(2, 3, 4) >>> q_retraces=torch.randn(2, 3, 1) >>> actions=torch.randint(0, 4, (2, 3)) >>> loss = acer_value_error(q_values, q_retraces, actions)

acer_trust_region_update(actor_gradients, target_logit, avg_logit, trust_region_value)

Overview

calcuate gradient with trust region constrain

Arguments: - actor_gradients (:obj:list(torch.Tensor)): gradients value's for different part - target_pi (:obj:torch.Tensor): The new policy's probability - avg_pi (:obj:torch.Tensor): The average policy's probability - trust_region_value (:obj:float): the range of trust region Returns: - update_gradients (:obj:list(torch.Tensor)): gradients with trust region constraint Shapes: - target_pi (:obj:torch.FloatTensor): :math:(T, B, N) - avg_pi (:obj:torch.FloatTensor): :math:(T, B, N) - update_gradients (:obj:list(torch.FloatTensor)): :math:(T, B, N) Examples: >>> actor_gradients=[torch.randn(2, 3, 4)] >>> target_pi=torch.randn(2, 3, 4) >>> avg_pi=torch.randn(2, 3, 4) >>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1)

Full Source Code

../ding/rl_utils/acer.py

1from typing import Tuple, List 2from collections import namedtuple 3import torch 4import torch.nn.functional as F 5EPS = 1e-8 6 7 8def acer_policy_error( 9 q_values: torch.Tensor, 10 q_retraces: torch.Tensor, 11 v_pred: torch.Tensor, 12 target_logit: torch.Tensor, 13 actions: torch.Tensor, 14 ratio: torch.Tensor, 15 c_clip_ratio: float = 10.0 16) -> Tuple[torch.Tensor, torch.Tensor]: 17 """ 18 Overview: 19 Get ACER policy loss. 20 Arguments: 21 - q_values (:obj:`torch.Tensor`): Q values 22 - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method) 23 - v_pred (:obj:`torch.Tensor`): V values 24 - target_pi (:obj:`torch.Tensor`): The new policy's probability 25 - actions (:obj:`torch.Tensor`): The actions in replay buffer 26 - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy 27 - c_clip_ratio (:obj:`float`): clip value for ratio 28 Returns: 29 - actor_loss (:obj:`torch.Tensor`): policy loss from q_retrace 30 - bc_loss (:obj:`torch.Tensor`): correct policy loss 31 Shapes: 32 - q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim 33 - q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` 34 - v_pred (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` 35 - target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` 36 - actions (:obj:`torch.LongTensor`): :math:`(T, B)` 37 - ratio (:obj:`torch.FloatTensor`): :math:`(T, B, N)` 38 - actor_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` 39 - bc_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` 40 Examples: 41 >>> q_values=torch.randn(2, 3, 4), 42 >>> q_retraces=torch.randn(2, 3, 1), 43 >>> v_pred=torch.randn(2, 3, 1), 44 >>> target_pi=torch.randn(2, 3, 4), 45 >>> actions=torch.randint(0, 4, (2, 3)), 46 >>> ratio=torch.randn(2, 3, 4), 47 >>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio) 48 """ 49 actions = actions.unsqueeze(-1) 50 with torch.no_grad(): 51 advantage_retraces = q_retraces - v_pred # shape T,B,1 52 advantage_native = q_values - v_pred # shape T,B,env_action_shape 53 actor_loss = ratio.gather(-1, actions).clamp(max=c_clip_ratio) * advantage_retraces * target_logit.gather( 54 -1, actions 55 ) # shape T,B,1 56 57 # bias correction term, the first target_pi will not calculate gradient flow 58 bias_correction_loss = (1.0-c_clip_ratio/(ratio+EPS)).clamp(min=0.0)*torch.exp(target_logit).detach() * \ 59 advantage_native*target_logit # shape T,B,env_action_shape 60 bias_correction_loss = bias_correction_loss.sum(-1, keepdim=True) 61 return actor_loss, bias_correction_loss 62 63 64def acer_value_error(q_values, q_retraces, actions): 65 """ 66 Overview: 67 Get ACER critic loss. 68 Arguments: 69 - q_values (:obj:`torch.Tensor`): Q values 70 - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method) 71 - actions (:obj:`torch.Tensor`): The actions in replay buffer 72 - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy 73 Returns: 74 - critic_loss (:obj:`torch.Tensor`): critic loss 75 Shapes: 76 - q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim 77 - q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` 78 - actions (:obj:`torch.LongTensor`): :math:`(T, B)` 79 - critic_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` 80 Examples: 81 >>> q_values=torch.randn(2, 3, 4) 82 >>> q_retraces=torch.randn(2, 3, 1) 83 >>> actions=torch.randint(0, 4, (2, 3)) 84 >>> loss = acer_value_error(q_values, q_retraces, actions) 85 """ 86 actions = actions.unsqueeze(-1) 87 critic_loss = 0.5 * (q_retraces - q_values.gather(-1, actions)).pow(2) 88 return critic_loss 89 90 91def acer_trust_region_update( 92 actor_gradients: List[torch.Tensor], target_logit: torch.Tensor, avg_logit: torch.Tensor, 93 trust_region_value: float 94) -> List[torch.Tensor]: 95 """ 96 Overview: 97 calcuate gradient with trust region constrain 98 Arguments: 99 - actor_gradients (:obj:`list(torch.Tensor)`): gradients value's for different part 100 - target_pi (:obj:`torch.Tensor`): The new policy's probability 101 - avg_pi (:obj:`torch.Tensor`): The average policy's probability 102 - trust_region_value (:obj:`float`): the range of trust region 103 Returns: 104 - update_gradients (:obj:`list(torch.Tensor)`): gradients with trust region constraint 105 Shapes: 106 - target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` 107 - avg_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` 108 - update_gradients (:obj:`list(torch.FloatTensor)`): :math:`(T, B, N)` 109 Examples: 110 >>> actor_gradients=[torch.randn(2, 3, 4)] 111 >>> target_pi=torch.randn(2, 3, 4) 112 >>> avg_pi=torch.randn(2, 3, 4) 113 >>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1) 114 """ 115 with torch.no_grad(): 116 KL_gradients = [torch.exp(avg_logit)] 117 update_gradients = [] 118 # TODO: here is only one elements in this list.Maybe will use to more elements in the future 119 actor_gradient = actor_gradients[0] 120 KL_gradient = KL_gradients[0] 121 scale = actor_gradient.mul(KL_gradient).sum(-1, keepdim=True) - trust_region_value 122 scale = torch.div(scale, KL_gradient.mul(KL_gradient).sum(-1, keepdim=True)).clamp(min=0.0) 123 update_gradients.append(actor_gradient - scale * KL_gradient) 124 return update_gradients