ding.data.buffer.middleware.group_sample¶
ding.data.buffer.middleware.group_sample
¶
group_sample(size_in_group, ordered_in_group=True, max_use_in_group=True)
¶
Overview
The middleware is designed to process the data in each group after sampling from the buffer.
Arguments:
- size_in_group (:obj:int): Sample size in each group.
- ordered_in_group (:obj:bool): Whether to keep the original order of records, default is true.
- max_use_in_group (:obj:bool): Whether to use as much data in each group as possible, default is true.
Full Source Code
../ding/data/buffer/middleware/group_sample.py
1import random 2from typing import Callable, List 3from ding.data.buffer.buffer import BufferedData 4 5 6def group_sample(size_in_group: int, ordered_in_group: bool = True, max_use_in_group: bool = True) -> Callable: 7 """ 8 Overview: 9 The middleware is designed to process the data in each group after sampling from the buffer. 10 Arguments: 11 - size_in_group (:obj:`int`): Sample size in each group. 12 - ordered_in_group (:obj:`bool`): Whether to keep the original order of records, default is true. 13 - max_use_in_group (:obj:`bool`): Whether to use as much data in each group as possible, default is true. 14 """ 15 16 def sample(chain: Callable, *args, **kwargs) -> List[List[BufferedData]]: 17 if not kwargs.get("groupby"): 18 raise Exception("Group sample must be used when the `groupby` parameter is specified.") 19 sampled_data = chain(*args, **kwargs) 20 for i, grouped_data in enumerate(sampled_data): 21 if ordered_in_group: 22 if max_use_in_group: 23 end = max(0, len(grouped_data) - size_in_group) + 1 24 else: 25 end = len(grouped_data) 26 start_idx = random.choice(range(end)) 27 sampled_data[i] = grouped_data[start_idx:start_idx + size_in_group] 28 else: 29 sampled_data[i] = random.sample(grouped_data, k=size_in_group) 30 return sampled_data 31 32 def _group_sample(action: str, chain: Callable, *args, **kwargs): 33 if action == "sample": 34 return sample(chain, *args, **kwargs) 35 return chain(*args, **kwargs) 36 37 return _group_sample