ding.framework.middleware.data_fetcher¶
ding.framework.middleware.data_fetcher
¶
Full Source Code
../ding/framework/middleware/data_fetcher.py
1from typing import TYPE_CHECKING 2from threading import Thread, Event 3from queue import Queue 4import time 5import numpy as np 6import torch 7from easydict import EasyDict 8from ding.framework import task 9from ding.data import Dataset, DataLoader 10from ding.utils import get_rank, get_world_size 11 12if TYPE_CHECKING: 13 from ding.framework import OfflineRLContext 14 15 16class OfflineMemoryDataFetcher: 17 18 def __new__(cls, *args, **kwargs): 19 if task.router.is_active and not task.has_role(task.role.FETCHER): 20 return task.void() 21 return super(OfflineMemoryDataFetcher, cls).__new__(cls) 22 23 def __init__(self, cfg: EasyDict, dataset: Dataset): 24 device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' 25 if device != 'cpu': 26 stream = torch.cuda.Stream() 27 28 def producer(queue, dataset, batch_size, device, event): 29 torch.set_num_threads(4) 30 if device != 'cpu': 31 nonlocal stream 32 sbatch_size = batch_size * get_world_size() 33 rank = get_rank() 34 idx_list = np.random.permutation(len(dataset)) 35 temp_idx_list = [] 36 for i in range(len(dataset) // sbatch_size): 37 temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) 38 idx_iter = iter(temp_idx_list) 39 40 if device != 'cpu': 41 with torch.cuda.stream(stream): 42 while True: 43 if queue.full(): 44 time.sleep(0.1) 45 else: 46 data = [] 47 for _ in range(batch_size): 48 try: 49 data.append(dataset.__getitem__(next(idx_iter))) 50 except StopIteration: 51 del idx_iter 52 idx_list = np.random.permutation(len(dataset)) 53 idx_iter = iter(idx_list) 54 data.append(dataset.__getitem__(next(idx_iter))) 55 data = [[i[j] for i in data] for j in range(len(data[0]))] 56 data = [torch.stack(x).to(device) for x in data] 57 queue.put(data) 58 if event.is_set(): 59 break 60 else: 61 while True: 62 if queue.full(): 63 time.sleep(0.1) 64 else: 65 data = [] 66 for _ in range(batch_size): 67 try: 68 data.append(dataset.__getitem__(next(idx_iter))) 69 except StopIteration: 70 del idx_iter 71 idx_list = np.random.permutation(len(dataset)) 72 idx_iter = iter(idx_list) 73 data.append(dataset.__getitem__(next(idx_iter))) 74 data = [[i[j] for i in data] for j in range(len(data[0]))] 75 data = [torch.stack(x) for x in data] 76 queue.put(data) 77 if event.is_set(): 78 break 79 80 self.queue = Queue(maxsize=50) 81 self.event = Event() 82 self.producer_thread = Thread( 83 target=producer, 84 args=(self.queue, dataset, cfg.policy.batch_size, device, self.event), 85 name='cuda_fetcher_producer' 86 ) 87 88 def __call__(self, ctx: "OfflineRLContext"): 89 if not self.producer_thread.is_alive(): 90 time.sleep(5) 91 self.producer_thread.start() 92 while self.queue.empty(): 93 time.sleep(0.001) 94 ctx.train_data = self.queue.get() 95 96 def __del__(self): 97 if self.producer_thread.is_alive(): 98 self.event.set() 99 del self.queue