ding.rl_utils.grpo¶
ding.rl_utils.grpo
¶
grpo_policy_error(data, log_prob_fn=efficient_method, clip_ratio=0.2, beta=0.1)
¶
Overview
Group Relative Policy Optimization (GRPO) algorithm, see https://arxiv.org/abs/2402.03300.
Arguments:
- data (:obj:namedtuple): the grpo input data with fields shown in grpo_policy_data.
- clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
- beta (:obj:float): weight coefficient for KL divergence regularization, defaults to 0.1.
- log_prob_fn (:obj:LogProbFunction): The method to calculate the log probabilities, defaults to efficient_method.
Returns:
- loss (:obj:torch.FloatTensor): the rloo policy loss, a differentiable 0-dim tensor.
- grpo_info (:obj:namedtuple): the grpo optim information for monitoring, all of them are Python scalar.
Shapes:
- logit_new (:obj:torch.FloatTensor): :math:(B, S, V), where B is batch size, S is sequence length, and V is vocabulary size.
- logit_old (:obj:torch.FloatTensor): :math:(B, S, V).
- logit_ref (:obj:torch.FloatTensor): :math:(B, S, V).
- action (:obj:torch.LongTensor): :math:(B, S).
- adv (:obj:torch.FloatTensor): :math:(B, ).
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, S).
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor.
- mean_kl (:obj:float): mean KL divergence between current and reference policy.
- mean_ratio (:obj:float): mean probability ratio.
- mean_clipped (:obj:float): proportion of clipped probability ratios.
Full Source Code
../ding/rl_utils/grpo.py