Skip to content

ding.framework.parallel

ding.framework.parallel

MQ

Overview

Abstract basic mq class.

__init__(*args, **kwargs)

Overview

The init method of the inheritance must support the extra kwargs parameter.

listen()

Overview

Bind to local socket or connect to third party components.

publish(topic, data)

Overview

Send data to mq.

Arguments: - topic (:obj:str): Topic. - data (:obj:bytes): Payload data.

subscribe(topic)

Overview

Subscribe to the topic.

Arguments: - topic (:obj:str): Topic

unsubscribe(topic)

Overview

Unsubscribe from the topic.

Arguments: - topic (:obj:str): Topic

recv()

Overview

Wait for incoming message, this function will block the current thread.

Returns: - data (:obj:Any): The sent payload.

stop()

Overview

Unsubscribe from all topics and stop the connection to the message queue server.

RedisMQ

Bases: MQ

__init__(redis_host, redis_port, **kwargs)

Overview

Connect distributed processes with redis

Arguments: - redis_host (:obj:str): Redis server host. - redis_port (:obj:int): Redis server port.

NNGMQ

Bases: MQ

__init__(listen_to, attach_to=None, **kwargs)

Overview

Connect distributed processes with nng

Arguments: - listen_to (:obj:Optional[List[str]]): The node address to attach to. - attach_to (:obj:Optional[List[str]]): The node's addresses you want to attach to.

Parallel

runner(n_parallel_workers, mq_type='nng', attach_to=None, protocol='ipc', address=None, ports=None, topology='mesh', labels=None, node_ids=None, auto_recover=False, max_retries=float('inf'), redis_host=None, redis_port=None, startup_interval=1) classmethod

Overview

This method allows you to configure parallel parameters, and now you are still in the parent process.

Arguments: - n_parallel_workers (:obj:int): Workers to spawn. - mq_type (:obj:str): Embedded message queue type, i.e. nng, redis. - attach_to (:obj:Optional[List[str]]): The node's addresses you want to attach to. - protocol (:obj:str): Network protocol. - address (:obj:Optional[str]): Bind address, ip or file path. - ports (:obj:Optional[List[int]]): Candidate ports. - topology (:obj:str): Network topology, includes: mesh (default): fully connected between each other; star: only connect to the first node; alone: do not connect to any node, except the node attached to; - labels (:obj:Optional[Set[str]]): Labels. - node_ids (:obj:Optional[List[int]]): Candidate node ids. - auto_recover (:obj:bool): Auto recover from uncaught exceptions from main. - max_retries (:obj:int): Max retries for auto recover. - redis_host (:obj:str): Redis server host. - redis_port (:obj:int): Redis server port. - startup_interval (:obj:int): Start up interval between each task. Returns: - _runner (:obj:Callable): The wrapper function for main.

padding_param(int_or_list, n_max, start_value) classmethod

Overview

Padding int or list param to the length of n_max.

Arguments: - int_or_list (:obj:Optional[Union[List[int], int]]): Int or list typed value. - n_max (:obj:int): Max length. - start_value (:obj:int): Start from value.

on(event, fn)

Overview

Register an remote event on parallel instance, this function will be executed when a remote process emit this event via network.

Arguments: - event (:obj:str): Event name. - fn (:obj:Callable): Function body.

once(event, fn)

Overview

Register an remote event which will only call once on parallel instance, this function will be executed when a remote process emit this event via network.

Arguments: - event (:obj:str): Event name. - fn (:obj:Callable): Function body.

off(event)

Overview

Unregister an event.

Arguments: - event (:obj:str): Event name.

emit(event, *args, **kwargs)

Overview

Send an remote event via network to subscribed processes.

Arguments: - event (:obj:str): Event name.

get_attch_to_len()

Overview

Get the length of the 'attach_to' list of message queue.

Returns: int: the length of the self._mq.attach_to. Returns 0 if self._mq is not initialized

Full Source Code

../ding/framework/parallel.py

1import atexit 2import os 3import random 4import time 5import traceback 6import pickle 7from mpire.pool import WorkerPool 8from ditk import logging 9import tempfile 10import socket 11from os import path 12from typing import Callable, Dict, List, Optional, Tuple, Union, Set 13from threading import Thread 14from ding.framework.event_loop import EventLoop 15from ding.utils.design_helper import SingletonMetaclass 16from ding.framework.message_queue import * 17from ding.utils.registry_factory import MQ_REGISTRY 18 19# Avoid ipc address conflict, random should always use random seed 20random = random.Random() 21 22 23class Parallel(metaclass=SingletonMetaclass): 24 25 def __init__(self) -> None: 26 # Init will only be called once in a process 27 self._listener = None 28 self.is_active = False 29 self.node_id = None 30 self.local_id = None 31 self.labels = set() 32 self._event_loop = EventLoop("parallel_{}".format(id(self))) 33 self._retries = 0 # Retries in auto recovery 34 35 def _run( 36 self, 37 node_id: int, 38 local_id: int, 39 n_parallel_workers: int, 40 labels: Optional[Set[str]] = None, 41 auto_recover: bool = False, 42 max_retries: int = float("inf"), 43 mq_type: str = "nng", 44 startup_interval: int = 1, 45 **kwargs 46 ) -> None: 47 self.node_id = node_id 48 self.local_id = local_id 49 self.startup_interval = startup_interval 50 self.n_parallel_workers = n_parallel_workers 51 self.labels = labels or set() 52 self.auto_recover = auto_recover 53 self.max_retries = max_retries 54 self._mq = MQ_REGISTRY.get(mq_type)(**kwargs) 55 time.sleep(self.local_id * self.startup_interval) 56 self._listener = Thread(target=self.listen, name="mq_listener", daemon=True) 57 self._listener.start() 58 59 self.mq_type = mq_type 60 self.barrier_runtime = Parallel.get_barrier_runtime()(self.node_id) 61 62 @classmethod 63 def runner( 64 cls, 65 n_parallel_workers: int, 66 mq_type: str = "nng", 67 attach_to: Optional[List[str]] = None, 68 protocol: str = "ipc", 69 address: Optional[str] = None, 70 ports: Optional[Union[List[int], int]] = None, 71 topology: str = "mesh", 72 labels: Optional[Set[str]] = None, 73 node_ids: Optional[Union[List[int], int]] = None, 74 auto_recover: bool = False, 75 max_retries: int = float("inf"), 76 redis_host: Optional[str] = None, 77 redis_port: Optional[int] = None, 78 startup_interval: int = 1 79 ) -> Callable: 80 """ 81 Overview: 82 This method allows you to configure parallel parameters, and now you are still in the parent process. 83 Arguments: 84 - n_parallel_workers (:obj:`int`): Workers to spawn. 85 - mq_type (:obj:`str`): Embedded message queue type, i.e. nng, redis. 86 - attach_to (:obj:`Optional[List[str]]`): The node's addresses you want to attach to. 87 - protocol (:obj:`str`): Network protocol. 88 - address (:obj:`Optional[str]`): Bind address, ip or file path. 89 - ports (:obj:`Optional[List[int]]`): Candidate ports. 90 - topology (:obj:`str`): Network topology, includes: 91 `mesh` (default): fully connected between each other; 92 `star`: only connect to the first node; 93 `alone`: do not connect to any node, except the node attached to; 94 - labels (:obj:`Optional[Set[str]]`): Labels. 95 - node_ids (:obj:`Optional[List[int]]`): Candidate node ids. 96 - auto_recover (:obj:`bool`): Auto recover from uncaught exceptions from main. 97 - max_retries (:obj:`int`): Max retries for auto recover. 98 - redis_host (:obj:`str`): Redis server host. 99 - redis_port (:obj:`int`): Redis server port. 100 - startup_interval (:obj:`int`): Start up interval between each task. 101 Returns: 102 - _runner (:obj:`Callable`): The wrapper function for main. 103 """ 104 all_args = locals() 105 del all_args["cls"] 106 args_parsers = {"nng": cls._nng_args_parser, "redis": cls._redis_args_parser} 107 108 assert n_parallel_workers > 0, "Parallel worker number should bigger than 0" 109 110 def _runner(main_process: Callable, *args, **kwargs) -> None: 111 """ 112 Overview: 113 Prepare to run in subprocess. 114 Arguments: 115 - main_process (:obj:`Callable`): The main function, your program start from here. 116 """ 117 runner_params = args_parsers[mq_type](**all_args) 118 params_group = [] 119 for i, runner_kwargs in enumerate(runner_params): 120 runner_kwargs["local_id"] = i 121 params_group.append([runner_kwargs, (main_process, args, kwargs)]) 122 123 if n_parallel_workers == 1: 124 cls._subprocess_runner(*params_group[0]) 125 else: 126 with WorkerPool(n_jobs=n_parallel_workers, start_method="spawn", daemon=False) as pool: 127 # Cleanup the pool just in case the program crashes. 128 atexit.register(pool.__exit__) 129 pool.map(cls._subprocess_runner, params_group) 130 131 return _runner 132 133 @classmethod 134 def _nng_args_parser( 135 cls, 136 n_parallel_workers: int, 137 attach_to: Optional[List[str]] = None, 138 protocol: str = "ipc", 139 address: Optional[str] = None, 140 ports: Optional[Union[List[int], int]] = None, 141 topology: str = "mesh", 142 node_ids: Optional[Union[List[int], int]] = None, 143 **kwargs 144 ) -> Dict[str, dict]: 145 attach_to = attach_to or [] 146 nodes = cls.get_node_addrs(n_parallel_workers, protocol=protocol, address=address, ports=ports) 147 148 def cleanup_nodes(): 149 for node in nodes: 150 protocol, file_path = node.split("://") 151 if protocol == "ipc" and path.exists(file_path): 152 os.remove(file_path) 153 154 atexit.register(cleanup_nodes) 155 156 def topology_network(i: int) -> List[str]: 157 if topology == "mesh": 158 return nodes[:i] + attach_to 159 elif topology == "star": 160 return nodes[:min(1, i)] + attach_to 161 elif topology == "alone": 162 return attach_to 163 else: 164 raise ValueError("Unknown topology: {}".format(topology)) 165 166 runner_params = [] 167 candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0) 168 for i in range(n_parallel_workers): 169 runner_kwargs = { 170 **kwargs, 171 "node_id": candidate_node_ids[i], 172 "listen_to": nodes[i], 173 "attach_to": topology_network(i), 174 "n_parallel_workers": n_parallel_workers, 175 } 176 runner_params.append(runner_kwargs) 177 178 return runner_params 179 180 @classmethod 181 def _redis_args_parser(cls, n_parallel_workers: int, node_ids: Optional[Union[List[int], int]] = None, **kwargs): 182 runner_params = [] 183 candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0) 184 for i in range(n_parallel_workers): 185 runner_kwargs = {**kwargs, "n_parallel_workers": n_parallel_workers, "node_id": candidate_node_ids[i]} 186 runner_params.append(runner_kwargs) 187 return runner_params 188 189 @classmethod 190 def _subprocess_runner(cls, runner_kwargs: dict, main_params: Tuple[Union[List, Dict]]) -> None: 191 """ 192 Overview: 193 Really run in subprocess. 194 Arguments: 195 - runner_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for runner. 196 - main_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for main function. 197 """ 198 logging.getLogger().setLevel(logging.INFO) 199 main_process, args, kwargs = main_params 200 201 with Parallel() as router: 202 router.is_active = True 203 router._run(**runner_kwargs) 204 time.sleep(0.3) # Waiting for network pairing 205 router._supervised_runner(main_process, *args, **kwargs) 206 207 def _supervised_runner(self, main: Callable, *args, **kwargs) -> None: 208 """ 209 Overview: 210 Run in supervised mode. 211 Arguments: 212 - main (:obj:`Callable`): Main function. 213 """ 214 if self.auto_recover: 215 while True: 216 try: 217 main(*args, **kwargs) 218 break 219 except Exception as e: 220 if self._retries < self.max_retries: 221 logging.warning( 222 "Auto recover from exception: {}, node: {}, retries: {}".format( 223 e, self.node_id, self._retries 224 ) 225 ) 226 logging.warning(traceback.format_exc()) 227 self._retries += 1 228 else: 229 logging.warning( 230 "Exceed the max retries, node: {}, retries: {}, max_retries: {}".format( 231 self.node_id, self._retries, self.max_retries 232 ) 233 ) 234 raise e 235 else: 236 main(*args, **kwargs) 237 238 @classmethod 239 def get_node_addrs( 240 cls, 241 n_workers: int, 242 protocol: str = "ipc", 243 address: Optional[str] = None, 244 ports: Optional[Union[List[int], int]] = None 245 ) -> None: 246 if protocol == "ipc": 247 node_name = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=4)) 248 tmp_dir = tempfile.gettempdir() 249 nodes = ["ipc://{}/ditask_{}_{}.ipc".format(tmp_dir, node_name, i) for i in range(n_workers)] 250 elif protocol == "tcp": 251 address = address or cls.get_ip() 252 ports = cls.padding_param(ports, n_workers, 50515) 253 assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \ 254now there are {} ports and {} workers".format(len(ports), n_workers) 255 nodes = ["tcp://{}:{}".format(address, port) for port in ports] 256 else: 257 raise Exception("Unknown protocol {}".format(protocol)) 258 return nodes 259 260 @classmethod 261 def padding_param(cls, int_or_list: Optional[Union[List[int], int]], n_max: int, start_value: int) -> List[int]: 262 """ 263 Overview: 264 Padding int or list param to the length of n_max. 265 Arguments: 266 - int_or_list (:obj:`Optional[Union[List[int], int]]`): Int or list typed value. 267 - n_max (:obj:`int`): Max length. 268 - start_value (:obj:`int`): Start from value. 269 """ 270 param = int_or_list 271 if isinstance(param, List) and len(param) == 1: 272 param = param[0] # List with only 1 element is equal to int 273 274 if isinstance(param, int): 275 param = range(param, param + n_max) 276 else: 277 param = param or range(start_value, start_value + n_max) 278 return param 279 280 def listen(self): 281 self._mq.listen() 282 while True: 283 if not self._mq: 284 break 285 msg = self._mq.recv() 286 # msg is none means that the message queue is no longer being listened to, 287 # especially if the message queue is already closed 288 if not msg: 289 break 290 topic, msg = msg 291 self._handle_message(topic, msg) 292 293 def on(self, event: str, fn: Callable) -> None: 294 """ 295 Overview: 296 Register an remote event on parallel instance, this function will be executed \ 297 when a remote process emit this event via network. 298 Arguments: 299 - event (:obj:`str`): Event name. 300 - fn (:obj:`Callable`): Function body. 301 """ 302 if self.is_active: 303 self._mq.subscribe(event) 304 self._event_loop.on(event, fn) 305 306 def once(self, event: str, fn: Callable) -> None: 307 """ 308 Overview: 309 Register an remote event which will only call once on parallel instance, 310 this function will be executed when a remote process emit this event via network. 311 Arguments: 312 - event (:obj:`str`): Event name. 313 - fn (:obj:`Callable`): Function body. 314 """ 315 if self.is_active: 316 self._mq.subscribe(event) 317 self._event_loop.once(event, fn) 318 319 def off(self, event: str) -> None: 320 """ 321 Overview: 322 Unregister an event. 323 Arguments: 324 - event (:obj:`str`): Event name. 325 """ 326 if self.is_active: 327 self._mq.unsubscribe(event) 328 self._event_loop.off(event) 329 330 def emit(self, event: str, *args, **kwargs) -> None: 331 """ 332 Overview: 333 Send an remote event via network to subscribed processes. 334 Arguments: 335 - event (:obj:`str`): Event name. 336 """ 337 if self.is_active: 338 payload = {"a": args, "k": kwargs} 339 try: 340 data = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) 341 except AttributeError as e: 342 logging.error("Arguments are not pickable! Event: {}, Args: {}".format(event, args)) 343 raise e 344 self._mq.publish(event, data) 345 346 def _handle_message(self, topic: str, msg: bytes) -> None: 347 """ 348 Overview: 349 Recv and parse payload from other processes, and call local functions. 350 Arguments: 351 - topic (:obj:`str`): Recevied topic. 352 - msg (:obj:`bytes`): Recevied message. 353 """ 354 event = topic 355 if not self._event_loop.listened(event): 356 logging.debug("Event {} was not listened in parallel {}".format(event, self.node_id)) 357 return 358 try: 359 payload = pickle.loads(msg) 360 except Exception as e: 361 logging.error("Error when unpacking message on node {}, msg: {}".format(self.node_id, e)) 362 return 363 self._event_loop.emit(event, *payload["a"], **payload["k"]) 364 365 @classmethod 366 def get_ip(cls): 367 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 368 try: 369 # doesn't even have to be reachable 370 s.connect(('10.255.255.255', 1)) 371 ip = s.getsockname()[0] 372 except Exception: 373 ip = '127.0.0.1' 374 finally: 375 s.close() 376 return ip 377 378 def get_attch_to_len(self) -> int: 379 """ 380 Overview: 381 Get the length of the 'attach_to' list of message queue. 382 Returns: 383 int: the length of the self._mq.attach_to. Returns 0 if self._mq is not initialized 384 """ 385 if self._mq: 386 if hasattr(self._mq, 'attach_to'): 387 return len(self._mq.attach_to) 388 return 0 389 390 def __enter__(self) -> "Parallel": 391 return self 392 393 def __exit__(self, exc_type, exc_val, exc_tb): 394 self.stop() 395 396 def stop(self): 397 logging.info("Stopping parallel worker on node: {}".format(self.node_id)) 398 self.is_active = False 399 time.sleep(0.03) 400 if self._mq: 401 self._mq.stop() 402 self._mq = None 403 if self._listener: 404 self._listener.join(timeout=1) 405 self._listener = None 406 self._event_loop.stop() 407 408 @classmethod 409 def get_barrier_runtime(cls): 410 # We get the BarrierRuntime object in the closure to avoid circular import. 411 from ding.framework.middleware.barrier import BarrierRuntime 412 return BarrierRuntime