Skip to content

ding.worker.replay_buffer.advanced_buffer

ding.worker.replay_buffer.advanced_buffer

AdvancedReplayBuffer

Bases: IBuffer

Overview

Prioritized replay buffer derived from NaiveReplayBuffer. This replay buffer adds:

1) Prioritized experience replay implemented by segment tree.
2) Data quality monitor. Monitor use count and staleness of each data.
3) Throughput monitor and control.
4) Logger. Log 2) and 3) in tensorboard or text.

Interface: start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config Property: beta, replay_buffer_size, push_count

__init__(cfg, tb_logger=None, exp_name='default_experiment', instance_name='buffer')

Overview

Initialize the buffer

Arguments: - cfg (:obj:dict): Config dict. - tb_logger (:obj:Optional['SummaryWriter']): Outer tb logger. Usually get this argument in serial mode. - exp_name (:obj:Optional[str]): Name of this experiment. - instance_name (:obj:Optional[str]): Name of this instance.

start()

Overview

Start the buffer's used_data_remover thread if enables track_used_data.

close()

Overview

Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data. Join periodic throughtput monitor, flush tensorboard logger.

sample(size, cur_learner_iter, sample_range=None)

Overview

Sample data with length size.

Arguments: - size (:obj:int): The number of the data that will be sampled. - cur_learner_iter (:obj:int): Learner's current iteration, used to calculate staleness. - sample_range (:obj:slice): Buffer slice for sampling, such as slice(-10, None), which means only sample among the last 10 data Returns: - sample_data (:obj:list): A list of data with length size ReturnsKeys: - necessary: original keys(e.g. obs, action, next_obs, reward, info), replay_unique_id, replay_buffer_idx - optional(if use priority): IS, priority

push(data, cur_collector_envstep)

Overview

Push a data into buffer.

Arguments: - data (:obj:Union[List[Any], Any]): The data which will be pushed into buffer. Can be one \ (in Any type), or many(int List[Any] type). - cur_collector_envstep (:obj:int): Collector's current env step.

update(info)

Overview

Update a data's priority. Use repaly_buffer_idx to locate, and use replay_unique_id to verify.

Arguments: - info (:obj:dict): Info dict containing all necessary keys for priority update. ArgumentsKeys: - necessary: replay_unique_id, replay_buffer_idx, priority. All values are lists with the same length.

clear()

Overview

Clear all the data and reset the related variables.

__del__()

Overview

Call close to delete the object.

count()

Overview

Count how many valid datas there are in the buffer.

Returns: - count (:obj:int): Number of valid data.

state_dict()

Overview

Provide a state dict to keep a record of current buffer.

Returns: - state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer. With the dict, one can easily reproduce the buffer.

load_state_dict(_state_dict, deepcopy=False)

Overview

Load state dict to reproduce the buffer.

Returns: - state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer.

Full Source Code

../ding/worker/replay_buffer/advanced_buffer.py

1import os 2import copy 3import time 4from typing import Union, Any, Optional, List, Dict, Tuple 5import numpy as np 6import hickle 7 8from ding.worker.replay_buffer import IBuffer 9from ding.utils import SumSegmentTree, MinSegmentTree, BUFFER_REGISTRY 10from ding.utils import LockContext, LockContextType, build_logger, get_rank 11from ding.utils.autolog import TickTime 12from .utils import UsedDataRemover, generate_id, SampledDataAttrMonitor, PeriodicThruputMonitor, ThruputController 13 14 15def to_positive_index(idx: Union[int, None], size: int) -> int: 16 if idx is None or idx >= 0: 17 return idx 18 else: 19 return size + idx 20 21 22@BUFFER_REGISTRY.register('advanced') 23class AdvancedReplayBuffer(IBuffer): 24 r""" 25 Overview: 26 Prioritized replay buffer derived from ``NaiveReplayBuffer``. 27 This replay buffer adds: 28 29 1) Prioritized experience replay implemented by segment tree. 30 2) Data quality monitor. Monitor use count and staleness of each data. 31 3) Throughput monitor and control. 32 4) Logger. Log 2) and 3) in tensorboard or text. 33 Interface: 34 start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config 35 Property: 36 beta, replay_buffer_size, push_count 37 """ 38 39 config = dict( 40 type='advanced', 41 # Max length of the buffer. 42 replay_buffer_size=4096, 43 # Max use times of one data in the buffer. Data will be removed once used for too many times. 44 max_use=float("inf"), 45 # Max staleness time duration of one data in the buffer; Data will be removed if 46 # the duration from collecting to training is too long, i.e. The data is too stale. 47 max_staleness=float("inf"), 48 # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization 49 alpha=0.6, 50 # (Float type) How much correction is used: 0 means no correction while 1 means full correction 51 beta=0.4, 52 # Anneal step for beta: 0 means no annealing 53 anneal_step=int(1e5), 54 # Whether to track the used data. Used data means they are removed out of buffer and would never be used again. 55 enable_track_used_data=False, 56 # Whether to deepcopy data when willing to insert and sample data. For security purpose. 57 deepcopy=False, 58 thruput_controller=dict( 59 # Rate limit. The ratio of "Sample Count" to "Push Count" should be in [min, max] range. 60 # If greater than max ratio, return `None` when calling ``sample```; 61 # If smaller than min ratio, throw away the new data when calling ``push``. 62 push_sample_rate_limit=dict( 63 max=float("inf"), 64 min=0, 65 ), 66 # Controller will take how many seconds into account, i.e. For the past `window_seconds` seconds, 67 # sample_push_rate will be calculated and campared with `push_sample_rate_limit`. 68 window_seconds=30, 69 # The minimum ratio that buffer must satisfy before anything can be sampled. 70 # The ratio is calculated by "Valid Count" divided by "Batch Size". 71 # E.g. sample_min_limit_ratio = 2.0, valid_count = 50, batch_size = 32, it is forbidden to sample. 72 sample_min_limit_ratio=1, 73 ), 74 # Monitor configuration for monitor and logger to use. This part does not affect buffer's function. 75 monitor=dict( 76 sampled_data_attr=dict( 77 # Past datas will be used for moving average. 78 average_range=5, 79 # Print data attributes every `print_freq` samples. 80 print_freq=200, # times 81 ), 82 periodic_thruput=dict( 83 # Every `seconds` seconds, thruput(push/sample/remove count) will be printed. 84 seconds=60, 85 ), 86 ), 87 ) 88 89 def __init__( 90 self, 91 cfg: dict, 92 tb_logger: Optional['SummaryWriter'] = None, # noqa 93 exp_name: Optional[str] = 'default_experiment', 94 instance_name: Optional[str] = 'buffer', 95 ) -> int: 96 """ 97 Overview: 98 Initialize the buffer 99 Arguments: 100 - cfg (:obj:`dict`): Config dict. 101 - tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode. 102 - exp_name (:obj:`Optional[str]`): Name of this experiment. 103 - instance_name (:obj:`Optional[str]`): Name of this instance. 104 """ 105 self._exp_name = exp_name 106 self._instance_name = instance_name 107 self._end_flag = False 108 self._cfg = cfg 109 self._rank = get_rank() 110 self._replay_buffer_size = self._cfg.replay_buffer_size 111 self._deepcopy = self._cfg.deepcopy 112 # ``_data`` is a circular queue to store data (full data or meta data) 113 self._data = [None for _ in range(self._replay_buffer_size)] 114 # Current valid data count, indicating how many elements in ``self._data`` is valid. 115 self._valid_count = 0 116 # How many pieces of data have been pushed into this buffer, should be no less than ``_valid_count``. 117 self._push_count = 0 118 # Point to the tail position where next data can be inserted, i.e. latest inserted data's next position. 119 self._tail = 0 120 # Is used to generate a unique id for each data: If a new data is inserted, its unique id will be this. 121 self._next_unique_id = 0 122 # Lock to guarantee thread safe 123 self._lock = LockContext(lock_type=LockContextType.THREAD_LOCK) 124 # Point to the head of the circular queue. The true data is the stalest(oldest) data in this queue. 125 # Because buffer would remove data due to staleness or use count, and at the beginning when queue is not 126 # filled with data head would always be 0, so ``head`` may be not equal to ``tail``; 127 # Otherwise, they two should be the same. Head is used to optimize staleness check in ``_sample_check``. 128 self._head = 0 129 # use_count is {position_idx: use_count} 130 self._use_count = {idx: 0 for idx in range(self._cfg.replay_buffer_size)} 131 # Max priority till now. Is used to initizalize a data's priority if "priority" is not passed in with the data. 132 self._max_priority = 1.0 133 # A small positive number to avoid edge-case, e.g. "priority" == 0. 134 self._eps = 1e-5 135 # Data check function list, used in ``_append`` and ``_extend``. This buffer requires data to be dict. 136 self.check_list = [lambda x: isinstance(x, dict)] 137 138 self._max_use = self._cfg.max_use 139 self._max_staleness = self._cfg.max_staleness 140 self.alpha = self._cfg.alpha 141 assert 0 <= self.alpha <= 1, self.alpha 142 self._beta = self._cfg.beta 143 assert 0 <= self._beta <= 1, self._beta 144 self._anneal_step = self._cfg.anneal_step 145 if self._anneal_step != 0: 146 self._beta_anneal_step = (1 - self._beta) / self._anneal_step 147 148 # Prioritized sample. 149 # Capacity needs to be the power of 2. 150 capacity = int(np.power(2, np.ceil(np.log2(self.replay_buffer_size)))) 151 # Sum segtree and min segtree are used to sample data according to priority. 152 self._sum_tree = SumSegmentTree(capacity) 153 self._min_tree = MinSegmentTree(capacity) 154 155 # Thruput controller 156 push_sample_rate_limit = self._cfg.thruput_controller.push_sample_rate_limit 157 self._always_can_push = True if push_sample_rate_limit['max'] == float('inf') else False 158 self._always_can_sample = True if push_sample_rate_limit['min'] == 0 else False 159 self._use_thruput_controller = not self._always_can_push or not self._always_can_sample 160 if self._use_thruput_controller: 161 self._thruput_controller = ThruputController(self._cfg.thruput_controller) 162 self._sample_min_limit_ratio = self._cfg.thruput_controller.sample_min_limit_ratio 163 assert self._sample_min_limit_ratio >= 1 164 165 # Monitor & Logger 166 monitor_cfg = self._cfg.monitor 167 if self._rank == 0: 168 if tb_logger is not None: 169 self._logger, _ = build_logger( 170 './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False 171 ) 172 self._tb_logger = tb_logger 173 else: 174 self._logger, self._tb_logger = build_logger( 175 './{}/log/{}'.format(self._exp_name, self._instance_name), 176 self._instance_name, 177 ) 178 else: 179 self._logger, _ = build_logger( 180 './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False 181 ) 182 self._tb_logger = None 183 self._start_time = time.time() 184 # Sampled data attributes. 185 self._cur_learner_iter = -1 186 self._cur_collector_envstep = -1 187 self._sampled_data_attr_print_count = 0 188 self._sampled_data_attr_monitor = SampledDataAttrMonitor( 189 TickTime(), expire=monitor_cfg.sampled_data_attr.average_range 190 ) 191 self._sampled_data_attr_print_freq = monitor_cfg.sampled_data_attr.print_freq 192 # Periodic thruput. 193 if self._rank == 0: 194 self._periodic_thruput_monitor = PeriodicThruputMonitor( 195 self._instance_name, monitor_cfg.periodic_thruput, self._logger, self._tb_logger 196 ) 197 198 # Used data remover 199 self._enable_track_used_data = self._cfg.enable_track_used_data 200 if self._enable_track_used_data: 201 self._used_data_remover = UsedDataRemover() 202 203 def start(self) -> None: 204 """ 205 Overview: 206 Start the buffer's used_data_remover thread if enables track_used_data. 207 """ 208 if self._enable_track_used_data: 209 self._used_data_remover.start() 210 211 def close(self) -> None: 212 """ 213 Overview: 214 Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data. 215 Join periodic throughtput monitor, flush tensorboard logger. 216 """ 217 if self._end_flag: 218 return 219 self._end_flag = True 220 self.clear() 221 if self._rank == 0: 222 self._periodic_thruput_monitor.close() 223 self._tb_logger.flush() 224 self._tb_logger.close() 225 if self._enable_track_used_data: 226 self._used_data_remover.close() 227 228 def sample(self, size: int, cur_learner_iter: int, sample_range: slice = None) -> Optional[list]: 229 """ 230 Overview: 231 Sample data with length ``size``. 232 Arguments: 233 - size (:obj:`int`): The number of the data that will be sampled. 234 - cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness. 235 - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \ 236 means only sample among the last 10 data 237 Returns: 238 - sample_data (:obj:`list`): A list of data with length ``size`` 239 ReturnsKeys: 240 - necessary: original keys(e.g. `obs`, `action`, `next_obs`, `reward`, `info`), \ 241 `replay_unique_id`, `replay_buffer_idx` 242 - optional(if use priority): `IS`, `priority` 243 """ 244 if size == 0: 245 return [] 246 can_sample_stalenss, staleness_info = self._sample_check(size, cur_learner_iter) 247 if self._always_can_sample: 248 can_sample_thruput, thruput_info = True, "Always can sample because push_sample_rate_limit['min'] == 0" 249 else: 250 can_sample_thruput, thruput_info = self._thruput_controller.can_sample(size) 251 if not can_sample_stalenss or not can_sample_thruput: 252 self._logger.info( 253 'Refuse to sample due to -- \nstaleness: {}, {} \nthruput: {}, {}'.format( 254 not can_sample_stalenss, staleness_info, not can_sample_thruput, thruput_info 255 ) 256 ) 257 return None 258 with self._lock: 259 indices = self._get_indices(size, sample_range) 260 result = self._sample_with_indices(indices, cur_learner_iter) 261 # Deepcopy ``result``'s same indice datas in case ``self._get_indices`` may get datas with 262 # the same indices, i.e. the same datas would be sampled afterwards. 263 # if self._deepcopy==True -> all data is different 264 # if len(indices) == len(set(indices)) -> no duplicate data 265 if not self._deepcopy and len(indices) != len(set(indices)): 266 for i, index in enumerate(indices): 267 tmp = [] 268 for j in range(i + 1, size): 269 if index == indices[j]: 270 tmp.append(j) 271 for j in tmp: 272 result[j] = copy.deepcopy(result[j]) 273 self._monitor_update_of_sample(result, cur_learner_iter) 274 return result 275 276 def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None: 277 r""" 278 Overview: 279 Push a data into buffer. 280 Arguments: 281 - data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \ 282 (in `Any` type), or many(int `List[Any]` type). 283 - cur_collector_envstep (:obj:`int`): Collector's current env step. 284 """ 285 push_size = len(data) if isinstance(data, list) else 1 286 if self._always_can_push: 287 can_push, push_info = True, "Always can push because push_sample_rate_limit['max'] == float('inf')" 288 else: 289 can_push, push_info = self._thruput_controller.can_push(push_size) 290 if not can_push: 291 self._logger.info('Refuse to push because {}'.format(push_info)) 292 return 293 if isinstance(data, list): 294 self._extend(data, cur_collector_envstep) 295 else: 296 self._append(data, cur_collector_envstep) 297 298 def save_data(self, file_name: str): 299 if not os.path.exists(os.path.dirname(file_name)): 300 if os.path.dirname(file_name) != "": 301 os.makedirs(os.path.dirname(file_name)) 302 hickle.dump(py_obj=self._data, file_obj=file_name) 303 304 def load_data(self, file_name: str): 305 self.push(hickle.load(file_name), 0) 306 307 def _sample_check(self, size: int, cur_learner_iter: int) -> Tuple[bool, str]: 308 r""" 309 Overview: 310 Do preparations for sampling and check whether data is enough for sampling 311 Preparation includes removing stale datas in ``self._data``. 312 Check includes judging whether this buffer has more than ``size`` datas to sample. 313 Arguments: 314 - size (:obj:`int`): The number of the data that will be sampled. 315 - cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness. 316 Returns: 317 - can_sample (:obj:`bool`): Whether this buffer can sample enough data. 318 - str_info (:obj:`str`): Str type info, explaining why cannot sample. (If can sample, return "Can sample") 319 320 .. note:: 321 This function must be called before data sample. 322 """ 323 staleness_remove_count = 0 324 with self._lock: 325 if self._max_staleness != float("inf"): 326 p = self._head 327 while True: 328 if self._data[p] is not None: 329 staleness = self._calculate_staleness(p, cur_learner_iter) 330 if staleness >= self._max_staleness: 331 self._remove(p) 332 staleness_remove_count += 1 333 else: 334 # Since the circular queue ``self._data`` guarantees that data's staleness is decreasing 335 # from index self._head to index self._tail - 1, we can jump out of the loop as soon as 336 # meeting a fresh enough data 337 break 338 p = (p + 1) % self._replay_buffer_size 339 if p == self._tail: 340 # Traverse a circle and go back to the tail, which means can stop staleness checking now 341 break 342 str_info = "Remove {} elements due to staleness. ".format(staleness_remove_count) 343 if self._valid_count / size < self._sample_min_limit_ratio: 344 str_info += "Not enough for sampling. valid({}) / sample({}) < sample_min_limit_ratio({})".format( 345 self._valid_count, size, self._sample_min_limit_ratio 346 ) 347 return False, str_info 348 else: 349 str_info += "Can sample." 350 return True, str_info 351 352 def _append(self, ori_data: Any, cur_collector_envstep: int = -1) -> None: 353 r""" 354 Overview: 355 Append a data item into queue. 356 Add two keys in data: 357 358 - replay_unique_id: The data item's unique id, using ``generate_id`` to generate it. 359 - replay_buffer_idx: The data item's position index in the queue, this position may already have an \ 360 old element, then it would be replaced by this new input one. using ``self._tail`` to locate. 361 Arguments: 362 - ori_data (:obj:`Any`): The data which will be inserted. 363 - cur_collector_envstep (:obj:`int`): Collector's current env step, used to draw tensorboard. 364 """ 365 with self._lock: 366 if self._deepcopy: 367 data = copy.deepcopy(ori_data) 368 else: 369 data = ori_data 370 try: 371 assert self._data_check(data) 372 except AssertionError: 373 # If data check fails, log it and return without any operations. 374 self._logger.info('Illegal data type [{}], reject it...'.format(type(data))) 375 return 376 self._push_count += 1 377 # remove->set weight->set data 378 if self._data[self._tail] is not None: 379 self._head = (self._tail + 1) % self._replay_buffer_size 380 self._remove(self._tail) 381 data['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id) 382 data['replay_buffer_idx'] = self._tail 383 self._set_weight(data) 384 self._data[self._tail] = data 385 self._valid_count += 1 386 if self._rank == 0: 387 self._periodic_thruput_monitor.valid_count = self._valid_count 388 self._tail = (self._tail + 1) % self._replay_buffer_size 389 self._next_unique_id += 1 390 self._monitor_update_of_push(1, cur_collector_envstep) 391 392 def _extend(self, ori_data: List[Any], cur_collector_envstep: int = -1) -> None: 393 r""" 394 Overview: 395 Extend a data list into queue. 396 Add two keys in each data item, you can refer to ``_append`` for more details. 397 Arguments: 398 - ori_data (:obj:`List[Any]`): The data list. 399 - cur_collector_envstep (:obj:`int`): Collector's current env step, used to draw tensorboard. 400 """ 401 with self._lock: 402 if self._deepcopy: 403 data = copy.deepcopy(ori_data) 404 else: 405 data = ori_data 406 check_result = [self._data_check(d) for d in data] 407 # Only keep data items that pass ``_data_check`. 408 valid_data = [d for d, flag in zip(data, check_result) if flag] 409 length = len(valid_data) 410 # When updating ``_data`` and ``_use_count``, should consider two cases regarding 411 # the relationship between "tail + data length" and "queue max length" to check whether 412 # data will exceed beyond queue's max length limitation. 413 if self._tail + length <= self._replay_buffer_size: 414 for j in range(self._tail, self._tail + length): 415 if self._data[j] is not None: 416 self._head = (j + 1) % self._replay_buffer_size 417 self._remove(j) 418 for i in range(length): 419 valid_data[i]['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id + i) 420 valid_data[i]['replay_buffer_idx'] = (self._tail + i) % self._replay_buffer_size 421 self._set_weight(valid_data[i]) 422 self._push_count += 1 423 self._data[self._tail:self._tail + length] = valid_data 424 else: 425 data_start = self._tail 426 valid_data_start = 0 427 residual_num = len(valid_data) 428 while True: 429 space = self._replay_buffer_size - data_start 430 L = min(space, residual_num) 431 for j in range(data_start, data_start + L): 432 if self._data[j] is not None: 433 self._head = (j + 1) % self._replay_buffer_size 434 self._remove(j) 435 for i in range(valid_data_start, valid_data_start + L): 436 valid_data[i]['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id + i) 437 valid_data[i]['replay_buffer_idx'] = (self._tail + i) % self._replay_buffer_size 438 self._set_weight(valid_data[i]) 439 self._push_count += 1 440 self._data[data_start:data_start + L] = valid_data[valid_data_start:valid_data_start + L] 441 residual_num -= L 442 if residual_num <= 0: 443 break 444 else: 445 data_start = 0 446 valid_data_start += L 447 self._valid_count += len(valid_data) 448 if self._rank == 0: 449 self._periodic_thruput_monitor.valid_count = self._valid_count 450 # Update ``tail`` and ``next_unique_id`` after the whole list is pushed into buffer. 451 self._tail = (self._tail + length) % self._replay_buffer_size 452 self._next_unique_id += length 453 self._monitor_update_of_push(length, cur_collector_envstep) 454 455 def update(self, info: dict) -> None: 456 r""" 457 Overview: 458 Update a data's priority. Use `repaly_buffer_idx` to locate, and use `replay_unique_id` to verify. 459 Arguments: 460 - info (:obj:`dict`): Info dict containing all necessary keys for priority update. 461 ArgumentsKeys: 462 - necessary: `replay_unique_id`, `replay_buffer_idx`, `priority`. All values are lists with the same length. 463 """ 464 with self._lock: 465 if 'priority' not in info: 466 return 467 data = [info['replay_unique_id'], info['replay_buffer_idx'], info['priority']] 468 for id_, idx, priority in zip(*data): 469 # Only if the data still exists in the queue, will the update operation be done. 470 if self._data[idx] is not None \ 471 and self._data[idx]['replay_unique_id'] == id_: # Verify the same transition(data) 472 assert priority >= 0, priority 473 assert self._data[idx]['replay_buffer_idx'] == idx 474 self._data[idx]['priority'] = priority + self._eps # Add epsilon to avoid priority == 0 475 self._set_weight(self._data[idx]) 476 # Update max priority 477 self._max_priority = max(self._max_priority, priority) 478 else: 479 self._logger.debug( 480 '[Skip Update]: buffer_idx: {}; id_in_buffer: {}; id_in_update_info: {}'.format( 481 idx, id_, priority 482 ) 483 ) 484 485 def clear(self) -> None: 486 """ 487 Overview: 488 Clear all the data and reset the related variables. 489 """ 490 with self._lock: 491 for i in range(len(self._data)): 492 self._remove(i) 493 assert self._valid_count == 0, self._valid_count 494 self._head = 0 495 self._tail = 0 496 self._max_priority = 1.0 497 498 def __del__(self) -> None: 499 """ 500 Overview: 501 Call ``close`` to delete the object. 502 """ 503 if not self._end_flag: 504 self.close() 505 506 def _set_weight(self, data: Dict) -> None: 507 r""" 508 Overview: 509 Set sumtree and mintree's weight of the input data according to its priority. 510 If input data does not have key "priority", it would set to ``self._max_priority`` instead. 511 Arguments: 512 - data (:obj:`Dict`): The data whose priority(weight) in segement tree should be set/updated. 513 """ 514 if 'priority' not in data.keys() or data['priority'] is None: 515 data['priority'] = self._max_priority 516 weight = data['priority'] ** self.alpha 517 idx = data['replay_buffer_idx'] 518 self._sum_tree[idx] = weight 519 self._min_tree[idx] = weight 520 521 def _data_check(self, d: Any) -> bool: 522 r""" 523 Overview: 524 Data legality check, using rules(functions) in ``self.check_list``. 525 Arguments: 526 - d (:obj:`Any`): The data which needs to be checked. 527 Returns: 528 - result (:obj:`bool`): Whether the data passes the check. 529 """ 530 # only the data passes all the check functions, would the check return True 531 return all([fn(d) for fn in self.check_list]) 532 533 def _get_indices(self, size: int, sample_range: slice = None) -> list: 534 r""" 535 Overview: 536 Get the sample index list according to the priority probability. 537 Arguments: 538 - size (:obj:`int`): The number of the data that will be sampled 539 Returns: 540 - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``. 541 """ 542 # Divide [0, 1) into size intervals on average 543 intervals = np.array([i * 1.0 / size for i in range(size)]) 544 # Uniformly sample within each interval 545 mass = intervals + np.random.uniform(size=(size, )) * 1. / size 546 if sample_range is None: 547 # Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree) 548 mass *= self._sum_tree.reduce() 549 else: 550 # Rescale to [a, b) 551 start = to_positive_index(sample_range.start, self._replay_buffer_size) 552 end = to_positive_index(sample_range.stop, self._replay_buffer_size) 553 a = self._sum_tree.reduce(0, start) 554 b = self._sum_tree.reduce(0, end) 555 mass = mass * (b - a) + a 556 # Find prefix sum index to sample with probability 557 return [self._sum_tree.find_prefixsum_idx(m) for m in mass] 558 559 def _remove(self, idx: int, use_too_many_times: bool = False) -> None: 560 r""" 561 Overview: 562 Remove a data(set the element in the list to ``None``) and update corresponding variables, 563 e.g. sum_tree, min_tree, valid_count. 564 Arguments: 565 - idx (:obj:`int`): Data at this position will be removed. 566 """ 567 if use_too_many_times: 568 if self._enable_track_used_data: 569 # Must track this data, but in parallel mode. 570 # Do not remove it, but make sure it will not be sampled. 571 self._data[idx]['priority'] = 0 572 self._sum_tree[idx] = self._sum_tree.neutral_element 573 self._min_tree[idx] = self._min_tree.neutral_element 574 return 575 elif idx == self._head: 576 # Correct `self._head` when the queue head is removed due to use_count 577 self._head = (self._head + 1) % self._replay_buffer_size 578 if self._data[idx] is not None: 579 if self._enable_track_used_data: 580 self._used_data_remover.add_used_data(self._data[idx]) 581 self._valid_count -= 1 582 if self._rank == 0: 583 self._periodic_thruput_monitor.valid_count = self._valid_count 584 self._periodic_thruput_monitor.remove_data_count += 1 585 self._data[idx] = None 586 self._sum_tree[idx] = self._sum_tree.neutral_element 587 self._min_tree[idx] = self._min_tree.neutral_element 588 self._use_count[idx] = 0 589 590 def _sample_with_indices(self, indices: List[int], cur_learner_iter: int) -> list: 591 r""" 592 Overview: 593 Sample data with ``indices``; Remove a data item if it is used for too many times. 594 Arguments: 595 - indices (:obj:`List[int]`): A list including all the sample indices. 596 - cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness. 597 Returns: 598 - data (:obj:`list`) Sampled data. 599 """ 600 # Calculate max weight for normalizing IS 601 sum_tree_root = self._sum_tree.reduce() 602 p_min = self._min_tree.reduce() / sum_tree_root 603 max_weight = (self._valid_count * p_min) ** (-self._beta) 604 data = [] 605 for idx in indices: 606 assert self._data[idx] is not None 607 assert self._data[idx]['replay_buffer_idx'] == idx, (self._data[idx]['replay_buffer_idx'], idx) 608 if self._deepcopy: 609 copy_data = copy.deepcopy(self._data[idx]) 610 else: 611 copy_data = self._data[idx] 612 # Store staleness, use and IS(importance sampling weight for gradient step) for monitor and outer use 613 self._use_count[idx] += 1 614 copy_data['staleness'] = self._calculate_staleness(idx, cur_learner_iter) 615 copy_data['use'] = self._use_count[idx] 616 p_sample = self._sum_tree[idx] / sum_tree_root 617 weight = (self._valid_count * p_sample) ** (-self._beta) 618 copy_data['IS'] = weight / max_weight 619 data.append(copy_data) 620 if self._max_use != float("inf"): 621 # Remove datas whose "use count" is greater than ``max_use`` 622 for idx in indices: 623 if self._use_count[idx] >= self._max_use: 624 self._remove(idx, use_too_many_times=True) 625 # Beta annealing 626 if self._anneal_step != 0: 627 self._beta = min(1.0, self._beta + self._beta_anneal_step) 628 return data 629 630 def _monitor_update_of_push(self, add_count: int, cur_collector_envstep: int = -1) -> None: 631 r""" 632 Overview: 633 Update values in monitor, then update text logger and tensorboard logger. 634 Called in ``_append`` and ``_extend``. 635 Arguments: 636 - add_count (:obj:`int`): How many datas are added into buffer. 637 - cur_collector_envstep (:obj:`int`): Collector envstep, passed in by collector. 638 """ 639 if self._rank == 0: 640 self._periodic_thruput_monitor.push_data_count += add_count 641 if self._use_thruput_controller: 642 self._thruput_controller.history_push_count += add_count 643 self._cur_collector_envstep = cur_collector_envstep 644 645 def _monitor_update_of_sample(self, sample_data: list, cur_learner_iter: int) -> None: 646 r""" 647 Overview: 648 Update values in monitor, then update text logger and tensorboard logger. 649 Called in ``sample``. 650 Arguments: 651 - sample_data (:obj:`list`): Sampled data. Used to get sample length and data's attributes, \ 652 e.g. use, priority, staleness, etc. 653 - cur_learner_iter (:obj:`int`): Learner iteration, passed in by learner. 654 """ 655 if self._rank == 0: 656 self._periodic_thruput_monitor.sample_data_count += len(sample_data) 657 if self._use_thruput_controller: 658 self._thruput_controller.history_sample_count += len(sample_data) 659 self._cur_learner_iter = cur_learner_iter 660 use_avg = sum([d['use'] for d in sample_data]) / len(sample_data) 661 use_max = max([d['use'] for d in sample_data]) 662 priority_avg = sum([d['priority'] for d in sample_data]) / len(sample_data) 663 priority_max = max([d['priority'] for d in sample_data]) 664 priority_min = min([d['priority'] for d in sample_data]) 665 staleness_avg = sum([d['staleness'] for d in sample_data]) / len(sample_data) 666 staleness_max = max([d['staleness'] for d in sample_data]) 667 self._sampled_data_attr_monitor.use_avg = use_avg 668 self._sampled_data_attr_monitor.use_max = use_max 669 self._sampled_data_attr_monitor.priority_avg = priority_avg 670 self._sampled_data_attr_monitor.priority_max = priority_max 671 self._sampled_data_attr_monitor.priority_min = priority_min 672 self._sampled_data_attr_monitor.staleness_avg = staleness_avg 673 self._sampled_data_attr_monitor.staleness_max = staleness_max 674 self._sampled_data_attr_monitor.time.step() 675 out_dict = { 676 'use_avg': self._sampled_data_attr_monitor.avg['use'](), 677 'use_max': self._sampled_data_attr_monitor.max['use'](), 678 'priority_avg': self._sampled_data_attr_monitor.avg['priority'](), 679 'priority_max': self._sampled_data_attr_monitor.max['priority'](), 680 'priority_min': self._sampled_data_attr_monitor.min['priority'](), 681 'staleness_avg': self._sampled_data_attr_monitor.avg['staleness'](), 682 'staleness_max': self._sampled_data_attr_monitor.max['staleness'](), 683 'beta': self._beta, 684 } 685 if self._sampled_data_attr_print_count % self._sampled_data_attr_print_freq == 0 and self._rank == 0: 686 self._logger.info("=== Sample data {} Times ===".format(self._sampled_data_attr_print_count)) 687 self._logger.info(self._logger.get_tabulate_vars_hor(out_dict)) 688 for k, v in out_dict.items(): 689 iter_metric = self._cur_learner_iter if self._cur_learner_iter != -1 else None 690 step_metric = self._cur_collector_envstep if self._cur_collector_envstep != -1 else None 691 if iter_metric is not None: 692 self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, iter_metric) 693 if step_metric is not None: 694 self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, step_metric) 695 self._sampled_data_attr_print_count += 1 696 697 def _calculate_staleness(self, pos_index: int, cur_learner_iter: int) -> Optional[int]: 698 r""" 699 Overview: 700 Calculate a data's staleness according to its own attribute ``collect_iter`` 701 and input parameter ``cur_learner_iter``. 702 Arguments: 703 - pos_index (:obj:`int`): The position index. Staleness of the data at this index will be calculated. 704 - cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness. 705 Returns: 706 - staleness (:obj:`int`): Staleness of data at position ``pos_index``. 707 708 .. note:: 709 Caller should guarantee that data at ``pos_index`` is not None; Otherwise this function may raise an error. 710 """ 711 if self._data[pos_index] is None: 712 raise ValueError("Prioritized's data at index {} is None".format(pos_index)) 713 else: 714 # Calculate staleness, remove it if too stale 715 collect_iter = self._data[pos_index].get('collect_iter', cur_learner_iter + 1) 716 if isinstance(collect_iter, list): 717 # Timestep transition's collect_iter is a list 718 collect_iter = min(collect_iter) 719 # ``staleness`` might be -1, means invalid, e.g. collector does not report collecting model iter, 720 # or it is a demonstration buffer(which means data is not generated by collector) etc. 721 staleness = cur_learner_iter - collect_iter 722 return staleness 723 724 def count(self) -> int: 725 """ 726 Overview: 727 Count how many valid datas there are in the buffer. 728 Returns: 729 - count (:obj:`int`): Number of valid data. 730 """ 731 return self._valid_count 732 733 @property 734 def beta(self) -> float: 735 return self._beta 736 737 @beta.setter 738 def beta(self, beta: float) -> None: 739 self._beta = beta 740 741 def state_dict(self) -> dict: 742 """ 743 Overview: 744 Provide a state dict to keep a record of current buffer. 745 Returns: 746 - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \ 747 With the dict, one can easily reproduce the buffer. 748 """ 749 return { 750 'data': self._data, 751 'use_count': self._use_count, 752 'tail': self._tail, 753 'max_priority': self._max_priority, 754 'anneal_step': self._anneal_step, 755 'beta': self._beta, 756 'head': self._head, 757 'next_unique_id': self._next_unique_id, 758 'valid_count': self._valid_count, 759 'push_count': self._push_count, 760 'sum_tree': self._sum_tree, 761 'min_tree': self._min_tree, 762 } 763 764 def load_state_dict(self, _state_dict: dict, deepcopy: bool = False) -> None: 765 """ 766 Overview: 767 Load state dict to reproduce the buffer. 768 Returns: 769 - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. 770 """ 771 assert 'data' in _state_dict 772 if set(_state_dict.keys()) == set(['data']): 773 self._extend(_state_dict['data']) 774 else: 775 for k, v in _state_dict.items(): 776 if deepcopy: 777 setattr(self, '_{}'.format(k), copy.deepcopy(v)) 778 else: 779 setattr(self, '_{}'.format(k), v) 780 781 @property 782 def replay_buffer_size(self) -> int: 783 return self._replay_buffer_size 784 785 @property 786 def push_count(self) -> int: 787 return self._push_count