1import os 2import time 3from typing import List, Union, Dict, Callable, Any 4from functools import partial 5from queue import Queue 6from threading import Thread 7 8from ding.utils import read_file, save_file, get_data_decompressor, COMM_LEARNER_REGISTRY 9from ding.utils.file_helper import read_from_di_store 10from ding.interaction import Slave, TaskFail 11from .base_comm_learner import BaseCommLearner 12from ..learner_hook import LearnerHook 13 14 15class LearnerSlave(Slave): 16 """ 17 Overview: 18 A slave, whose master is coordinator. 19 Used to pass message between comm learner and coordinator. 20 """ 21 22 def __init__(self, *args, callback_fn: Dict[str, Callable], **kwargs) -> None: 23 """ 24 Overview: 25 Init callback functions additionally. Callback functions are methods in comm learner. 26 """ 27 super().__init__(*args, **kwargs) 28 self._callback_fn = callback_fn 29 30 def _process_task(self, task: dict) -> Union[dict, TaskFail]: 31 """ 32 Overview: 33 Process a task according to input task info dict, which is passed in by master coordinator. 34 For each type of task, you can refer to corresponding callback function in comm learner for details. 35 Arguments: 36 - cfg (:obj:`EasyDict`): Task dict. Must contain key "name". 37 Returns: 38 - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception. 39 """ 40 task_name = task['name'] 41 if task_name == 'resource': 42 return self._callback_fn['deal_with_resource']() 43 elif task_name == 'learner_start_task': 44 self._current_task_info = task['task_info'] 45 self._callback_fn['deal_with_learner_start'](self._current_task_info) 46 return {'message': 'learner task has started'} 47 elif task_name == 'learner_get_data_task': 48 data_demand = self._callback_fn['deal_with_get_data']() 49 ret = { 50 'task_id': self._current_task_info['task_id'], 51 'buffer_id': self._current_task_info['buffer_id'], 52 } 53 ret.update(data_demand) 54 return ret 55 elif task_name == 'learner_learn_task': 56 info = self._callback_fn['deal_with_learner_learn'](task['data']) 57 data = {'info': info} 58 data['buffer_id'] = self._current_task_info['buffer_id'] 59 data['task_id'] = self._current_task_info['task_id'] 60 return data 61 elif task_name == 'learner_close_task': 62 self._callback_fn['deal_with_learner_close']() 63 return { 64 'task_id': self._current_task_info['task_id'], 65 'buffer_id': self._current_task_info['buffer_id'], 66 } 67 else: 68 raise TaskFail(result={'message': 'task name error'}, message='illegal learner task <{}>'.format(task_name)) 69 70 71@COMM_LEARNER_REGISTRY.register('flask_fs') 72class FlaskFileSystemLearner(BaseCommLearner): 73 """ 74 Overview: 75 An implementation of CommLearner, using flask and the file system. 76 Interfaces: 77 __init__, send_policy, get_data, send_learn_info, start, close 78 Property: 79 hooks4call 80 """ 81 82 def __init__(self, cfg: 'EasyDict') -> None: # noqa 83 """ 84 Overview: 85 Init method. 86 Arguments: 87 - cfg (:obj:`EasyDict`): Config dict. 88 """ 89 BaseCommLearner.__init__(self, cfg) 90 91 # Callback functions for message passing between comm learner and coordinator. 92 self._callback_fn = { 93 'deal_with_resource': self.deal_with_resource, 94 'deal_with_learner_start': self.deal_with_learner_start, 95 'deal_with_get_data': self.deal_with_get_data, 96 'deal_with_learner_learn': self.deal_with_learner_learn, 97 'deal_with_learner_close': self.deal_with_learner_close, 98 } 99 # Learner slave to implement those callback functions. Host and port is used to build connection with master. 100 host, port = cfg.host, cfg.port 101 if isinstance(port, list): 102 port = port[self._rank] 103 elif isinstance(port, int) and self._world_size > 1: 104 port = port + self._rank 105 self._slave = LearnerSlave(host, port, callback_fn=self._callback_fn) 106 107 self._path_data = cfg.path_data # path to read data from 108 self._path_policy = cfg.path_policy # path to save policy 109 110 # Queues to store info dicts. Only one info is needed to pass between learner and coordinator at a time. 111 self._data_demand_queue = Queue(maxsize=1) 112 self._data_result_queue = Queue(maxsize=1) 113 self._learn_info_queue = Queue(maxsize=1) 114 115 # Task-level learner and policy will only be set once received the task. 116 self._learner = None 117 self._policy_id = None 118 119 def start(self) -> None: 120 """ 121 Overview: 122 Start comm learner itself and the learner slave. 123 """ 124 BaseCommLearner.start(self) 125 self._slave.start() 126 127 def close(self) -> None: 128 """ 129 Overview: 130 Join learner thread and close learner if still running. 131 Then close learner slave and comm learner itself. 132 """ 133 if self._end_flag: 134 return 135 if self._learner is not None: 136 self.deal_with_learner_close() 137 self._slave.close() 138 BaseCommLearner.close(self) 139 140 def __del__(self) -> None: 141 """ 142 Overview: 143 Call ``close`` for deletion. 144 """ 145 self.close() 146 147 def deal_with_resource(self) -> dict: 148 """ 149 Overview: 150 Callback function. Return how many resources are needed to start current learner. 151 Returns: 152 - resource (:obj:`dict`): Resource info dict, including ["gpu"]. 153 """ 154 return {'gpu': self._world_size} 155 156 def deal_with_learner_start(self, task_info: dict) -> None: 157 """ 158 Overview: 159 Callback function. Create a learner and help register its hooks. Start a learner thread of the created one. 160 Arguments: 161 - task_info (:obj:`dict`): Task info dict. 162 163 .. note:: 164 In ``_create_learner`` method in base class ``BaseCommLearner``, 3 methods 165 ('get_data', 'send_policy', 'send_learn_info'), dataloader and policy are set. 166 You can refer to it for details. 167 """ 168 self._policy_id = task_info['policy_id'] 169 self._league_save_checkpoint_path = task_info.get('league_save_checkpoint_path', None) 170 self._learner = self._create_learner(task_info) 171 for h in self.hooks4call: 172 self._learner.register_hook(h) 173 self._learner_thread = Thread(target=self._learner.start, args=(), daemon=True, name='learner_start') 174 self._learner_thread.start() 175 176 def deal_with_get_data(self) -> Any: 177 """ 178 Overview: 179 Callback function. Get data demand info dict from ``_data_demand_queue``, 180 which will be sent to coordinator afterwards. 181 Returns: 182 - data_demand (:obj:`Any`): Data demand info dict. 183 """ 184 data_demand = self._data_demand_queue.get() 185 return data_demand 186 187 def deal_with_learner_learn(self, data: dict) -> dict: 188 """ 189 Overview: 190 Callback function. Put training data info dict (i.e. meta data), which is received from coordinator, into 191 ``_data_result_queue``, and wait for ``get_data`` to retrieve. Wait for learner training and 192 get learn info dict from ``_learn_info_queue``. If task is finished, join the learner thread and 193 close the learner. 194 Returns: 195 - learn_info (:obj:`Any`): Learn info dict. 196 """ 197 self._data_result_queue.put(data) 198 learn_info = self._learn_info_queue.get() 199 return learn_info 200 201 def deal_with_learner_close(self) -> None: 202 self._learner.close() 203 self._learner_thread.join() 204 del self._learner_thread 205 self._learner = None 206 self._policy_id = None 207 208 # override 209 def send_policy(self, state_dict: dict) -> None: 210 """ 211 Overview: 212 Save learner's policy in corresponding path, called by ``SendPolicyHook``. 213 Arguments: 214 - state_dict (:obj:`dict`): State dict of the policy. 215 """ 216 if not os.path.exists(self._path_policy): 217 os.mkdir(self._path_policy) 218 path = self._policy_id 219 if self._path_policy not in path: 220 path = os.path.join(self._path_policy, path) 221 setattr(self, "_latest_policy_path", path) 222 save_file(path, state_dict, use_lock=True) 223 224 if self._league_save_checkpoint_path is not None: 225 save_file(self._league_save_checkpoint_path, state_dict, use_lock=True) 226 227 @staticmethod 228 def load_data_fn(path, meta: Dict[str, Any], decompressor: Callable) -> Any: 229 """ 230 Overview: 231 The function that is used to load data file. 232 Arguments: 233 - meta (:obj:`Dict[str, Any]`): Meta data info dict. 234 - decompressor (:obj:`Callable`): Decompress function. 235 Returns: 236 - s (:obj:`Any`): Data which is read from file. 237 """ 238 # Due to read-write conflict, read_file raise an error, therefore we set a while loop. 239 while True: 240 try: 241 s = read_from_di_store(path) if read_from_di_store else read_file(path, use_lock=False) 242 s = decompressor(s) 243 break 244 except Exception: 245 time.sleep(0.01) 246 unroll_len = meta.get('unroll_len', 1) 247 if 'unroll_split_begin' in meta: 248 begin = meta['unroll_split_begin'] 249 if unroll_len == 1: 250 s = s[begin] 251 s.update(meta) 252 else: 253 end = begin + unroll_len 254 s = s[begin:end] 255 # add metadata key-value to stepdata 256 for i in range(len(s)): 257 s[i].update(meta) 258 else: 259 s.update(meta) 260 return s 261 262 # override 263 def get_data(self, batch_size: int) -> List[Callable]: 264 """ 265 Overview: 266 Get a list of data loading function, which can be implemented by dataloader to read data from files. 267 Arguments: 268 - batch_size (:obj:`int`): Batch size. 269 Returns: 270 - data (:obj:`List[Callable]`): A list of callable data loading function. 271 """ 272 while self._learner is None: 273 time.sleep(1) 274 # Tell coordinator that we need training data, by putting info dict in data_demand_queue. 275 assert self._data_demand_queue.qsize() == 0 276 self._data_demand_queue.put({'batch_size': batch_size, 'cur_learner_iter': self._learner.last_iter.val}) 277 # Get a list of meta data (data info dict) from coordinator, by getting info dict from data_result_queue. 278 data = self._data_result_queue.get() 279 assert isinstance(data, list) 280 assert len(data) == batch_size, '{}/{}'.format(len(data), batch_size) 281 # Transform meta data to callable data loading function (partial ``load_data_fn``). 282 decompressor = get_data_decompressor(data[0].get('compressor', 'none')) 283 data = [ 284 partial( 285 FlaskFileSystemLearner.load_data_fn, 286 path=m['object_ref'] if read_from_di_store else os.path.join(self._path_data, m['data_id']), 287 meta=m, 288 decompressor=decompressor, 289 ) for m in data 290 ] 291 return data 292 293 # override 294 def send_learn_info(self, learn_info: dict) -> None: 295 """ 296 Overview: 297 Store learn info dict in queue, which will be retrieved by callback function "deal_with_learner_learn" 298 in learner slave, then will be sent to coordinator. 299 Arguments: 300 - learn_info (:obj:`dict`): Learn info in `dict` type. Keys are like 'learner_step', 'priority_info' \ 301 'finished_task', etc. You can refer to ``learn_info``(``worker/learner/base_learner.py``) for details. 302 """ 303 assert self._learn_info_queue.qsize() == 0 304 self._learn_info_queue.put(learn_info) 305 306 @property 307 def hooks4call(self) -> List[LearnerHook]: 308 """ 309 Overview: 310 Return the hooks that are related to message passing with coordinator. 311 Returns: 312 - hooks (:obj:`list`): The hooks which comm learner has. Will be registered in learner as well. 313 """ 314 return [ 315 SendPolicyHook('send_policy', 100, position='before_run', ext_args={}), 316 SendPolicyHook('send_policy', 100, position='after_iter', ext_args={'send_policy_freq': 1}), 317 SendLearnInfoHook( 318 'send_learn_info', 319 100, 320 position='after_iter', 321 ext_args={'freq': 10}, 322 ), 323 SendLearnInfoHook( 324 'send_learn_info', 325 100, 326 position='after_run', 327 ext_args={'freq': 1}, 328 ), 329 ] 330 331 332class SendPolicyHook(LearnerHook): 333 """ 334 Overview: 335 Hook to send policy 336 Interfaces: 337 __init__, __call__ 338 Property: 339 name, priority, position 340 """ 341 342 def __init__(self, *args, ext_args: dict = {}, **kwargs) -> None: 343 """ 344 Overview: 345 init SendpolicyHook 346 Arguments: 347 - ext_args (:obj:`dict`): Extended arguments. Use ``ext_args.freq`` to set send_policy_freq 348 """ 349 super().__init__(*args, **kwargs) 350 if 'send_policy_freq' in ext_args: 351 self._freq = ext_args['send_policy_freq'] 352 else: 353 self._freq = 1 354 355 def __call__(self, engine: 'BaseLearner') -> None: # noqa 356 """ 357 Overview: 358 Save learner's policy in corresponding path at interval iterations by calling ``engine``'s ``send_policy``. 359 Saved file includes model_state_dict, learner_last_iter. 360 Arguments: 361 - engine (:obj:`BaseLearner`): The BaseLearner. 362 363 .. note:: 364 Only rank == 0 learner will save policy. 365 """ 366 last_iter = engine.last_iter.val 367 if engine.rank == 0 and last_iter % self._freq == 0: 368 state_dict = {'model': engine.policy.state_dict()['model'], 'iter': last_iter} 369 engine.send_policy(state_dict) 370 engine.debug('{} save iter{} policy'.format(engine.instance_name, last_iter)) 371 372 373class SendLearnInfoHook(LearnerHook): 374 """ 375 Overview: 376 Hook to send learn info 377 Interfaces: 378 __init__, __call__ 379 Property: 380 name, priority, position 381 """ 382 383 def __init__(self, *args, ext_args: dict, **kwargs) -> None: 384 """ 385 Overview: 386 init SendLearnInfoHook 387 Arguments: 388 - ext_args (:obj:`dict`): extended_args, use ext_args.freq 389 """ 390 super().__init__(*args, **kwargs) 391 self._freq = ext_args['freq'] 392 393 def __call__(self, engine: 'BaseLearner') -> None: # noqa 394 """ 395 Overview: 396 Send learn info including last_iter at interval iterations and priority info 397 Arguments: 398 - engine (:obj:`BaseLearner`): the BaseLearner 399 """ 400 last_iter = engine.last_iter.val 401 engine.send_learn_info(engine.learn_info) 402 if last_iter % self._freq == 0: 403 engine.debug('{} save iter{} learn_info'.format(engine.instance_name, last_iter))