Skip to content

ding.rl_utils.grpo

ding.rl_utils.grpo

grpo_policy_error(data, log_prob_fn=efficient_method, clip_ratio=0.2, beta=0.1)

Overview

Group Relative Policy Optimization (GRPO) algorithm, see https://arxiv.org/abs/2402.03300.

Arguments: - data (:obj:namedtuple): the grpo input data with fields shown in grpo_policy_data. - clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2. - beta (:obj:float): weight coefficient for KL divergence regularization, defaults to 0.1. - log_prob_fn (:obj:LogProbFunction): The method to calculate the log probabilities, defaults to efficient_method. Returns: - loss (:obj:torch.FloatTensor): the rloo policy loss, a differentiable 0-dim tensor. - grpo_info (:obj:namedtuple): the grpo optim information for monitoring, all of them are Python scalar. Shapes: - logit_new (:obj:torch.FloatTensor): :math:(B, S, V), where B is batch size, S is sequence length, and V is vocabulary size. - logit_old (:obj:torch.FloatTensor): :math:(B, S, V). - logit_ref (:obj:torch.FloatTensor): :math:(B, S, V). - action (:obj:torch.LongTensor): :math:(B, S). - adv (:obj:torch.FloatTensor): :math:(B, ). - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, S). - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor. - mean_kl (:obj:float): mean KL divergence between current and reference policy. - mean_ratio (:obj:float): mean probability ratio. - mean_clipped (:obj:float): proportion of clipped probability ratios.

Full Source Code

../ding/rl_utils/grpo.py

1from typing import Tuple 2from collections import namedtuple 3import torch 4from .log_prob_utils import efficient_method, naive_method, less_efficient_method, LogProbFunction 5 6grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight']) 7grpo_info = namedtuple('grpo_info', ['approx_kl', 'clipfrac']) 8 9 10def grpo_policy_error( 11 data: namedtuple, 12 log_prob_fn: LogProbFunction = efficient_method, # Method to calculate the log probabilities 13 clip_ratio: float = 0.2, 14 beta: float = 0.1 # Weight coefficient for KL divergence 15) -> Tuple[namedtuple, namedtuple]: 16 """ 17 Overview: 18 Group Relative Policy Optimization (GRPO) algorithm, see https://arxiv.org/abs/2402.03300. 19 Arguments: 20 - data (:obj:`namedtuple`): the grpo input data with fields shown in ``grpo_policy_data``. 21 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2. 22 - beta (:obj:`float`): weight coefficient for KL divergence regularization, defaults to 0.1. 23 - log_prob_fn (:obj:`LogProbFunction`): The method to calculate the log probabilities, \ 24 defaults to `efficient_method`. 25 Returns: 26 - loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor. 27 - grpo_info (:obj:`namedtuple`): the grpo optim information for monitoring, all of them are Python scalar. 28 Shapes: 29 - logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length, \ 30 and V is vocabulary size. 31 - logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`. 32 - logit_ref (:obj:`torch.FloatTensor`): :math:`(B, S, V)`. 33 - action (:obj:`torch.LongTensor`): :math:`(B, S)`. 34 - adv (:obj:`torch.FloatTensor`): :math:`(B, )`. 35 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, S)`. 36 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor. 37 - mean_kl (:obj:`float`): mean KL divergence between current and reference policy. 38 - mean_ratio (:obj:`float`): mean probability ratio. 39 - mean_clipped (:obj:`float`): proportion of clipped probability ratios. 40 """ 41 42 # Calculate log probabilities for selected token 43 per_token_logps = log_prob_fn(data.logit_new, data.action) 44 per_token_ref_logps = log_prob_fn(data.logit_ref, data.action) 45 per_token_old_logps = log_prob_fn(data.logit_old, data.action) 46 47 # Calculate KL divergence: exp(q-p) - (q-p) - 1, 48 # where p is current policy and q is reference policy 49 per_token_kl = (torch.exp(per_token_ref_logps - per_token_logps) - (per_token_ref_logps - per_token_logps) - 1) 50 51 # Calculate policy ratio 52 ratio = torch.exp(per_token_logps - per_token_old_logps) 53 ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) 54 55 # Calculate loss for each token 56 advantages = data.adv.unsqueeze(1) # [B, 1] 57 per_token_loss_unclipped = ratio * advantages 58 per_token_loss_clipped = ratio_clipped * advantages 59 per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped) 60 61 # Add KL divergence regularization term 62 per_token_loss = per_token_loss + beta * per_token_kl 63 64 # Calculate average loss using weight mask 65 weight = data.weight if data.weight is not None \ 66 else torch.ones_like(per_token_loss) 67 loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean() 68 69 # Calculate additional metrics 70 with torch.no_grad(): 71 approx_kl = (per_token_old_logps - per_token_logps).mean().item() 72 clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 73 clipfrac = torch.as_tensor(clipped).float().mean().item() 74 75 return loss, grpo_info(approx_kl=approx_kl, clipfrac=clipfrac)