Skip to content

ding.data.buffer.middleware.staleness_check

ding.data.buffer.middleware.staleness_check

staleness_check(buffer_, max_staleness=float('inf'))

Overview

This middleware aims to check staleness before each sample operation, staleness = train_iter_sample_data - train_iter_data_collected, means how old/off-policy the data is, If data's staleness is greater(>) than max_staleness, this data will be removed from buffer as soon as possible.

Arguments: - max_staleness (:obj:int): The maximum legal span between the time of collecting and time of sampling.

Full Source Code

../ding/data/buffer/middleware/staleness_check.py

1from typing import Callable, Any, List, TYPE_CHECKING 2if TYPE_CHECKING: 3 from ding.data.buffer.buffer import Buffer 4 5 6def staleness_check(buffer_: 'Buffer', max_staleness: int = float("inf")) -> Callable: 7 """ 8 Overview: 9 This middleware aims to check staleness before each sample operation, 10 staleness = train_iter_sample_data - train_iter_data_collected, means how old/off-policy the data is, 11 If data's staleness is greater(>) than max_staleness, this data will be removed from buffer as soon as possible. 12 Arguments: 13 - max_staleness (:obj:`int`): The maximum legal span between the time of collecting and time of sampling. 14 """ 15 16 def push(next: Callable, data: Any, *args, **kwargs) -> Any: 17 assert 'meta' in kwargs and 'train_iter_data_collected' in kwargs[ 18 'meta'], "staleness_check middleware must push data with meta={'train_iter_data_collected': <iter>}" 19 return next(data, *args, **kwargs) 20 21 def sample(next: Callable, train_iter_sample_data: int, *args, **kwargs) -> List[Any]: 22 delete_index = [] 23 for i, item in enumerate(buffer_.storage): 24 index, meta = item.index, item.meta 25 staleness = train_iter_sample_data - meta['train_iter_data_collected'] 26 meta['staleness'] = staleness 27 if staleness > max_staleness: 28 delete_index.append(index) 29 for index in delete_index: 30 buffer_.delete(index) 31 data = next(*args, **kwargs) 32 return data 33 34 def _staleness_check(action: str, next: Callable, *args, **kwargs) -> Any: 35 if action == "push": 36 return push(next, *args, **kwargs) 37 elif action == "sample": 38 return sample(next, *args, **kwargs) 39 return next(*args, **kwargs) 40 41 return _staleness_check