Skip to content

ding.torch_utils.metric

ding.torch_utils.metric

levenshtein_distance(pred, target, pred_extra=None, target_extra=None, extra_fn=None)

Overview

Levenshtein Distance, i.e. Edit Distance.

Arguments: - pred (:obj:torch.LongTensor): The first tensor to calculate the distance, shape: (N1, ) (N1 >= 0). - target (:obj:torch.LongTensor): The second tensor to calculate the distance, shape: (N2, ) (N2 >= 0). - pred_extra (:obj:Optional[torch.Tensor]): Extra tensor to calculate the distance, only works when extra_fn is not None. - target_extra (:obj:Optional[torch.Tensor]): Extra tensor to calculate the distance, only works when extra_fn is not None. - extra_fn (:obj:Optional[Callable]): The distance function for pred_extra and target_extra. If set to None, this distance will not be considered. Returns: - distance (:obj:torch.FloatTensor): distance(scalar), shape: (1, ).

hamming_distance(pred, target, weight=1.0)

Overview

Hamming Distance.

Arguments: - pred (:obj:torch.LongTensor): Pred input, boolean vector(0 or 1). - target (:obj:torch.LongTensor): Target input, boolean vector(0 or 1). - weight (:obj:torch.LongTensor): Weight to multiply. Returns: - distance(:obj:torch.LongTensor): Distance (scalar), shape (1, ). Shapes: - pred & target (:obj:torch.LongTensor): shape :math:(B, N), while B is the batch size, N is the dimension

Full Source Code

../ding/torch_utils/metric.py

1import torch 2from typing import Optional, Callable 3 4 5def levenshtein_distance( 6 pred: torch.LongTensor, 7 target: torch.LongTensor, 8 pred_extra: Optional[torch.Tensor] = None, 9 target_extra: Optional[torch.Tensor] = None, 10 extra_fn: Optional[Callable] = None 11) -> torch.FloatTensor: 12 """ 13 Overview: 14 Levenshtein Distance, i.e. Edit Distance. 15 Arguments: 16 - pred (:obj:`torch.LongTensor`): The first tensor to calculate the distance, shape: (N1, ) (N1 >= 0). 17 - target (:obj:`torch.LongTensor`): The second tensor to calculate the distance, shape: (N2, ) (N2 >= 0). 18 - pred_extra (:obj:`Optional[torch.Tensor]`): Extra tensor to calculate the distance, only works when \ 19 ``extra_fn`` is not ``None``. 20 - target_extra (:obj:`Optional[torch.Tensor]`): Extra tensor to calculate the distance, only works when \ 21 ``extra_fn`` is not ``None``. 22 - extra_fn (:obj:`Optional[Callable]`): The distance function for ``pred_extra`` and \ 23 ``target_extra``. If set to ``None``, this distance will not be considered. 24 Returns: 25 - distance (:obj:`torch.FloatTensor`): distance(scalar), shape: (1, ). 26 """ 27 assert (isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor)) 28 assert (pred.dtype == torch.long and target.dtype == torch.long), '{}\t{}'.format(pred.dtype, target.dtype) 29 assert (pred.device == target.device) 30 assert (type(pred_extra) == type(target_extra)) 31 if not extra_fn: 32 assert (not pred_extra) 33 N1, N2 = pred.shape[0], target.shape[0] 34 assert (N1 >= 0 and N2 >= 0) 35 if N1 == 0 or N2 == 0: 36 distance = max(N1, N2) 37 else: 38 dp_array = torch.zeros(N1, N2).float() 39 if extra_fn: 40 if pred[0] == target[0]: 41 extra = extra_fn(pred_extra[0], target_extra[0]) 42 else: 43 extra = 1. 44 dp_array[0, :] = torch.arange(0, N2) + extra 45 dp_array[:, 0] = torch.arange(0, N1) + extra 46 else: 47 dp_array[0, :] = torch.arange(0, N2) 48 dp_array[:, 0] = torch.arange(0, N1) 49 for i in range(1, N1): 50 for j in range(1, N2): 51 if pred[i] == target[j]: 52 if extra_fn: 53 dp_array[i, j] = dp_array[i - 1, j - 1] + extra_fn(pred_extra[i], target_extra[j]) 54 else: 55 dp_array[i, j] = dp_array[i - 1, j - 1] 56 else: 57 dp_array[i, j] = min(dp_array[i - 1, j] + 1, dp_array[i, j - 1] + 1, dp_array[i - 1, j - 1] + 1) 58 distance = dp_array[N1 - 1, N2 - 1] 59 return torch.FloatTensor([distance]).to(pred.device) 60 61 62def hamming_distance(pred: torch.LongTensor, target: torch.LongTensor, weight=1.) -> torch.LongTensor: 63 """ 64 Overview: 65 Hamming Distance. 66 Arguments: 67 - pred (:obj:`torch.LongTensor`): Pred input, boolean vector(0 or 1). 68 - target (:obj:`torch.LongTensor`): Target input, boolean vector(0 or 1). 69 - weight (:obj:`torch.LongTensor`): Weight to multiply. 70 Returns: 71 - distance(:obj:`torch.LongTensor`): Distance (scalar), shape (1, ). 72 Shapes: 73 - pred & target (:obj:`torch.LongTensor`): shape :math:`(B, N)`, \ 74 while B is the batch size, N is the dimension 75 """ 76 assert (isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor)) 77 assert (pred.dtype == torch.long and target.dtype == torch.long) 78 assert (pred.device == target.device) 79 assert (pred.shape == target.shape) 80 return pred.ne(target).sum(dim=1).float().mul_(weight)