Skip to content

ding.worker.adapter.learner_aggregator

ding.worker.adapter.learner_aggregator

LearnerAggregatorSlave

Bases: Slave

Overview

A slave, whose master is coordinator.

__init__(*args, callback_fn=None, **kwargs)

Overview

Init callback functions additionally. Callback functions are methods in LearnerAggregator. As for callback mechanisim, you can refer to worker/learner/comm/flask_fs_learner.py for help.

LearnerAggregator

Bases: object

Overview

Aggregate multiple learners.

Interfaces: init, start, close, merge_info

__init__(cfg)

Overview

Init method.

Arguments: - cfg (:obj:EasyDict): Config dict.

start()

Overview

Start the aggregator. Set up a master and build connections with all learners within max retry time.

close()

Overview

Close aggregator slave, connections with learners, and master.

Full Source Code

../ding/worker/adapter/learner_aggregator.py

1from typing import Union, Optional 2import traceback 3import numbers 4import copy 5import time 6from functools import reduce 7from threading import Thread 8from easydict import EasyDict 9 10from ding.interaction import Master, Slave, TaskFail 11from ding.interaction.master.task import TaskStatus 12from ding.utils import build_logger, get_operator_server_kwargs, exist_operator_server 13from ..coordinator.operator_server import OperatorServer 14 15 16class LearnerAggregatorSlave(Slave): 17 """ 18 Overview: 19 A slave, whose master is coordinator. 20 """ 21 22 def __init__(self, *args, callback_fn: Optional[dict] = None, **kwargs) -> None: 23 """ 24 Overview: 25 Init callback functions additionally. Callback functions are methods in ``LearnerAggregator``. 26 As for callback mechanisim, you can refer to ``worker/learner/comm/flask_fs_learner.py`` for help. 27 """ 28 super().__init__(*args, **kwargs) 29 self._callback_fn = callback_fn 30 31 def _process_task(self, task: dict) -> Union[dict, TaskFail]: 32 """ 33 Overview: 34 Process a task according to input task info dict, which is passed in by coordinator's master. 35 For each type of task, you can refer to corresponding callback function in 36 ``LearnerAggregator`` for details. 37 Arguments: 38 - cfg (:obj:`EasyDict`): Task dict. Must contain key "name". 39 Returns: 40 - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception. 41 """ 42 task_name = task['name'] 43 if task_name == 'resource': 44 return self._callback_fn['deal_with_get_resource']() 45 elif task_name == 'learner_start_task': 46 return self._callback_fn['deal_with_learner_start'](task) 47 elif task_name == 'learner_get_data_task': 48 return self._callback_fn['deal_with_get_data'](task) 49 elif task_name == 'learner_learn_task': 50 return self._callback_fn['deal_with_learn'](task) 51 else: 52 raise TaskFail(result={'message': 'task name error'}, message='illegal learner task <{}>'.format(task_name)) 53 54 55class LearnerAggregator(object): 56 """ 57 Overview: 58 Aggregate multiple learners. 59 Interfaces: 60 __init__, start, close, merge_info 61 """ 62 63 def __init__(self, cfg: dict) -> None: 64 """ 65 Overview: 66 Init method. 67 Arguments: 68 - cfg (:obj:`EasyDict`): Config dict. 69 """ 70 self._cfg = cfg 71 callback_fn = { 72 'deal_with_get_resource': self.deal_with_get_resource, 73 'deal_with_learner_start': self.deal_with_learner_start, 74 'deal_with_get_data': self.deal_with_get_data, 75 'deal_with_learn': self.deal_with_learn, 76 } 77 host, port = cfg.slave.host, cfg.slave.port 78 self._slave = LearnerAggregatorSlave(host, port, callback_fn=callback_fn) 79 self._logger, _ = build_logger(path='./log', name='learner_aggregator', need_tb=False) 80 self._end_flag = True 81 self._max_retry_second = 60 82 83 # ``_world_size`` indicates how many learners are connected; 84 # And ``_learner_connection`` lists those connections in dict type. 85 self._world_size = 0 86 self._learner_connection = {} 87 88 # create operator server 89 if exist_operator_server(): 90 # get from default or env vars 91 server_kwargs = get_operator_server_kwargs(EasyDict({})) 92 self._operator_server = OperatorServer(**server_kwargs) 93 self._operator_server.set_worker_type('aggregator') 94 else: 95 self._operator_server = None 96 97 # failed connection 98 self._failed_learner_conn = set() 99 100 def start(self) -> None: 101 """ 102 Overview: 103 Start the aggregator. Set up a master and build connections with all learners within max retry time. 104 """ 105 self._end_flag = False 106 try: 107 self._slave.start() 108 except Exception as e: 109 self._logger.error( 110 "learner_aggregator slave start error:\n" + ''.join(traceback.format_tb(e.__traceback__)) + repr(e) 111 ) 112 return 113 try: 114 self._master = Master(self._cfg.master.host, self._cfg.master.port) 115 self._master.start() 116 self._master.ping() 117 except Exception as e: 118 self._logger.error( 119 "learner_aggregator master start error:\n" + ''.join(traceback.format_tb(e.__traceback__)) + repr(e) 120 ) 121 return 122 self._world_size = 0 123 for _, (learner_id, learner_host, learner_port) in self._cfg.learner.items(): 124 self._new_connection_learner(learner_id, learner_host, int(learner_port)) 125 126 if self._operator_server: 127 self._init_conn_flag = False 128 # create sync learner thread 129 self._period_sync_with_server_thread = Thread( 130 target=self._period_sync_with_server, name="period_sync", daemon=True 131 ) 132 self._period_sync_with_server_thread.start() 133 start_time = time.time() 134 while time.time() - start_time <= self._max_retry_second and not self._end_flag: 135 if not self._init_conn_flag: 136 time.sleep(0.2) 137 138 # Exceeds max retry time and no learner connection found. 139 if len(self._learner_connection) == 0: 140 self._logger.error("learner_aggregator master max retries failed") 141 else: 142 self._logger.info("learner aggregator is started") 143 144 def close(self) -> None: 145 """ 146 Overview: 147 Close aggregator slave, connections with learners, and master. 148 """ 149 if self._end_flag: 150 return 151 self._end_flag = True 152 try: 153 self._slave.close() 154 for _, conn in self._learner_connection.items(): 155 conn.disconnect() 156 assert not conn.is_connected 157 self._master.close() 158 except Exception: # Ignore close exception. 159 pass 160 161 def deal_with_get_resource(self) -> dict: 162 return {'gpu': self._world_size} 163 164 def deal_with_learner_start(self, task: dict) -> dict: 165 if len(self._learner_connection) == 0: 166 raise TaskFail(message='no connected learner', result={'message': 'no connected learner'}) 167 name = task['name'] 168 start_task = {} 169 for k, v in self._learner_connection.items(): 170 start_task[k] = v.new_task({'name': name, 'task_info': task['task_info']}) 171 start_task[k].start() 172 for k, v in start_task.items(): 173 v.join() 174 task_status = [v.status for v in start_task.values()] 175 if any([s != TaskStatus.COMPLETED for s in task_status]): 176 # TODO(nyz) dynamic learner gpu add/remove 177 message = "one of learner can't start_task" 178 raise TaskFail(message=message, result={'message': message}) 179 return {'message': 'learner task has started'} 180 181 def deal_with_get_data(self, task: dict) -> dict: 182 data_task = {} 183 for k, v in self._learner_connection.items(): 184 data_task[k] = v.new_task({'name': task['name']}) 185 data_task[k].start() 186 for k, v in data_task.items(): 187 v.join() 188 # TODO deal with task fail 189 self._data_demand = {k: v.result for k, v in data_task.items()} 190 demand_list = list(self._data_demand.values()) 191 # Merge data demand info by adding up all learners' demand batch size. 192 merged_demand = copy.deepcopy(demand_list[0]) 193 merged_demand['batch_size'] = sum([d['batch_size'] for d in demand_list]) 194 return merged_demand 195 196 def deal_with_learn(self, task: dict) -> dict: 197 learn_task = {} 198 merged_data = task['data'] 199 # Split training data for each learner according to ``self._data_demand``. 200 split_data = [] 201 start = 0 202 for item in self._data_demand.values(): 203 end = item['batch_size'] + start 204 split_data.append(merged_data[start:end]) 205 start = end 206 for (k, v), d in zip(self._learner_connection.items(), split_data): 207 learn_task[k] = v.new_task({'name': task['name'], 'data': d}) 208 learn_task[k].start() 209 for k, v in learn_task.items(): 210 v.join() 211 # TODO deal with task fail 212 info_list = [v.result for v in learn_task.values()] 213 # Merge learn info through ``merge_info`` method. 214 merged_info = self.merge_info(info_list) 215 return merged_info 216 217 @staticmethod 218 def merge_info(info: list) -> dict: 219 homogeneous_keys = ['learner_step', 'buffer_id', 'task_id', 'learner_done'] 220 elem = info[0] 221 if elem is None: 222 return info 223 elif isinstance(elem, numbers.Integral) or isinstance(elem, str) or isinstance(elem, float): 224 return info 225 elif isinstance(elem, list) or isinstance(elem, tuple): 226 return list(reduce(lambda x, y: x + y, info)) 227 elif isinstance(elem, dict): 228 ret = {} 229 for k in elem.keys(): 230 if k in homogeneous_keys: 231 ret[k] = elem[k] 232 else: 233 ret[k] = LearnerAggregator.merge_info([e[k] for e in info]) 234 return ret 235 else: 236 raise TypeError("not support type: {}".format(type(elem))) 237 238 def _new_connection_learner(self, learner_id: str, learner_host: str, learner_port: int) -> None: 239 start_time = time.time() 240 conn = None 241 while time.time() - start_time <= self._max_retry_second and not self._end_flag: 242 try: 243 if conn is None or not conn.is_connected: 244 conn = self._master.new_connection(learner_id, learner_host, learner_port) 245 conn.connect() 246 assert conn.is_connected 247 self._learner_connection[learner_id] = conn 248 self._world_size += 1 249 break 250 except Exception as e: 251 self._logger.error( 252 f"learner({learner_id}) connection start error:\n" + ''.join(traceback.format_tb(e.__traceback__)) + 253 repr(e) + '\nAuto Retry...' 254 ) 255 time.sleep(2) 256 257 if learner_id in self._learner_connection: 258 self._logger.info(f"Succeed to connect to learner({learner_id})") 259 else: 260 self._logger.info(f"Fail to connect to learner({learner_id})") 261 self._failed_learner_conn.add(learner_id) 262 263 def _update_connection_learner(self, cur_learners) -> None: 264 conn_learners = list(self._learner_connection.keys()) 265 new_c = set(cur_learners) - set(conn_learners) 266 del_c = set(conn_learners) - (set(cur_learners) | self._failed_learner_conn) 267 # conns which have terminated in server side, clear up 268 self._failed_learner_conn = self._failed_learner_conn & set(cur_learners) 269 270 # connect to each new learner 271 for learner_id in new_c: 272 learner_host, learner_port = learner_id.split(':') 273 self._new_connection_learner(learner_id, learner_host, int(learner_port)) 274 275 for learner_id in del_c: 276 if learner_id in conn_learners: 277 if self._connection_learner[learner_id].is_connected: 278 conn = self._connection_learner.pop(learner_id) 279 conn.disconnect() 280 assert not conn.is_connected 281 else: 282 # ignore the operation of disconnect, since the pod will be terminated by server, 283 # just throw the connection 284 self._connection_learner.pop(learner_id) 285 286 def _period_sync_with_server(self) -> None: 287 while not self._end_flag: 288 # First: send failed list to notify server which replicas are failed, then terminate such replicas. 289 if len(self._failed_learner_conn) > 0: 290 learner_conn = [] 291 for replica_conn in self._failed_learner_conn: 292 dns_name = replica_conn.split(":")[0] 293 pod_name_list = dns_name.split(".")[:-1] 294 pod_name = ".".join(pod_name_list) 295 if pod_name not in learner_conn: 296 learner_conn.append(pod_name) 297 success, _, message, _ = self._operator_server.post_replicas_failed(learners=list(learner_conn)) 298 if success: 299 # do not update learner instantly, update at /GET replicas 300 self._failed_learner_conn.clear() 301 else: 302 self._logger.error("Failed to send failed list to server, message: {}".format(message)) 303 304 # get list from server 305 success, _, message, data = self._operator_server.get_replicas() 306 if success: 307 cur_learners = data["learners"] 308 # self._logger.info("current list:", cur_learners) 309 self._update_connection_learner(cur_learners) 310 self._init_conn_flag = self._init_conn_flag | True 311 else: 312 self._logger.error("Failed to sync with server, message: {}".format(message)) 313 314 time.sleep(3)