Skip to content

ding.rl_utils.log_prob_utils

ding.rl_utils.log_prob_utils

naive_method(logits, index)

Calculate per-token log probabilities using naive method.

Parameters:

Name Type Description Default
logits Tensor

Token logits of shape [B, S, V] or [S, V] where: B = batch size S = sequence length V = vocabulary size

required
index Tensor

Selected token indices of shape [B, S] or [S]

required

Returns:

Name Type Description
Tensor Tensor

Log probabilities for selected tokens of shape [B, S] or [S]

efficient_method(logits, index)

Calculate per-token log probabilities efficiently.

Parameters:

Name Type Description Default
logits Tensor

Token logits of shape [B, S, V] or [S, V] where: B = batch size S = sequence length V = vocabulary size

required
index Tensor

Selected token indices of shape [B, S] or [S]

required

Returns:

Name Type Description
Tensor Tensor

Log probabilities for selected tokens of shape [B, S] or [S]

less_efficient_method(logits, index)

Calculate per-token log probabilities using categorical distribution.

Parameters:

Name Type Description Default
logits Tensor

Token logits of shape [B, S, V] or [S, V] where: B = batch size S = sequence length V = vocabulary size

required
index Tensor

Selected token indices of shape [B, S] or [S]

required

Returns:

Name Type Description
Tensor Tensor

Log probabilities for selected tokens of shape [B, S] or [S]

Full Source Code

../ding/rl_utils/log_prob_utils.py

1from typing import List, Callable, Optional, Any 2import torch 3from torch import Tensor 4 5LogitsProcessor = Callable[[Tensor, Tensor], Tensor] 6 7 8def naive_method(logits: Tensor, index: Tensor) -> Tensor: 9 """Calculate per-token log probabilities using naive method. 10 11 Args: 12 logits: Token logits of shape [B, S, V] or [S, V] where: 13 B = batch size 14 S = sequence length 15 V = vocabulary size 16 index: Selected token indices of shape [B, S] or [S] 17 18 Returns: 19 Tensor: Log probabilities for selected tokens of shape [B, S] or [S] 20 """ 21 # Calculate log probabilities for each token 22 log_prob_new: Tensor = torch.log_softmax(logits, dim=-1) 23 # Get log probabilities for selected actions 24 index = index.unsqueeze(-1) # [B, S, 1] or [S, 1] 25 per_token_logps: Tensor = torch.gather(log_prob_new, -1, index).squeeze(-1) 26 return per_token_logps 27 28 29def efficient_method(logits: Tensor, index: Tensor) -> Tensor: 30 """Calculate per-token log probabilities efficiently. 31 32 Args: 33 logits: Token logits of shape [B, S, V] or [S, V] where: 34 B = batch size 35 S = sequence length 36 V = vocabulary size 37 index: Selected token indices of shape [B, S] or [S] 38 39 Returns: 40 Tensor: Log probabilities for selected tokens of shape [B, S] or [S] 41 """ 42 if logits.dtype in [torch.float32, torch.float64]: 43 selected_logits: Tensor = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) 44 45 # Loop to reduce peak mem consumption 46 logsumexp_values: Tensor = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) 47 48 # log_softmax(x_i) = x_i - logsumexp(x) 49 per_token_logps: Tensor = selected_logits - logsumexp_values 50 else: 51 # logsumexp approach is unstable with bfloat16 52 per_token_logps: List[Tensor] = [] 53 54 # Loop to reduce peak mem consumption 55 for row_logits, row_labels in zip(logits, index): # Iterate over sequence length 56 row_logps: Tensor = torch.log_softmax(row_logits, dim=-1) 57 row_per_token_logps: Tensor = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) 58 per_token_logps.append(row_per_token_logps) 59 60 per_token_logps = torch.stack(per_token_logps) 61 62 return per_token_logps 63 64 65def less_efficient_method(logits: Tensor, index: Tensor) -> Tensor: 66 """Calculate per-token log probabilities using categorical distribution. 67 68 Args: 69 logits: Token logits of shape [B, S, V] or [S, V] where: 70 B = batch size 71 S = sequence length 72 V = vocabulary size 73 index: Selected token indices of shape [B, S] or [S] 74 75 Returns: 76 Tensor: Log probabilities for selected tokens of shape [B, S] or [S] 77 """ 78 dist = torch.distributions.categorical.Categorical(logits=logits) 79 logp: Tensor = dist.log_prob(index) 80 return logp 81 82 83# 定义一个统一的类型 84LogProbFunction = Callable[[Tensor, Tensor], Tensor] 85 86# 导出所有方法 87__all__ = ['naive_method', 'efficient_method', 'less_efficient_method', 'LogProbFunction']