Skip to content

ding.entry.dist_entry

ding.entry.dist_entry

Full Source Code

../ding/entry/dist_entry.py

1import os 2import sys 3import subprocess 4import signal 5import pickle 6from ditk import logging 7import time 8from threading import Thread 9from easydict import EasyDict 10import numpy as np 11from ding.worker import Coordinator, create_comm_collector, create_comm_learner, LearnerAggregator 12from ding.config import read_config_with_system, compile_config_parallel 13from ding.utils import set_pkg_seed, DEFAULT_K8S_AGGREGATOR_SLAVE_PORT, pod_exec_command 14 15 16def dist_prepare_config( 17 filename: str, 18 seed: int, 19 platform: str, 20 coordinator_host: str, 21 learner_host: str, 22 collector_host: str, 23 coordinator_port: int, 24 learner_port: int, 25 collector_port, 26) -> str: 27 set_pkg_seed(seed) 28 main_cfg, create_cfg, system_cfg = read_config_with_system(filename) 29 config = compile_config_parallel( 30 main_cfg, 31 create_cfg=create_cfg, 32 system_cfg=system_cfg, 33 seed=seed, 34 platform=platform, 35 coordinator_host=coordinator_host, 36 learner_host=learner_host, 37 collector_host=collector_host, 38 coordinator_port=coordinator_port, 39 learner_port=learner_port, 40 collector_port=collector_port, 41 ) 42 # Pickle dump config to disk for later use. 43 real_filename = filename + '.pkl' 44 with open(real_filename, 'wb') as f: 45 pickle.dump(config, f) 46 return real_filename 47 48 49def dist_launch_coordinator( 50 filename: str, 51 seed: int, 52 coordinator_port: int, 53 disable_flask_log: bool, 54 enable_total_log: bool = False 55) -> None: 56 set_pkg_seed(seed) 57 # Disable some part of DI-engine log 58 if not enable_total_log: 59 coordinator_log = logging.getLogger('coordinator_logger') 60 coordinator_log.disabled = True 61 if disable_flask_log: 62 log = logging.getLogger('werkzeug') 63 log.disabled = True 64 with open(filename, 'rb') as f: 65 config = pickle.load(f) 66 # CLI > ENV VARIABLE > CONFIG 67 if coordinator_port is not None: 68 config.system.coordinator.port = coordinator_port 69 elif os.environ.get('COORDINATOR_PORT', None): 70 port = os.environ['COORDINATOR_PORT'] 71 if port.isdigit(): 72 config.system.coordinator.port = int(port) 73 else: # use config pre-defined value 74 assert 'port' in config.system.coordinator and np.isscalar(config.system.coordinator.port) 75 coordinator = Coordinator(config) 76 coordinator.start() 77 78 # Monitor thread: Coordinator will remain running until its ``system_shutdown_flag`` is set to False. 79 def shutdown_monitor(): 80 while True: 81 time.sleep(3) 82 if coordinator.system_shutdown_flag: 83 coordinator.close() 84 break 85 86 shutdown_monitor_thread = Thread(target=shutdown_monitor, args=(), daemon=True, name='shutdown_monitor') 87 shutdown_monitor_thread.start() 88 shutdown_monitor_thread.join() 89 print("[DI-engine dist pipeline]Your RL agent is converged, you can refer to 'log' and 'tensorboard' for details") 90 91 92def dist_launch_learner( 93 filename: str, seed: int, learner_port: int, name: str = None, disable_flask_log: bool = True 94) -> None: 95 set_pkg_seed(seed) 96 if disable_flask_log: 97 log = logging.getLogger('werkzeug') 98 log.disabled = True 99 if name is None: 100 name = 'learner' 101 with open(filename, 'rb') as f: 102 config = pickle.load(f).system[name] 103 # CLI > ENV VARIABLE > CONFIG 104 if learner_port is not None: 105 config.port = learner_port 106 elif os.environ.get('LEARNER_PORT', None): 107 port = os.environ['LEARNER_PORT'] 108 if port.isdigit(): 109 config.port = int(port) 110 else: # use config pre-defined value 111 assert 'port' in config and np.isscalar(config.port) 112 learner = create_comm_learner(config) 113 learner.start() 114 115 116def dist_launch_collector( 117 filename: str, seed: int, collector_port: int, name: str = None, disable_flask_log: bool = True 118) -> None: 119 set_pkg_seed(seed) 120 if disable_flask_log: 121 log = logging.getLogger('werkzeug') 122 log.disabled = True 123 if name is None: 124 name = 'collector' 125 with open(filename, 'rb') as f: 126 config = pickle.load(f).system[name] 127 # CLI > ENV VARIABLE > CONFIG 128 if collector_port is not None: 129 config.port = collector_port 130 elif os.environ.get('COLLECTOR_PORT', None): 131 port = os.environ['COLLECTOR_PORT'] 132 if port.isdigit(): 133 config.port = int(port) 134 else: # use config pre-defined value 135 assert 'port' in config and np.isscalar(config.port) 136 collector = create_comm_collector(config) 137 collector.start() 138 139 140def dist_launch_learner_aggregator( 141 filename: str, 142 seed: int, 143 aggregator_host: str, 144 aggregator_port: int, 145 name: str = None, 146 disable_flask_log: bool = True 147) -> None: 148 set_pkg_seed(seed) 149 if disable_flask_log: 150 log = logging.getLogger('werkzeug') 151 log.disabled = True 152 if filename is not None: 153 if name is None: 154 name = 'learner_aggregator' 155 with open(filename, 'rb') as f: 156 config = pickle.load(f).system[name] 157 else: 158 # start without config (create a fake one) 159 host, port = aggregator_host, DEFAULT_K8S_AGGREGATOR_SLAVE_PORT 160 if aggregator_port is not None: 161 port = aggregator_port 162 elif os.environ.get('AGGREGATOR_PORT', None): 163 _port = os.environ['AGGREGATOR_PORT'] 164 if _port.isdigit(): 165 port = int(_port) 166 config = dict( 167 master=dict(host=host, port=port + 1), 168 slave=dict(host=host, port=port + 0), 169 learner={}, 170 ) 171 config = EasyDict(config) 172 learner_aggregator = LearnerAggregator(config) 173 learner_aggregator.start() 174 175 176def dist_launch_spawn_learner( 177 filename: str, seed: int, learner_port: int, name: str = None, disable_flask_log: bool = True 178) -> None: 179 current_env = os.environ.copy() 180 local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) 181 processes = [] 182 183 for local_rank in range(0, local_world_size): 184 dist_rank = int(os.environ.get('START_RANK', 0)) + local_rank 185 current_env["RANK"] = str(dist_rank) 186 current_env["LOCAL_RANK"] = str(local_rank) 187 188 executable = subprocess.getoutput('which ding') 189 assert len(executable) > 0, "cannot find executable \"ding\"" 190 191 cmd = [executable, '-m', 'dist', '--module', 'learner'] 192 if filename is not None: 193 cmd += ['-c', f'{filename}'] 194 if seed is not None: 195 cmd += ['-s', f'{seed}'] 196 if learner_port is not None: 197 cmd += ['-lp', f'{learner_port}'] 198 if name is not None: 199 cmd += ['--module-name', f'{name}'] 200 if disable_flask_log is not None: 201 cmd += ['--disable-flask-log', f'{int(disable_flask_log)}'] 202 203 sig_names = {2: "SIGINT", 15: "SIGTERM"} 204 last_return_code = None 205 206 def sigkill_handler(signum, frame): 207 for process in processes: 208 print(f"Killing subprocess {process.pid}") 209 try: 210 process.kill() 211 except Exception: 212 pass 213 if last_return_code is not None: 214 raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd) 215 if signum in sig_names: 216 print(f"Main process received {sig_names[signum]}, exiting") 217 sys.exit(1) 218 219 # pass SIGINT/SIGTERM to children if the parent is being terminated 220 signal.signal(signal.SIGINT, sigkill_handler) 221 signal.signal(signal.SIGTERM, sigkill_handler) 222 223 process = subprocess.Popen(cmd, env=current_env, stdout=None, stderr=None) 224 processes.append(process) 225 226 try: 227 alive_processes = set(processes) 228 while len(alive_processes): 229 finished_processes = [] 230 for process in alive_processes: 231 if process.poll() is None: 232 # the process is still running 233 continue 234 else: 235 if process.returncode != 0: 236 last_return_code = process.returncode # for sigkill_handler 237 sigkill_handler(signal.SIGTERM, None) # not coming back 238 else: 239 # exited cleanly 240 finished_processes.append(process) 241 alive_processes = set(alive_processes) - set(finished_processes) 242 243 time.sleep(1) 244 finally: 245 # close open file descriptors 246 pass 247 248 249def dist_add_replicas( 250 replicas_type: str, 251 kubeconfig: str, 252 replicas: int, 253 coordinator_name: str, 254 namespace: str, 255 cpus: int, 256 gpus: int, 257 memory: str, 258) -> None: 259 assert coordinator_name and namespace, "Please provide --coordinator-name or --namespace" 260 261 import json 262 data = { 263 "namespace": namespace, 264 "coordinator": coordinator_name, 265 } 266 res = {"replicas": replicas} 267 if cpus > 0: 268 res['cpus'] = cpus 269 if gpus > 0: 270 res['gpus'] = gpus 271 if memory: 272 res['memory'] = memory 273 if replicas_type == 'collector': 274 data['collectors'] = res 275 elif replicas_type == 'learner': 276 data['learners'] = res 277 cmd = 'curl -X POST $KUBERNETES_SERVER_URL/v1alpha1/replicas ' \ 278 '-H "content-type: application/json" ' \ 279 f'-d \'{json.dumps(data)}\'' 280 ret, msg = pod_exec_command(kubeconfig, coordinator_name, namespace, cmd) 281 if ret == 0: 282 print(f'{replicas_type} add successfully') 283 else: 284 print(f'Failed to add {replicas_type}, return code: {ret}, message: {msg}') 285 286 287def dist_delete_replicas( 288 replicas_type: str, kubeconfig: str, replicas: int, coordinator_name: str, namespace: str 289) -> None: 290 assert coordinator_name and namespace, "Please provide --coordinator-name or --namespace" 291 292 import json 293 data = { 294 "namespace": namespace, 295 "coordinator": coordinator_name, 296 } 297 if replicas_type == 'collector': 298 data['collectors'] = {"replicas": replicas} 299 elif replicas_type == 'learner': 300 data['learners'] = {"replicas": replicas} 301 cmd = 'curl -X DELETE $KUBERNETES_SERVER_URL/v1alpha1/replicas ' \ 302 '-H "content-type: application/json" ' \ 303 f'-d \'{json.dumps(data)}\'' 304 ret, msg = pod_exec_command(kubeconfig, coordinator_name, namespace, cmd) 305 if ret == 0: 306 print(f'{replicas_type} delete successfully') 307 else: 308 print(f'Failed to delete {replicas_type}, return code: {ret}, message: {msg}') 309 310 311def dist_restart_replicas( 312 replicas_type: str, kubeconfig: str, coordinator_name: str, namespace: str, restart_pod_name: str 313) -> None: 314 assert coordinator_name and namespace, "Please provide --coordinator-name or --namespace" 315 316 import json 317 data = { 318 "namespace": namespace, 319 "coordinator": coordinator_name, 320 } 321 assert restart_pod_name, "Please provide restart pod name with --restart-pod-name" 322 if replicas_type == 'collector': 323 data['collectors'] = [restart_pod_name] 324 elif replicas_type == 'learner': 325 data['learners'] = [restart_pod_name] 326 cmd = 'curl -X POST $KUBERNETES_SERVER_URL/v1alpha1/replicas/failed ' \ 327 '-H "content-type: application/json" ' \ 328 f'-d \'{json.dumps(data)}\'' 329 ret, msg = pod_exec_command(kubeconfig, coordinator_name, namespace, cmd) 330 if ret == 0: 331 print(f'{replicas_type} restart successfully') 332 else: 333 print(f'Failed to restart {replicas_type}, return code: {ret}, message: {msg}')