Skip to content

ding.torch_utils.reshape_helper

ding.torch_utils.reshape_helper

fold_batch(x, nonbatch_ndims=1)

Overview

:math:(T, B, X) \leftarrow (T*B, X) Fold the first (ndim - nonbatch_ndims) dimensions of a tensor as batch dimension. This operation is similar to torch.flatten but provides an inverse function unfold_batch to restore the folded dimensions.

Parameters:

Name Type Description Default
- x (

obj:torch.Tensor): the tensor to fold

required
- nonbatch_ndims (

obj:int): the number of dimensions that is not folded as batch dimension.

required

Returns:

Type Description
Tensor
  • x (:obj:torch.Tensor): the folded tensor
Size
  • batch_dims: the folded dimensions of the original tensor, which can be used to reverse the operation

Examples:

>>> x = torch.ones(10, 20, 5, 4, 8)
>>> x, batch_dim = fold_batch(x, 2)
>>> x.shape == (1000, 4, 8)
>>> batch_dim == (10, 20, 5)

unfold_batch(x, batch_dims)

Overview

Unfold the batch dimension of a tensor.

Parameters:

Name Type Description Default
- x (

obj:torch.Tensor): the tensor to unfold

required
- batch_dims (

obj:torch.Size): the dimensions that are folded

required

Returns:

Type Description
Tensor
  • x (:obj:torch.Tensor): the original unfolded tensor

Examples:

>>> x = torch.ones(10, 20, 5, 4, 8)
>>> x, batch_dim = fold_batch(x, 2)
>>> x.shape == (1000, 4, 8)
>>> batch_dim == (10, 20, 5)
>>> x = unfold_batch(x, batch_dim)
>>> x.shape == (10, 20, 5, 4, 8)

unsqueeze_repeat(x, repeat_times, unsqueeze_dim=0)

Overview

Squeeze the tensor on unsqueeze_dim and then repeat in this dimension for repeat_times times. This is useful for preproprocessing the input to an model ensemble.

Parameters:

Name Type Description Default
- x (

obj:torch.Tensor): the tensor to squeeze and repeat

required
- repeat_times (

obj:int): the times that the tensor is repeatd

required
- unsqueeze_dim (

obj:int): the unsqueezed dimension

required

Returns:

Type Description
Tensor
  • x (:obj:torch.Tensor): the unsqueezed and repeated tensor

Examples:

>>> x = torch.ones(64, 6)
>>> x = unsqueeze_repeat(x, 4)
>>> x.shape == (4, 64, 6)
>>> x = torch.ones(64, 6)
>>> x = unsqueeze_repeat(x, 4, -1)
>>> x.shape == (64, 6, 4)

Full Source Code

../ding/torch_utils/reshape_helper.py

1from typing import Tuple, Union 2 3from torch import Tensor, Size 4 5 6def fold_batch(x: Tensor, nonbatch_ndims: int = 1) -> Tuple[Tensor, Size]: 7 """ 8 Overview: 9 :math:`(T, B, X) \leftarrow (T*B, X)`\ 10 Fold the first (ndim - nonbatch_ndims) dimensions of a tensor as batch dimension.\ 11 This operation is similar to `torch.flatten` but provides an inverse function 12 `unfold_batch` to restore the folded dimensions. 13 14 Arguments: 15 - x (:obj:`torch.Tensor`): the tensor to fold 16 - nonbatch_ndims (:obj:`int`): the number of dimensions that is not folded as 17 batch dimension. 18 19 Returns: 20 - x (:obj:`torch.Tensor`): the folded tensor 21 - batch_dims: the folded dimensions of the original tensor, which can be used to 22 reverse the operation 23 24 Examples: 25 >>> x = torch.ones(10, 20, 5, 4, 8) 26 >>> x, batch_dim = fold_batch(x, 2) 27 >>> x.shape == (1000, 4, 8) 28 >>> batch_dim == (10, 20, 5) 29 30 """ 31 if nonbatch_ndims > 0: 32 batch_dims = x.shape[:-nonbatch_ndims] 33 x = x.view(-1, *(x.shape[-nonbatch_ndims:])) 34 return x, batch_dims 35 else: 36 batch_dims = x.shape 37 x = x.view(-1) 38 return x, batch_dims 39 40 41def unfold_batch(x: Tensor, batch_dims: Union[Size, Tuple]) -> Tensor: 42 """ 43 Overview: 44 Unfold the batch dimension of a tensor. 45 46 Arguments: 47 - x (:obj:`torch.Tensor`): the tensor to unfold 48 - batch_dims (:obj:`torch.Size`): the dimensions that are folded 49 50 Returns: 51 - x (:obj:`torch.Tensor`): the original unfolded tensor 52 53 Examples: 54 >>> x = torch.ones(10, 20, 5, 4, 8) 55 >>> x, batch_dim = fold_batch(x, 2) 56 >>> x.shape == (1000, 4, 8) 57 >>> batch_dim == (10, 20, 5) 58 >>> x = unfold_batch(x, batch_dim) 59 >>> x.shape == (10, 20, 5, 4, 8) 60 """ 61 return x.view(*batch_dims, *x.shape[1:]) 62 63 64def unsqueeze_repeat(x: Tensor, repeat_times: int, unsqueeze_dim: int = 0) -> Tensor: 65 """ 66 Overview: 67 Squeeze the tensor on `unsqueeze_dim` and then repeat in this dimension for `repeat_times` times.\ 68 This is useful for preproprocessing the input to an model ensemble. 69 70 Arguments: 71 - x (:obj:`torch.Tensor`): the tensor to squeeze and repeat 72 - repeat_times (:obj:`int`): the times that the tensor is repeatd 73 - unsqueeze_dim (:obj:`int`): the unsqueezed dimension 74 75 Returns: 76 - x (:obj:`torch.Tensor`): the unsqueezed and repeated tensor 77 78 Examples: 79 >>> x = torch.ones(64, 6) 80 >>> x = unsqueeze_repeat(x, 4) 81 >>> x.shape == (4, 64, 6) 82 83 >>> x = torch.ones(64, 6) 84 >>> x = unsqueeze_repeat(x, 4, -1) 85 >>> x.shape == (64, 6, 4) 86 """ 87 assert -1 <= unsqueeze_dim <= len(x.shape), f'unsqueeze_dim should be from {-1} to {len(x.shape)}' 88 x = x.unsqueeze(unsqueeze_dim) 89 repeats = [1] * len(x.shape) 90 repeats[unsqueeze_dim] *= repeat_times 91 return x.repeat(*repeats)