Skip to content

ding.data.storage_loader

ding.data.storage_loader

StorageLoader

Bases: Supervisor, ABC

__init__(worker_num=3)

Overview

Save and send data synchronously and load them asynchronously.

Arguments: - worker_num (:obj:int): Subprocess worker number.

save(obj) abstractmethod

Overview

Save data with a storage object synchronously.

Arguments: - obj (:obj:Union[Dict, List]): The data (traj or episodes), can be numpy, tensor or treetensor. Returns: - storage (:obj:Storage): The storage object.

load(storage, callback)

Overview

Load data from a storage object asynchronously. This function will analysis the data structure when first meet a new data, then alloc a shared memory buffer for each subprocess, these shared memory buffer will be responsible for asynchronously loading data into memory.

Arguments: - storage (:obj:Storage): The storage object. - callback (:obj:Callable): Callback function after data loaded.

FileStorageLoader

Bases: StorageLoader

__init__(dirname, ttl=20, worker_num=3)

Overview

Dump and load object with file storage.

Arguments: - dirname (:obj:str): The directory to save files. - ttl (:obj:str): Maximum time to keep a file, after which it will be deleted. - worker_num (:obj:int): Number of subprocess worker loaders.

Full Source Code

../ding/data/storage_loader.py

1from dataclasses import dataclass 2import os 3import torch 4import numpy as np 5import uuid 6import treetensor.torch as ttorch 7from abc import ABC, abstractmethod 8from ditk import logging 9from time import sleep, time 10from threading import Lock, Thread 11from typing import Any, Callable, Dict, List, Optional, Union 12from ding.data import FileStorage, Storage 13from os import path 14from ding.data.shm_buffer import ShmBuffer 15from ding.framework.supervisor import RecvPayload, Supervisor, ChildType, SendPayload 16 17 18@dataclass 19class ShmObject: 20 id_: ShmBuffer 21 buf: Any 22 23 24class StorageWorker: 25 26 def load(self, storage: Storage) -> Any: 27 return storage.load() 28 29 30class StorageLoader(Supervisor, ABC): 31 32 def __init__(self, worker_num: int = 3) -> None: 33 """ 34 Overview: 35 Save and send data synchronously and load them asynchronously. 36 Arguments: 37 - worker_num (:obj:`int`): Subprocess worker number. 38 """ 39 super().__init__(type_=ChildType.PROCESS) 40 self._load_lock = Lock() # Load (first meet) should be called one by one. 41 self._callback_map: Dict[str, Callable] = {} 42 self._shm_obj_map: Dict[int, ShmObject] = {} 43 self._worker_num = worker_num 44 self._req_count = 0 45 46 def shutdown(self, timeout: Optional[float] = None) -> None: 47 super().shutdown(timeout) 48 self._recv_loop = None 49 self._callback_map = {} 50 self._shm_obj_map = {} 51 self._req_count = 0 52 53 def start_link(self) -> None: 54 if not self._running: 55 super().start_link() 56 self._recv_loop = Thread(target=self._loop_recv, daemon=True) 57 self._recv_loop.start() 58 59 @property 60 def _next_proc_id(self): 61 return self._req_count % self._worker_num 62 63 @abstractmethod 64 def save(self, obj: Union[Dict, List]) -> Storage: 65 """ 66 Overview: 67 Save data with a storage object synchronously. 68 Arguments: 69 - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor. 70 Returns: 71 - storage (:obj:`Storage`): The storage object. 72 """ 73 raise NotImplementedError 74 75 def load(self, storage: Storage, callback: Callable): 76 """ 77 Overview: 78 Load data from a storage object asynchronously. \ 79 This function will analysis the data structure when first meet a new data, \ 80 then alloc a shared memory buffer for each subprocess, these shared memory buffer \ 81 will be responsible for asynchronously loading data into memory. 82 Arguments: 83 - storage (:obj:`Storage`): The storage object. 84 - callback (:obj:`Callable`): Callback function after data loaded. 85 """ 86 with self._load_lock: 87 if not self._running: 88 self._first_meet(storage, callback) 89 return 90 91 payload = SendPayload(proc_id=self._next_proc_id, method="load", args=[storage]) 92 self._callback_map[payload.req_id] = callback 93 self.send(payload) 94 self._req_count += 1 95 96 def _first_meet(self, storage: Storage, callback: Callable): 97 """ 98 Overview: 99 When first meet an object type, we'll load this object directly and analysis the structure, 100 to allocate the shared memory object and create subprocess workers. 101 Arguments: 102 - storage (:obj:`Storage`): The storage object. 103 - callback (:obj:`Callable`): Callback function after data loaded. 104 """ 105 obj = storage.load() 106 # Create three workers for each usage type. 107 for i in range(self._worker_num): 108 shm_obj = self._create_shm_buffer(obj) 109 self._shm_obj_map[i] = shm_obj 110 self.register(StorageWorker, shm_buffer=shm_obj, shm_callback=self._shm_callback) 111 self.start_link() 112 callback(obj) 113 114 def _loop_recv(self): 115 while True: 116 payload = self.recv(ignore_err=True) 117 if payload.err: 118 logging.warning("Got error when loading data: {}".format(payload.err)) 119 if payload.req_id in self._callback_map: 120 del self._callback_map[payload.req_id] 121 else: 122 self._shm_putback(payload, self._shm_obj_map[payload.proc_id]) 123 if payload.req_id in self._callback_map: 124 callback = self._callback_map.pop(payload.req_id) 125 callback(payload.data) 126 127 def _create_shm_buffer(self, obj: Union[Dict, List]) -> Optional[ShmObject]: 128 """ 129 Overview: 130 Create shared object (buf and callback) by walk through the data structure. 131 Arguments: 132 - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor. 133 Returns: 134 - shm_buf (:obj:`Optional[ShmObject]`): The shared memory buffer. 135 """ 136 max_level = 2 137 138 def to_shm(obj: Dict, level: int): 139 if level > max_level: 140 return 141 shm_buf = None 142 if isinstance(obj, Dict) or isinstance(obj, ttorch.Tensor): 143 shm_buf = {} 144 for key, val in obj.items(): 145 # Only numpy array can fill into shm buffer 146 if isinstance(val, np.ndarray): 147 shm_buf[key] = ShmBuffer(val.dtype, val.shape, copy_on_get=False) 148 elif isinstance(val, torch.Tensor): 149 shm_buf[key] = ShmBuffer( 150 val.numpy().dtype, val.numpy().shape, copy_on_get=False, ctype=torch.Tensor 151 ) 152 # Recursive parsing structure 153 elif isinstance(val, Dict) or isinstance(val, ttorch.Tensor) or isinstance(val, List): 154 buf = to_shm(val, level=level + 1) 155 if buf: 156 shm_buf[key] = buf 157 elif isinstance(obj, List): 158 # Double the size of buffer 159 shm_buf = [to_shm(o, level=level) for o in obj] * 2 160 if all(s is None for s in shm_buf): 161 shm_buf = [] 162 return shm_buf 163 164 shm_buf = to_shm(obj, level=0) 165 if shm_buf is not None: 166 random_id = self._random_id() 167 shm_buf = ShmObject(id_=ShmBuffer(random_id.dtype, random_id.shape, copy_on_get=False), buf=shm_buf) 168 return shm_buf 169 170 def _random_id(self) -> np.ndarray: 171 return np.random.randint(1, 9e6, size=(1)) 172 173 def _shm_callback(self, payload: RecvPayload, shm_obj: ShmObject): 174 """ 175 Overview: 176 Called in subprocess, put payload.data into buf. 177 Arguments: 178 - payload (:obj:`RecvPayload`): The recv payload with meta info of the data. 179 - shm_obj (:obj:`ShmObject`): The shm buffer. 180 """ 181 assert isinstance(payload.data, type( 182 shm_obj.buf 183 )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf)) 184 185 # Sleep while shm object is not ready. 186 while shm_obj.id_.get()[0] != 0: 187 sleep(0.001) 188 189 max_level = 2 190 191 def shm_callback(data: Union[Dict, List, ttorch.Tensor], buf: Union[Dict, List], level: int): 192 if level > max_level: 193 return 194 195 if isinstance(buf, List): 196 assert isinstance(data, List), "Data ({}) and buf ({}) type not match".format(type(data), type(buf)) 197 elif isinstance(buf, Dict): 198 assert isinstance(data, ttorch.Tensor) or isinstance( 199 data, Dict 200 ), "Data ({}) and buf ({}) type not match".format(type(data), type(buf)) 201 202 if isinstance(data, Dict) or isinstance(data, ttorch.Tensor): 203 for key, val in data.items(): 204 if isinstance(val, torch.Tensor): 205 val = val.numpy() 206 buf_val = buf.get(key) 207 if buf_val is None: 208 continue 209 if isinstance(buf_val, ShmBuffer) and isinstance(val, np.ndarray): 210 buf_val.fill(val) 211 data[key] = None 212 else: 213 shm_callback(val, buf_val, level=level + 1) 214 elif isinstance(data, List): 215 for i, data_ in enumerate(data): 216 shm_callback(data_, buf[i], level=level) 217 218 shm_callback(payload.data, buf=shm_obj.buf, level=0) 219 id_ = self._random_id() 220 shm_obj.id_.fill(id_) 221 payload.extra = id_ 222 223 def _shm_putback(self, payload: RecvPayload, shm_obj: ShmObject): 224 """ 225 Overview: 226 Called in main process, put buf back into payload.data. 227 Arguments: 228 - payload (:obj:`RecvPayload`): The recv payload with meta info of the data. 229 - shm_obj (:obj:`ShmObject`): The shm buffer. 230 """ 231 assert isinstance(payload.data, type( 232 shm_obj.buf 233 )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf)) 234 235 assert shm_obj.id_.get()[0] == payload.extra[0], "Shm object and payload do not match ({} - {}).".format( 236 shm_obj.id_.get()[0], payload.extra[0] 237 ) 238 239 def shm_putback(data: Union[Dict, List], buf: Union[Dict, List]): 240 if isinstance(data, Dict) or isinstance(data, ttorch.Tensor): 241 for key, val in data.items(): 242 buf_val = buf.get(key) 243 if buf_val is None: 244 continue 245 if val is None and isinstance(buf_val, ShmBuffer): 246 data[key] = buf[key].get() 247 else: 248 shm_putback(val, buf_val) 249 elif isinstance(data, List): 250 for i, data_ in enumerate(data): 251 shm_putback(data_, buf[i]) 252 253 shm_putback(payload.data, buf=shm_obj.buf) 254 shm_obj.id_.fill(np.array([0])) 255 256 257class FileStorageLoader(StorageLoader): 258 259 def __init__(self, dirname: str, ttl: int = 20, worker_num: int = 3) -> None: 260 """ 261 Overview: 262 Dump and load object with file storage. 263 Arguments: 264 - dirname (:obj:`str`): The directory to save files. 265 - ttl (:obj:`str`): Maximum time to keep a file, after which it will be deleted. 266 - worker_num (:obj:`int`): Number of subprocess worker loaders. 267 """ 268 super().__init__(worker_num) 269 self._dirname = dirname 270 self._files = [] 271 self._cleanup_thread = None 272 self._ttl = ttl # # Delete files created 10 minutes ago. 273 274 def save(self, obj: Union[Dict, List]) -> FileStorage: 275 if not path.exists(self._dirname): 276 os.mkdir(self._dirname) 277 filename = "{}.pkl".format(uuid.uuid1()) 278 full_path = path.join(self._dirname, filename) 279 f = FileStorage(full_path) 280 f.save(obj) 281 self._files.append([time(), f.path]) 282 self._start_cleanup() 283 return f 284 285 def _start_cleanup(self): 286 """ 287 Overview: 288 Start a cleanup thread to clean up files that are taking up too much time on the disk. 289 """ 290 if self._cleanup_thread is None: 291 self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) 292 self._cleanup_thread.start() 293 294 def shutdown(self, timeout: Optional[float] = None) -> None: 295 super().shutdown(timeout) 296 self._cleanup_thread = None 297 298 def _loop_cleanup(self): 299 while True: 300 if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: 301 sleep(1) 302 continue 303 _, file_path = self._files.pop(0) 304 if path.exists(file_path): 305 os.remove(file_path)