Skip to content

ding.framework.supervisor

ding.framework.supervisor

Child

Bases: ABC

Abstract class of child process/thread.

Supervisor

send(payload)

Overview

Send message to child process.

Arguments: - payload (:obj:SendPayload): Send payload.

recv(ignore_err=False, timeout=None)

Overview

Wait for message from child process

Arguments: - ignore_err (:obj:bool): If ignore_err is True, put the err in the property of recv_payload. Otherwise, an exception will be raised. - timeout (:obj:float): Timeout for queue.get, will raise an Empty exception if timeout. Returns: - recv_payload (:obj:RecvPayload): Recv payload.

recv_all(send_payloads, ignore_err=False, callback=None, timeout=None)

Overview

Wait for messages with specific req ids until all ids are fulfilled.

Arguments: - send_payloads (:obj:List[SendPayload]): Request payloads. - ignore_err (:obj:bool): If ignore_err is True, put the err in the property of recv_payload. Otherwise, an exception will be raised. This option will also ignore timeout error. - callback (:obj:Callable): Callback for each recv payload. - timeout (:obj:Optional[float]): Timeout when wait for responses. Returns: - recv_payload (:obj:List[RecvPayload]): Recv payload, may contain timeout error.

get_child_attr(proc_id, key)

Overview

Get attr of one child process instance.

Arguments: - proc_id (:obj:str): Proc id. - key (:obj:str): Attribute key. Returns: - attr (:obj:Any): Attribute of child.

Full Source Code

../ding/framework/supervisor.py

1from abc import ABC, abstractmethod 2import functools 3import torch.multiprocessing as mp 4from multiprocessing.context import BaseContext 5import threading 6import queue 7import platform 8import traceback 9import uuid 10import time 11from ditk import logging 12from dataclasses import dataclass, field 13from typing import Any, Callable, Dict, List, Optional, Union 14from enum import Enum 15 16 17@functools.lru_cache(maxsize=1) 18def get_mp_ctx() -> BaseContext: 19 context = 'spawn' if platform.system().lower() == 'windows' else 'fork' 20 mp_ctx = mp.get_context(context) 21 return mp_ctx 22 23 24@dataclass 25class SendPayload: 26 proc_id: int 27 # Use uuid1 here to include the timestamp 28 req_id: str = field(default_factory=lambda: uuid.uuid1().hex) 29 method: str = None 30 args: List = field(default_factory=list) 31 kwargs: Dict = field(default_factory=dict) 32 33 34@dataclass 35class RecvPayload: 36 proc_id: int 37 req_id: str = None 38 method: str = None 39 data: Any = None 40 err: Exception = None 41 extra: Any = None 42 43 44class ReserveMethod(Enum): 45 SHUTDOWN = "_shutdown" 46 GETATTR = "_getattr" 47 48 49class ChildType(Enum): 50 PROCESS = "process" 51 THREAD = "thread" 52 53 54class Child(ABC): 55 """ 56 Abstract class of child process/thread. 57 """ 58 59 def __init__(self, proc_id: int, init: Union[Callable, object], **kwargs) -> None: 60 self._proc_id = proc_id 61 self._init = init 62 self._recv_queue = None 63 self._send_queue = None 64 65 @abstractmethod 66 def start(self, recv_queue: Union[mp.Queue, queue.Queue]): 67 raise NotImplementedError 68 69 def restart(self): 70 self.shutdown() 71 self.start(self._recv_queue) 72 73 @abstractmethod 74 def shutdown(self, timeout: Optional[float] = None): 75 raise NotImplementedError 76 77 @abstractmethod 78 def send(self, payload: SendPayload): 79 raise NotImplementedError 80 81 def _target( 82 self, 83 proc_id: int, 84 init: Union[Callable, object], 85 send_queue: Union[mp.Queue, queue.Queue], 86 recv_queue: Union[mp.Queue, queue.Queue], 87 shm_buffer: Optional[Any] = None, 88 shm_callback: Optional[Callable] = None 89 ): 90 send_payload = SendPayload(proc_id=proc_id) 91 if isinstance(init, Callable): 92 child_ins = init() 93 else: 94 child_ins = init 95 while True: 96 try: 97 send_payload: SendPayload = send_queue.get() 98 if send_payload.method == ReserveMethod.SHUTDOWN: 99 break 100 if send_payload.method == ReserveMethod.GETATTR: 101 data = getattr(child_ins, send_payload.args[0]) 102 else: 103 data = getattr(child_ins, send_payload.method)(*send_payload.args, **send_payload.kwargs) 104 recv_payload = RecvPayload( 105 proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, data=data 106 ) 107 if shm_callback is not None and shm_buffer is not None: 108 shm_callback(recv_payload, shm_buffer) 109 recv_queue.put(recv_payload) 110 except Exception as e: 111 logging.warning(traceback.format_exc()) 112 logging.warning("Error in child process! id: {}, error: {}".format(self._proc_id, e)) 113 recv_payload = RecvPayload( 114 proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, err=e 115 ) 116 recv_queue.put(recv_payload) 117 118 def __del__(self): 119 self.shutdown() 120 121 122class ChildProcess(Child): 123 124 def __init__( 125 self, 126 proc_id: int, 127 init: Union[Callable, object], 128 shm_buffer: Optional[Any] = None, 129 shm_callback: Optional[Callable] = None, 130 mp_ctx: Optional[BaseContext] = None, 131 **kwargs 132 ) -> None: 133 super().__init__(proc_id, init, **kwargs) 134 self._proc = None 135 self._mp_ctx = mp_ctx 136 self._shm_buffer = shm_buffer 137 self._shm_callback = shm_callback 138 139 def start(self, recv_queue: mp.Queue): 140 if self._proc is None: 141 self._recv_queue = recv_queue 142 ctx = self._mp_ctx or get_mp_ctx() 143 self._send_queue = ctx.Queue() 144 proc = ctx.Process( 145 target=self._target, 146 args=( 147 self._proc_id, self._init, self._send_queue, self._recv_queue, self._shm_buffer, self._shm_callback 148 ), 149 name="supervisor_child_{}_{}".format(self._proc_id, time.time()), 150 daemon=True 151 ) 152 proc.start() 153 self._proc = proc 154 155 def shutdown(self, timeout: Optional[float] = None): 156 if self._proc: 157 self._send_queue.put(SendPayload(proc_id=self._proc_id, method=ReserveMethod.SHUTDOWN)) 158 self._proc.terminate() 159 self._proc.join(timeout=timeout) 160 if hasattr(self._proc, "close"): # Compatible with 3.6 161 self._proc.close() 162 self._proc = None 163 self._send_queue.close() 164 self._send_queue.join_thread() 165 self._send_queue = None 166 167 def send(self, payload: SendPayload): 168 if self._send_queue is None: 169 logging.warning("Child worker has been terminated or not started.") 170 return 171 self._send_queue.put(payload) 172 173 174class ChildThread(Child): 175 176 def __init__(self, proc_id: int, init: Union[Callable, object], *args, **kwargs) -> None: 177 super().__init__(proc_id, init, *args, **kwargs) 178 self._thread = None 179 180 def start(self, recv_queue: queue.Queue): 181 if self._thread is None: 182 self._recv_queue = recv_queue 183 self._send_queue = queue.Queue() 184 thread = threading.Thread( 185 target=self._target, 186 args=(self._proc_id, self._init, self._send_queue, self._recv_queue), 187 name="supervisor_child_{}_{}".format(self._proc_id, time.time()), 188 daemon=True 189 ) 190 thread.start() 191 self._thread = thread 192 193 def shutdown(self, timeout: Optional[float] = None): 194 if self._thread: 195 self._send_queue.put(SendPayload(proc_id=self._proc_id, method=ReserveMethod.SHUTDOWN)) 196 self._thread.join(timeout=timeout) 197 self._thread = None 198 self._send_queue = None 199 200 def send(self, payload: SendPayload): 201 if self._send_queue is None: 202 logging.warning("Child worker has been terminated or not started.") 203 return 204 self._send_queue.put(payload) 205 206 207class Supervisor: 208 209 TYPE_MAPPING = {ChildType.PROCESS: ChildProcess, ChildType.THREAD: ChildThread} 210 211 def __init__(self, type_: ChildType, mp_ctx: Optional[BaseContext] = None) -> None: 212 self._children: List[Child] = [] 213 self._type = type_ 214 self._child_class = self.TYPE_MAPPING[self._type] 215 self._running = False 216 self.__queue = None 217 self._mp_ctx = mp_ctx or get_mp_ctx() 218 219 def register( 220 self, 221 init: Union[Callable, object], 222 shm_buffer: Optional[Any] = None, 223 shm_callback: Optional[Callable] = None 224 ) -> None: 225 proc_id = len(self._children) 226 self._children.append( 227 self._child_class(proc_id, init, shm_buffer=shm_buffer, shm_callback=shm_callback, mp_ctx=self._mp_ctx) 228 ) 229 230 @property 231 def _recv_queue(self) -> Union[queue.Queue, mp.Queue]: 232 if not self.__queue: 233 if self._type is ChildType.PROCESS: 234 self.__queue = self._mp_ctx.Queue() 235 elif self._type is ChildType.THREAD: 236 self.__queue = queue.Queue() 237 return self.__queue 238 239 @_recv_queue.setter 240 def _recv_queue(self, queue: Union[queue.Queue, mp.Queue]): 241 self.__queue = queue 242 243 def start_link(self) -> None: 244 if not self._running: 245 for child in self._children: 246 child.start(recv_queue=self._recv_queue) 247 self._running = True 248 249 def send(self, payload: SendPayload) -> None: 250 """ 251 Overview: 252 Send message to child process. 253 Arguments: 254 - payload (:obj:`SendPayload`): Send payload. 255 """ 256 if not self._running: 257 logging.warning("Please call start_link before sending any payload to child process.") 258 return 259 self._children[payload.proc_id].send(payload) 260 261 def recv(self, ignore_err: bool = False, timeout: float = None) -> RecvPayload: 262 """ 263 Overview: 264 Wait for message from child process 265 Arguments: 266 - ignore_err (:obj:`bool`): If ignore_err is True, put the err in the property of recv_payload. \ 267 Otherwise, an exception will be raised. 268 - timeout (:obj:`float`): Timeout for queue.get, will raise an Empty exception if timeout. 269 Returns: 270 - recv_payload (:obj:`RecvPayload`): Recv payload. 271 """ 272 recv_payload: RecvPayload = self._recv_queue.get(timeout=timeout) 273 if recv_payload.err and not ignore_err: 274 raise recv_payload.err 275 return recv_payload 276 277 def recv_all( 278 self, 279 send_payloads: List[SendPayload], 280 ignore_err: bool = False, 281 callback: Callable = None, 282 timeout: Optional[float] = None 283 ) -> List[RecvPayload]: 284 """ 285 Overview: 286 Wait for messages with specific req ids until all ids are fulfilled. 287 Arguments: 288 - send_payloads (:obj:`List[SendPayload]`): Request payloads. 289 - ignore_err (:obj:`bool`): If ignore_err is True, \ 290 put the err in the property of recv_payload. Otherwise, an exception will be raised. \ 291 This option will also ignore timeout error. 292 - callback (:obj:`Callable`): Callback for each recv payload. 293 - timeout (:obj:`Optional[float]`): Timeout when wait for responses. 294 Returns: 295 - recv_payload (:obj:`List[RecvPayload]`): Recv payload, may contain timeout error. 296 """ 297 assert send_payloads, "Req payload is empty!" 298 recv_payloads = {} 299 remain_payloads = {payload.req_id: payload for payload in send_payloads} 300 unrelated_payloads = [] 301 try: 302 while remain_payloads: 303 try: 304 recv_payload: RecvPayload = self._recv_queue.get(block=True, timeout=timeout) 305 if recv_payload.req_id in remain_payloads: 306 del remain_payloads[recv_payload.req_id] 307 recv_payloads[recv_payload.req_id] = recv_payload 308 if recv_payload.err and not ignore_err: 309 raise recv_payload.err 310 if callback: 311 callback(recv_payload, remain_payloads) 312 else: 313 unrelated_payloads.append(recv_payload) 314 except queue.Empty: 315 if ignore_err: 316 req_ids = list(remain_payloads.keys()) 317 logging.warning("Timeout ({}s) when receving payloads! Req ids: {}".format(timeout, req_ids)) 318 for req_id in req_ids: 319 send_payload = remain_payloads.pop(req_id) 320 # If timeout error happens in timeout recover, there may not find any send_payload 321 # in the original indexed payloads. 322 recv_payload = RecvPayload( 323 proc_id=send_payload.proc_id, 324 req_id=send_payload.req_id, 325 method=send_payload.method, 326 err=TimeoutError("Timeout on req_id ({})".format(req_id)) 327 ) 328 recv_payloads[req_id] = recv_payload 329 if callback: 330 callback(recv_payload, remain_payloads) 331 else: 332 raise TimeoutError("Timeout ({}s) when receving payloads!".format(timeout)) 333 finally: 334 # Put back the unrelated payload. 335 for payload in unrelated_payloads: 336 self._recv_queue.put(payload) 337 338 # Keep the original order of requests. 339 return [recv_payloads[p.req_id] for p in send_payloads] 340 341 def shutdown(self, timeout: Optional[float] = None) -> None: 342 if self._running: 343 for child in self._children: 344 child.shutdown(timeout=timeout) 345 self._cleanup_queue() 346 self._running = False 347 348 def _cleanup_queue(self): 349 while True: 350 while not self._recv_queue.empty(): 351 self._recv_queue.get() 352 time.sleep(0.1) # mp.Queue is not reliable. 353 if self._recv_queue.empty(): 354 break 355 if hasattr(self._recv_queue, "close"): 356 self._recv_queue.close() 357 self._recv_queue.join_thread() 358 self._recv_queue = None 359 360 def __getattr__(self, key: str) -> List[Any]: 361 assert self._running, "Supervisor is not running, please call start_link first!" 362 send_payloads = [] 363 for i, child in enumerate(self._children): 364 payload = SendPayload(proc_id=i, method=ReserveMethod.GETATTR, args=[key]) 365 send_payloads.append(payload) 366 child.send(payload) 367 return [payload.data for payload in self.recv_all(send_payloads)] 368 369 def get_child_attr(self, proc_id: str, key: str) -> Any: 370 """ 371 Overview: 372 Get attr of one child process instance. 373 Arguments: 374 - proc_id (:obj:`str`): Proc id. 375 - key (:obj:`str`): Attribute key. 376 Returns: 377 - attr (:obj:`Any`): Attribute of child. 378 """ 379 assert self._running, "Supervisor is not running, please call start_link first!" 380 payload = SendPayload(proc_id=proc_id, method=ReserveMethod.GETATTR, args=[key]) 381 self._children[proc_id].send(payload) 382 payloads = self.recv_all([payload]) 383 return payloads[0].data 384 385 def __del__(self) -> None: 386 self.shutdown(timeout=5) 387 self._children.clear()