ding.data.buffer.deque_buffer_wrapper¶
ding.data.buffer.deque_buffer_wrapper
¶
Full Source Code
../ding/data/buffer/deque_buffer_wrapper.py
1import os 2from typing import Optional 3import copy 4from easydict import EasyDict 5import numpy as np 6import hickle 7 8from ding.data.buffer import DequeBuffer 9from ding.data.buffer.middleware import use_time_check, PriorityExperienceReplay 10from ding.utils import BUFFER_REGISTRY 11 12 13@BUFFER_REGISTRY.register('deque') 14class DequeBufferWrapper(object): 15 16 @classmethod 17 def default_config(cls: type) -> EasyDict: 18 cfg = EasyDict(copy.deepcopy(cls.config)) 19 cfg.cfg_type = cls.__name__ + 'Dict' 20 return cfg 21 22 config = dict( 23 replay_buffer_size=10000, 24 max_use=float("inf"), 25 train_iter_per_log=100, 26 priority=False, 27 priority_IS_weight=False, 28 priority_power_factor=0.6, 29 IS_weight_power_factor=0.4, 30 IS_weight_anneal_train_iter=int(1e5), 31 priority_max_limit=1000, 32 ) 33 34 def __init__( 35 self, 36 cfg: EasyDict, 37 tb_logger: Optional[object] = None, 38 exp_name: str = 'default_experiement', 39 instance_name: str = 'buffer' 40 ) -> None: 41 self.cfg = cfg 42 self.priority_max_limit = cfg.priority_max_limit 43 self.name = '{}_iter'.format(instance_name) 44 self.tb_logger = tb_logger 45 self.buffer = DequeBuffer(size=cfg.replay_buffer_size) 46 self.last_log_train_iter = -1 47 48 # use_count middleware 49 if self.cfg.max_use != float("inf"): 50 self.buffer.use(use_time_check(self.buffer, max_use=self.cfg.max_use)) 51 # priority middleware 52 if self.cfg.priority: 53 self.buffer.use( 54 PriorityExperienceReplay( 55 self.buffer, 56 IS_weight=self.cfg.priority_IS_weight, 57 priority_power_factor=self.cfg.priority_power_factor, 58 IS_weight_power_factor=self.cfg.IS_weight_power_factor, 59 IS_weight_anneal_train_iter=self.cfg.IS_weight_anneal_train_iter 60 ) 61 ) 62 self.last_sample_index = None 63 self.last_sample_meta = None 64 65 def sample(self, size: int, train_iter: int = 0): 66 output = self.buffer.sample(size=size, ignore_insufficient=True) 67 if len(output) > 0: 68 if self.last_log_train_iter == -1 or train_iter - self.last_log_train_iter >= self.cfg.train_iter_per_log: 69 meta = [o.meta for o in output] 70 if self.cfg.max_use != float("inf"): 71 use_count_avg = np.mean([m['use_count'] for m in meta]) 72 self.tb_logger.add_scalar('{}/use_count_avg'.format(self.name), use_count_avg, train_iter) 73 if self.cfg.priority: 74 self.last_sample_index = [o.index for o in output] 75 self.last_sample_meta = meta 76 priority_list = [m['priority'] for m in meta] 77 priority_avg = np.mean(priority_list) 78 priority_max = np.max(priority_list) 79 self.tb_logger.add_scalar('{}/priority_avg'.format(self.name), priority_avg, train_iter) 80 self.tb_logger.add_scalar('{}/priority_max'.format(self.name), priority_max, train_iter) 81 self.tb_logger.add_scalar('{}/buffer_data_count'.format(self.name), self.buffer.count(), train_iter) 82 self.last_log_train_iter = train_iter 83 84 data = [o.data for o in output] 85 if self.cfg.priority_IS_weight: 86 IS = [o.meta['priority_IS'] for o in output] 87 for i in range(len(data)): 88 data[i]['IS'] = IS[i] 89 return data 90 else: 91 return None 92 93 def push(self, data, cur_collector_envstep: int = -1) -> None: 94 for d in data: 95 meta = {} 96 if self.cfg.priority and 'priority' in d: 97 init_priority = d.pop('priority') 98 meta['priority'] = init_priority 99 self.buffer.push(d, meta=meta) 100 101 def update(self, meta: dict) -> None: 102 if not self.cfg.priority: 103 return 104 if self.last_sample_index is None: 105 return 106 new_meta = self.last_sample_meta 107 for m, p in zip(new_meta, meta['priority']): 108 m['priority'] = min(self.priority_max_limit, p) 109 for idx, m in zip(self.last_sample_index, new_meta): 110 self.buffer.update(idx, data=None, meta=m) 111 self.last_sample_index = None 112 self.last_sample_meta = None 113 114 def count(self) -> int: 115 return self.buffer.count() 116 117 def save_data(self, file_name): 118 self.buffer.save_data(file_name) 119 120 def load_data(self, file_name: str): 121 self.buffer.load_data(file_name)