Skip to content

ding.rl_utils.ppo

ding.rl_utils.ppo

calculate_kl_div(log_ratio, kl_type)

Overview

Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)], where q is the current policy and p is the pretrained policy. The implementation is based on John Schulman's blog post "Approximating KL Divergence". Reference: http://joschu.net/blog/kl-approx.html

Arguments: - log_ratio (:obj:torch.Tensor): The log-ratio of probabilities, which should be log(q/p) = logp_new - logp_pretrained. - kl_type (:obj:str): The type of KL divergence estimator to use. - 'k1': The standard, unbiased but high-variance estimator: E_q[log(q/p)]. - 'k2': A biased, low-variance estimator from a second-order approximation: E_q[1/2 * (log(p/q))^2]. - 'k3': An unbiased, low-variance estimator: E_q[(p/q - 1) - log(p/q)]. Returns: - kl_div (:obj:torch.Tensor): The calculated KL divergence estimate.

shape_fn_ppo(args, kwargs)

Overview

Return shape of ppo for hpc

Returns: shape: [B, N]

ppo_error(data, clip_ratio=0.2, use_value_clip=True, dual_clip=None, kl_type='k1')

Overview

Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip

Arguments: - data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data - clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - use_value_clip (:obj:bool): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None - kl_type (:obj:str): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar Shapes: - logit_new (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim - logit_old (:obj:torch.FloatTensor): :math:(B, N) - action (:obj:torch.LongTensor): :math:(B, ) - value_new (:obj:torch.FloatTensor): :math:(B, ) - value_old (:obj:torch.FloatTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - return (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - value_loss (:obj:torch.FloatTensor): :math:() - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> data = ppo_data( >>> logit_new=torch.randn(3, action_dim), >>> logit_old=torch.randn(3, action_dim), >>> action=torch.randint(0, action_dim, (3,)), >>> value_new=torch.randn(3), >>> value_old=torch.randn(3), >>> adv=torch.randn(3), >>> return_=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_error(data)

.. note::

adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
this part into ppo_error, you can refer to our examples for different ways.

ppo_policy_error(data, clip_ratio=0.2, dual_clip=None, entropy_bonus=True, kl_type='k1')

Overview

Get PPO policy loss (both for classical RL in control/video games and LLM/VLM RLHF).

Arguments: - data (:obj:namedtuple): Ppo input data with fieids shown in ppo_policy_data. - clip_ratio (:obj:float): Clip value for ratio, defaults to 0.2. - dual_clip (:obj:float): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None - entropy_bonus (:obj:bool): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it. - kl_type (:obj:str): which kl loss to use, default set to 'k1'. Returns: - ppo_policy_loss (:obj:namedtuple): the ppo policy loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar Shapes: - logit_new (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim - logit_old (:obj:torch.FloatTensor): :math:(B, N) - action (:obj:torch.LongTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> data = ppo_policy_data( >>> logit_new=torch.randn(3, action_dim), >>> logit_old=torch.randn(3, action_dim), >>> action=torch.randint(0, action_dim, (3,)), >>> adv=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_policy_error(data)

.. note:: This function can be extended from B to more parallel dimensions, like (B, S), where S is the sequence length in LLM/VLM.

.. note:: For the action mask often used in LLM/VLM, users can set the weight to the action mask.

ppo_value_error(data, clip_ratio=0.2, use_value_clip=True)

Overview

Get PPO value loss

Arguments: - data (:obj:namedtuple): ppo input data with fieids shown in ppo_value_data - clip_ratio (:obj:float): clip value for ratio - use_value_clip (:obj:bool): whether use value clip Returns: - value_loss (:obj:torch.FloatTensor): the ppo value loss item, all of them are the differentiable 0-dim tensor Shapes: - value_new (:obj:torch.FloatTensor): :math:(B, ), where B is batch size - value_old (:obj:torch.FloatTensor): :math:(B, ) - return (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - value_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor Examples: >>> action_dim = 4 >>> data = ppo_value_data( >>> value_new=torch.randn(3), >>> value_old=torch.randn(3), >>> return_=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_value_error(data)

ppo_error_continuous(data, clip_ratio=0.2, use_value_clip=True, dual_clip=None, kl_type='k1')

Overview

Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip

Arguments: - data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data - clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - use_value_clip (:obj:bool): whether to use clip in value loss with the same ratio as policy - dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None - kl_type (:obj:str): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar Shapes: - mu_sigma_new (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim - mu_sigma_old (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim - action (:obj:torch.LongTensor): :math:(B, ) - value_new (:obj:torch.FloatTensor): :math:(B, ) - value_old (:obj:torch.FloatTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - return (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - value_loss (:obj:torch.FloatTensor): :math:() - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> data = ppo_data_continuous( >>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2), >>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2), >>> action=torch.randn(3, action_dim), >>> value_new=torch.randn(3), >>> value_old=torch.randn(3), >>> adv=torch.randn(3), >>> return_=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_error(data)

.. note::

adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
this part into ppo_error, you can refer to our examples for different ways.

ppo_policy_error_continuous(data, clip_ratio=0.2, dual_clip=None, kl_type='k1')

Overview

Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip

Arguments: - data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data - clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2 - dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None - kl_type (:obj:str): which kl loss to use, default set to 'k1'. Returns: - ppo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor - ppo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar Shapes: - mu_sigma_new (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim - mu_sigma_old (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim - action (:obj:torch.LongTensor): :math:(B, ) - adv (:obj:torch.FloatTensor): :math:(B, ) - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, ) - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor - entropy_loss (:obj:torch.FloatTensor): :math:() Examples: >>> action_dim = 4 >>> data = ppo_policy_data_continuous( >>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2), >>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2), >>> action=torch.randn(3, action_dim), >>> adv=torch.randn(3), >>> weight=torch.ones(3), >>> ) >>> loss, info = ppo_policy_error_continuous(data)

Full Source Code

../ding/rl_utils/ppo.py

1from collections import namedtuple 2from typing import Optional, Tuple 3import torch 4import torch.nn as nn 5from torch.distributions import Independent, Normal 6from ding.hpc_rl import hpc_wrapper 7 8ppo_data = namedtuple( 9 'ppo_data', 10 ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'logit_pretrained'] 11) 12ppo_data_continuous = namedtuple( 13 'ppo_data_continuous', [ 14 'mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 15 'logit_pretrained' 16 ] 17) 18ppo_policy_data = namedtuple( 19 'ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'logit_pretrained'] 20) 21ppo_policy_data_continuous = namedtuple( 22 'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight', 'logit_pretrained'] 23) 24ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight']) 25ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss', 'kl_div']) 26ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss', 'kl_div']) 27ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac']) 28 29 30def calculate_kl_div(log_ratio: torch.Tensor, kl_type: str) -> torch.Tensor: 31 """ 32 Overview: 33 Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)], 34 where q is the current policy and p is the pretrained policy. 35 The implementation is based on John Schulman's blog post "Approximating KL Divergence". 36 Reference: http://joschu.net/blog/kl-approx.html 37 Arguments: 38 - log_ratio (:obj:`torch.Tensor`): The log-ratio of probabilities, which should be 39 log(q/p) = logp_new - logp_pretrained. 40 - kl_type (:obj:`str`): The type of KL divergence estimator to use. 41 - 'k1': The standard, unbiased but high-variance estimator: `E_q[log(q/p)]`. 42 - 'k2': A biased, low-variance estimator from a second-order approximation: `E_q[1/2 * (log(p/q))^2]`. 43 - 'k3': An unbiased, low-variance estimator: `E_q[(p/q - 1) - log(p/q)]`. 44 Returns: 45 - kl_div (:obj:`torch.Tensor`): The calculated KL divergence estimate. 46 """ 47 if kl_type == 'k1': 48 return log_ratio.mean() 49 elif kl_type == 'k2': 50 return (log_ratio ** 2 / 2).mean() 51 elif kl_type == 'k3': 52 return (torch.exp(-log_ratio) - 1 + log_ratio).mean() 53 else: 54 raise ValueError(f"Unknown kl_type: {kl_type}") 55 56 57def shape_fn_ppo(args, kwargs): 58 r""" 59 Overview: 60 Return shape of ppo for hpc 61 Returns: 62 shape: [B, N] 63 """ 64 if len(args) <= 0: 65 tmp = kwargs['data'].logit_new.shape 66 else: 67 tmp = args[0].logit_new.shape 68 return tmp 69 70 71@hpc_wrapper( 72 shape_fn=shape_fn_ppo, 73 namedtuple_data=True, 74 include_args=[0, 1, 2, 3], 75 include_kwargs=['data', 'clip_ratio', 'use_value_clip', 'dual_clip'] 76) 77def ppo_error( 78 data: namedtuple, 79 clip_ratio: float = 0.2, 80 use_value_clip: bool = True, 81 dual_clip: Optional[float] = None, 82 kl_type: str = 'k1' 83) -> Tuple[namedtuple, namedtuple]: 84 """ 85 Overview: 86 Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip 87 Arguments: 88 - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` 89 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 90 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy 91 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ 92 defaults to 5.0, if you don't want to use it, set this parameter to None 93 - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. 94 Returns: 95 - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor 96 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar 97 Shapes: 98 - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim 99 - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` 100 - action (:obj:`torch.LongTensor`): :math:`(B, )` 101 - value_new (:obj:`torch.FloatTensor`): :math:`(B, )` 102 - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` 103 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 104 - return (:obj:`torch.FloatTensor`): :math:`(B, )` 105 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 106 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 107 - value_loss (:obj:`torch.FloatTensor`): :math:`()` 108 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 109 Examples: 110 >>> action_dim = 4 111 >>> data = ppo_data( 112 >>> logit_new=torch.randn(3, action_dim), 113 >>> logit_old=torch.randn(3, action_dim), 114 >>> action=torch.randint(0, action_dim, (3,)), 115 >>> value_new=torch.randn(3), 116 >>> value_old=torch.randn(3), 117 >>> adv=torch.randn(3), 118 >>> return_=torch.randn(3), 119 >>> weight=torch.ones(3), 120 >>> ) 121 >>> loss, info = ppo_error(data) 122 123 .. note:: 124 125 adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many 126 ways to calculate this mean and std, like among data buffer or train batch, so we don't couple 127 this part into ppo_error, you can refer to our examples for different ways. 128 """ 129 assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( 130 dual_clip 131 ) 132 logit_new, logit_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data 133 policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight, logit_pretrained) 134 policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip, kl_type=kl_type) 135 value_data = ppo_value_data(value_new, value_old, return_, weight) 136 value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip) 137 138 return ppo_loss( 139 policy_output.policy_loss, value_loss, policy_output.entropy_loss, policy_output.kl_div 140 ), policy_info 141 142 143def ppo_policy_error( 144 data: namedtuple, 145 clip_ratio: float = 0.2, 146 dual_clip: Optional[float] = None, 147 entropy_bonus: bool = True, 148 kl_type: str = 'k1' 149) -> Tuple[namedtuple, namedtuple]: 150 """ 151 Overview: 152 Get PPO policy loss (both for classical RL in control/video games and LLM/VLM RLHF). 153 Arguments: 154 - data (:obj:`namedtuple`): Ppo input data with fieids shown in ``ppo_policy_data``. 155 - clip_ratio (:obj:`float`): Clip value for ratio, defaults to 0.2. 156 - dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \ 157 defaults to 5.0, if you don't want to use it, set this parameter to None 158 - entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it. 159 - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. 160 Returns: 161 - ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor 162 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar 163 Shapes: 164 - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim 165 - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` 166 - action (:obj:`torch.LongTensor`): :math:`(B, )` 167 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 168 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 169 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 170 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 171 Examples: 172 >>> action_dim = 4 173 >>> data = ppo_policy_data( 174 >>> logit_new=torch.randn(3, action_dim), 175 >>> logit_old=torch.randn(3, action_dim), 176 >>> action=torch.randint(0, action_dim, (3,)), 177 >>> adv=torch.randn(3), 178 >>> weight=torch.ones(3), 179 >>> ) 180 >>> loss, info = ppo_policy_error(data) 181 182 .. note:: 183 This function can be extended from `B` to more parallel dimensions, like `(B, S)`, where `S` is the 184 sequence length in LLM/VLM. 185 186 .. note:: 187 For the action mask often used in LLM/VLM, users can set the `weight` to the action mask. 188 """ 189 logit_new, logit_old, action, adv, weight, logit_pretrained = data 190 if weight is None: 191 weight = torch.ones_like(adv) 192 dist_new = torch.distributions.categorical.Categorical(logits=logit_new) 193 dist_old = torch.distributions.categorical.Categorical(logits=logit_old) 194 logp_new = dist_new.log_prob(action) 195 logp_old = dist_old.log_prob(action) 196 197 if entropy_bonus: 198 dist_new_entropy = dist_new.entropy() 199 if dist_new_entropy.shape != weight.shape: # for the multi-agent rl case 200 dist_new_entropy = dist_new.entropy().mean(dim=1) 201 entropy_loss = (dist_new_entropy * weight).mean() 202 else: 203 entropy_loss = torch.tensor(0.0) 204 # policy_loss 205 ratio = torch.exp(logp_new - logp_old) 206 if ratio.shape != adv.shape: 207 ratio = ratio.mean(dim=1) 208 surr1 = ratio * adv 209 surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv 210 if dual_clip is not None: 211 clip1 = torch.min(surr1, surr2) 212 clip2 = torch.max(clip1, dual_clip * adv) 213 # only use dual_clip when adv < 0 214 policy_loss = -(torch.where(adv < 0, clip2, clip1) * weight).mean() 215 else: 216 policy_loss = (-torch.min(surr1, surr2) * weight).mean() 217 with torch.no_grad(): 218 approx_kl = (logp_old - logp_new).mean().item() 219 clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 220 clipfrac = torch.as_tensor(clipped).float().mean().item() 221 222 if logit_pretrained is not None: 223 dist_pretrained = torch.distributions.categorical.Categorical(logits=logit_pretrained) 224 logp_pretrained = dist_pretrained.log_prob(action) 225 log_ratio = logp_new - logp_pretrained 226 kl_div = calculate_kl_div(log_ratio, kl_type) 227 else: 228 kl_div = torch.tensor(0., dtype=policy_loss.dtype, device=policy_loss.device) 229 230 return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac) 231 232 233def ppo_value_error( 234 data: namedtuple, 235 clip_ratio: float = 0.2, 236 use_value_clip: bool = True, 237) -> torch.Tensor: 238 ''' 239 Overview: 240 Get PPO value loss 241 Arguments: 242 - data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_value_data`` 243 - clip_ratio (:obj:`float`): clip value for ratio 244 - use_value_clip (:obj:`bool`): whether use value clip 245 Returns: 246 - value_loss (:obj:`torch.FloatTensor`): the ppo value loss item, \ 247 all of them are the differentiable 0-dim tensor 248 Shapes: 249 - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size 250 - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` 251 - return (:obj:`torch.FloatTensor`): :math:`(B, )` 252 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 253 - value_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 254 Examples: 255 >>> action_dim = 4 256 >>> data = ppo_value_data( 257 >>> value_new=torch.randn(3), 258 >>> value_old=torch.randn(3), 259 >>> return_=torch.randn(3), 260 >>> weight=torch.ones(3), 261 >>> ) 262 >>> loss, info = ppo_value_error(data) 263 ''' 264 value_new, value_old, return_, weight = data 265 if weight is None: 266 weight = torch.ones_like(value_old) 267 # value_loss 268 if use_value_clip: 269 value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) 270 v1 = (return_ - value_new).pow(2) 271 v2 = (return_ - value_clip).pow(2) 272 value_loss = 0.5 * (torch.max(v1, v2) * weight).mean() 273 else: 274 value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() 275 return value_loss 276 277 278def ppo_error_continuous( 279 data: namedtuple, 280 clip_ratio: float = 0.2, 281 use_value_clip: bool = True, 282 dual_clip: Optional[float] = None, 283 kl_type: str = 'k1' 284) -> Tuple[namedtuple, namedtuple]: 285 """ 286 Overview: 287 Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip 288 Arguments: 289 - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` 290 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 291 - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy 292 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ 293 defaults to 5.0, if you don't want to use it, set this parameter to None 294 - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. 295 Returns: 296 - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor 297 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar 298 Shapes: 299 - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim 300 - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim 301 - action (:obj:`torch.LongTensor`): :math:`(B, )` 302 - value_new (:obj:`torch.FloatTensor`): :math:`(B, )` 303 - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` 304 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 305 - return (:obj:`torch.FloatTensor`): :math:`(B, )` 306 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 307 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 308 - value_loss (:obj:`torch.FloatTensor`): :math:`()` 309 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 310 Examples: 311 >>> action_dim = 4 312 >>> data = ppo_data_continuous( 313 >>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), 314 >>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), 315 >>> action=torch.randn(3, action_dim), 316 >>> value_new=torch.randn(3), 317 >>> value_old=torch.randn(3), 318 >>> adv=torch.randn(3), 319 >>> return_=torch.randn(3), 320 >>> weight=torch.ones(3), 321 >>> ) 322 >>> loss, info = ppo_error(data) 323 324 .. note:: 325 326 adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many 327 ways to calculate this mean and std, like among data buffer or train batch, so we don't couple 328 this part into ppo_error, you can refer to our examples for different ways. 329 """ 330 assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( 331 dual_clip 332 ) 333 mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, logit_pretrained = data 334 if weight is None: 335 weight = torch.ones_like(adv) 336 337 dist_new = Independent(Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']), 1) 338 if len(mu_sigma_old['mu'].shape) == 1: 339 dist_old = Independent(Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)), 1) 340 else: 341 dist_old = Independent(Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']), 1) 342 logp_new = dist_new.log_prob(action) 343 logp_old = dist_old.log_prob(action) 344 entropy_loss = (dist_new.entropy() * weight).mean() 345 # policy_loss 346 ratio = torch.exp(logp_new - logp_old) 347 surr1 = ratio * adv 348 surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv 349 if dual_clip is not None: 350 policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean() 351 else: 352 policy_loss = (-torch.min(surr1, surr2) * weight).mean() 353 with torch.no_grad(): 354 approx_kl = (logp_old - logp_new).mean().item() 355 clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 356 clipfrac = torch.as_tensor(clipped).float().mean().item() 357 # value_loss 358 if use_value_clip: 359 value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) 360 v1 = (return_ - value_new).pow(2) 361 v2 = (return_ - value_clip).pow(2) 362 value_loss = 0.5 * (torch.max(v1, v2) * weight).mean() 363 else: 364 value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() 365 366 if logit_pretrained is not None: 367 dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) 368 logp_pretrained = dist_pretrained.log_prob(action) 369 log_ratio = logp_new - logp_pretrained 370 kl_div = calculate_kl_div(log_ratio, kl_type) 371 else: 372 kl_div = torch.tensor(0., dtype=policy_loss.dtype, device=policy_loss.device) 373 374 return ppo_loss(policy_loss, value_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac) 375 376 377def ppo_policy_error_continuous( 378 data: namedtuple, 379 clip_ratio: float = 0.2, 380 dual_clip: Optional[float] = None, 381 kl_type: str = 'k1' 382) -> Tuple[namedtuple, namedtuple]: 383 """ 384 Overview: 385 Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip 386 Arguments: 387 - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` 388 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 389 - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ 390 defaults to 5.0, if you don't want to use it, set this parameter to None 391 - kl_type (:obj:`str`): which kl loss to use, default set to 'k1'. 392 Returns: 393 - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor 394 - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar 395 Shapes: 396 - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim 397 - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim 398 - action (:obj:`torch.LongTensor`): :math:`(B, )` 399 - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 400 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 401 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 402 - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` 403 Examples: 404 >>> action_dim = 4 405 >>> data = ppo_policy_data_continuous( 406 >>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), 407 >>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2), 408 >>> action=torch.randn(3, action_dim), 409 >>> adv=torch.randn(3), 410 >>> weight=torch.ones(3), 411 >>> ) 412 >>> loss, info = ppo_policy_error_continuous(data) 413 """ 414 assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( 415 dual_clip 416 ) 417 mu_sigma_new, mu_sigma_old, action, adv, weight, logit_pretrained = data 418 if weight is None: 419 weight = torch.ones_like(adv) 420 421 dist_new = Independent(Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']), 1) 422 if len(mu_sigma_old['mu'].shape) == 1: 423 dist_old = Independent(Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)), 1) 424 else: 425 dist_old = Independent(Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']), 1) 426 logp_new = dist_new.log_prob(action) 427 logp_old = dist_old.log_prob(action) 428 entropy_loss = (dist_new.entropy() * weight).mean() 429 # policy_loss 430 ratio = torch.exp(logp_new - logp_old) 431 surr1 = ratio * adv 432 surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv 433 if dual_clip is not None: 434 policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean() 435 else: 436 policy_loss = (-torch.min(surr1, surr2) * weight).mean() 437 with torch.no_grad(): 438 approx_kl = (logp_old - logp_new).mean().item() 439 clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 440 clipfrac = torch.as_tensor(clipped).float().mean().item() 441 442 if logit_pretrained is not None: 443 dist_pretrained = Independent(Normal(logit_pretrained['mu'], logit_pretrained['sigma']), 1) 444 logp_pretrained = dist_pretrained.log_prob(action) 445 log_ratio = logp_new - logp_pretrained 446 kl_div = calculate_kl_div(log_ratio, kl_type) 447 else: 448 kl_div = 0 449 450 return ppo_policy_loss(policy_loss, entropy_loss, kl_div), ppo_info(approx_kl, clipfrac)