Skip to content

ding.data.model_loader

ding.data.model_loader

ModelLoader

Bases: Supervisor, ABC

__init__(model)

Overview

Save and send models asynchronously and load them synchronously.

Arguments: - model (:obj:torch.nn.Module): Torch module.

load(storage)

Overview

Load model synchronously.

Arguments: - storage (:obj:Stroage): The model should be wrapped in a storage object, e.g. FileModelStorage. Returns: - object (:obj:): The loaded model.

save(callback) abstractmethod

Overview

Save model asynchronously.

Arguments: - callback (:obj:Callable): The callback function after saving model. Returns: - storage (:obj:Storage): The storage object is created synchronously, so it can be returned.

FileModelLoader

Bases: ModelLoader

__init__(model, dirname, ttl=20)

Overview

Model loader using files as storage media.

Arguments: - model (:obj:torch.nn.Module): Torch module. - dirname (:obj:str): The directory for saving files. - ttl (:obj:int): Files will be automatically cleaned after ttl. Note that files that do not time out when the process is stopped are not cleaned up (to avoid errors when other processes read the file), so you may need to clean up the remaining files manually

Full Source Code

../ding/data/model_loader.py

1from abc import ABC, abstractmethod 2import logging 3from os import path 4import os 5from threading import Thread 6from time import sleep, time 7from typing import Callable, Optional 8import uuid 9import torch.multiprocessing as mp 10 11import torch 12from ding.data.storage.file import FileModelStorage 13from ding.data.storage.storage import Storage 14from ding.framework import Supervisor 15from ding.framework.supervisor import ChildType, SendPayload 16 17 18class ModelWorker(): 19 20 def __init__(self, model: torch.nn.Module) -> None: 21 self._model = model 22 23 def save(self, storage: Storage) -> Storage: 24 storage.save(self._model.state_dict()) 25 return storage 26 27 28class ModelLoader(Supervisor, ABC): 29 30 def __init__(self, model: torch.nn.Module) -> None: 31 """ 32 Overview: 33 Save and send models asynchronously and load them synchronously. 34 Arguments: 35 - model (:obj:`torch.nn.Module`): Torch module. 36 """ 37 if next(model.parameters()).is_cuda: 38 super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) 39 else: 40 super().__init__(type_=ChildType.PROCESS) 41 self._model = model 42 self._send_callback_loop = None 43 self._send_callbacks = {} 44 self._model_worker = ModelWorker(self._model) 45 46 def start(self): 47 if not self._running: 48 self._model.share_memory() 49 self.register(self._model_worker) 50 self.start_link() 51 self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True) 52 self._send_callback_loop.start() 53 54 def shutdown(self, timeout: Optional[float] = None) -> None: 55 super().shutdown(timeout) 56 self._send_callback_loop = None 57 self._send_callbacks = {} 58 59 def _loop_send_callback(self): 60 while True: 61 payload = self.recv(ignore_err=True) 62 if payload.err: 63 logging.warning("Got error when loading data: {}".format(payload.err)) 64 if payload.req_id in self._send_callbacks: 65 del self._send_callbacks[payload.req_id] 66 else: 67 if payload.req_id in self._send_callbacks: 68 callback = self._send_callbacks.pop(payload.req_id) 69 callback(payload.data) 70 71 def load(self, storage: Storage) -> object: 72 """ 73 Overview: 74 Load model synchronously. 75 Arguments: 76 - storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage. 77 Returns: 78 - object (:obj:): The loaded model. 79 """ 80 return storage.load() 81 82 @abstractmethod 83 def save(self, callback: Callable) -> Storage: 84 """ 85 Overview: 86 Save model asynchronously. 87 Arguments: 88 - callback (:obj:`Callable`): The callback function after saving model. 89 Returns: 90 - storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned. 91 """ 92 raise NotImplementedError 93 94 95class FileModelLoader(ModelLoader): 96 97 def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None: 98 """ 99 Overview: 100 Model loader using files as storage media. 101 Arguments: 102 - model (:obj:`torch.nn.Module`): Torch module. 103 - dirname (:obj:`str`): The directory for saving files. 104 - ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \ 105 files that do not time out when the process is stopped are not cleaned up \ 106 (to avoid errors when other processes read the file), so you may need to \ 107 clean up the remaining files manually 108 """ 109 super().__init__(model) 110 self._dirname = dirname 111 self._ttl = ttl 112 self._files = [] 113 self._cleanup_thread = None 114 115 def _start_cleanup(self): 116 """ 117 Overview: 118 Start a cleanup thread to clean up files that are taking up too much time on the disk. 119 """ 120 if self._cleanup_thread is None: 121 self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) 122 self._cleanup_thread.start() 123 124 def shutdown(self, timeout: Optional[float] = None) -> None: 125 super().shutdown(timeout) 126 self._cleanup_thread = None 127 128 def _loop_cleanup(self): 129 while True: 130 if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: 131 sleep(1) 132 continue 133 _, file_path = self._files.pop(0) 134 if path.exists(file_path): 135 os.remove(file_path) 136 137 def save(self, callback: Callable) -> FileModelStorage: 138 if not self._running: 139 logging.warning("Please start model loader before saving model.") 140 return 141 if not path.exists(self._dirname): 142 os.mkdir(self._dirname) 143 file_path = "model_{}.pth.tar".format(uuid.uuid1()) 144 file_path = path.join(self._dirname, file_path) 145 model_storage = FileModelStorage(file_path) 146 payload = SendPayload(proc_id=0, method="save", args=[model_storage]) 147 self.send(payload) 148 149 def clean_callback(storage: Storage): 150 self._files.append([time(), file_path]) 151 callback(storage) 152 153 self._send_callbacks[payload.req_id] = clean_callback 154 self._start_cleanup() 155 return model_storage