1from typing import Iterable, Callable, Optional, Any, Union 2import time 3import platform 4import threading 5import queue 6 7import torch 8import torch.multiprocessing as tm 9from ding.torch_utils import to_device 10from ding.utils import LockContext, LockContextType 11from .base_dataloader import IDataLoader 12from .collate_fn import default_collate 13 14 15class AsyncDataLoader(IDataLoader): 16 """ 17 Overview: 18 An asynchronous dataloader. 19 Interfaces: 20 ``__init__``, ``__iter__``, ``__next__``, ``_get_data``, ``_async_loop``, ``_worker_loop``, ``_cuda_loop``, \ 21 ``_get_data``, ``close`` 22 """ 23 24 def __init__( 25 self, 26 data_source: Union[Callable, dict], 27 batch_size: int, 28 device: str, 29 chunk_size: Optional[int] = None, 30 collate_fn: Optional[Callable] = None, 31 num_workers: int = 0 32 ) -> None: 33 """ 34 Overview: 35 Init dataloader with input parameters. 36 If ``data_source`` is ``dict``, data will only be processed in ``get_data_thread`` and put into 37 ``async_train_queue``. 38 If ``data_source`` is ``Callable``, data will be processed by implementing functions, and can be sorted 39 in two types: 40 41 - ``num_workers`` == 0 or 1: Only main worker will process it and put into ``async_train_queue``. 42 - ``num_workers`` > 1: Main worker will divide a job into several pieces, push every job into \ 43 ``job_queue``; Then slave workers get jobs and implement; Finally they will push procesed data \ 44 into ``async_train_queue``. 45 46 At the last step, if ``device`` contains "cuda", data in ``async_train_queue`` will be transferred to 47 ``cuda_queue`` for uer to access. 48 Arguments: 49 - data_source (:obj:`Union[Callable, dict]`): The data source, e.g. function to be implemented(Callable), \ 50 replay buffer's real data(dict), etc. 51 - batch_size (:obj:`int`): Batch size. 52 - device (:obj:`str`): Device. 53 - chunk_size (:obj:`int`): The size of a chunked piece in a batch, should exactly divide ``batch_size``, \ 54 only function when there are more than 1 worker. 55 - collate_fn (:obj:`Callable`): The function which is used to collate batch size into each data field. 56 - num_workers (:obj:`int`): Number of extra workers. \ 57 0 or 1 means only 1 main worker and no extra ones, i.e. Multiprocessing is disabled. \ 58 More than 1 means multiple workers implemented by multiprocessing are to processs data respectively. 59 """ 60 self.data_source = data_source 61 self.batch_size = batch_size 62 self.device = device 63 self.use_cuda = 'cuda' in self.device 64 if self.use_cuda: 65 self.stream = torch.cuda.Stream() 66 if chunk_size is None: 67 self.chunk_size = 1 68 else: 69 self.chunk_size = chunk_size 70 assert self.batch_size >= self.chunk_size and self.batch_size % self.chunk_size == 0, '{}/{}'.format( 71 self.batch_size, self.chunk_size 72 ) 73 if collate_fn is None: 74 self.collate_fn = default_collate 75 else: 76 self.collate_fn = collate_fn 77 self.num_workers = num_workers 78 if self.num_workers < 0: 79 raise ValueError( 80 '"num_workers" should be non-negative; ' 81 'Use num_workers = 0 or 1 to disable multiprocessing.' 82 ) 83 # Up to "2 * num_workers" pieces of data will be stored in dataloader, waiting for learner to get. 84 # Up to "2 * num_workers" jobs will be stored in dataloader, waiting for slave process to get and accomplish. 85 queue_maxsize = max(1, self.num_workers) * 2 86 self.queue_maxsize = queue_maxsize 87 88 # For multiprocessing: Use ``spawn`` on Windows, ``fork`` on other platforms. 89 context_str = 'spawn' if platform.system().lower() == 'windows' else 'fork' 90 self.mp_context = tm.get_context(context_str) 91 self.manager = self.mp_context.Manager() 92 # ``async_train_queue`` is the queue to store processed data. 93 # User can directly access data if don't use cuda; Otherwise, user will access data from ``cuda_queue``. 94 self.async_train_queue = self.mp_context.Queue(maxsize=queue_maxsize) 95 self.end_flag = False 96 97 # Multiprocessing workers: If num_workers > 1, more than 1 worker are to process data. 98 if self.num_workers > 1: 99 self.batch_id = self.mp_context.Value('i', 0) 100 self.cur_batch = self.mp_context.Value('i', 0) 101 if self.batch_size != self.chunk_size: 102 # job_result {batch_id: result_list} is used to store processed result in temporal. 103 self.job_result = self.manager.dict() 104 self.job_result_lock = LockContext(lock_type=LockContextType.PROCESS_LOCK) 105 self.job_queue = self.mp_context.Queue(maxsize=queue_maxsize) 106 self.worker = [ 107 self.mp_context.Process( 108 target=self._worker_loop, args=(), name='dataloader_worker{}_{}'.format(i, time.time()) 109 ) for i in range(self.num_workers) 110 ] 111 for w in self.worker: 112 w.daemon = True 113 w.start() 114 print('Using {} workers to load data'.format(self.num_workers)) 115 116 # Parent and child pipes. Used by ``async_process`` and ``get_data_thread`` to coordinate. 117 p, c = self.mp_context.Pipe() 118 119 # Async process (Main worker): Process data if num_workers <= 1; Assign job to other workers if num_workers > 1. 120 self.async_process = self.mp_context.Process(target=self._async_loop, args=(p, c)) 121 self.async_process.daemon = True 122 self.async_process.start() 123 124 # Get data thread: Get data from ``data_source`` and send it to ``async_process``.` 125 self.get_data_thread = threading.Thread(target=self._get_data, args=(p, c)) 126 self.get_data_thread.daemon = True 127 self.get_data_thread.start() 128 129 # Cuda thread: If use cuda, data in ``async_train_queue`` will be transferred to ``cuda_queue``; 130 # Then user will access data from ``cuda_queue``. 131 if self.use_cuda: 132 self.cuda_queue = queue.Queue(maxsize=queue_maxsize) 133 self.cuda_thread = threading.Thread(target=self._cuda_loop, args=(), name='dataloader_cuda') 134 self.cuda_thread.daemon = True 135 self.cuda_thread.start() 136 137 def __iter__(self) -> Iterable: 138 """ 139 Overview: 140 Return the iterable self as an iterator. 141 Returns: 142 - self (:obj:`Iterable`): Self as an iterator. 143 """ 144 return self 145 146 def _get_data(self, p: tm.multiprocessing.connection, c: tm.multiprocessing.connection) -> None: 147 """ 148 Overview: 149 Init dataloader with input parameters. Will run as a thread through ``self.get_data_thread``. 150 Arguments: 151 - p (:obj:`tm.multiprocessing.connection`): Parent connection. 152 - c (:obj:`tm.multiprocessing.connection`): Child connection. 153 """ 154 c.close() # Close unused c, only use p 155 while not self.end_flag: 156 if not p.poll(timeout=0.2): 157 time.sleep(0.01) 158 continue 159 try: 160 cmd = p.recv() 161 except EOFError: 162 break 163 if cmd == 'get_data': 164 # Main worker asks for data. 165 data = self.data_source(self.batch_size) 166 # ``data`` can be callable, e.g. a function to read data from file, therefore we can divide 167 # this job to pieces, assign to every slave worker and accomplish jobs asynchronously. 168 # But if we get a list of dicts, which means the data has already been processed and 169 # can be used directly, we can put it directly in async_train_queue and wait it 170 # to be accessed by a user, e.g. learner. 171 if isinstance(data[0], dict): 172 data = self.collate_fn(data) 173 self.async_train_queue.put(data) 174 p.send('pass') 175 else: 176 p.send(data) 177 p.close() 178 179 def _async_loop(self, p: tm.multiprocessing.connection, c: tm.multiprocessing.connection) -> None: 180 """ 181 Overview: 182 Main worker process. Run through ``self.async_process``. 183 Firstly, get data from ``self.get_data_thread``. 184 If multiple workers, put data in ``self.job_queue`` for further multiprocessing operation; 185 If only one worker, process data and put directly into ``self.async_train_queue``. 186 Arguments: 187 - p (:obj:`tm.multiprocessing.connection`): Parent connection. 188 - c (:obj:`tm.multiprocessing.connection`): Child connection. 189 """ 190 torch.set_num_threads(1) 191 p.close() # Close unused p, only use c 192 while not self.end_flag: 193 if self.num_workers > 1: 194 # Multiple workers: Put jobs (chunked data) into job_queue 195 if self.job_queue.full(): 196 time.sleep(0.001) 197 else: 198 # Get data from ``_get_data`` thread. 199 c.send('get_data') 200 data = c.recv() 201 if isinstance(data, str) and data == 'pass': 202 continue 203 # Get data to be processed, chunk it into pieces and put them into job_queue. 204 chunk_num = self.batch_size // self.chunk_size 205 with self.batch_id.get_lock(): 206 for i in range(chunk_num): 207 start, end = i * self.chunk_size, (i + 1) * self.chunk_size 208 self.job_queue.put({'batch_id': self.batch_id.value, 'job': data[start:end]}) 209 self.batch_id.value = (self.batch_id.value + 1) % self.queue_maxsize # Increment batch_id 210 time.sleep(0.001) 211 else: 212 # Only one worker: Process data and directly put it into async_train_queue 213 if self.async_train_queue.full(): 214 time.sleep(0.001) 215 else: 216 c.send('get_data') 217 data = c.recv() 218 if isinstance(data, str) and data == 'pass': 219 continue 220 data = [fn() for fn in data] # Implement functions in list ``data``. 221 data = self.collate_fn(data) 222 self.async_train_queue.put(data) 223 c.close() 224 225 def _worker_loop(self) -> None: 226 """ 227 Overview: 228 Worker process. Run through each element in list ``self.worker``. 229 Get data job from ``self.job_queue``, process it and then put into ``self.async_train_queue``. 230 Only function when ``self.num_workers`` > 1, which means using multiprocessing. 231 """ 232 while not self.end_flag: 233 if self.job_queue.empty() or self.async_train_queue.full(): 234 # No left job to be done, or finished job have no space to store. 235 time.sleep(0.01) 236 continue 237 else: 238 try: 239 element = self.job_queue.get() 240 except (ConnectionResetError, ConnectionRefusedError) as e: 241 break 242 batch_id, job = element['batch_id'], element['job'] 243 # Process the assigned data. 244 data = [fn() for fn in job] # Only function-type job will arrive here, dict-type will not 245 if len(data) == self.batch_size == self.chunk_size: 246 # Data not chunked: Finish the assigned one means finishing a whole batch. 247 data = self.collate_fn(data) 248 while batch_id != self.cur_batch.value: 249 time.sleep(0.01) 250 self.async_train_queue.put(data) 251 # Directly update cur_batch, since a whole batch is finished 252 with self.cur_batch.get_lock(): 253 self.cur_batch.value = (self.cur_batch.value + 1) % self.queue_maxsize 254 else: 255 # Data chunked: Must wait for all chunked pieces in a batch to be accomplished. 256 finish_flag = False # indicate whether a whole batch is accomplished 257 with self.job_result_lock: 258 if batch_id not in self.job_result: 259 # The first one in a batch 260 self.job_result[batch_id] = data 261 elif len(self.job_result[batch_id]) + len(data) == self.batch_size: 262 # The last one in a batch 263 data += self.job_result.pop(batch_id) 264 assert batch_id not in self.job_result 265 finish_flag = True 266 else: 267 # Middle pieces in a batch 268 self.job_result[batch_id] += data 269 if finish_flag: 270 data = self.collate_fn(data) 271 while batch_id != self.cur_batch.value: 272 time.sleep(0.01) 273 self.async_train_queue.put(data) 274 with self.cur_batch.get_lock(): 275 self.cur_batch.value = (self.cur_batch.value + 1) % self.queue_maxsize 276 # If ``self.end_flag`` is True, clear and close job_queue, because _worker_loop gets jobs from job_queue. 277 while not self.job_queue.empty(): 278 try: 279 _ = self.job_queue.get() 280 except Exception as e: 281 break 282 self.job_queue.close() 283 self.job_queue.join_thread() 284 285 def _cuda_loop(self) -> None: 286 """ 287 Overview: 288 Only when using cuda, would this be run as a thread through ``self.cuda_thread``. 289 Get data from ``self.async_train_queue``, change its device and put it into ``self.cuda_queue`` 290 """ 291 with torch.cuda.stream(self.stream): 292 while not self.end_flag: 293 if self.async_train_queue.empty() or self.cuda_queue.full(): 294 time.sleep(0.01) 295 else: 296 data = self.async_train_queue.get() 297 data = to_device(data, self.device) 298 self.cuda_queue.put(data) 299 # If ``self.end_flag``` is True, clear and close async_train_queue, 300 # because _cuda_loop gets data from async_train_queue. 301 while not self.async_train_queue.empty(): 302 _ = self.async_train_queue.get() 303 self.async_train_queue.close() 304 self.async_train_queue.join_thread() 305 306 def __next__(self) -> Any: 307 """ 308 Overview: 309 Return next data in the iterator. If use cuda, get from ``self.cuda_queue``; 310 Otherwise, get from ``self.async_train_queue``. 311 Returns: 312 - data (:obj:`torch.Tensor`): Next data in the dataloader iterator. 313 """ 314 while not self.end_flag: 315 if self.use_cuda: 316 if self.cuda_queue.empty(): 317 time.sleep(0.01) 318 else: 319 data = self.cuda_queue.get(timeout=60) 320 self.cuda_queue.task_done() 321 return data 322 else: 323 if self.async_train_queue.empty(): 324 time.sleep(0.01) 325 else: 326 return self.async_train_queue.get() 327 # If ``self.end_flag``` is True, clear and close either 1) or 2): 328 # 1) cuda_queue. Because user get data from cuda_queue, and async_train_queue is closed by cuda_loop. 329 # 2) async_train_queue. Because user get data from async_train_queue. 330 if self.use_cuda: 331 while not self.cuda_queue.empty(): 332 _ = self.cuda_queue.get() 333 self.cuda_queue.task_done() 334 self.cuda_queue.join() 335 else: 336 while not self.async_train_queue.empty(): 337 _ = self.async_train_queue.get() 338 self.async_train_queue.close() 339 self.async_train_queue.join_thread() 340 341 def __del__(self) -> None: 342 """ 343 Overview: 344 Delete this dataloader. 345 """ 346 self.close() 347 348 def close(self) -> None: 349 """ 350 Overview: 351 Delete this dataloader. First set ``end_flag`` to True, which means different processes/threads 352 will clear and close all data queues; Then all processes will be terminated and joined. 353 """ 354 if self.end_flag: 355 return 356 self.end_flag = True 357 self.async_process.terminate() 358 self.async_process.join() 359 if self.num_workers > 1: 360 for w in self.worker: 361 w.terminate() 362 w.join() 363 print('Del AsyncDataLoader')