Skip to content

ding.rl_utils.rloo

ding.rl_utils.rloo

rloo_policy_error(data, log_prob_fn=efficient_method, clip_ratio=0.2)

Overview

REINFORCE Leave-One-Out (RLOO) algorithm, see https://arxiv.org/abs/2402.14740.

Arguments: - data (:obj:namedtuple): the rloo input data with fields shown in rloo_policy_data. - clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2. - log_prob_fn (:obj:LogProbFunction): The method to calculate the log probabilities, defaults to efficient_method. Returns: - loss (:obj:torch.FloatTensor): the rloo policy loss, a differentiable 0-dim tensor. - rloo_info (:obj:namedtuple): the rloo optim information for monitoring, all of them are Python scalar. Shapes: - logit_new (:obj:torch.FloatTensor): :math:(B, S, V), where B is batch size, S is sequence length, and V is vocabulary size. - logit_old (:obj:torch.FloatTensor): :math:(B, S, V). - action (:obj:torch.LongTensor): :math:(B, S). - reward (:obj:torch.FloatTensor): :math:(K, B), where K is the number of samples per prompt. - weight (:obj:torch.FloatTensor or :obj:None): :math:(B, S). - policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor. - mean_ratio (:obj:float): mean probability ratio. - mean_clipped (:obj:float): proportion of clipped probability ratios. - mean_advantage (:obj:float): mean advantage value.

Full Source Code

../ding/rl_utils/rloo.py

1from typing import Tuple 2from collections import namedtuple 3import torch 4from .log_prob_utils import efficient_method, naive_method, less_efficient_method, LogProbFunction 5 6rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'reward', 'weight']) 7rloo_info = namedtuple('rloo_info', ['approx_kl', 'clipfrac']) 8 9 10def rloo_policy_error( 11 data: namedtuple, 12 log_prob_fn: LogProbFunction = efficient_method, # Method to calculate the log probabilities 13 clip_ratio: float = 0.2, 14) -> Tuple[namedtuple, namedtuple]: 15 """ 16 Overview: 17 REINFORCE Leave-One-Out (RLOO) algorithm, see https://arxiv.org/abs/2402.14740. 18 Arguments: 19 - data (:obj:`namedtuple`): the rloo input data with fields shown in ``rloo_policy_data``. 20 - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2. 21 - log_prob_fn (:obj:`LogProbFunction`): The method to calculate the log probabilities, \ 22 defaults to `efficient_method`. 23 Returns: 24 - loss (:obj:`torch.FloatTensor`): the rloo policy loss, a differentiable 0-dim tensor. 25 - rloo_info (:obj:`namedtuple`): the rloo optim information for monitoring, all of them are Python scalar. 26 Shapes: 27 - logit_new (:obj:`torch.FloatTensor`): :math:`(B, S, V)`, where B is batch size, S is sequence length,\ 28 and V is vocabulary size. 29 - logit_old (:obj:`torch.FloatTensor`): :math:`(B, S, V)`. 30 - action (:obj:`torch.LongTensor`): :math:`(B, S)`. 31 - reward (:obj:`torch.FloatTensor`): :math:`(K, B)`, where K is the number of samples per prompt. 32 - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, S)`. 33 - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor. 34 - mean_ratio (:obj:`float`): mean probability ratio. 35 - mean_clipped (:obj:`float`): proportion of clipped probability ratios. 36 - mean_advantage (:obj:`float`): mean advantage value. 37 """ 38 39 # Calculate advantage of each action 40 rloo_k = data.reward.size(0) 41 baseline = (data.reward.sum(0) - data.reward) / (rloo_k - 1) 42 adv = data.reward - baseline 43 adv = adv.flatten() 44 45 # Get log probabilities for selected actions 46 per_token_logps = log_prob_fn(data.logit_new, data.action) 47 per_token_old_logps = log_prob_fn(data.logit_old, data.action) 48 49 # Calculate policy ratio 50 ratio = torch.exp(per_token_logps - per_token_old_logps) 51 ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) 52 53 # Calculate loss for each token 54 advantages = adv.unsqueeze(1) # [B, 1] 55 per_token_loss_unclipped = ratio * advantages 56 per_token_loss_clipped = ratio_clipped * advantages 57 per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped) 58 59 # Calculate average loss using weight mask 60 weight = data.weight if data.weight is not None else (torch.ones_like(per_token_loss)) 61 loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean() 62 63 # Calculate additional metrics 64 with torch.no_grad(): 65 approx_kl = (per_token_old_logps - per_token_logps).mean().item() 66 clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 67 clipfrac = torch.as_tensor(clipped).float().mean().item() 68 69 return loss, rloo_info(approx_kl=approx_kl, clipfrac=clipfrac)