ding.rl_utils.ppo¶
ding.rl_utils.ppo
¶
calculate_kl_div(log_ratio, kl_type)
¶
Overview
Calculate different Monte-Carlo estimators for KL-divergence KL(q, p) = E_q[log(q/p)], where q is the current policy and p is the pretrained policy. The implementation is based on John Schulman's blog post "Approximating KL Divergence". Reference: http://joschu.net/blog/kl-approx.html
Arguments:
- log_ratio (:obj:torch.Tensor): The log-ratio of probabilities, which should be
log(q/p) = logp_new - logp_pretrained.
- kl_type (:obj:str): The type of KL divergence estimator to use.
- 'k1': The standard, unbiased but high-variance estimator: E_q[log(q/p)].
- 'k2': A biased, low-variance estimator from a second-order approximation: E_q[1/2 * (log(p/q))^2].
- 'k3': An unbiased, low-variance estimator: E_q[(p/q - 1) - log(p/q)].
Returns:
- kl_div (:obj:torch.Tensor): The calculated KL divergence estimate.
shape_fn_ppo(args, kwargs)
¶
Overview
Return shape of ppo for hpc
Returns: shape: [B, N]
ppo_error(data, clip_ratio=0.2, use_value_clip=True, dual_clip=None, kl_type='k1')
¶
Overview
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip
Arguments:
- data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data
- clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2
- use_value_clip (:obj:bool): whether to use clip in value loss with the same ratio as policy
- dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None
- kl_type (:obj:str): which kl loss to use, default set to 'k1'.
Returns:
- ppo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar
Shapes:
- logit_new (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim
- logit_old (:obj:torch.FloatTensor): :math:(B, N)
- action (:obj:torch.LongTensor): :math:(B, )
- value_new (:obj:torch.FloatTensor): :math:(B, )
- value_old (:obj:torch.FloatTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- return (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- value_loss (:obj:torch.FloatTensor): :math:()
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> action_dim = 4
>>> data = ppo_data(
>>> logit_new=torch.randn(3, action_dim),
>>> logit_old=torch.randn(3, action_dim),
>>> action=torch.randint(0, action_dim, (3,)),
>>> value_new=torch.randn(3),
>>> value_old=torch.randn(3),
>>> adv=torch.randn(3),
>>> return_=torch.randn(3),
>>> weight=torch.ones(3),
>>> )
>>> loss, info = ppo_error(data)
.. note::
adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
this part into ppo_error, you can refer to our examples for different ways.
ppo_policy_error(data, clip_ratio=0.2, dual_clip=None, entropy_bonus=True, kl_type='k1')
¶
Overview
Get PPO policy loss (both for classical RL in control/video games and LLM/VLM RLHF).
Arguments:
- data (:obj:namedtuple): Ppo input data with fieids shown in ppo_policy_data.
- clip_ratio (:obj:float): Clip value for ratio, defaults to 0.2.
- dual_clip (:obj:float): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None
- entropy_bonus (:obj:bool): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
- kl_type (:obj:str): which kl loss to use, default set to 'k1'.
Returns:
- ppo_policy_loss (:obj:namedtuple): the ppo policy loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar
Shapes:
- logit_new (:obj:torch.FloatTensor): :math:(B, N), where B is batch size and N is action dim
- logit_old (:obj:torch.FloatTensor): :math:(B, N)
- action (:obj:torch.LongTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> action_dim = 4
>>> data = ppo_policy_data(
>>> logit_new=torch.randn(3, action_dim),
>>> logit_old=torch.randn(3, action_dim),
>>> action=torch.randint(0, action_dim, (3,)),
>>> adv=torch.randn(3),
>>> weight=torch.ones(3),
>>> )
>>> loss, info = ppo_policy_error(data)
.. note::
This function can be extended from B to more parallel dimensions, like (B, S), where S is the
sequence length in LLM/VLM.
.. note::
For the action mask often used in LLM/VLM, users can set the weight to the action mask.
ppo_value_error(data, clip_ratio=0.2, use_value_clip=True)
¶
Overview
Get PPO value loss
Arguments:
- data (:obj:namedtuple): ppo input data with fieids shown in ppo_value_data
- clip_ratio (:obj:float): clip value for ratio
- use_value_clip (:obj:bool): whether use value clip
Returns:
- value_loss (:obj:torch.FloatTensor): the ppo value loss item, all of them are the differentiable 0-dim tensor
Shapes:
- value_new (:obj:torch.FloatTensor): :math:(B, ), where B is batch size
- value_old (:obj:torch.FloatTensor): :math:(B, )
- return (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- value_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
Examples:
>>> action_dim = 4
>>> data = ppo_value_data(
>>> value_new=torch.randn(3),
>>> value_old=torch.randn(3),
>>> return_=torch.randn(3),
>>> weight=torch.ones(3),
>>> )
>>> loss, info = ppo_value_error(data)
ppo_error_continuous(data, clip_ratio=0.2, use_value_clip=True, dual_clip=None, kl_type='k1')
¶
Overview
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip
Arguments:
- data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data
- clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2
- use_value_clip (:obj:bool): whether to use clip in value loss with the same ratio as policy
- dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None
- kl_type (:obj:str): which kl loss to use, default set to 'k1'.
Returns:
- ppo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar
Shapes:
- mu_sigma_new (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim
- mu_sigma_old (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim
- action (:obj:torch.LongTensor): :math:(B, )
- value_new (:obj:torch.FloatTensor): :math:(B, )
- value_old (:obj:torch.FloatTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- return (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- value_loss (:obj:torch.FloatTensor): :math:()
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> action_dim = 4
>>> data = ppo_data_continuous(
>>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2),
>>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2),
>>> action=torch.randn(3, action_dim),
>>> value_new=torch.randn(3),
>>> value_old=torch.randn(3),
>>> adv=torch.randn(3),
>>> return_=torch.randn(3),
>>> weight=torch.ones(3),
>>> )
>>> loss, info = ppo_error(data)
.. note::
adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
this part into ppo_error, you can refer to our examples for different ways.
ppo_policy_error_continuous(data, clip_ratio=0.2, dual_clip=None, kl_type='k1')
¶
Overview
Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
Arguments:
- data (:obj:namedtuple): the ppo input data with fieids shown in ppo_data
- clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2
- dual_clip (:obj:float): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), defaults to 5.0, if you don't want to use it, set this parameter to None
- kl_type (:obj:str): which kl loss to use, default set to 'k1'.
Returns:
- ppo_loss (:obj:namedtuple): the ppo loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:namedtuple): the ppo optim information for monitoring, all of them are Python scalar
Shapes:
- mu_sigma_new (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim
- mu_sigma_old (:obj:tuple): :math:((B, N), (B, N)), where B is batch size and N is action dim
- action (:obj:torch.LongTensor): :math:(B, )
- adv (:obj:torch.FloatTensor): :math:(B, )
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, )
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor
- entropy_loss (:obj:torch.FloatTensor): :math:()
Examples:
>>> action_dim = 4
>>> data = ppo_policy_data_continuous(
>>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2),
>>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)2),
>>> action=torch.randn(3, action_dim),
>>> adv=torch.randn(3),
>>> weight=torch.ones(3),
>>> )
>>> loss, info = ppo_policy_error_continuous(data)
Full Source Code
../ding/rl_utils/ppo.py