1from time import sleep, time 2from ditk import logging 3from ding.framework import task 4from ding.utils.lock_helper import LockContext, LockContextType 5from ding.utils.design_helper import SingletonMetaclass 6 7 8class BarrierRuntime(metaclass=SingletonMetaclass): 9 10 def __init__(self, node_id: int, max_world_size: int = 100): 11 """ 12 Overview: 13 'BarrierRuntime' is a singleton class. In addition, it must be initialized before the 14 class 'Parallel' starts MQ, otherwise the messages sent by other nodes may be lost after 15 the detection is completed. We don't have a message retransmission mechanism, and losing 16 a message means deadlock. 17 Arguments: 18 - node_id (int): Process ID. 19 - max_world_size (int, optional): The maximum total number of processes that can be 20 synchronized, the defalut value is 100. 21 """ 22 self.node_id = node_id 23 self._has_detected = False 24 self._range_len = len(str(max_world_size)) + 1 25 26 self._barrier_epoch = 0 27 self._barrier_recv_peers_buff = dict() 28 self._barrier_recv_peers = dict() 29 self._barrier_ack_peers = [] 30 self._barrier_lock = LockContext(LockContextType.THREAD_LOCK) 31 32 self.mq_type = task.router.mq_type 33 self._connected_peers = dict() 34 self._connected_peers_lock = LockContext(LockContextType.THREAD_LOCK) 35 self._keep_alive_daemon = False 36 37 self._event_name_detect = "b_det" 38 self.event_name_req = "b_req" 39 self.event_name_ack = "b_ack" 40 41 def _alive_msg_handler(self, peer_id): 42 with self._connected_peers_lock: 43 self._connected_peers[peer_id] = time() 44 45 def _add_barrier_req(self, msg): 46 peer, epoch = self._unpickle_barrier_tag(msg) 47 logging.debug("Node:[{}] recv barrier request from node:{}, epoch:{}".format(self.node_id, peer, epoch)) 48 with self._barrier_lock: 49 if peer not in self._barrier_recv_peers: 50 self._barrier_recv_peers[peer] = [] 51 self._barrier_recv_peers[peer].append(epoch) 52 53 def _add_barrier_ack(self, peer): 54 logging.debug("Node:[{}] recv barrier ack from node:{}".format(self.node_id, peer)) 55 with self._barrier_lock: 56 self._barrier_ack_peers.append(peer) 57 58 def _unpickle_barrier_tag(self, msg): 59 return msg % self._range_len, msg // self._range_len 60 61 def pickle_barrier_tag(self): 62 return int(self._barrier_epoch * self._range_len + self.node_id) 63 64 def reset_all_peers(self): 65 with self._barrier_lock: 66 for peer, q in self._barrier_recv_peers.items(): 67 if len(q) != 0: 68 assert q.pop(0) == self._barrier_epoch 69 self._barrier_ack_peers = [] 70 self._barrier_epoch += 1 71 72 def get_recv_num(self): 73 count = 0 74 with self._barrier_lock: 75 if len(self._barrier_recv_peers) > 0: 76 for _, q in self._barrier_recv_peers.items(): 77 if len(q) > 0 and q[0] == self._barrier_epoch: 78 count += 1 79 return count 80 81 def get_ack_num(self): 82 with self._barrier_lock: 83 return len(self._barrier_ack_peers) 84 85 def detect_alive(self, expected, timeout): 86 # The barrier can only block other nodes within the visible range of the current node. 87 # If the 'attch_to' list of a node is empty, it does not know how many nodes will attach to him, 88 # so we cannot specify the effective range of a barrier in advance. 89 assert task._running 90 task.on(self._event_name_detect, self._alive_msg_handler) 91 task.on(self.event_name_req, self._add_barrier_req) 92 task.on(self.event_name_ack, self._add_barrier_ack) 93 start = time() 94 while True: 95 sleep(0.1) 96 task.emit(self._event_name_detect, self.node_id, only_remote=True) 97 # In case the other node has not had time to receive our detect message, 98 # we will send an additional round. 99 if self._has_detected: 100 break 101 with self._connected_peers_lock: 102 if len(self._connected_peers) == expected: 103 self._has_detected = True 104 105 if time() - start > timeout: 106 raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) 107 108 task.off(self._event_name_detect) 109 logging.info( 110 "Barrier detect node done, node-[{}] has connected with {} active nodes!".format(self.node_id, expected) 111 ) 112 113 114class BarrierContext: 115 116 def __init__(self, runtime: BarrierRuntime, detect_timeout, expected_peer_num: int = 0): 117 self._runtime = runtime 118 self._expected_peer_num = expected_peer_num 119 self._timeout = detect_timeout 120 121 def __enter__(self): 122 if not self._runtime._has_detected: 123 self._runtime.detect_alive(self._expected_peer_num, self._timeout) 124 125 def __exit__(self, exc_type, exc_value, tb): 126 if exc_type is not None: 127 import traceback 128 traceback.print_exception(exc_type, exc_value, tb) 129 self._runtime.reset_all_peers() 130 131 132class Barrier: 133 134 def __init__(self, attch_from_nums: int, timeout: int = 60): 135 """ 136 Overview: 137 Barrier() is a middleware for debug or profiling. It can synchronize the task step of each 138 process within the scope of all visible processes. When using Barrier(), you need to pay 139 attention to the following points: 140 141 1. All processes must call the same number of Barrier(), otherwise a deadlock occurs. 142 143 2. 'attch_from_nums' is a very important variable, This value indicates the number of times 144 the current process will be attached to by other processes (the number of connections 145 established). 146 For example: 147 Node0: address: 127.0.0.1:12345, attach_to = [] 148 Node1: address: 127.0.0.1:12346, attach_to = ["tcp://127.0.0.1:12345"] 149 For Node0, the 'attch_from_nums' value is 1. (It will be acttched by Node1) 150 For Node1, the 'attch_from_nums' value is 0. (No one will attach to Node1) 151 Please note that this value must be given correctly, otherwise, for a node whose 'attach_to' 152 list is empty, it cannot perceive how many processes will establish connections with it, 153 resulting in any form of synchronization cannot be performed. 154 155 3. Barrier() is thread-safe, but it is not recommended to use barrier in multithreading. You need 156 to carefully calculate the number of times each thread calls Barrier() to avoid deadlock. 157 158 4. In normal training tasks, please do not use Barrier(), which will force the step synchronization 159 between each process, so it will greatly damage the training efficiency. In addition, if your 160 training task has dynamic processes, do not use Barrier() to prevent deadlock. 161 162 Arguments: 163 - attch_from_nums (int): [description] 164 - timeout (int, optional): The timeout for successful detection of 'expected_peer_num' 165 number of nodes, the default value is 60 seconds. 166 """ 167 self.node_id = task.router.node_id 168 self.timeout = timeout 169 self._runtime: BarrierRuntime = task.router.barrier_runtime 170 self._barrier_peers_nums = task.get_attch_to_len() + attch_from_nums 171 172 logging.info( 173 "Node:[{}], attach to num is:{}, attach from num is:{}".format( 174 self.node_id, task.get_attch_to_len(), attch_from_nums 175 ) 176 ) 177 178 def __call__(self, ctx): 179 self._wait_barrier(ctx) 180 yield 181 self._wait_barrier(ctx) 182 183 def _wait_barrier(self, ctx): 184 self_ready = False 185 with BarrierContext(self._runtime, self.timeout, self._barrier_peers_nums): 186 logging.debug("Node:[{}] enter barrier".format(self.node_id)) 187 # Step1: Notifies all the attached nodes that we have reached the barrier. 188 task.emit(self._runtime.event_name_req, self._runtime.pickle_barrier_tag(), only_remote=True) 189 logging.debug("Node:[{}] sended barrier request".format(self.node_id)) 190 191 # Step2: We check the number of flags we have received. 192 # In the current CI design of DI-engine, there will always be a node whose 'attach_to' list is empty, 193 # so there will always be a node that will send ACK unconditionally, so deadlock will not occur. 194 if self._runtime.get_recv_num() == self._barrier_peers_nums: 195 self_ready = True 196 197 # Step3: Waiting for our own to be ready. 198 # Even if the current process has reached the barrier, we will not send an ack immediately, 199 # we need to wait for the slowest directly connected or indirectly connected peer to 200 # reach the barrier. 201 start = time() 202 if not self_ready: 203 while True: 204 if time() - start > self.timeout: 205 raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) 206 207 if self._runtime.get_recv_num() != self._barrier_peers_nums: 208 sleep(0.1) 209 else: 210 break 211 212 # Step4: Notifies all attached nodes that we are ready. 213 task.emit(self._runtime.event_name_ack, self.node_id, only_remote=True) 214 logging.debug("Node:[{}] sended barrier ack".format(self.node_id)) 215 216 # Step5: Wait until all directly or indirectly connected nodes are ready. 217 start = time() 218 while True: 219 if time() - start > self.timeout: 220 raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) 221 222 if self._runtime.get_ack_num() != self._barrier_peers_nums: 223 sleep(0.1) 224 else: 225 break 226 227 logging.info("Node-[{}] env_step:[{}] barrier finish".format(self.node_id, ctx.env_step))