ding.data.buffer.middleware.padding¶
ding.data.buffer.middleware.padding
¶
padding(policy='random')
¶
Overview
Fill the nested buffer list to the same size as the largest list.
The default policy random will randomly select data from each group
and fill it into the current group list.
Arguments:
- policy (:obj:str): Padding policy, supports random, none.
Full Source Code
../ding/data/buffer/middleware/padding.py
1import random 2from typing import Callable, Union, List 3 4from ding.data.buffer import BufferedData 5from ding.utils import fastcopy 6 7 8def padding(policy="random"): 9 """ 10 Overview: 11 Fill the nested buffer list to the same size as the largest list. 12 The default policy `random` will randomly select data from each group 13 and fill it into the current group list. 14 Arguments: 15 - policy (:obj:`str`): Padding policy, supports `random`, `none`. 16 """ 17 18 def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: 19 sampled_data = chain(*args, **kwargs) 20 if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData): 21 return sampled_data 22 max_len = len(max(sampled_data, key=len)) 23 for i, grouped_data in enumerate(sampled_data): 24 group_len = len(grouped_data) 25 if group_len == max_len: 26 continue 27 for _ in range(max_len - group_len): 28 if policy == "random": 29 sampled_data[i].append(fastcopy.copy(random.choice(grouped_data))) 30 elif policy == "none": 31 sampled_data[i].append(BufferedData(data=None, index=None, meta=None)) 32 33 return sampled_data 34 35 def _padding(action: str, chain: Callable, *args, **kwargs): 36 if action == "sample": 37 return sample(chain, *args, **kwargs) 38 return chain(*args, **kwargs) 39 40 return _padding