ding.rl_utils.rloo¶
ding.rl_utils.rloo
¶
rloo_policy_error(data, log_prob_fn=efficient_method, clip_ratio=0.2)
¶
Overview
REINFORCE Leave-One-Out (RLOO) algorithm, see https://arxiv.org/abs/2402.14740.
Arguments:
- data (:obj:namedtuple): the rloo input data with fields shown in rloo_policy_data.
- clip_ratio (:obj:float): the ppo clip ratio for the constraint of policy update, defaults to 0.2.
- log_prob_fn (:obj:LogProbFunction): The method to calculate the log probabilities, defaults to efficient_method.
Returns:
- loss (:obj:torch.FloatTensor): the rloo policy loss, a differentiable 0-dim tensor.
- rloo_info (:obj:namedtuple): the rloo optim information for monitoring, all of them are Python scalar.
Shapes:
- logit_new (:obj:torch.FloatTensor): :math:(B, S, V), where B is batch size, S is sequence length, and V is vocabulary size.
- logit_old (:obj:torch.FloatTensor): :math:(B, S, V).
- action (:obj:torch.LongTensor): :math:(B, S).
- reward (:obj:torch.FloatTensor): :math:(K, B), where K is the number of samples per prompt.
- weight (:obj:torch.FloatTensor or :obj:None): :math:(B, S).
- policy_loss (:obj:torch.FloatTensor): :math:(), 0-dim tensor.
- mean_ratio (:obj:float): mean probability ratio.
- mean_clipped (:obj:float): proportion of clipped probability ratios.
- mean_advantage (:obj:float): mean advantage value.
Full Source Code
../ding/rl_utils/rloo.py