Skip to content

ding.framework.task

ding.framework.task

Task

Task will manage the execution order of the entire pipeline, register new middleware, and generate new context objects.

use(fn, lock=False)

Overview

Register middleware to task. The middleware will be executed by it's registry order.

Arguments: - fn (:obj:Callable): A middleware is a function with only one argument: ctx. - lock (:obj:Union[bool, Lock]): There can only be one middleware execution under lock at any one time. Returns: - task (:obj:Task): The task.

use_wrapper(fn)

Overview

Register wrappers to task. A wrapper works like a decorator, but task will apply this decorator on top of each middleware.

Arguments: - fn (:obj:Callable): A wrapper is a decorator, so the first argument is a callable function. Returns: - task (:obj:Task): The task.

match_labels(patterns)

Overview

A list of patterns to match labels.

Arguments: - patterns (:obj:Union[Iterable[str], str]): Glob like pattern, e.g. node.1, node.*.

run(max_step=int(1000000000000.0))

Overview

Execute the iterations, when reach the max_step or task.finish is true, The loop will be break.

Arguments: - max_step (:obj:int): Max step of iterations.

wrap(fn, lock=False)

Overview

Wrap the middleware, make it can be called directly in other middleware.

Arguments: - fn (:obj:Callable): The middleware. - lock (:obj:Union[bool, Lock]): There can only be one middleware execution under lock at any one time. Returns: - fn_back (:obj:Callable): It will return a backward function, which will call the rest part of the middleware after yield. If this backward function was not called, the rest part of the middleware will be called in the global backward step.

forward(fn, ctx=None)

Overview

This function will execute the middleware until the first yield statment, or the end of the middleware.

Arguments: - fn (:obj:Callable): Function with contain the ctx argument in middleware. - ctx (:obj:Optional[Context]): Replace global ctx with a customized ctx. Returns: - g (:obj:Optional[Generator]): The generator if the return value of fn is a generator.

backward(backward_stack=None)

Overview

Execute the rest part of middleware, by the reversed order of registry.

Arguments: - backward_stack (:obj:Optional[Dict[str, Generator]]): Replace global backward_stack with a customized stack.

serial(*fns)

Overview

Wrap functions and keep them run in serial, Usually in order to avoid the confusion of dependencies in async mode.

Arguments: - fn (:obj:Callable): Chain a serial of middleware, wrap them into one middleware function.

parallel(*fns)

Overview

Wrap functions and keep them run in parallel, should not use this funciton in async mode.

Arguments: - fn (:obj:Callable): Parallelized middleware, wrap them into one middleware function.

renew()

Overview

Renew the context instance, this function should be called after backward in the end of iteration.

stop()

Overview

Stop and cleanup every thing in the runtime of task.

async_executor(fn, *args, **kwargs)

Overview

Execute task in background, then apppend the future instance in _async_stack.

Arguments: - fn (:obj:Callable): Synchronization fuction.

emit(event, *args, only_remote=False, only_local=False, **kwargs)

Overview

Emit an event, call listeners.

Arguments: - event (:obj:str): Event name. - only_remote (:obj:bool): Only broadcast the event to the connected nodes, default is False. - only_local (:obj:bool): Only emit local event, default is False. - args (:obj:any): Rest arguments for listeners.

on(event, fn)

Overview

Subscribe to an event, execute this function every time the event is emitted.

Arguments: - event (:obj:str): Event name. - fn (:obj:Callable): The function.

once(event, fn)

Overview

Subscribe to an event, execute this function only once when the event is emitted.

Arguments: - event (:obj:str): Event name. - fn (:obj:Callable): The function.

off(event, fn=None)

Overview

Unsubscribe an event

Arguments: - event (:obj:str): Event name. - fn (:obj:Callable): The function.

wait_for(event, timeout=math.inf, ignore_timeout_exception=True)

Overview

Wait for an event and block the thread.

Arguments: - event (:obj:str): Event name. - timeout (:obj:float): Timeout in seconds. - ignore_timeout_exception (:obj:bool): If this is False, an exception will occur when meeting timeout.

get_attch_to_len()

Overview

Get the length of the 'attach_to' list in Parallel._mq.

Returns: int: the length of the Parallel._mq.

enable_async(func)

Overview

Empower the function with async ability.

Arguments: - func (:obj:Callable): The original function. Returns: - runtime_handler (:obj:Callable): The wrap function.

Full Source Code

../ding/framework/task.py

1from asyncio import InvalidStateError 2from asyncio.tasks import FIRST_EXCEPTION 3from collections import OrderedDict 4from threading import Lock 5import time 6import asyncio 7import concurrent.futures 8import fnmatch 9import math 10import enum 11from types import GeneratorType 12from typing import Any, Awaitable, Callable, Dict, Generator, Iterable, List, Optional, Set, Union 13import inspect 14 15from ding.framework.context import Context 16from ding.framework.parallel import Parallel 17from ding.framework.event_loop import EventLoop 18from functools import wraps 19 20 21def enable_async(func: Callable) -> Callable: 22 """ 23 Overview: 24 Empower the function with async ability. 25 Arguments: 26 - func (:obj:`Callable`): The original function. 27 Returns: 28 - runtime_handler (:obj:`Callable`): The wrap function. 29 """ 30 31 @wraps(func) 32 def runtime_handler(task: "Task", *args, async_mode: Optional[bool] = None, **kwargs) -> "Task": 33 """ 34 Overview: 35 If task's async mode is enabled, execute the step in current loop executor asyncly, 36 or execute the task sync. 37 Arguments: 38 - task (:obj:`Task`): The task instance. 39 - async_mode (:obj:`Optional[bool]`): Whether using async mode. 40 Returns: 41 - result (:obj:`Union[Any, Awaitable]`): The result or future object of middleware. 42 """ 43 if async_mode is None: 44 async_mode = task.async_mode 45 if async_mode: 46 assert not kwargs, "Should not use kwargs in async_mode, use position parameters, kwargs: {}".format(kwargs) 47 t = task._async_loop.run_in_executor(task._thread_pool, func, task, *args, **kwargs) 48 task._async_stack.append(t) 49 return task 50 else: 51 return func(task, *args, **kwargs) 52 53 return runtime_handler 54 55 56class Role(str, enum.Enum): 57 LEARNER = "learner" 58 COLLECTOR = "collector" 59 EVALUATOR = "evaluator" 60 FETCHER = 'fetcher' 61 62 63class VoidMiddleware: 64 65 def __call__(self, _): 66 return 67 68 69class Task: 70 """ 71 Task will manage the execution order of the entire pipeline, register new middleware, 72 and generate new context objects. 73 """ 74 role = Role 75 76 def __init__(self) -> None: 77 self.router = Parallel() 78 self._finish = False 79 80 def start( 81 self, 82 async_mode: bool = False, 83 n_async_workers: int = 3, 84 ctx: Optional[Context] = None, 85 labels: Optional[Set[str]] = None 86 ) -> "Task": 87 # This flag can be modified by external or associated processes 88 self._finish = False 89 # This flag can only be modified inside the class, it will be set to False in the end of stop 90 self._running = True 91 self._middleware = [] 92 self._wrappers = [] 93 self.ctx = ctx or Context() 94 self._backward_stack = OrderedDict() 95 self._roles = set() 96 # Bind event loop functions 97 self._event_loop = EventLoop("task_{}".format(id(self))) 98 99 # Async segment 100 self.async_mode = async_mode 101 self.n_async_workers = n_async_workers 102 self._async_stack = [] 103 self._async_loop = None 104 self._thread_pool = None 105 self._exception = None 106 self._thread_lock = Lock() 107 self.labels = labels or set() 108 109 # Parallel segment 110 if async_mode or self.router.is_active: 111 self._activate_async() 112 113 if self.router.is_active: 114 115 def sync_finish(value): 116 self._finish = value 117 118 self.on("finish", sync_finish) 119 120 self.init_labels() 121 return self 122 123 def add_role(self, role: Role): 124 self._roles.add(role) 125 126 def has_role(self, role: Role) -> bool: 127 if len(self._roles) == 0: 128 return True 129 return role in self._roles 130 131 @property 132 def roles(self) -> Set[Role]: 133 return self._roles 134 135 def void(self): 136 return VoidMiddleware() 137 138 def init_labels(self): 139 if self.async_mode: 140 self.labels.add("async") 141 if self.router.is_active: 142 self.labels.add("distributed") 143 self.labels.add("node.{}".format(self.router.node_id)) 144 for label in self.router.labels: 145 self.labels.add(label) 146 else: 147 self.labels.add("standalone") 148 149 def use(self, fn: Callable, lock: Union[bool, Lock] = False) -> 'Task': 150 """ 151 Overview: 152 Register middleware to task. The middleware will be executed by it's registry order. 153 Arguments: 154 - fn (:obj:`Callable`): A middleware is a function with only one argument: ctx. 155 - lock (:obj:`Union[bool, Lock]`): There can only be one middleware execution under lock at any one time. 156 Returns: 157 - task (:obj:`Task`): The task. 158 """ 159 assert isinstance(fn, Callable), "Middleware function should be a callable object, current fn {}".format(fn) 160 if isinstance(fn, VoidMiddleware): # Skip void function 161 return self 162 for wrapper in self._wrappers: 163 fn = wrapper(fn) 164 self._middleware.append(self.wrap(fn, lock=lock)) 165 return self 166 167 def use_wrapper(self, fn: Callable) -> 'Task': 168 """ 169 Overview: 170 Register wrappers to task. A wrapper works like a decorator, but task will apply this \ 171 decorator on top of each middleware. 172 Arguments: 173 - fn (:obj:`Callable`): A wrapper is a decorator, so the first argument is a callable function. 174 Returns: 175 - task (:obj:`Task`): The task. 176 """ 177 # Wrap exist middlewares 178 for i, middleware in enumerate(self._middleware): 179 self._middleware[i] = fn(middleware) 180 self._wrappers.append(fn) 181 return self 182 183 def match_labels(self, patterns: Union[Iterable[str], str]) -> bool: 184 """ 185 Overview: 186 A list of patterns to match labels. 187 Arguments: 188 - patterns (:obj:`Union[Iterable[str], str]`): Glob like pattern, e.g. node.1, node.*. 189 """ 190 if isinstance(patterns, str): 191 patterns = [patterns] 192 return any([fnmatch.filter(self.labels, p) for p in patterns]) 193 194 def run(self, max_step: int = int(1e12)) -> None: 195 """ 196 Overview: 197 Execute the iterations, when reach the max_step or task.finish is true, 198 The loop will be break. 199 Arguments: 200 - max_step (:obj:`int`): Max step of iterations. 201 """ 202 assert self._running, "Please make sure the task is running before calling the this method, see the task.start" 203 if len(self._middleware) == 0: 204 return 205 for i in range(max_step): 206 for fn in self._middleware: 207 self.forward(fn) 208 # Sync should be called before backward, otherwise it is possible 209 # that some generators have not been pushed to backward_stack. 210 self.sync() 211 self.backward() 212 self.sync() 213 if i == max_step - 1: 214 self.finish = True 215 if self.finish: 216 break 217 self.renew() 218 219 def wrap(self, fn: Callable, lock: Union[bool, Lock] = False) -> Callable: 220 """ 221 Overview: 222 Wrap the middleware, make it can be called directly in other middleware. 223 Arguments: 224 - fn (:obj:`Callable`): The middleware. 225 - lock (:obj:`Union[bool, Lock]`): There can only be one middleware execution under lock at any one time. 226 Returns: 227 - fn_back (:obj:`Callable`): It will return a backward function, which will call the rest part of 228 the middleware after yield. If this backward function was not called, the rest part of the middleware 229 will be called in the global backward step. 230 """ 231 if lock is True: 232 lock = self._thread_lock 233 234 def forward(ctx: Context): 235 if lock: 236 with lock: 237 g = self.forward(fn, ctx, async_mode=False) 238 else: 239 g = self.forward(fn, ctx, async_mode=False) 240 241 def backward(): 242 backward_stack = OrderedDict() 243 key = id(g) 244 backward_stack[key] = self._backward_stack.pop(key) 245 if lock: 246 with lock: 247 self.backward(backward_stack, async_mode=False) 248 else: 249 self.backward(backward_stack, async_mode=False) 250 251 return backward 252 253 if hasattr(fn, "__name__"): 254 forward = wraps(fn)(forward) 255 else: 256 forward = wraps(fn.__class__)(forward) 257 258 return forward 259 260 @enable_async 261 def forward(self, fn: Callable, ctx: Optional[Context] = None) -> Optional[Generator]: 262 """ 263 Overview: 264 This function will execute the middleware until the first yield statment, 265 or the end of the middleware. 266 Arguments: 267 - fn (:obj:`Callable`): Function with contain the ctx argument in middleware. 268 - ctx (:obj:`Optional[Context]`): Replace global ctx with a customized ctx. 269 Returns: 270 - g (:obj:`Optional[Generator]`): The generator if the return value of fn is a generator. 271 """ 272 assert self._running, "Please make sure the task is running before calling the this method, see the task.start" 273 if not ctx: 274 ctx = self.ctx 275 g = fn(ctx) 276 if isinstance(g, GeneratorType): 277 try: 278 next(g) 279 self._backward_stack[id(g)] = g 280 return g 281 except StopIteration: 282 pass 283 284 @enable_async 285 def backward(self, backward_stack: Optional[Dict[str, Generator]] = None) -> None: 286 """ 287 Overview: 288 Execute the rest part of middleware, by the reversed order of registry. 289 Arguments: 290 - backward_stack (:obj:`Optional[Dict[str, Generator]]`): Replace global backward_stack with a customized \ 291 stack. 292 """ 293 assert self._running, "Please make sure the task is running before calling the this method, see the task.start" 294 if not backward_stack: 295 backward_stack = self._backward_stack 296 while backward_stack: 297 # FILO 298 _, g = backward_stack.popitem() 299 try: 300 next(g) 301 except StopIteration: 302 continue 303 304 @property 305 def running(self): 306 return self._running 307 308 def serial(self, *fns: List[Callable]) -> Callable: 309 """ 310 Overview: 311 Wrap functions and keep them run in serial, Usually in order to avoid the confusion 312 of dependencies in async mode. 313 Arguments: 314 - fn (:obj:`Callable`): Chain a serial of middleware, wrap them into one middleware function. 315 """ 316 317 def _serial(ctx): 318 backward_keys = [] 319 for fn in fns: 320 g = self.forward(fn, ctx, async_mode=False) 321 if isinstance(g, GeneratorType): 322 backward_keys.append(id(g)) 323 yield 324 backward_stack = OrderedDict() 325 for k in backward_keys: 326 backward_stack[k] = self._backward_stack.pop(k) 327 self.backward(backward_stack=backward_stack, async_mode=False) 328 329 name = ",".join([fn.__name__ for fn in fns]) 330 _serial.__name__ = "serial<{}>".format(name) 331 return _serial 332 333 def parallel(self, *fns: List[Callable]) -> Callable: 334 """ 335 Overview: 336 Wrap functions and keep them run in parallel, should not use this funciton in async mode. 337 Arguments: 338 - fn (:obj:`Callable`): Parallelized middleware, wrap them into one middleware function. 339 """ 340 self._activate_async() 341 342 def _parallel(ctx): 343 backward_keys = [] 344 for fn in fns: 345 g = self.forward(fn, ctx, async_mode=True) 346 if isinstance(g, GeneratorType): 347 backward_keys.append(id(g)) 348 self.sync() 349 yield 350 backward_stack = OrderedDict() 351 for k in backward_keys: 352 backward_stack[k] = self._backward_stack.pop(k) 353 self.backward(backward_stack, async_mode=True) 354 self.sync() 355 356 name = ",".join([fn.__name__ for fn in fns]) 357 _parallel.__name__ = "parallel<{}>".format(name) 358 return _parallel 359 360 def renew(self) -> 'Task': 361 """ 362 Overview: 363 Renew the context instance, this function should be called after backward in the end of iteration. 364 """ 365 assert self._running, "Please make sure the task is running before calling the this method, see the task.start" 366 self.ctx = self.ctx.renew() 367 return self 368 369 def __enter__(self) -> "Task": 370 return self 371 372 def __exit__(self, exc_type, exc_val, exc_tb): 373 self.stop() 374 375 def stop(self) -> None: 376 """ 377 Overview: 378 Stop and cleanup every thing in the runtime of task. 379 """ 380 if self.router.is_active: 381 self.emit("finish", True) 382 if self._thread_pool: 383 self._thread_pool.shutdown() 384 self._event_loop.stop() 385 self.router.off(self._wrap_event_name("*")) 386 if self._async_loop: 387 self._async_loop.stop() 388 self._async_loop.close() 389 # The middleware and listeners may contain some methods that reference to task, 390 # If we do not clear them after the task exits, we may find that gc will not clean up the task object. 391 self._middleware.clear() 392 self._wrappers.clear() 393 self._backward_stack.clear() 394 self._async_stack.clear() 395 self._running = False 396 397 def sync(self) -> 'Task': 398 if self._async_loop: 399 self._async_loop.run_until_complete(self.sync_tasks()) 400 return self 401 402 async def sync_tasks(self) -> Awaitable[None]: 403 if self._async_stack: 404 await asyncio.wait(self._async_stack, return_when=FIRST_EXCEPTION) 405 while self._async_stack: 406 t = self._async_stack.pop(0) 407 try: 408 e = t.exception() 409 if e: 410 self._exception = e 411 raise e 412 except InvalidStateError: 413 # Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception 414 pass 415 416 def async_executor(self, fn: Callable, *args, **kwargs) -> None: 417 """ 418 Overview: 419 Execute task in background, then apppend the future instance in _async_stack. 420 Arguments: 421 - fn (:obj:`Callable`): Synchronization fuction. 422 """ 423 if not self._async_loop: 424 raise Exception("Event loop was not initialized, please call this function in async or parallel mode") 425 t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs) 426 self._async_stack.append(t) 427 428 def emit(self, event: str, *args, only_remote: bool = False, only_local: bool = False, **kwargs) -> None: 429 """ 430 Overview: 431 Emit an event, call listeners. 432 Arguments: 433 - event (:obj:`str`): Event name. 434 - only_remote (:obj:`bool`): Only broadcast the event to the connected nodes, default is False. 435 - only_local (:obj:`bool`): Only emit local event, default is False. 436 - args (:obj:`any`): Rest arguments for listeners. 437 """ 438 # Check if need to broadcast event to connected nodes, default is True 439 assert self._running, "Please make sure the task is running before calling the this method, see the task.start" 440 if only_local: 441 self._event_loop.emit(event, *args, **kwargs) 442 elif only_remote: 443 if self.router.is_active: 444 self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) 445 else: 446 if self.router.is_active: 447 self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs) 448 self._event_loop.emit(event, *args, **kwargs) 449 450 def on(self, event: str, fn: Callable) -> None: 451 """ 452 Overview: 453 Subscribe to an event, execute this function every time the event is emitted. 454 Arguments: 455 - event (:obj:`str`): Event name. 456 - fn (:obj:`Callable`): The function. 457 """ 458 self._event_loop.on(event, fn) 459 if self.router.is_active: 460 self.router.on(self._wrap_event_name(event), self._event_loop.emit) 461 462 def once(self, event: str, fn: Callable) -> None: 463 """ 464 Overview: 465 Subscribe to an event, execute this function only once when the event is emitted. 466 Arguments: 467 - event (:obj:`str`): Event name. 468 - fn (:obj:`Callable`): The function. 469 """ 470 self._event_loop.once(event, fn) 471 if self.router.is_active: 472 self.router.on(self._wrap_event_name(event), self._event_loop.emit) 473 474 def off(self, event: str, fn: Optional[Callable] = None) -> None: 475 """ 476 Overview: 477 Unsubscribe an event 478 Arguments: 479 - event (:obj:`str`): Event name. 480 - fn (:obj:`Callable`): The function. 481 """ 482 self._event_loop.off(event, fn) 483 if self.router.is_active: 484 self.router.off(self._wrap_event_name(event)) 485 486 def wait_for(self, event: str, timeout: float = math.inf, ignore_timeout_exception: bool = True) -> Any: 487 """ 488 Overview: 489 Wait for an event and block the thread. 490 Arguments: 491 - event (:obj:`str`): Event name. 492 - timeout (:obj:`float`): Timeout in seconds. 493 - ignore_timeout_exception (:obj:`bool`): If this is False, an exception will occur when meeting timeout. 494 """ 495 assert self._running, "Please make sure the task is running before calling the this method, see the task.start" 496 received = False 497 result = None 498 499 def _receive_event(*args, **kwargs): 500 nonlocal result, received 501 result = (args, kwargs) 502 received = True 503 504 self.once(event, _receive_event) 505 506 start = time.time() 507 while time.time() - start < timeout: 508 if received or self._exception: 509 return result 510 time.sleep(0.01) 511 512 if ignore_timeout_exception: 513 return result 514 else: 515 raise TimeoutError("Timeout when waiting for event: {}".format(event)) 516 517 @property 518 def finish(self): 519 return self._finish 520 521 @finish.setter 522 def finish(self, value: bool): 523 self._finish = value 524 525 def _wrap_event_name(self, event: str) -> str: 526 """ 527 Overview: 528 Wrap the event name sent to the router. 529 Arguments: 530 - event (:obj:`str`): Event name 531 """ 532 return "task.{}".format(event) 533 534 def _activate_async(self): 535 if not self._thread_pool: 536 self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.n_async_workers) 537 if not self._async_loop: 538 self._async_loop = asyncio.new_event_loop() 539 540 def get_attch_to_len(self) -> int: 541 """ 542 Overview: 543 Get the length of the 'attach_to' list in Parallel._mq. 544 Returns: 545 int: the length of the Parallel._mq. 546 """ 547 if self.router.is_active: 548 return self.router.get_attch_to_len() 549 else: 550 raise RuntimeError("The router is inactive, failed to be obtained the length of 'attch_to' list.") 551 552 553task = Task()