Skip to content

ding.data.buffer.middleware.use_time_check

ding.data.buffer.middleware.use_time_check

use_time_check(buffer_, max_use=float('inf'))

Overview

This middleware aims to check the usage times of data in buffer. If the usage times of a data is greater than or equal to max_use, this data will be removed from buffer as soon as possible.

Arguments: - max_use (:obj:int): The max reused (resampled) count for any individual object.

Full Source Code

../ding/data/buffer/middleware/use_time_check.py

1from typing import Callable, Any, List, Optional, Union, TYPE_CHECKING 2from collections import defaultdict 3from ding.data.buffer import BufferedData 4if TYPE_CHECKING: 5 from ding.data.buffer.buffer import Buffer 6 7 8def use_time_check(buffer_: 'Buffer', max_use: int = float("inf")) -> Callable: 9 """ 10 Overview: 11 This middleware aims to check the usage times of data in buffer. If the usage times of a data is 12 greater than or equal to max_use, this data will be removed from buffer as soon as possible. 13 Arguments: 14 - max_use (:obj:`int`): The max reused (resampled) count for any individual object. 15 """ 16 17 use_count = defaultdict(int) 18 19 def _need_delete(item: BufferedData) -> bool: 20 nonlocal use_count 21 idx = item.index 22 use_count[idx] += 1 23 item.meta['use_count'] = use_count[idx] 24 if use_count[idx] >= max_use: 25 return True 26 else: 27 return False 28 29 def _check_use_count(sampled_data: List[BufferedData]): 30 delete_indices = [item.index for item in filter(_need_delete, sampled_data)] 31 buffer_.delete(delete_indices) 32 for index in delete_indices: 33 del use_count[index] 34 35 def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: 36 sampled_data = chain(*args, **kwargs) 37 if len(sampled_data) == 0: 38 return sampled_data 39 40 if isinstance(sampled_data[0], BufferedData): 41 _check_use_count(sampled_data) 42 else: 43 for grouped_data in sampled_data: 44 _check_use_count(grouped_data) 45 return sampled_data 46 47 def _use_time_check(action: str, chain: Callable, *args, **kwargs) -> Any: 48 if action == "sample": 49 return sample(chain, *args, **kwargs) 50 return chain(*args, **kwargs) 51 52 return _use_time_check