1import os 2import itertools 3import random 4import uuid 5from ditk import logging 6import hickle 7from typing import Any, Iterable, List, Optional, Tuple, Union 8from collections import Counter 9from collections import defaultdict, deque, OrderedDict 10from ding.data.buffer import Buffer, apply_middleware, BufferedData 11from ding.utils import fastcopy 12from ding.torch_utils import get_null_data 13 14 15class BufferIndex(): 16 """ 17 Overview: 18 Save index string and offset in key value pair. 19 """ 20 21 def __init__(self, maxlen: int, *args, **kwargs): 22 self.maxlen = maxlen 23 self.__map = OrderedDict(*args, **kwargs) 24 self._last_key = next(reversed(self.__map)) if len(self) > 0 else None 25 self._cumlen = len(self.__map) 26 27 def get(self, key: str) -> int: 28 value = self.__map[key] 29 value = value % self._cumlen + min(0, (self.maxlen - self._cumlen)) 30 return value 31 32 def __len__(self) -> int: 33 return len(self.__map) 34 35 def has(self, key: str) -> bool: 36 return key in self.__map 37 38 def append(self, key: str): 39 self.__map[key] = self.__map[self._last_key] + 1 if self._last_key else 0 40 self._last_key = key 41 self._cumlen += 1 42 if len(self) > self.maxlen: 43 self.__map.popitem(last=False) 44 45 def clear(self): 46 self.__map = OrderedDict() 47 self._last_key = None 48 self._cumlen = 0 49 50 51class DequeBuffer(Buffer): 52 """ 53 Overview: 54 A buffer implementation based on the deque structure. 55 """ 56 57 def __init__(self, size: int, sliced: bool = False) -> None: 58 """ 59 Overview: 60 The initialization method of DequeBuffer. 61 Arguments: 62 - size (:obj:`int`): The maximum number of objects that the buffer can hold. 63 - sliced (:obj:`bool`): The flag whether slice data by unroll_len when sample by group 64 """ 65 super().__init__(size=size) 66 self.storage = deque(maxlen=size) 67 self.indices = BufferIndex(maxlen=size) 68 self.sliced = sliced 69 # Meta index is a dict which uses deque as values 70 self.meta_index = {} 71 72 @apply_middleware("push") 73 def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: 74 """ 75 Overview: 76 The method that input the objects and the related meta information into the buffer. 77 Arguments: 78 - data (:obj:`Any`): The input object which can be in any format. 79 - meta (:obj:`Optional[dict]`): A dict that helps describe data, such as\ 80 category, label, priority, etc. Default to ``None``. 81 """ 82 return self._push(data, meta) 83 84 @apply_middleware("sample") 85 def sample( 86 self, 87 size: Optional[int] = None, 88 indices: Optional[List[str]] = None, 89 replace: bool = False, 90 sample_range: Optional[slice] = None, 91 ignore_insufficient: bool = False, 92 groupby: Optional[str] = None, 93 unroll_len: Optional[int] = None 94 ) -> Union[List[BufferedData], List[List[BufferedData]]]: 95 """ 96 Overview: 97 The method that randomly sample data from the buffer or retrieve certain data by indices. 98 Arguments: 99 - size (:obj:`Optional[int]`): The number of objects to be obtained from the buffer. 100 If ``indices`` is not specified, the ``size`` is required to randomly sample the\ 101 corresponding number of objects from the buffer. 102 - indices (:obj:`Optional[List[str]]`): Only used when you want to retrieve data by indices. 103 Default to ``None``. 104 - replace (:obj:`bool`): As the sampling process is carried out one by one, this parameter\ 105 determines whether the previous samples will be put back into the buffer for subsequent\ 106 sampling. Default to ``False``, it means that duplicate samples will not appear in one\ 107 ``sample`` call. 108 - sample_range (:obj:`Optional[slice]`): The indices range to sample data. Default to ``None``,\ 109 it means no restrictions on the range of indices for the sampling process. 110 - ignore_insufficient (:obj:`bool`): whether throw `` ValueError`` if the sampled size is smaller\ 111 than the required size. Default to ``False``. 112 - groupby (:obj:`Optional[str]`): If this parameter is activated, the method will return a\ 113 target size of object groups. 114 - unroll_len (:obj:`Optional[int]`): The unroll length of a trajectory, used only when the\ 115 ``groupby`` is activated. 116 Returns: 117 - sampled_data (Union[List[BufferedData], List[List[BufferedData]]]): The sampling result. 118 """ 119 storage = self.storage 120 if sample_range: 121 storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step)) 122 123 # Size and indices 124 assert size or indices, "One of size and indices must not be empty." 125 if (size and indices) and (size != len(indices)): 126 raise AssertionError("Size and indices length must be equal.") 127 if not size: 128 size = len(indices) 129 # Indices and groupby 130 assert not (indices and groupby), "Cannot use groupby and indicex at the same time." 131 # Groupby and unroll_len 132 assert not unroll_len or ( 133 unroll_len and groupby 134 ), "Parameter unroll_len needs to be used in conjunction with groupby." 135 136 value_error = None 137 sampled_data = [] 138 if indices: 139 indices_set = set(indices) 140 hashed_data = filter(lambda item: item.index in indices_set, storage) 141 hashed_data = map(lambda item: (item.index, item), hashed_data) 142 hashed_data = dict(hashed_data) 143 # Re-sample and return in indices order 144 sampled_data = [hashed_data[index] for index in indices] 145 elif groupby: 146 sampled_data = self._sample_by_group( 147 size=size, groupby=groupby, replace=replace, unroll_len=unroll_len, storage=storage, sliced=self.sliced 148 ) 149 else: 150 if replace: 151 sampled_data = random.choices(storage, k=size) 152 else: 153 try: 154 sampled_data = random.sample(storage, k=size) 155 except ValueError as e: 156 value_error = e 157 158 if value_error or len(sampled_data) != size: 159 if ignore_insufficient: 160 logging.warning( 161 "Sample operation is ignored due to data insufficient, current buffer is {} while sample is {}". 162 format(self.count(), size) 163 ) 164 else: 165 raise ValueError("There are less than {} records/groups in buffer({})".format(size, self.count())) 166 167 sampled_data = self._independence(sampled_data) 168 169 return sampled_data 170 171 @apply_middleware("update") 172 def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool: 173 """ 174 Overview: 175 the method that update data and the related meta information with a certain index. 176 Arguments: 177 - data (:obj:`Any`): The data which is supposed to replace the old one. If you set it\ 178 to ``None``, nothing will happen to the old record. 179 - meta (:obj:`Optional[dict]`): The new dict which is supposed to merge with the old one. 180 """ 181 if not self.indices.has(index): 182 return False 183 i = self.indices.get(index) 184 item = self.storage[i] 185 if data is not None: 186 item.data = data 187 if meta is not None: 188 item.meta = meta 189 for key in self.meta_index: 190 self.meta_index[key][i] = meta[key] if key in meta else None 191 return True 192 193 @apply_middleware("delete") 194 def delete(self, indices: Union[str, Iterable[str]]) -> None: 195 """ 196 Overview: 197 The method that delete the data and related meta information by specific indices. 198 Arguments: 199 - indices (Union[str, Iterable[str]]): Where the data to be cleared in the buffer. 200 """ 201 if isinstance(indices, str): 202 indices = [indices] 203 del_idx = [] 204 for index in indices: 205 if self.indices.has(index): 206 del_idx.append(self.indices.get(index)) 207 if len(del_idx) == 0: 208 return 209 del_idx = sorted(del_idx, reverse=True) 210 for idx in del_idx: 211 del self.storage[idx] 212 remain_indices = [item.index for item in self.storage] 213 key_value_pairs = zip(remain_indices, range(len(indices))) 214 self.indices = BufferIndex(self.storage.maxlen, key_value_pairs) 215 216 def save_data(self, file_name: str): 217 if not os.path.exists(os.path.dirname(file_name)): 218 # If the folder for the specified file does not exist, it will be created. 219 if os.path.dirname(file_name) != "": 220 os.makedirs(os.path.dirname(file_name)) 221 hickle.dump( 222 py_obj=( 223 self.storage, 224 self.indices, 225 self.meta_index, 226 ), file_obj=file_name 227 ) 228 229 def load_data(self, file_name: str): 230 self.storage, self.indices, self.meta_index = hickle.load(file_name) 231 232 def count(self) -> int: 233 """ 234 Overview: 235 The method that returns the current length of the buffer. 236 """ 237 return len(self.storage) 238 239 def get(self, idx: int) -> BufferedData: 240 """ 241 Overview: 242 The method that returns the BufferedData object by subscript idx (int). 243 """ 244 return self.storage[idx] 245 246 def get_by_index(self, index: str) -> BufferedData: 247 """ 248 Overview: 249 The method that returns the BufferedData object given a specific index (str). 250 """ 251 return self.storage[self.indices.get(index)] 252 253 @apply_middleware("clear") 254 def clear(self) -> None: 255 """ 256 Overview: 257 The method that clear all data, indices, and the meta information in the buffer. 258 """ 259 self.storage.clear() 260 self.indices.clear() 261 self.meta_index = {} 262 263 def _push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: 264 index = uuid.uuid1().hex 265 if meta is None: 266 meta = {} 267 buffered = BufferedData(data=data, index=index, meta=meta) 268 self.storage.append(buffered) 269 self.indices.append(index) 270 # Add meta index 271 for key in self.meta_index: 272 self.meta_index[key].append(meta[key] if key in meta else None) 273 274 return buffered 275 276 def _independence( 277 self, buffered_samples: Union[List[BufferedData], List[List[BufferedData]]] 278 ) -> Union[List[BufferedData], List[List[BufferedData]]]: 279 """ 280 Overview: 281 Make sure that each record is different from each other, but remember that this function 282 is different from clone_object. You may change the data in the buffer by modifying a record. 283 Arguments: 284 - buffered_samples (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`) Sampled data, 285 can be nested if groupby has been set. 286 """ 287 if len(buffered_samples) == 0: 288 return buffered_samples 289 occurred = defaultdict(int) 290 291 for i, buffered in enumerate(buffered_samples): 292 if isinstance(buffered, list): 293 sampled_list = buffered 294 # Loop over nested samples 295 for j, buffered in enumerate(sampled_list): 296 occurred[buffered.index] += 1 297 if occurred[buffered.index] > 1: 298 sampled_list[j] = fastcopy.copy(buffered) 299 elif isinstance(buffered, BufferedData): 300 occurred[buffered.index] += 1 301 if occurred[buffered.index] > 1: 302 buffered_samples[i] = fastcopy.copy(buffered) 303 else: 304 raise Exception("Get unexpected buffered type {}".format(type(buffered))) 305 return buffered_samples 306 307 def _sample_by_group( 308 self, 309 size: int, 310 groupby: str, 311 replace: bool = False, 312 unroll_len: Optional[int] = None, 313 storage: deque = None, 314 sliced: bool = False 315 ) -> List[List[BufferedData]]: 316 """ 317 Overview: 318 Sampling by `group` instead of records, the result will be a collection 319 of lists with a length of `size`, but the length of each list may be different from other lists. 320 """ 321 if storage is None: 322 storage = self.storage 323 if groupby not in self.meta_index: 324 self._create_index(groupby) 325 326 def filter_by_unroll_len(): 327 "Filter groups by unroll len, ensure count of items in each group is greater than unroll_len." 328 group_count = Counter(self.meta_index[groupby]) 329 group_names = [] 330 for key, count in group_count.items(): 331 if count >= unroll_len: 332 group_names.append(key) 333 return group_names 334 335 if unroll_len and unroll_len > 1: 336 group_names = filter_by_unroll_len() 337 if len(group_names) == 0: 338 return [] 339 else: 340 group_names = list(set(self.meta_index[groupby])) 341 342 sampled_groups = [] 343 if replace: 344 sampled_groups = random.choices(group_names, k=size) 345 else: 346 try: 347 sampled_groups = random.sample(group_names, k=size) 348 except ValueError: 349 raise ValueError("There are less than {} groups in buffer({} groups)".format(size, len(group_names))) 350 351 # Build dict like {"group name": [records]} 352 sampled_data = defaultdict(list) 353 for buffered in storage: 354 meta_value = buffered.meta[groupby] if groupby in buffered.meta else None 355 if meta_value in sampled_groups: 356 sampled_data[buffered.meta[groupby]].append(buffered) 357 358 final_sampled_data = [] 359 for group in sampled_groups: 360 seq_data = sampled_data[group] 361 # Filter records by unroll_len 362 if unroll_len: 363 # slice b unroll_len. If don’t do this, more likely obtain duplicate data, \ 364 # and the training will easily crash. 365 if sliced: 366 start_indice = random.choice(range(max(1, len(seq_data)))) 367 start_indice = start_indice // unroll_len 368 if start_indice == (len(seq_data) - 1) // unroll_len: 369 seq_data = seq_data[-unroll_len:] 370 else: 371 seq_data = seq_data[start_indice * unroll_len:start_indice * unroll_len + unroll_len] 372 else: 373 start_indice = random.choice(range(max(1, len(seq_data) - unroll_len))) 374 seq_data = seq_data[start_indice:start_indice + unroll_len] 375 376 final_sampled_data.append(seq_data) 377 378 return final_sampled_data 379 380 def _create_index(self, meta_key: str): 381 self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen) 382 for data in self.storage: 383 self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None) 384 385 def __iter__(self) -> deque: 386 return iter(self.storage) 387 388 def __copy__(self) -> "DequeBuffer": 389 buffer = type(self)(size=self.storage.maxlen) 390 buffer.storage = self.storage 391 buffer.meta_index = self.meta_index 392 buffer.indices = self.indices 393 return buffer