Skip to content

ding.framework.middleware.distributer

ding.framework.middleware.distributer

ContextExchanger

__init__(skip_n_iter=1, storage_loader=None)

Overview

Exchange context between processes, support properties: trajectories, episodes, env_step, env_episode, train_iter

Arguments: - skip_n_iter (:obj:int): For collectors, it may be necessary to skip waiting for the first n iterations to collect data for the learner to learn. This parameter will not work on learner. - storage_loader (:obj:Optional[StorageLoader]): Turn data into storage class to reduce the network overhead.

put(payload)

Overview

Get attributes from ctx on the callback of event. Each attribute should have a standalone put handler, which named _put_{key}

fetch(ctx)

Overview

Fetch attributes from ctx before emit them to the event bus. Each attribute should have a standalone fetch handler, which named _fetch_{key}

ModelExchanger

__init__(model, model_loader=None)

Overview

Exchange model between processes, only the learner will send the model, otherwise the model will only be received. If you are using a shared model on a single host, there is no need to use this middleware.

Arguments: - model (:obj:torch.nn.Module): Pytorch module. - model_loader (:obj:ModelLoader): Encode model in subprocess.

PeriodicalModelExchanger

__init__(model, mode, period=1, delay_toleration=np.inf, stale_toleration=1, event_name='model_exchanger', model_loader=None)

Overview

Exchange model between processes, set the mode to "send" or "receive" to specify the role of the process. If you are using a shared model on a single host, there is no need to use this middleware.

Arguments: - model (:obj:torch.nn.Module): Pytorch module. - mode (:obj:str): "send" or "receive". - period (:obj:int): The period of model exchange. - delay_toleration (:obj:float): The permitted time interval for receiving model after being sent. - stale_toleration (:obj:int): The permitted number of iterations for receiving model after being sent. - event_name (:obj:str): The event name for model exchange. - model_loader (:obj:ModelLoader): ModelLoader for this PeriodicalModelExchanger to use.

Full Source Code

../ding/framework/middleware/distributer.py

1import numpy as np 2from time import sleep, time 3from dataclasses import fields 4from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union 5from ditk import logging 6from ding.framework import task 7from ding.data import StorageLoader, Storage, ModelLoader 8if TYPE_CHECKING: 9 from ding.framework.context import Context 10 from torch.nn import Module 11 12 13class ContextExchanger: 14 15 def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] = None) -> None: 16 """ 17 Overview: 18 Exchange context between processes, 19 support properties: trajectories, episodes, env_step, env_episode, train_iter 20 Arguments: 21 - skip_n_iter (:obj:`int`): For collectors, it may be necessary to skip waiting \ 22 for the first n iterations to collect data for the learner to learn. This parameter \ 23 will not work on learner. 24 - storage_loader (:obj:`Optional[StorageLoader]`): Turn data into storage class to reduce \ 25 the network overhead. 26 """ 27 if not task.router.is_active: 28 raise RuntimeError("ContextHandler should be used in parallel mode!") 29 self._state = {} 30 self._local_state = {} # just save local state, not send to remote node 31 if task.has_role(task.role.COLLECTOR): 32 self._local_state['env_step'] = 0 33 self._local_state['env_episode'] = 0 34 self._event_name = "context_exchanger_{role}" 35 self._skip_n_iter = skip_n_iter 36 self._storage_loader = storage_loader 37 for role in task.role: # Only subscribe to other roles 38 if not task.has_role(role): 39 task.on(self._event_name.format(role=role), self.put) 40 if storage_loader: 41 task.once("finish", lambda _: storage_loader.shutdown()) 42 43 def __new__(cls, *args, **kwargs): 44 if not task.router.is_active: 45 return task.void() 46 47 if len(task.roles) == 0: 48 logging.warning("The task does not have any roles defined, the ContextExchanger will not work.") 49 return task.void() 50 51 if len(task.roles) > 1: 52 logging.warning( 53 "Use multiple roles in one exchanger may lead to unexpected result, please check your code." 54 ) 55 56 return super(ContextExchanger, cls).__new__(cls) 57 58 def __call__(self, ctx: "Context"): 59 self.merge(ctx) 60 yield 61 payload = self.fetch(ctx) 62 if payload: 63 if self._storage_loader and task.has_role(task.role.COLLECTOR): 64 payload = self._storage_loader.save(payload) 65 for role in task.roles: 66 task.emit(self._event_name.format(role=role), payload, only_remote=True) 67 68 def __del__(self): 69 if self._storage_loader: 70 self._storage_loader.shutdown() 71 72 def put(self, payload: Union[Dict, Storage]): 73 """ 74 Overview: 75 Get attributes from ctx on the callback of event. 76 Each attribute should have a standalone put handler, which named `_put_{key}` 77 """ 78 79 def callback(payload: Dict): 80 for key, item in payload.items(): 81 fn_name = "_put_{}".format(key) 82 if hasattr(self, fn_name): 83 getattr(self, fn_name)(item) 84 else: 85 logging.warning("Receive unexpected key ({}) in context exchanger".format(key)) 86 87 if isinstance(payload, Storage): 88 assert self._storage_loader is not None, "Storage loader is not defined when data is a storage object." 89 self._storage_loader.load(payload, callback) 90 else: 91 callback(payload) 92 93 def fetch(self, ctx: "Context") -> Dict[str, Any]: 94 """ 95 Overview: 96 Fetch attributes from ctx before emit them to the event bus. 97 Each attribute should have a standalone fetch handler, which named `_fetch_{key}` 98 """ 99 payload = {} 100 for field in fields(ctx): 101 key, item = field.name, getattr(ctx, field.name) 102 fn_name = "_fetch_{}".format(key) 103 if hasattr(self, fn_name): 104 value = getattr(self, fn_name)(item) 105 if value is not None: 106 payload[key] = value 107 return payload 108 109 def merge(self, ctx: "Context"): 110 if task.has_role(task.role.LEARNER): 111 # Learner should always wait for trajs. 112 # TODO: Automaticlly wait based on properties, not roles. 113 while len(self._state) == 0: 114 sleep(0.01) 115 elif ctx.total_step >= self._skip_n_iter: 116 start = time() 117 while len(self._state) == 0: 118 if time() - start > 60: 119 logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id)) 120 break 121 sleep(0.01) 122 123 for k, v in self._state.items(): 124 if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): 125 pure_k = k.split('increment_')[-1] 126 setattr(ctx, pure_k, getattr(ctx, pure_k) + v) 127 else: 128 setattr(ctx, k, v) 129 self._state = {} 130 131 # Handle each attibute of context 132 def _put_trajectories(self, traj: List[Any]): 133 if not task.has_role(task.role.LEARNER): 134 return 135 if "trajectories" not in self._state: 136 self._state["trajectories"] = [] 137 self._state["trajectories"].extend(traj) 138 139 def _fetch_trajectories(self, traj: List[Any]): 140 if task.has_role(task.role.COLLECTOR): 141 return traj 142 143 def _put_episodes(self, episodes: List[Any]): 144 if not task.has_role(task.role.LEARNER): 145 return 146 if "episodes" not in self._state: 147 self._state["episodes"] = [] 148 self._state["episodes"].extend(episodes) 149 150 def _fetch_episodes(self, episodes: List[Any]): 151 if task.has_role(task.role.COLLECTOR): 152 return episodes 153 154 def _put_trajectory_end_idx(self, trajectory_end_idx: List[str]): 155 if not task.has_role(task.role.LEARNER): 156 return 157 if "trajectory_end_idx" not in self._state: 158 self._state["trajectory_end_idx"] = [] 159 self._state["trajectory_end_idx"].extend(trajectory_end_idx) 160 161 def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]): 162 if task.has_role(task.role.COLLECTOR): 163 return trajectory_end_idx 164 165 def _put_env_step(self, increment_env_step: int): 166 if not task.has_role(task.role.COLLECTOR): 167 if 'increment_env_step' not in self._state: 168 self._state['increment_env_step'] = 0 169 self._state["increment_env_step"] += increment_env_step 170 171 def _fetch_env_step(self, env_step: int): 172 if task.has_role(task.role.COLLECTOR): 173 increment_env_step = env_step - self._local_state['env_step'] 174 self._local_state['env_step'] = env_step 175 return increment_env_step 176 177 def _put_env_episode(self, increment_env_episode: int): 178 if not task.has_role(task.role.COLLECTOR): 179 if 'increment_env_episode' not in self._state: 180 self._state['increment_env_episode'] = 0 181 self._state["increment_env_episode"] += increment_env_episode 182 183 def _fetch_env_episode(self, env_episode: int): 184 if task.has_role(task.role.COLLECTOR): 185 increment_env_episode = env_episode - self._local_state['env_episode'] 186 self._local_state['env_episode'] = env_episode 187 return increment_env_episode 188 189 def _put_train_iter(self, train_iter: int): 190 if not task.has_role(task.role.LEARNER): 191 self._state["train_iter"] = train_iter 192 193 def _fetch_train_iter(self, train_iter: int): 194 if task.has_role(task.role.LEARNER): 195 return train_iter 196 197 198class ModelExchanger: 199 200 def __init__(self, model: "Module", model_loader: Optional[ModelLoader] = None) -> None: 201 """ 202 Overview: 203 Exchange model between processes, only the learner will send the model, 204 otherwise the model will only be received. 205 If you are using a shared model on a single host, there is no need to use this middleware. 206 Arguments: 207 - model (:obj:`torch.nn.Module`): Pytorch module. 208 - model_loader (:obj:`ModelLoader`): Encode model in subprocess. 209 """ 210 self._model = model 211 self._model_loader = model_loader 212 self._event_name = "model_exchanger" 213 self._state_dict_cache: Optional[Union[object, Storage]] = None 214 self._is_learner = task.has_role(task.role.LEARNER) 215 if not self._is_learner: 216 task.on(self._event_name, self._cache_state_dict) 217 if model_loader: 218 task.once("finish", lambda _: model_loader.shutdown()) 219 220 def _cache_state_dict(self, state_dict: Union[object, Storage]): 221 self._state_dict_cache = state_dict 222 223 def __new__(cls, *args, **kwargs): 224 if not task.router.is_active: 225 return task.void() 226 227 if len(task.roles) == 0: 228 logging.warning("The task does not have any roles defined, the ModelExchanger will not work.") 229 return task.void() 230 231 if len(task.roles) > 1: 232 logging.warning( 233 "Use multiple roles in one exchanger may lead to unexpected result, please check your code." 234 ) 235 236 return super(ModelExchanger, cls).__new__(cls) 237 238 def __call__(self, ctx: "Context") -> Any: 239 if self._model_loader: 240 self._model_loader.start() 241 242 if not self._is_learner: 243 if ctx.total_step != 0: # Skip first iteration 244 self._update_model() 245 else: 246 yield 247 self._send_model() 248 249 def _update_model(self): 250 start = time() 251 while True: 252 if task.finish: 253 return 254 if time() - start > 60: 255 logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id)) 256 break 257 if self._state_dict_cache is None: 258 sleep(0.01) 259 else: 260 if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None: 261 try: 262 self._model.load_state_dict(self._model_loader.load(self._state_dict_cache)) 263 self._state_dict_cache = None 264 break 265 except FileNotFoundError as e: 266 logging.warning( 267 "Model file has been deleted on node {}, maybe you can increase the ttl.".format( 268 task.router.node_id 269 ) 270 ) 271 self._state_dict_cache = None 272 continue 273 else: 274 self._model.load_state_dict(self._state_dict_cache) 275 self._state_dict_cache = None 276 break 277 278 def _send_model(self): 279 if self._model_loader: 280 self._model_loader.save(self._send_callback) 281 else: 282 task.emit(self._event_name, self._model.state_dict(), only_remote=True) 283 284 def _send_callback(self, storage: Storage): 285 if task.running: 286 task.emit(self._event_name, storage, only_remote=True) 287 288 def __del__(self): 289 if self._model_loader: 290 self._model_loader.shutdown() 291 292 293class PeriodicalModelExchanger: 294 295 def __init__( 296 self, 297 model: "Module", 298 mode: str, 299 period: int = 1, 300 delay_toleration: float = np.inf, 301 stale_toleration: int = 1, 302 event_name: str = "model_exchanger", 303 model_loader: Optional[ModelLoader] = None 304 ) -> None: 305 """ 306 Overview: 307 Exchange model between processes, set the mode to "send" or "receive" to specify the role of the process. 308 If you are using a shared model on a single host, there is no need to use this middleware. 309 Arguments: 310 - model (:obj:`torch.nn.Module`): Pytorch module. 311 - mode (:obj:`str`): "send" or "receive". 312 - period (:obj:`int`): The period of model exchange. 313 - delay_toleration (:obj:`float`): The permitted time interval for receiving model after being sent. 314 - stale_toleration (:obj:`int`): The permitted number of iterations for receiving model after being sent. 315 - event_name (:obj:`str`): The event name for model exchange. 316 - model_loader (:obj:`ModelLoader`): ModelLoader for this PeriodicalModelExchanger to use. 317 """ 318 self._model = model 319 self._model_loader = model_loader 320 self._event_name = event_name 321 self._period = period 322 self._mode = mode 323 if self._mode == "receive": 324 self._id_counter = -1 325 self._model_id = -1 326 else: 327 self._id_counter = 0 328 self._stale_toleration = stale_toleration 329 self._model_stale = stale_toleration 330 self._delay_toleration = delay_toleration 331 self._state_dict_cache: Optional[Union[object, Storage]] = None 332 333 if self._mode == "receive": 334 task.on(self._event_name, self._cache_state_dict) 335 if model_loader: 336 task.once("finish", lambda _: model_loader.shutdown()) 337 338 def _cache_state_dict(self, msg: Dict[str, Any]): 339 if msg['id'] % self._period == 0: 340 self._state_dict_cache = msg['model'] 341 self._id_counter = msg['id'] 342 self._time = msg['time'] 343 344 def __new__(cls, *args, **kwargs): 345 return super(PeriodicalModelExchanger, cls).__new__(cls) 346 347 def __call__(self, ctx: "Context") -> Any: 348 if self._model_loader: 349 self._model_loader.start() 350 351 if self._mode == "receive": 352 if ctx.total_step != 0: # Skip first iteration 353 self._update_model() 354 elif self._mode == "send": 355 yield 356 if self._id_counter % self._period == 0: 357 self._send_model(id=self._id_counter) 358 self._id_counter += 1 359 else: 360 raise NotImplementedError 361 362 def _update_model(self): 363 start = time() 364 while True: 365 if task.finish: 366 return 367 if time() - start > 60: 368 logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id)) 369 self._model_stale += 1 370 break 371 if self._state_dict_cache is None: 372 if self._model_stale < self._stale_toleration and time() - self._time < self._delay_toleration: 373 self._model_stale += 1 374 break 375 else: 376 sleep(0.01) 377 else: 378 if self._id_counter > self._model_id and time() - self._time < self._delay_toleration: 379 if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None: 380 try: 381 self._model.load_state_dict(self._model_loader.load(self._state_dict_cache)) 382 self._state_dict_cache = None 383 self._model_id = self._id_counter 384 self._model_stale = 1 385 break 386 except FileNotFoundError as e: 387 logging.warning( 388 "Model file has been deleted on node {}, maybe you can increase the ttl.".format( 389 task.router.node_id 390 ) 391 ) 392 self._state_dict_cache = None 393 continue 394 else: 395 self._model.load_state_dict(self._state_dict_cache) 396 self._state_dict_cache = None 397 self._model_id = self._id_counter 398 self._model_stale = 1 399 break 400 else: 401 self._model_stale += 1 402 403 def _send_model(self, id: int): 404 if self._model_loader: 405 self._model_loader.save(self._send_callback) 406 else: 407 task.emit(self._event_name, {'id': id, 'model': self._model.state_dict(), 'time': time()}, only_remote=True) 408 409 def _send_callback(self, storage: Storage): 410 if task.running: 411 task.emit(self._event_name, storage, only_remote=True) 412 413 def __del__(self): 414 if self._model_loader: 415 self._model_loader.shutdown()