1import traceback 2import time 3import sys 4import requests 5from typing import Dict, Callable 6from threading import Thread 7 8from ding.utils import LockContext, LockContextType, get_operator_server_kwargs 9from ding.interaction import Master 10from ding.interaction.master.task import TaskStatus 11from .resource_manager import NaiveResourceManager 12from .operator_server import OperatorServer 13 14 15class CommCoordinator(object): 16 r""" 17 Overview: 18 the communication part of coordinator(coordinator intercollector) 19 Interface: 20 __init__ , start, close, __del__, send_collector_task, send_learner_task 21 """ 22 23 def __init__(self, cfg: dict, callback_fn: Dict[str, Callable], logger: 'logging.Logger') -> None: # noqa 24 r""" 25 Overview: 26 init the interactor of coordinator 27 Arguments: 28 - cfg (:obj:`dict`): The config file of communication coordinator 29 - callback_fn (:obj:`Dict[str, Callable]`): The callback functions given by coordinator 30 - logger (:obj:`logging.Logger`): The text logger. 31 """ 32 self._cfg = cfg 33 self._callback_fn = callback_fn 34 self._logger = logger 35 self._max_retry_second = 120 36 self._end_flag = True 37 38 self._connection_collector = {} 39 self._connection_learner = {} 40 self._resource_manager = NaiveResourceManager() 41 42 self._remain_task_lock = LockContext(LockContextType.THREAD_LOCK) 43 self._remain_collector_task = set() 44 self._remain_learner_task = set() 45 46 if self._cfg.operator_server: 47 server_kwargs = get_operator_server_kwargs(self._cfg.operator_server) 48 self._operator_server = OperatorServer(**server_kwargs) 49 self._operator_server.set_worker_type('coordinator') 50 self._collector_target_num = self._cfg.operator_server.collector_target_num 51 self._learner_target_num = self._cfg.operator_server.learner_target_num 52 else: 53 self._operator_server = None 54 55 # for update resource 56 self._resource_lock = LockContext(LockContextType.THREAD_LOCK) 57 58 # failed connection 59 self._failed_learner_conn = set() 60 self._failed_collector_conn = set() 61 62 def start(self) -> None: 63 r""" 64 Overview: 65 start the coordinator interactor and manage resources and connections 66 """ 67 self._end_flag = False 68 self._master = Master(self._cfg.host, self._cfg.port) 69 self._master.start() 70 self._master.ping() 71 72 # new connection from config 73 for _, (learner_id, learner_host, learner_port) in self._cfg.learner.items(): 74 self._new_connection_learner(learner_id, learner_host, learner_port) 75 for _, (collector_id, collector_host, collector_port) in self._cfg.collector.items(): 76 self._new_connection_collector(collector_id, collector_host, collector_port) 77 78 if self._operator_server: 79 # post init learner/collector demand 80 start_time, init_flag = time.time(), False 81 while time.time() - start_time <= self._max_retry_second and not self._end_flag: 82 success, _, message, _ = self._operator_server.post_replicas( 83 self._cfg.operator_server.init_replicas_request 84 ) 85 if success: 86 self._logger.info("Post replicas demand to server successfully") 87 init_flag = True 88 break 89 else: 90 self._logger.info("Failed to post replicas request to server, message: {}".format(message)) 91 time.sleep(2) 92 93 if not init_flag: 94 self._logger.info('Exit since cannot request replicas to operator-server...') 95 self.close() 96 sys.exit(1) 97 98 # create sync learner/collector thread 99 self._period_sync_with_server_thread = Thread( 100 target=self._period_sync_with_server, name="period_sync", daemon=True 101 ) 102 self._period_sync_with_server_thread.start() 103 104 # wait for enough collector/learner 105 start_time = time.time() 106 enough_flag = False 107 while time.time() - start_time <= self._max_retry_second: 108 if len(self._connection_collector) < self._collector_target_num and len(self._connection_learner 109 ) < self._learner_target_num: 110 self._logger.info( 111 "Only can connect {} collectors, {} learners.".format( 112 len(self._connection_collector), len(self._connection_learner) 113 ) 114 ) 115 time.sleep(2) 116 else: 117 self._logger.info( 118 "Have connected {} collectors, {} learners, match limit requests.".format( 119 len(self._connection_collector), len(self._connection_learner) 120 ) 121 ) 122 self._logger.info("Total DI-engine pipeline start...") 123 enough_flag = True 124 break 125 126 if not enough_flag: 127 self._logger.error( 128 "Exit since only can connect {} collectors, {} learners.".format( 129 len(self._connection_collector), len(self._connection_learner) 130 ) 131 ) 132 self.close() 133 sys.exit(1) 134 135 if self._end_flag: 136 self._logger.error("connection max retries failed") 137 sys.exit(1) 138 139 def _new_connection_collector( 140 self, 141 collector_id: str, 142 collector_host: str, 143 collector_port: int, 144 increase_task_space: bool = False, 145 ) -> None: 146 start_time = time.time() 147 conn = None 148 while time.time() - start_time <= self._max_retry_second and not self._end_flag: 149 try: 150 if conn is None or not conn.is_connected: 151 conn = self._master.new_connection(collector_id, collector_host, collector_port) 152 conn.connect() 153 assert conn.is_connected 154 resource_task = self._get_resource(conn) 155 if resource_task.status != TaskStatus.COMPLETED: 156 self._logger.error("can't acquire resource for collector({})".format(collector_id)) 157 continue 158 else: 159 with self._resource_lock: 160 self._resource_manager.update('collector', collector_id, resource_task.result) 161 self._connection_collector[collector_id] = conn 162 if increase_task_space: 163 self._callback_fn['deal_with_increase_collector']() 164 break 165 166 except Exception as e: 167 self._logger.error( 168 f"Collector({collector_id}) connection start error:\n" + 169 ''.join(traceback.format_tb(e.__traceback__)) + repr(e) + '\nAuto Retry...' 170 ) 171 time.sleep(2) 172 173 if collector_id in self._connection_collector: 174 self._logger.info(f"Succeed to connect to collector({collector_id})") 175 else: 176 self._logger.info(f"Fail to connect to collector({collector_id})") 177 self._failed_collector_conn.add(collector_id) 178 179 def _new_connection_learner(self, learner_id: str, learner_host: str, learner_port: int) -> None: 180 start_time = time.time() 181 conn = None 182 while time.time() - start_time <= self._max_retry_second and not self._end_flag: 183 try: 184 if conn is None or not conn.is_connected: 185 conn = self._master.new_connection(learner_id, learner_host, learner_port) 186 conn.connect() 187 assert conn.is_connected 188 resource_task = self._get_resource(conn) 189 if resource_task.status != TaskStatus.COMPLETED: 190 self._logger.error("can't acquire resource for learner({})".format(learner_id)) 191 continue 192 else: 193 with self._resource_lock: 194 self._resource_manager.update('learner', learner_id, resource_task.result) 195 self._connection_learner[learner_id] = conn 196 break 197 198 except Exception as e: 199 self._logger.error( 200 f"learner({learner_id}) connection start error:\n" + ''.join(traceback.format_tb(e.__traceback__)) + 201 repr(e) + '\nAuto Retry...' 202 ) 203 time.sleep(2) 204 205 if learner_id in self._connection_learner: 206 self._logger.info(f"Succeed to connect to learner({learner_id})") 207 else: 208 self._logger.info(f"Fail to connect to learner({learner_id})") 209 self._failed_learner_conn.add(learner_id) 210 211 def close(self) -> None: 212 r""" 213 Overview: 214 close the coordinator interactor 215 """ 216 if self._end_flag: 217 return 218 self._end_flag = True 219 # wait for execute thread 220 start_time = time.time() 221 # TODO 222 if self._operator_server: 223 self._period_sync_with_server_thread.join() 224 # wait from all slave receive DELETE 225 time.sleep(5) 226 while time.time() - start_time <= 60: 227 if len(self._remain_learner_task) == 0 and len(self._remain_collector_task) == 0: 228 break 229 else: 230 time.sleep(1) 231 for collector_id, conn in self._connection_collector.items(): 232 conn.disconnect() 233 assert not conn.is_connected 234 for learner_id, conn in self._connection_learner.items(): 235 conn.disconnect() 236 assert not conn.is_connected 237 self._master.close() 238 239 def __del__(self) -> None: 240 r""" 241 Overview: 242 __del__ method will close the coordinator interactor 243 """ 244 self.close() 245 246 def _get_resource(self, conn: 'Connection') -> 'TaskResult': # noqa 247 r""" 248 Overview: 249 get the resources according to connection 250 Arguments: 251 - conn (:obj:`Connection`): the connection to get resource_task 252 """ 253 resource_task = conn.new_task({'name': 'resource'}) 254 resource_task.start().join() 255 return resource_task 256 257 def send_collector_task(self, collector_task: dict) -> bool: 258 r""" 259 Overview: 260 send the collector_task to collector_task threads and execute 261 Arguments: 262 - collector_task (:obj:`dict`): the collector_task to send 263 """ 264 # assert not self._end_flag, "please start interaction first" 265 task_id = collector_task['task_id'] 266 # according to resource info, assign task to a specific collector and adapt task 267 assigned_collector = self._resource_manager.assign_collector(collector_task) 268 if assigned_collector is None: 269 self._logger.error("collector task({}) doesn't have enough collector to execute".format(task_id)) 270 return False 271 collector_task.update(assigned_collector) 272 273 collector_id = collector_task['collector_id'] 274 start_task = self._connection_collector[collector_id].new_task( 275 { 276 'name': 'collector_start_task', 277 'task_info': collector_task 278 } 279 ) 280 start_task.start().join() 281 if start_task.status != TaskStatus.COMPLETED: 282 self._resource_manager.update( 283 'collector', assigned_collector['collector_id'], assigned_collector['resource_info'] 284 ) 285 self._logger.error('collector_task({}) start failed: {}'.format(task_id, start_task.result)) 286 return False 287 else: 288 self._logger.info('collector task({}) is assigned to collector({})'.format(task_id, collector_id)) 289 with self._remain_task_lock: 290 self._remain_collector_task.add(task_id) 291 collector_task_thread = Thread( 292 target=self._execute_collector_task, args=(collector_task, ), name='coordinator_collector_task' 293 ) 294 collector_task_thread.start() 295 return True 296 297 def _execute_collector_task(self, collector_task: dict) -> None: 298 r""" 299 Overview: 300 execute the collector task 301 Arguments: 302 - collector_task (:obj:`dict`): the collector task to execute 303 """ 304 close_flag = False 305 collector_id = collector_task['collector_id'] 306 while not self._end_flag: 307 try: 308 # data task 309 data_task = self._connection_collector[collector_id].new_task({'name': 'collector_data_task'}) 310 self._logger.info('collector data task begin') 311 data_task.start().join() 312 self._logger.info('collector data task end') 313 if data_task.status != TaskStatus.COMPLETED: 314 # TODO(deal with fail task) 315 self._logger.error('collector data task is failed') 316 continue 317 result = data_task.result 318 task_id = result.get('task_id', None) 319 # data result 320 if 'data_id' in result: 321 buffer_id = result.get('buffer_id', None) 322 data_id = result.get('data_id', None) 323 self._callback_fn['deal_with_collector_send_data'](task_id, buffer_id, data_id, result) 324 # info result 325 else: 326 is_finished = self._callback_fn['deal_with_collector_judge_finish'](task_id, result) 327 if not is_finished: 328 continue 329 # close task 330 self._logger.error('close_task: {}\n{}'.format(task_id, result)) 331 close_task = self._connection_collector[collector_id].new_task({'name': 'collector_close_task'}) 332 close_task.start().join() 333 if close_task.status != TaskStatus.COMPLETED: 334 # TODO(deal with fail task) 335 self._logger.error('collector close is failed') 336 break 337 result = close_task.result 338 task_id = result.get('task_id', None) 339 self._callback_fn['deal_with_collector_finish_task'](task_id, result) 340 resource_task = self._get_resource(self._connection_collector[collector_id]) 341 if resource_task.status == TaskStatus.COMPLETED: 342 self._resource_manager.update('collector', collector_id, resource_task.result) 343 close_flag = True 344 break 345 except requests.exceptions.HTTPError as e: 346 if self._end_flag: 347 break 348 else: 349 raise e 350 351 if not close_flag: 352 close_task = self._connection_collector[collector_id].new_task({'name': 'collector_close_task'}) 353 close_task.start().join() 354 with self._remain_task_lock: 355 self._remain_collector_task.remove(task_id) 356 357 def send_learner_task(self, learner_task: dict) -> bool: 358 r""" 359 Overview: 360 send the learner_task to learner_task threads and execute 361 Arguments: 362 - learner_task (:obj:`dict`): the learner_task to send 363 """ 364 # assert not self._end_flag, "please start interaction first" 365 task_id = learner_task['task_id'] 366 assigned_learner = self._resource_manager.assign_learner(learner_task) 367 if assigned_learner is None: 368 self._logger.error("learner task({}) doesn't have enough learner to execute".format(task_id)) 369 return False 370 learner_task.update(assigned_learner) 371 372 learner_id = learner_task['learner_id'] 373 start_task = self._connection_learner[learner_id].new_task( 374 { 375 'name': 'learner_start_task', 376 'task_info': learner_task 377 } 378 ) 379 start_task.start().join() 380 if start_task.status != TaskStatus.COMPLETED: 381 self._resource_manager.update('learner', assigned_learner['learner_id'], assigned_learner['resource_info']) 382 self._logger.info('learner_task({}) start failed: {}'.format(task_id, start_task.result)) 383 return False 384 else: 385 self._logger.info('learner task({}) is assigned to learner({})'.format(task_id, learner_id)) 386 with self._remain_task_lock: 387 self._remain_learner_task.add(task_id) 388 learner_task_thread = Thread( 389 target=self._execute_learner_task, args=(learner_task, ), name='coordinator_learner_task' 390 ) 391 learner_task_thread.start() 392 return True 393 394 def _execute_learner_task(self, learner_task: dict) -> None: 395 r""" 396 Overview: 397 execute the learner task 398 Arguments: 399 - learner_task (:obj:`dict`): the learner task to execute 400 """ 401 close_flag = False 402 learner_id = learner_task['learner_id'] 403 while not self._end_flag: 404 try: 405 # get data 406 get_data_task = self._connection_learner[learner_id].new_task({'name': 'learner_get_data_task'}) 407 get_data_task.start().join() 408 if get_data_task.status != TaskStatus.COMPLETED: 409 # TODO(deal with fail task) 410 self._logger.error('learner get_data_task failed: {}'.format(get_data_task.result)) 411 continue 412 result = get_data_task.result 413 task_id, buffer_id, batch_size = result['task_id'], result['buffer_id'], result['batch_size'] 414 cur_learner_iter = result['cur_learner_iter'] 415 sleep_count = 1 416 while True: 417 data = self._callback_fn['deal_with_learner_get_data']( 418 task_id, buffer_id, batch_size, cur_learner_iter 419 ) 420 if self._end_flag or data is not None: 421 self._logger.info('sample result is ok') 422 break 423 else: 424 self._logger.info('sample result is None') 425 time.sleep(sleep_count) 426 sleep_count += 2 427 if self._end_flag: 428 break 429 430 # learn task 431 learn_task = self._connection_learner[learner_id].new_task({'name': 'learner_learn_task', 'data': data}) 432 learn_task.start().join() 433 if learn_task.status != TaskStatus.COMPLETED: 434 # TODO(deal with fail task) 435 self._logger.error('learner learn_task failed: {}'.format(learn_task.result)) 436 continue 437 result = learn_task.result 438 task_id, info = result['task_id'], result['info'] 439 is_finished = self._callback_fn['deal_with_learner_judge_finish'](task_id, info) 440 if is_finished: 441 # close task and update resource 442 close_task = self._connection_learner[learner_id].new_task({'name': 'learner_close_task'}) 443 close_task.start().join() 444 if close_task.status != TaskStatus.COMPLETED: 445 self._logger.error('learner close_task failed: {}'.format(close_task.result)) 446 break 447 result = close_task.result 448 task_id = result.get('task_id', None) 449 self._callback_fn['deal_with_learner_finish_task'](task_id, result) 450 resource_task = self._get_resource(self._connection_learner[learner_id]) 451 if resource_task.status == TaskStatus.COMPLETED: 452 self._resource_manager.update('learner', learner_id, resource_task.result) 453 close_flag = True 454 break 455 else: 456 # update info 457 buffer_id = result['buffer_id'] 458 self._callback_fn['deal_with_learner_send_info'](task_id, buffer_id, info) 459 except requests.exceptions.HTTPError as e: 460 if self._end_flag: 461 break 462 else: 463 raise e 464 465 if not close_flag: 466 close_task = self._connection_learner[learner_id].new_task({'name': 'learner_close_task'}) 467 close_task.start().join() 468 with self._remain_task_lock: 469 self._remain_learner_task.remove(task_id) 470 471 def _period_sync_with_server(self) -> None: 472 while not self._end_flag: 473 # First: send failed list to notify DI-engine server which replicas are failed, 474 # then terminate such replicas. 475 # self._logger.info("failed list:", list(self._failed_collector_conn), list(self._failed_learner_conn)) 476 if len(self._failed_learner_conn) > 0 or len(self._failed_collector_conn) > 0: 477 collector_conn = [] 478 for replica_conn in self._failed_collector_conn: 479 dns_name = replica_conn.split(":")[0] 480 pod_name_list = dns_name.split(".")[:-1] 481 pod_name = ".".join(pod_name_list) 482 collector_conn.append(pod_name) 483 learner_conn = [] 484 for replica_conn in self._failed_learner_conn: 485 dns_name = replica_conn.split(":")[0] 486 pod_name_list = dns_name.split(".")[:-1] 487 pod_name = ".".join(pod_name_list) 488 learner_conn.append(pod_name) 489 490 success, _, message, _ = self._operator_server.post_replicas_failed( 491 learners=list(learner_conn), collectors=list(collector_conn) 492 ) 493 if success: 494 # do not update collector or learner instantly, update at /GET replicas 495 self._failed_collector_conn.clear() 496 self._failed_learner_conn.clear() 497 else: 498 self._logger.error("Failed to send failed list to server, message: {}".format(message)) 499 500 # get list from server 501 success, _, message, data = self._operator_server.get_replicas() 502 if success: 503 cur_collectors = data["collectors"] 504 cur_learners = data["learners"] 505 # self._logger.info("current list:", cur_collectors, cur_learners) 506 self._update_connection_collector(cur_collectors) 507 self._update_connection_learner(cur_learners) 508 else: 509 self._logger.error("Failed to sync with server, message: {}".format(message)) 510 511 time.sleep(1) 512 513 def _update_connection_collector(self, cur_collectors: list) -> None: 514 conn_collectors = list(self._connection_collector.keys()) 515 new_c = set(cur_collectors) - set(conn_collectors) 516 del_c = set(conn_collectors) - (set(cur_collectors) | self._failed_collector_conn) 517 # conns which have terminated in server side, clear up 518 self._failed_collector_conn = self._failed_collector_conn & set(cur_collectors) 519 520 # connect to each new collector 521 for collector_id in new_c: 522 collector_host, collector_port = collector_id.split(':') 523 self._new_connection_collector(collector_id, collector_host, int(collector_port), True) 524 525 for collector_id in del_c: 526 if collector_id in conn_collectors: 527 # TODO(nyz) whether to need to close task first 528 with self._resource_lock: 529 if not self._resource_manager.have_assigned('collector', collector_id): 530 self._resource_manager.delete("collector", collector_id) 531 532 if self._connection_collector[collector_id].is_connected: 533 conn = self._connection_collector.pop(collector_id) 534 conn.disconnect() 535 assert not conn.is_connected 536 self._callback_fn['deal_with_decrease_collector']() 537 else: 538 # ignore the operation of disconnect, since the pod will be terminated by server, 539 # just throw the connection 540 self._connection_collector.pop(collector_id) 541 542 def _update_connection_learner(self, cur_learners) -> None: 543 conn_learners = list(self._connection_learner.keys()) 544 new_c = set(cur_learners) - set(conn_learners) 545 del_c = set(conn_learners) - (set(cur_learners) | self._failed_learner_conn) 546 # conns which have terminated in server side, clear up 547 self._failed_learner_conn = self._failed_learner_conn & set(cur_learners) 548 549 # connect to each new learner 550 for learner_id in new_c: 551 learner_host, learner_port = learner_id.split(':') 552 self._new_connection_learner(learner_id, learner_host, int(learner_port)) 553 554 for learner_id in del_c: 555 if learner_id in conn_learners: 556 # TODO(nyz) whether to need to close task first 557 with self._resource_lock: 558 if not self._resource_manager.have_assigned('learner', learner_id): 559 self._resource_manager.delete("learner", learner_id) 560 561 if self._connection_learner[learner_id].is_connected: 562 conn = self._connection_learner.pop(learner_id) 563 conn.disconnect() 564 assert not conn.is_connected 565 else: 566 # ignore the operation of disconnect, since the pod will be terminated by server, 567 # just throw the connection 568 self._connection_learner.pop(learner_id)