1import io 2from ditk import logging 3import os 4import pickle 5import time 6from functools import lru_cache 7from typing import Union 8 9import torch 10 11from .import_helper import try_import_ceph, try_import_redis, try_import_rediscluster, try_import_mc 12from .lock_helper import get_file_lock 13 14_memcached = None 15_redis_cluster = None 16 17if os.environ.get('DI_STORE', 'off').lower() == 'on': 18 print('Enable DI-store') 19 from di_store import Client 20 21 di_store_config_path = os.environ.get("DI_STORE_CONFIG_PATH", './di_store.yaml') 22 di_store_client = Client(di_store_config_path) 23 24 def save_to_di_store(data): 25 return di_store_client.put(data) 26 27 def read_from_di_store(object_ref): 28 data = di_store_client.get(object_ref) 29 di_store_client.delete(object_ref) 30 return data 31else: 32 save_to_di_store = read_from_di_store = None 33 34 35@lru_cache() 36def get_ceph_package(): 37 return try_import_ceph() 38 39 40@lru_cache() 41def get_redis_package(): 42 return try_import_redis() 43 44 45@lru_cache() 46def get_rediscluster_package(): 47 return try_import_rediscluster() 48 49 50@lru_cache() 51def get_mc_package(): 52 return try_import_mc() 53 54 55def read_from_ceph(path: str) -> object: 56 """ 57 Overview: 58 Read file from ceph 59 Arguments: 60 - path (:obj:`str`): File path in ceph, start with ``"s3://"`` 61 Returns: 62 - (:obj:`data`): Deserialized data 63 """ 64 value = get_ceph_package().Get(path) 65 if not value: 66 raise FileNotFoundError("File({}) doesn't exist in ceph".format(path)) 67 68 return pickle.loads(value) 69 70 71@lru_cache() 72def _get_redis(host='localhost', port=6379): 73 """ 74 Overview: 75 Ensures redis usage 76 Arguments: 77 - host (:obj:`str`): Host string 78 - port (:obj:`int`): Port number 79 Returns: 80 - (:obj:`Redis(object)`): Redis object with given ``host``, ``port``, and ``db=0`` 81 """ 82 return get_redis_package().StrictRedis(host=host, port=port, db=0) 83 84 85def read_from_redis(path: str) -> object: 86 """ 87 Overview: 88 Read file from redis 89 Arguments: 90 - path (:obj:`str`): Dile path in redis, could be a string key 91 Returns: 92 - (:obj:`data`): Deserialized data 93 """ 94 return pickle.loads(_get_redis().get(path)) 95 96 97def _ensure_rediscluster(startup_nodes=[{"host": "127.0.0.1", "port": "7000"}]): 98 """ 99 Overview: 100 Ensures redis usage 101 Arguments: 102 - List of startup nodes (:obj:`dict`) of 103 - host (:obj:`str`): Host string 104 - port (:obj:`int`): Port number 105 Returns: 106 - (:obj:`RedisCluster(object)`): RedisCluster object with given ``host``, ``port``, \ 107 and ``False`` for ``decode_responses`` in default. 108 """ 109 global _redis_cluster 110 if _redis_cluster is None: 111 _redis_cluster = get_rediscluster_package().RedisCluster(startup_nodes=startup_nodes, decode_responses=False) 112 return 113 114 115def read_from_rediscluster(path: str) -> object: 116 """ 117 Overview: 118 Read file from rediscluster 119 Arguments: 120 - path (:obj:`str`): Dile path in rediscluster, could be a string key 121 Returns: 122 - (:obj:`data`): Deserialized data 123 """ 124 _ensure_rediscluster() 125 value_bytes = _redis_cluster.get(path) 126 value = pickle.loads(value_bytes) 127 return value 128 129 130def read_from_file(path: str) -> object: 131 """ 132 Overview: 133 Read file from local file system 134 Arguments: 135 - path (:obj:`str`): File path in local file system 136 Returns: 137 - (:obj:`data`): Deserialized data 138 """ 139 with open(path, "rb") as f: 140 value = pickle.load(f) 141 142 return value 143 144 145def _ensure_memcached(): 146 """ 147 Overview: 148 Ensures memcache usage 149 Returns: 150 - (:obj:`MemcachedClient instance`): MemcachedClient's class instance built with current \ 151 memcached_client's ``server_list.conf`` and ``client.conf`` files 152 """ 153 global _memcached 154 if _memcached is None: 155 server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf" 156 client_config_file = "/mnt/lustre/share/memcached_client/client.conf" 157 _memcached = get_mc_package().MemcachedClient.GetInstance(server_list_config_file, client_config_file) 158 return 159 160 161def read_from_mc(path: str, flush=False) -> object: 162 """ 163 Overview: 164 Read file from memcache, file must be saved by `torch.save()` 165 Arguments: 166 - path (:obj:`str`): File path in local system 167 Returns: 168 - (:obj:`data`): Deserialized data 169 """ 170 _ensure_memcached() 171 while True: 172 try: 173 value = get_mc_package().pyvector() 174 if flush: 175 _memcached.Get(path, value, get_mc_package().MC_READ_THROUGH) 176 return 177 else: 178 _memcached.Get(path, value) 179 value_buf = get_mc_package().ConvertBuffer(value) 180 value_str = io.BytesIO(value_buf) 181 value_str = torch.load(value_str, map_location='cpu') 182 return value_str 183 except Exception: 184 print('read mc failed, retry...') 185 time.sleep(0.01) 186 187 188def read_from_path(path: str): 189 """ 190 Overview: 191 Read file from ceph 192 Arguments: 193 - path (:obj:`str`): File path in ceph, start with ``"s3://"``, or use local file system 194 Returns: 195 - (:obj:`data`): Deserialized data 196 """ 197 if get_ceph_package() is None: 198 logging.info( 199 "You do not have ceph installed! Loading local file!" 200 " If you are not testing locally, something is wrong!" 201 ) 202 return read_from_file(path) 203 else: 204 return read_from_ceph(path) 205 206 207def save_file_ceph(path, data): 208 """ 209 Overview: 210 Save pickle dumped data file to ceph 211 Arguments: 212 - path (:obj:`str`): File path in ceph, start with ``"s3://"``, use file system when not 213 - data (:obj:`Any`): Could be dict, list or tensor etc. 214 """ 215 data = pickle.dumps(data) 216 save_path = os.path.dirname(path) 217 file_name = os.path.basename(path) 218 ceph = get_ceph_package() 219 if ceph is not None: 220 if hasattr(ceph, 'save_from_string'): 221 ceph.save_from_string(save_path, file_name, data) 222 elif hasattr(ceph, 'put'): 223 ceph.put(os.path.join(save_path, file_name), data) 224 else: 225 raise RuntimeError('ceph can not save file, check your ceph installation') 226 else: 227 size = len(data) 228 if save_path == 'do_not_save': 229 logging.info( 230 "You do not have ceph installed! ignored file {} of size {}!".format(file_name, size) + 231 " If you are not testing locally, something is wrong!" 232 ) 233 return 234 p = os.path.join(save_path, file_name) 235 with open(p, 'wb') as f: 236 logging.info( 237 "You do not have ceph installed! Saving as local file at {} of size {}!".format(p, size) + 238 " If you are not testing locally, something is wrong!" 239 ) 240 f.write(data) 241 242 243def save_file_redis(path, data): 244 """ 245 Overview: 246 Save pickle dumped data file to redis 247 Arguments: 248 - path (:obj:`str`): File path (could be a string key) in redis 249 - data (:obj:`Any`): Could be dict, list or tensor etc. 250 """ 251 _get_redis().set(path, pickle.dumps(data)) 252 253 254def save_file_rediscluster(path, data): 255 """ 256 Overview: 257 Save pickle dumped data file to rediscluster 258 Arguments: 259 - path (:obj:`str`): File path (could be a string key) in redis 260 - data (:obj:`Any`): Could be dict, list or tensor etc. 261 """ 262 _ensure_rediscluster() 263 data = pickle.dumps(data) 264 _redis_cluster.set(path, data) 265 return 266 267 268def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = False) -> object: 269 """ 270 Overview: 271 Read file from path 272 Arguments: 273 - path (:obj:`str`): The path of file to read 274 - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` 275 - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system 276 """ 277 if fs_type is None: 278 if path.lower().startswith('s3'): 279 fs_type = 'ceph' 280 elif get_mc_package() is not None: 281 fs_type = 'mc' 282 else: 283 fs_type = 'normal' 284 assert fs_type in ['normal', 'ceph', 'mc'] 285 if fs_type == 'ceph': 286 data = read_from_path(path) 287 elif fs_type == 'normal': 288 if use_lock: 289 with get_file_lock(path, 'read'): 290 data = torch.load(path, map_location='cpu') 291 else: 292 data = torch.load(path, map_location='cpu') 293 elif fs_type == 'mc': 294 data = read_from_mc(path) 295 return data 296 297 298def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_lock: bool = False) -> None: 299 """ 300 Overview: 301 Save data to file of path 302 Arguments: 303 - path (:obj:`str`): The path of file to save to 304 - data (:obj:`object`): The data to save 305 - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` 306 - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system 307 """ 308 if fs_type is None: 309 if path.lower().startswith('s3'): 310 fs_type = 'ceph' 311 elif get_mc_package() is not None: 312 fs_type = 'mc' 313 else: 314 fs_type = 'normal' 315 assert fs_type in ['normal', 'ceph', 'mc'] 316 if fs_type == 'ceph': 317 save_file_ceph(path, data) 318 elif fs_type == 'normal': 319 if use_lock: 320 with get_file_lock(path, 'write'): 321 torch.save(data, path) 322 else: 323 torch.save(data, path) 324 elif fs_type == 'mc': 325 torch.save(data, path) 326 read_from_mc(path, flush=True) 327 328 329def remove_file(path: str, fs_type: Union[None, str] = None) -> None: 330 """ 331 Overview: 332 Remove file 333 Arguments: 334 - path (:obj:`str`): The path of file you want to remove 335 - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` 336 """ 337 if fs_type is None: 338 fs_type = 'ceph' if path.lower().startswith('s3') else 'normal' 339 assert fs_type in ['normal', 'ceph'] 340 if fs_type == 'ceph': 341 os.popen("aws s3 rm --recursive {}".format(path)) 342 elif fs_type == 'normal': 343 os.popen("rm -rf {}".format(path))