Skip to content

ding.worker.replay_buffer.utils

ding.worker.replay_buffer.utils

UsedDataRemover

Overview

UsedDataRemover is a tool to remove file datas that will no longer be used anymore.

Interface: start, close, add_used_data

start()

Overview

Start the delete_used_data thread.

close()

Overview

Delete all datas in self._used_data. Then join the delete_used_data thread.

add_used_data(data)

Overview

Delete all datas in self._used_data. Then join the delete_used_data thread.

Arguments: - data (:obj:Any): Add a used data item into self._used_data for further remove.

SampledDataAttrMonitor

Bases: LoggedModel

Overview

SampledDataAttrMonitor is to monitor read-out indicators for expire times recent read-outs. Indicators include: read out time; average and max of read out data items' use; average, max and min of read out data items' priorityl; average and max of staleness.

Interface: init, fixed_time, current_time, freeze, unfreeze, register_attribute_value, getattr Property: time, expire

PeriodicThruputMonitor

Overview

PeriodicThruputMonitor is a tool to record and print logs(text & tensorboard) how many datas are pushed/sampled/removed/valid in a period of time. For tensorboard, you can view it in 'buffer_{$NAME}_sec'.

Interface: close Property: push_data_count, sample_data_count, remove_data_count, valid_count

.. note:: thruput_log thread is initialized and started in __init__ method, so PeriodicThruputMonitor only provide one signle interface close

close()

Overview

Join the thruput_log thread by setting self._end_flag to True.

generate_id(name, data_id)

Overview

Use self.name and input id to generate a unique id for next data to be inserted.

Arguments: - data_id (:obj:int): Current unique id. Returns: - id (:obj:str): Id in format "BufferName_DataId".

Full Source Code

../ding/worker/replay_buffer/utils.py

1from typing import Any 2import time 3from queue import Queue 4from typing import Union, Tuple 5from threading import Thread 6from functools import partial 7 8from ding.utils.autolog import LoggedValue, LoggedModel 9from ding.utils import LockContext, LockContextType, remove_file 10 11 12def generate_id(name, data_id: int) -> str: 13 """ 14 Overview: 15 Use ``self.name`` and input ``id`` to generate a unique id for next data to be inserted. 16 Arguments: 17 - data_id (:obj:`int`): Current unique id. 18 Returns: 19 - id (:obj:`str`): Id in format "BufferName_DataId". 20 """ 21 return "{}_{}".format(name, str(data_id)) 22 23 24class UsedDataRemover: 25 """ 26 Overview: 27 UsedDataRemover is a tool to remove file datas that will no longer be used anymore. 28 Interface: 29 start, close, add_used_data 30 """ 31 32 def __init__(self) -> None: 33 self._used_data = Queue() 34 self._delete_used_data_thread = Thread(target=self._delete_used_data, name='delete_used_data') 35 self._delete_used_data_thread.daemon = True 36 self._end_flag = True 37 38 def start(self) -> None: 39 """ 40 Overview: 41 Start the `delete_used_data` thread. 42 """ 43 self._end_flag = False 44 self._delete_used_data_thread.start() 45 46 def close(self) -> None: 47 """ 48 Overview: 49 Delete all datas in `self._used_data`. Then join the `delete_used_data` thread. 50 """ 51 while not self._used_data.empty(): 52 data_id = self._used_data.get() 53 remove_file(data_id) 54 self._end_flag = True 55 56 def add_used_data(self, data: Any) -> None: 57 """ 58 Overview: 59 Delete all datas in `self._used_data`. Then join the `delete_used_data` thread. 60 Arguments: 61 - data (:obj:`Any`): Add a used data item into `self._used_data` for further remove. 62 """ 63 assert data is not None and isinstance(data, dict) and 'data_id' in data 64 self._used_data.put(data['data_id']) 65 66 def _delete_used_data(self) -> None: 67 while not self._end_flag: 68 if not self._used_data.empty(): 69 data_id = self._used_data.get() 70 remove_file(data_id) 71 else: 72 time.sleep(0.001) 73 74 75class SampledDataAttrMonitor(LoggedModel): 76 """ 77 Overview: 78 SampledDataAttrMonitor is to monitor read-out indicators for ``expire`` times recent read-outs. 79 Indicators include: read out time; average and max of read out data items' use; average, max and min of 80 read out data items' priorityl; average and max of staleness. 81 Interface: 82 __init__, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__ 83 Property: 84 time, expire 85 """ 86 use_max = LoggedValue(int) 87 use_avg = LoggedValue(float) 88 priority_max = LoggedValue(float) 89 priority_avg = LoggedValue(float) 90 priority_min = LoggedValue(float) 91 staleness_max = LoggedValue(int) 92 staleness_avg = LoggedValue(float) 93 94 def __init__(self, time_: 'BaseTime', expire: Union[int, float]): # noqa 95 LoggedModel.__init__(self, time_, expire) 96 self.__register() 97 98 def __register(self): 99 100 def __avg_func(prop_name: str) -> float: 101 records = self.range_values[prop_name]() 102 _list = [_value for (_begin_time, _end_time), _value in records] 103 return sum(_list) / len(_list) if len(_list) != 0 else 0 104 105 def __max_func(prop_name: str) -> Union[float, int]: 106 records = self.range_values[prop_name]() 107 _list = [_value for (_begin_time, _end_time), _value in records] 108 return max(_list) if len(_list) != 0 else 0 109 110 def __min_func(prop_name: str) -> Union[float, int]: 111 records = self.range_values[prop_name]() 112 _list = [_value for (_begin_time, _end_time), _value in records] 113 return min(_list) if len(_list) != 0 else 0 114 115 self.register_attribute_value('avg', 'use', partial(__avg_func, prop_name='use_avg')) 116 self.register_attribute_value('max', 'use', partial(__max_func, prop_name='use_max')) 117 self.register_attribute_value('avg', 'priority', partial(__avg_func, prop_name='priority_avg')) 118 self.register_attribute_value('max', 'priority', partial(__max_func, prop_name='priority_max')) 119 self.register_attribute_value('min', 'priority', partial(__min_func, prop_name='priority_min')) 120 self.register_attribute_value('avg', 'staleness', partial(__avg_func, prop_name='staleness_avg')) 121 self.register_attribute_value('max', 'staleness', partial(__max_func, prop_name='staleness_max')) 122 123 124class PeriodicThruputMonitor: 125 """ 126 Overview: 127 PeriodicThruputMonitor is a tool to record and print logs(text & tensorboard) how many datas are 128 pushed/sampled/removed/valid in a period of time. For tensorboard, you can view it in 'buffer_{$NAME}_sec'. 129 Interface: 130 close 131 Property: 132 push_data_count, sample_data_count, remove_data_count, valid_count 133 134 .. note:: 135 `thruput_log` thread is initialized and started in `__init__` method, so PeriodicThruputMonitor only provide 136 one signle interface `close` 137 """ 138 139 def __init__(self, name, cfg, logger, tb_logger) -> None: 140 self.name = name 141 self._end_flag = False 142 self._logger = logger 143 self._tb_logger = tb_logger 144 self._thruput_print_seconds = cfg.seconds 145 self._thruput_print_times = 0 146 self._thruput_start_time = time.time() 147 self._history_push_count = 0 148 self._history_sample_count = 0 149 self._remove_data_count = 0 150 self._valid_count = 0 151 self._thruput_log_thread = Thread(target=self._thrput_print_periodically, args=(), name='periodic_thruput_log') 152 self._thruput_log_thread.daemon = True 153 self._thruput_log_thread.start() 154 155 def _thrput_print_periodically(self) -> None: 156 while not self._end_flag: 157 time_passed = time.time() - self._thruput_start_time 158 if time_passed >= self._thruput_print_seconds: 159 self._logger.info('In the past {:.1f} seconds, buffer statistics is as follows:'.format(time_passed)) 160 count_dict = { 161 'pushed_in': self._history_push_count, 162 'sampled_out': self._history_sample_count, 163 'removed': self._remove_data_count, 164 'current_have': self._valid_count, 165 } 166 self._logger.info(self._logger.get_tabulate_vars_hor(count_dict)) 167 for k, v in count_dict.items(): 168 self._tb_logger.add_scalar('{}_sec/'.format(self.name) + k, v, self._thruput_print_times) 169 self._history_push_count = 0 170 self._history_sample_count = 0 171 self._remove_data_count = 0 172 self._thruput_start_time = time.time() 173 self._thruput_print_times += 1 174 else: 175 time.sleep(min(1, self._thruput_print_seconds * 0.2)) 176 177 def close(self) -> None: 178 """ 179 Overview: 180 Join the `thruput_log` thread by setting `self._end_flag` to `True`. 181 """ 182 self._end_flag = True 183 184 def __del__(self) -> None: 185 self.close() 186 187 @property 188 def push_data_count(self) -> int: 189 return self._history_push_count 190 191 @push_data_count.setter 192 def push_data_count(self, count) -> None: 193 self._history_push_count = count 194 195 @property 196 def sample_data_count(self) -> int: 197 return self._history_sample_count 198 199 @sample_data_count.setter 200 def sample_data_count(self, count) -> None: 201 self._history_sample_count = count 202 203 @property 204 def remove_data_count(self) -> int: 205 return self._remove_data_count 206 207 @remove_data_count.setter 208 def remove_data_count(self, count) -> None: 209 self._remove_data_count = count 210 211 @property 212 def valid_count(self) -> int: 213 return self._valid_count 214 215 @valid_count.setter 216 def valid_count(self, count) -> None: 217 self._valid_count = count 218 219 220class ThruputController: 221 222 def __init__(self, cfg) -> None: 223 self._push_sample_rate_limit = cfg.push_sample_rate_limit 224 assert 'min' in self._push_sample_rate_limit and self._push_sample_rate_limit['min'] >= 0 225 assert 'max' in self._push_sample_rate_limit and self._push_sample_rate_limit['max'] <= float("inf") 226 window_seconds = cfg.window_seconds 227 self._decay_factor = 0.01 ** (1 / window_seconds) 228 229 self._push_lock = LockContext(lock_type=LockContextType.THREAD_LOCK) 230 self._sample_lock = LockContext(lock_type=LockContextType.THREAD_LOCK) 231 self._history_push_count = 0 232 self._history_sample_count = 0 233 234 self._end_flag = False 235 self._count_decay_thread = Thread(target=self._count_decay, name='count_decay') 236 self._count_decay_thread.daemon = True 237 self._count_decay_thread.start() 238 239 def _count_decay(self) -> None: 240 while not self._end_flag: 241 time.sleep(1) 242 with self._push_lock: 243 self._history_push_count *= self._decay_factor 244 with self._sample_lock: 245 self._history_sample_count *= self._decay_factor 246 247 def can_push(self, push_size: int) -> Tuple[bool, str]: 248 if abs(self._history_sample_count) < 1e-5: 249 return True, "Can push because `self._history_sample_count` < 1e-5" 250 rate = (self._history_push_count + push_size) / self._history_sample_count 251 if rate > self._push_sample_rate_limit['max']: 252 return False, "push({}+{}) / sample({}) > limit_max({})".format( 253 self._history_push_count, push_size, self._history_sample_count, self._push_sample_rate_limit['max'] 254 ) 255 return True, "Can push." 256 257 def can_sample(self, sample_size: int) -> Tuple[bool, str]: 258 rate = self._history_push_count / (self._history_sample_count + sample_size) 259 if rate < self._push_sample_rate_limit['min']: 260 return False, "push({}) / sample({}+{}) < limit_min({})".format( 261 self._history_push_count, self._history_sample_count, sample_size, self._push_sample_rate_limit['min'] 262 ) 263 return True, "Can sample." 264 265 def close(self) -> None: 266 self._end_flag = True 267 268 @property 269 def history_push_count(self) -> int: 270 return self._history_push_count 271 272 @history_push_count.setter 273 def history_push_count(self, count) -> None: 274 with self._push_lock: 275 self._history_push_count = count 276 277 @property 278 def history_sample_count(self) -> int: 279 return self._history_sample_count 280 281 @history_sample_count.setter 282 def history_sample_count(self, count) -> None: 283 with self._sample_lock: 284 self._history_sample_count = count