Skip to content

ding.utils.data.collate_fn

ding.utils.data.collate_fn

ttorch_collate(x, json=False, cat_1dim=True)

Overview

Collates a list of tensors or nested dictionaries of tensors into a single tensor or nested dictionary of tensors.

Parameters:

Name Type Description Default
- x

The input list of tensors or nested dictionaries of tensors.

required
- json (

obj:bool): If True, converts the output to JSON format. Defaults to False.

required
- cat_1dim (

obj:bool): If True, concatenates tensors with shape (B, 1) along the last dimension. Defaults to True.

required

Returns:

Type Description

The collated output tensor or nested dictionary of tensors.

Examples:

>>> # case 1: Collate a list of tensors
>>> tensors = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])]
>>> collated = ttorch_collate(tensors)
collated = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> # case 2: Collate a nested dictionary of tensors
>>> nested_dict = {
        'a': torch.tensor([1, 2, 3]),
        'b': torch.tensor([4, 5, 6]),
        'c': torch.tensor([7, 8, 9])
    }
>>> collated = ttorch_collate(nested_dict)
collated = {
    'a': torch.tensor([1, 2, 3]),
    'b': torch.tensor([4, 5, 6]),
    'c': torch.tensor([7, 8, 9])
}
>>> # case 3: Collate a list of nested dictionaries of tensors
>>> nested_dicts = [
        {'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6])},
        {'a': torch.tensor([7, 8, 9]), 'b': torch.tensor([10, 11, 12])}
    ]
>>> collated = ttorch_collate(nested_dicts)
collated = {
    'a': torch.tensor([[1, 2, 3], [7, 8, 9]]),
    'b': torch.tensor([[4, 5, 6], [10, 11, 12]])
}

default_collate(batch, cat_1dim=True, ignore_prefix=['collate_ignore'])

Overview

Put each data field into a tensor with outer dimension batch size.

Parameters:

Name Type Description Default
- batch (

obj:Sequence): A data sequence, whose length is batch size, whose element is one piece of data.

required
- cat_1dim (

obj:bool): Whether to concatenate tensors with shape (B, 1) to (B), defaults to True.

required
- ignore_prefix (

obj:list): A list of prefixes to ignore when collating dictionaries, defaults to ['collate_ignore'].

required

Returns:

Type Description
Union[Tensor, Mapping, Sequence]
  • ret (:obj:Union[torch.Tensor, Mapping, Sequence]): the collated data, with batch size into each data field. The return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence].
Example

a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)

a = [torch.zeros(2,3) for _ in range(4)] default_collate(a).shape torch.Size([4, 2, 3])

a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )

a = [[0 for __ in range(3)] for _ in range(4)] default_collate(a) [tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]

a list with B dicts, whose values are tensors shaped :math:(m, n) -->>

a dict whose values are tensors with shape :math:(B, m, n)

a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)] print(a[0][2].shape, a[0][3].shape) torch.Size([2, 3]) torch.Size([3, 4]) b = default_collate(a) print(b[2].shape, b[3].shape) torch.Size([4, 2, 3]) torch.Size([4, 3, 4])

timestep_collate(batch)

Overview

Collates a batch of timestepped data fields into tensors with the outer dimension being the batch size. Each timestepped data field is represented as a tensor with shape [T, B, any_dims], where T is the length of the sequence, B is the batch size, and any_dims represents the shape of the tensor at each timestep.

Parameters:

Name Type Description Default
- batch(

obj:List[Dict[str, Any]]): A list of dictionaries with length B, where each dictionary represents a timestepped data field. Each dictionary contains a key-value pair, where the key is the name of the data field and the value is a sequence of torch.Tensor objects with any shape.

required

Returns:

Type Description
Dict[str, Union[Tensor, list]]
  • ret(:obj:Dict[str, Union[torch.Tensor, list]]): The collated data, with the timestep and batch size incorporated into each data field. The shape of each data field is [T, B, dim1, dim2, ...].

Examples:

>>> batch = [
        {'data0': [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]},
        {'data1': [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])]}
    ]
>>> collated_data = timestep_collate(batch)
>>> print(collated_data['data'].shape)
torch.Size([2, 2, 3])

diff_shape_collate(batch)

Overview

Collates a batch of data with different shapes. This function is similar to default_collate, but it allows tensors in the batch to have None values, which is common in StarCraft observations.

Parameters:

Name Type Description Default
- batch (

obj:Sequence): A sequence of data, where each element is a piece of data.

required

Returns:

Type Description
Union[Tensor, Mapping, Sequence]
  • ret (:obj:Union[torch.Tensor, Mapping, Sequence]): The collated data, with the batch size applied to each data field. The return type depends on the original element type and can be a torch.Tensor, Mapping, or Sequence.

Examples:

>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
>>> a = [torch.zeros(2,3) for _ in range(4)]
>>> diff_shape_collate(a).shape
torch.Size([4, 2, 3])
>>>
>>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
>>> a = [[0 for __ in range(3)] for _ in range(4)]
>>> diff_shape_collate(a)
[tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
>>>
>>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
>>> # a dict whose values are tensors with shape :math:`(B, m, n)`
>>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
>>> print(a[0][2].shape, a[0][3].shape)
torch.Size([2, 3]) torch.Size([3, 4])
>>> b = diff_shape_collate(a)
>>> print(b[2].shape, b[3].shape)
torch.Size([4, 2, 3]) torch.Size([4, 3, 4])

default_decollate(batch, ignore=['prev_state', 'prev_actor_state', 'prev_critic_state'])

Overview

Drag out batch_size collated data's batch size to decollate it, which is the reverse operation of default_collate.

Parameters:

Name Type Description Default
- batch (

obj:Union[torch.Tensor, Sequence, Mapping]): The collated data batch. It can be a tensor, sequence, or mapping.

required
- ignore(

obj:List[str]): A list of names to be ignored. Only applicable if the input batch is a dictionary. If a key is in this list, its value will remain the same without decollation. Defaults to ['prev_state', 'prev_actor_state', 'prev_critic_state'].

required

Returns:

Type Description
List[Any]
  • ret (:obj:List[Any]): A list with B elements, where B is the batch size.

Examples:

>>> batch = {
    'a': [
        [1, 2, 3],
        [4, 5, 6]
    ],
    'b': [
        [7, 8, 9],
        [10, 11, 12]
    ]}
>>> default_decollate(batch)
{
    0: {'a': [1, 2, 3], 'b': [7, 8, 9]},
    1: {'a': [4, 5, 6], 'b': [10, 11, 12]},
}

Full Source Code

../ding/utils/data/collate_fn.py

1from collections.abc import Sequence, Mapping 2from typing import List, Dict, Union, Any 3 4import torch 5import treetensor.torch as ttorch 6import re 7import collections.abc as container_abcs 8from ding.compatibility import torch_ge_131 9 10int_classes = int 11string_classes = (str, bytes) 12np_str_obj_array_pattern = re.compile(r'[SaUO]') 13 14default_collate_err_msg_format = ( 15 "default_collate: batch must contain tensors, numpy arrays, numbers, " 16 "dicts or lists; found {}" 17) 18 19 20def ttorch_collate(x, json: bool = False, cat_1dim: bool = True): 21 """ 22 Overview: 23 Collates a list of tensors or nested dictionaries of tensors into a single tensor or nested \ 24 dictionary of tensors. 25 26 Arguments: 27 - x : The input list of tensors or nested dictionaries of tensors. 28 - json (:obj:`bool`): If True, converts the output to JSON format. Defaults to False. 29 - cat_1dim (:obj:`bool`): If True, concatenates tensors with shape (B, 1) along the last dimension. \ 30 Defaults to True. 31 32 Returns: 33 The collated output tensor or nested dictionary of tensors. 34 35 Examples: 36 >>> # case 1: Collate a list of tensors 37 >>> tensors = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])] 38 >>> collated = ttorch_collate(tensors) 39 collated = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 40 >>> # case 2: Collate a nested dictionary of tensors 41 >>> nested_dict = { 42 'a': torch.tensor([1, 2, 3]), 43 'b': torch.tensor([4, 5, 6]), 44 'c': torch.tensor([7, 8, 9]) 45 } 46 >>> collated = ttorch_collate(nested_dict) 47 collated = { 48 'a': torch.tensor([1, 2, 3]), 49 'b': torch.tensor([4, 5, 6]), 50 'c': torch.tensor([7, 8, 9]) 51 } 52 >>> # case 3: Collate a list of nested dictionaries of tensors 53 >>> nested_dicts = [ 54 {'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6])}, 55 {'a': torch.tensor([7, 8, 9]), 'b': torch.tensor([10, 11, 12])} 56 ] 57 >>> collated = ttorch_collate(nested_dicts) 58 collated = { 59 'a': torch.tensor([[1, 2, 3], [7, 8, 9]]), 60 'b': torch.tensor([[4, 5, 6], [10, 11, 12]]) 61 } 62 """ 63 64 def inplace_fn(t): 65 for k in t.keys(): 66 if isinstance(t[k], torch.Tensor): 67 if len(t[k].shape) == 2 and t[k].shape[1] == 1: # reshape (B, 1) -> (B) 68 t[k] = t[k].squeeze(-1) 69 else: 70 inplace_fn(t[k]) 71 72 x = ttorch.stack(x) 73 if cat_1dim: 74 inplace_fn(x) 75 if json: 76 x = x.json() 77 return x 78 79 80def default_collate(batch: Sequence, 81 cat_1dim: bool = True, 82 ignore_prefix: list = ['collate_ignore']) -> Union[torch.Tensor, Mapping, Sequence]: 83 """ 84 Overview: 85 Put each data field into a tensor with outer dimension batch size. 86 87 Arguments: 88 - batch (:obj:`Sequence`): A data sequence, whose length is batch size, whose element is one piece of data. 89 - cat_1dim (:obj:`bool`): Whether to concatenate tensors with shape (B, 1) to (B), defaults to True. 90 - ignore_prefix (:obj:`list`): A list of prefixes to ignore when collating dictionaries, \ 91 defaults to ['collate_ignore']. 92 93 Returns: 94 - ret (:obj:`Union[torch.Tensor, Mapping, Sequence]`): the collated data, with batch size into each data \ 95 field. The return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence]. 96 97 Example: 98 >>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n) 99 >>> a = [torch.zeros(2,3) for _ in range(4)] 100 >>> default_collate(a).shape 101 torch.Size([4, 2, 3]) 102 >>> 103 >>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, ) 104 >>> a = [[0 for __ in range(3)] for _ in range(4)] 105 >>> default_collate(a) 106 [tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])] 107 >>> 108 >>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->> 109 >>> # a dict whose values are tensors with shape :math:`(B, m, n)` 110 >>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)] 111 >>> print(a[0][2].shape, a[0][3].shape) 112 torch.Size([2, 3]) torch.Size([3, 4]) 113 >>> b = default_collate(a) 114 >>> print(b[2].shape, b[3].shape) 115 torch.Size([4, 2, 3]) torch.Size([4, 3, 4]) 116 """ 117 118 if isinstance(batch, ttorch.Tensor): 119 return batch.json() 120 121 elem = batch[0] 122 elem_type = type(elem) 123 if isinstance(elem, torch.Tensor): 124 out = None 125 if torch_ge_131() and torch.utils.data.get_worker_info() is not None: 126 # If we're in a background process, directly concatenate into a 127 # shared memory tensor to avoid an extra copy 128 numel = sum([x.numel() for x in batch]) 129 storage = elem.storage()._new_shared(numel) 130 out = elem.new(storage) 131 if elem.shape == (1, ) and cat_1dim: 132 # reshape (B, 1) -> (B) 133 return torch.cat(batch, 0, out=out) 134 # return torch.stack(batch, 0, out=out) 135 else: 136 return torch.stack(batch, 0, out=out) 137 elif isinstance(elem, ttorch.Tensor): 138 return ttorch_collate(batch, json=True, cat_1dim=cat_1dim) 139 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 140 and elem_type.__name__ != 'string_': 141 if elem_type.__name__ == 'ndarray': 142 # array of string classes and object 143 if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 144 raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 145 return default_collate([torch.as_tensor(b) for b in batch], cat_1dim=cat_1dim) 146 elif elem.shape == (): # scalars 147 return torch.as_tensor(batch) 148 elif isinstance(elem, float): 149 return torch.tensor(batch, dtype=torch.float32) 150 elif isinstance(elem, int_classes): 151 dtype = torch.bool if isinstance(elem, bool) else torch.int64 152 return torch.tensor(batch, dtype=dtype) 153 elif isinstance(elem, string_classes): 154 return batch 155 elif isinstance(elem, container_abcs.Mapping): 156 ret = {} 157 for key in elem: 158 if any([key.startswith(t) for t in ignore_prefix]): 159 ret[key] = [d[key] for d in batch] 160 else: 161 ret[key] = default_collate([d[key] for d in batch], cat_1dim=cat_1dim) 162 return ret 163 elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 164 return elem_type(*(default_collate(samples, cat_1dim=cat_1dim) for samples in zip(*batch))) 165 elif isinstance(elem, container_abcs.Sequence): 166 transposed = zip(*batch) 167 return [default_collate(samples, cat_1dim=cat_1dim) for samples in transposed] 168 169 raise TypeError(default_collate_err_msg_format.format(elem_type)) 170 171 172def timestep_collate(batch: List[Dict[str, Any]]) -> Dict[str, Union[torch.Tensor, list]]: 173 """ 174 Overview: 175 Collates a batch of timestepped data fields into tensors with the outer dimension being the batch size. \ 176 Each timestepped data field is represented as a tensor with shape [T, B, any_dims], where T is the length \ 177 of the sequence, B is the batch size, and any_dims represents the shape of the tensor at each timestep. 178 179 Arguments: 180 - batch(:obj:`List[Dict[str, Any]]`): A list of dictionaries with length B, where each dictionary represents \ 181 a timestepped data field. Each dictionary contains a key-value pair, where the key is the name of the \ 182 data field and the value is a sequence of torch.Tensor objects with any shape. 183 184 Returns: 185 - ret(:obj:`Dict[str, Union[torch.Tensor, list]]`): The collated data, with the timestep and batch size \ 186 incorporated into each data field. The shape of each data field is [T, B, dim1, dim2, ...]. 187 188 Examples: 189 >>> batch = [ 190 {'data0': [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]}, 191 {'data1': [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])]} 192 ] 193 >>> collated_data = timestep_collate(batch) 194 >>> print(collated_data['data'].shape) 195 torch.Size([2, 2, 3]) 196 """ 197 198 def stack(data): 199 if isinstance(data, container_abcs.Mapping): 200 return {k: stack(data[k]) for k in data} 201 elif isinstance(data, container_abcs.Sequence) and isinstance(data[0], torch.Tensor): 202 return torch.stack(data) 203 else: 204 return data 205 206 elem = batch[0] 207 assert isinstance(elem, (container_abcs.Mapping, list)), type(elem) 208 if isinstance(batch[0], list): # new pipeline + treetensor 209 prev_state = [[b[i].get('prev_state') for b in batch] for i in range(len(batch[0]))] 210 batch_data = ttorch.stack([ttorch_collate(b) for b in batch]) # (B, T, *) 211 del batch_data.prev_state 212 batch_data = batch_data.transpose(1, 0) 213 batch_data.prev_state = prev_state 214 else: 215 prev_state = [b.pop('prev_state') for b in batch] 216 batch_data = default_collate(batch) # -> {some_key: T lists}, each list is [B, some_dim] 217 batch_data = stack(batch_data) # -> {some_key: [T, B, some_dim]} 218 transformed_prev_state = list(zip(*prev_state)) 219 batch_data['prev_state'] = transformed_prev_state 220 # append back prev_state, avoiding multi batch share the same data bug 221 for i in range(len(batch)): 222 batch[i]['prev_state'] = prev_state[i] 223 return batch_data 224 225 226def diff_shape_collate(batch: Sequence) -> Union[torch.Tensor, Mapping, Sequence]: 227 """ 228 Overview: 229 Collates a batch of data with different shapes. 230 This function is similar to `default_collate`, but it allows tensors in the batch to have `None` values, \ 231 which is common in StarCraft observations. 232 233 Arguments: 234 - batch (:obj:`Sequence`): A sequence of data, where each element is a piece of data. 235 236 Returns: 237 - ret (:obj:`Union[torch.Tensor, Mapping, Sequence]`): The collated data, with the batch size applied \ 238 to each data field. The return type depends on the original element type and can be a torch.Tensor, \ 239 Mapping, or Sequence. 240 241 Examples: 242 >>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n) 243 >>> a = [torch.zeros(2,3) for _ in range(4)] 244 >>> diff_shape_collate(a).shape 245 torch.Size([4, 2, 3]) 246 >>> 247 >>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, ) 248 >>> a = [[0 for __ in range(3)] for _ in range(4)] 249 >>> diff_shape_collate(a) 250 [tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])] 251 >>> 252 >>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->> 253 >>> # a dict whose values are tensors with shape :math:`(B, m, n)` 254 >>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)] 255 >>> print(a[0][2].shape, a[0][3].shape) 256 torch.Size([2, 3]) torch.Size([3, 4]) 257 >>> b = diff_shape_collate(a) 258 >>> print(b[2].shape, b[3].shape) 259 torch.Size([4, 2, 3]) torch.Size([4, 3, 4]) 260 """ 261 elem = batch[0] 262 elem_type = type(elem) 263 if any([isinstance(elem, type(None)) for elem in batch]): 264 return batch 265 elif isinstance(elem, torch.Tensor): 266 shapes = [e.shape for e in batch] 267 if len(set(shapes)) != 1: 268 return batch 269 else: 270 return torch.stack(batch, 0) 271 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 272 and elem_type.__name__ != 'string_': 273 if elem_type.__name__ == 'ndarray': 274 return diff_shape_collate([torch.as_tensor(b) for b in batch]) # todo 275 elif elem.shape == (): # scalars 276 return torch.as_tensor(batch) 277 elif isinstance(elem, float): 278 return torch.tensor(batch, dtype=torch.float32) 279 elif isinstance(elem, int_classes): 280 dtype = torch.bool if isinstance(elem, bool) else torch.int64 281 return torch.tensor(batch, dtype=dtype) 282 elif isinstance(elem, Mapping): 283 return {key: diff_shape_collate([d[key] for d in batch]) for key in elem} 284 elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 285 return elem_type(*(diff_shape_collate(samples) for samples in zip(*batch))) 286 elif isinstance(elem, Sequence): 287 transposed = zip(*batch) 288 return [diff_shape_collate(samples) for samples in transposed] 289 290 raise TypeError('not support element type: {}'.format(elem_type)) 291 292 293def default_decollate( 294 batch: Union[torch.Tensor, Sequence, Mapping], 295 ignore: List[str] = ['prev_state', 'prev_actor_state', 'prev_critic_state'] 296) -> List[Any]: 297 """ 298 Overview: 299 Drag out batch_size collated data's batch size to decollate it, which is the reverse operation of \ 300 ``default_collate``. 301 302 Arguments: 303 - batch (:obj:`Union[torch.Tensor, Sequence, Mapping]`): The collated data batch. It can be a tensor, \ 304 sequence, or mapping. 305 - ignore(:obj:`List[str]`): A list of names to be ignored. Only applicable if the input ``batch`` is a \ 306 dictionary. If a key is in this list, its value will remain the same without decollation. Defaults to \ 307 ['prev_state', 'prev_actor_state', 'prev_critic_state']. 308 309 Returns: 310 - ret (:obj:`List[Any]`): A list with B elements, where B is the batch size. 311 312 Examples: 313 >>> batch = { 314 'a': [ 315 [1, 2, 3], 316 [4, 5, 6] 317 ], 318 'b': [ 319 [7, 8, 9], 320 [10, 11, 12] 321 ]} 322 >>> default_decollate(batch) 323 { 324 0: {'a': [1, 2, 3], 'b': [7, 8, 9]}, 325 1: {'a': [4, 5, 6], 'b': [10, 11, 12]}, 326 } 327 """ 328 if isinstance(batch, torch.Tensor): 329 batch = torch.split(batch, 1, dim=0) 330 # Squeeze if the original batch's shape is like (B, dim1, dim2, ...); 331 # otherwise, directly return the list. 332 if len(batch[0].shape) > 1: 333 batch = [elem.squeeze(0) for elem in batch] 334 return list(batch) 335 elif isinstance(batch, Sequence): 336 return list(zip(*[default_decollate(e) for e in batch])) 337 elif isinstance(batch, Mapping): 338 tmp = {k: v if k in ignore else default_decollate(v) for k, v in batch.items()} 339 B = len(list(tmp.values())[0]) 340 return [{k: tmp[k][i] for k in tmp.keys()} for i in range(B)] 341 elif isinstance(batch, torch.distributions.Distribution): # For compatibility 342 return [None for _ in range(batch.batch_shape[0])] 343 344 raise TypeError("Not supported batch type: {}".format(type(batch)))