ding.config.utils¶
ding.config.utils
¶
save_config_formatted(config_, path='formatted_total_config.py')
¶
Overview
save formatted configuration to python file that can be read by serial_pipeline directly.
Arguments:
- config (:obj:dict): Config dict
- path (:obj:str): Path of python file
Full Source Code
../ding/config/utils.py
1from typing import Optional, List 2import copy 3from easydict import EasyDict 4 5from ding.utils import find_free_port, find_free_port_slurm, node_to_partition, node_to_host, pretty_print, \ 6 DEFAULT_K8S_COLLECTOR_PORT, DEFAULT_K8S_LEARNER_PORT, DEFAULT_K8S_COORDINATOR_PORT 7from dizoo.classic_control.cartpole.config.parallel import cartpole_dqn_config 8 9default_host = '0.0.0.0' 10default_port = 22270 11 12 13def set_host_port(cfg: EasyDict, coordinator_host: str, learner_host: str, collector_host: str) -> EasyDict: 14 cfg.coordinator.host = coordinator_host 15 if cfg.coordinator.port == 'auto': 16 cfg.coordinator.port = find_free_port(coordinator_host) 17 learner_count = 0 18 collector_count = 0 19 for k in cfg.keys(): 20 if k == 'learner_aggregator': 21 raise NotImplementedError 22 if k.startswith('learner'): 23 if cfg[k].host == 'auto': 24 if isinstance(learner_host, list): 25 cfg[k].host = learner_host[learner_count] 26 learner_count += 1 27 elif isinstance(learner_host, str): 28 cfg[k].host = learner_host 29 else: 30 raise TypeError("not support learner_host type: {}".format(learner_host)) 31 if cfg[k].port == 'auto': 32 cfg[k].port = find_free_port(cfg[k].host) 33 cfg[k].aggregator = False 34 if k.startswith('collector'): 35 if cfg[k].host == 'auto': 36 if isinstance(collector_host, list): 37 cfg[k].host = collector_host[collector_count] 38 collector_count += 1 39 elif isinstance(collector_host, str): 40 cfg[k].host = collector_host 41 else: 42 raise TypeError("not support collector_host type: {}".format(collector_host)) 43 if cfg[k].port == 'auto': 44 cfg[k].port = find_free_port(cfg[k].host) 45 return cfg 46 47 48def set_host_port_slurm(cfg: EasyDict, coordinator_host: str, learner_node: list, collector_node: list) -> EasyDict: 49 cfg.coordinator.host = coordinator_host 50 if cfg.coordinator.port == 'auto': 51 cfg.coordinator.port = find_free_port(coordinator_host) 52 if isinstance(learner_node, str): 53 learner_node = [learner_node] 54 if isinstance(collector_node, str): 55 collector_node = [collector_node] 56 learner_count, collector_count = 0, 0 57 learner_multi = {} 58 for k in cfg.keys(): 59 if learner_node is not None and k.startswith('learner'): 60 node = learner_node[learner_count % len(learner_node)] 61 cfg[k].node = node 62 cfg[k].partition = node_to_partition(node) 63 gpu_num = cfg[k].gpu_num 64 if cfg[k].host == 'auto': 65 cfg[k].host = node_to_host(node) 66 if cfg[k].port == 'auto': 67 if gpu_num == 1: 68 cfg[k].port = find_free_port_slurm(node) 69 learner_multi[k] = False 70 else: 71 cfg[k].port = [find_free_port_slurm(node) for _ in range(gpu_num)] 72 learner_multi[k] = True 73 learner_count += 1 74 if collector_node is not None and k.startswith('collector'): 75 node = collector_node[collector_count % len(collector_node)] 76 cfg[k].node = node 77 cfg[k].partition = node_to_partition(node) 78 if cfg[k].host == 'auto': 79 cfg[k].host = node_to_host(node) 80 if cfg[k].port == 'auto': 81 cfg[k].port = find_free_port_slurm(node) 82 collector_count += 1 83 for k, flag in learner_multi.items(): 84 if flag: 85 host = cfg[k].host 86 learner_interaction_cfg = {str(i): [str(i), host, p] for i, p in enumerate(cfg[k].port)} 87 aggregator_cfg = dict( 88 master=dict( 89 host=host, 90 port=find_free_port_slurm(cfg[k].node), 91 ), 92 slave=dict( 93 host=host, 94 port=find_free_port_slurm(cfg[k].node), 95 ), 96 learner=learner_interaction_cfg, 97 node=cfg[k].node, 98 partition=cfg[k].partition, 99 ) 100 cfg[k].aggregator = True 101 cfg['learner_aggregator' + k[7:]] = aggregator_cfg 102 else: 103 cfg[k].aggregator = False 104 return cfg 105 106 107def set_host_port_k8s(cfg: EasyDict, coordinator_port: int, learner_port: int, collector_port: int) -> EasyDict: 108 cfg.coordinator.host = default_host 109 cfg.coordinator.port = coordinator_port if coordinator_port is not None else DEFAULT_K8S_COORDINATOR_PORT 110 base_learner_cfg = None 111 base_collector_cfg = None 112 if learner_port is None: 113 learner_port = DEFAULT_K8S_LEARNER_PORT 114 if collector_port is None: 115 collector_port = DEFAULT_K8S_COLLECTOR_PORT 116 for k in cfg.keys(): 117 if k.startswith('learner'): 118 # create the base learner config 119 if base_learner_cfg is None: 120 base_learner_cfg = copy.deepcopy(cfg[k]) 121 base_learner_cfg.host = default_host 122 base_learner_cfg.port = learner_port 123 cfg[k].port = learner_port 124 elif k.startswith('collector'): 125 # create the base collector config 126 if base_collector_cfg is None: 127 base_collector_cfg = copy.deepcopy(cfg[k]) 128 base_collector_cfg.host = default_host 129 base_collector_cfg.port = collector_port 130 cfg[k].port = collector_port 131 cfg['learner'] = base_learner_cfg 132 cfg['collector'] = base_collector_cfg 133 return cfg 134 135 136def set_learner_interaction_for_coordinator(cfg: EasyDict) -> EasyDict: 137 cfg.coordinator.learner = {} 138 for k in cfg.keys(): 139 if k.startswith('learner') and not k.startswith('learner_aggregator'): 140 if cfg[k].aggregator: 141 dst_k = 'learner_aggregator' + k[7:] 142 cfg.coordinator.learner[k] = [k, cfg[dst_k].slave.host, cfg[dst_k].slave.port] 143 else: 144 dst_k = k 145 cfg.coordinator.learner[k] = [k, cfg[dst_k].host, cfg[dst_k].port] 146 return cfg 147 148 149def set_collector_interaction_for_coordinator(cfg: EasyDict) -> EasyDict: 150 cfg.coordinator.collector = {} 151 for k in cfg.keys(): 152 if k.startswith('collector'): 153 cfg.coordinator.collector[k] = [k, cfg[k].host, cfg[k].port] 154 return cfg 155 156 157def set_system_cfg(cfg: EasyDict) -> EasyDict: 158 learner_num = cfg.main.policy.learn.learner.learner_num 159 collector_num = cfg.main.policy.collect.collector.collector_num 160 path_data = cfg.system.path_data 161 path_policy = cfg.system.path_policy 162 coordinator_cfg = cfg.system.coordinator 163 communication_mode = cfg.system.communication_mode 164 assert communication_mode in ['auto'], communication_mode 165 learner_gpu_num = cfg.system.learner_gpu_num 166 learner_multi_gpu = learner_gpu_num > 1 167 new_cfg = dict(coordinator=dict( 168 host='auto', 169 port='auto', 170 )) 171 new_cfg['coordinator'].update(coordinator_cfg) 172 for i in range(learner_num): 173 new_cfg[f'learner{i}'] = dict( 174 type=cfg.system.comm_learner.type, 175 import_names=cfg.system.comm_learner.import_names, 176 host='auto', 177 port='auto', 178 path_data=path_data, 179 path_policy=path_policy, 180 multi_gpu=learner_multi_gpu, 181 gpu_num=learner_gpu_num, 182 ) 183 for i in range(collector_num): 184 new_cfg[f'collector{i}'] = dict( 185 type=cfg.system.comm_collector.type, 186 import_names=cfg.system.comm_collector.import_names, 187 host='auto', 188 port='auto', 189 path_data=path_data, 190 path_policy=path_policy, 191 ) 192 return EasyDict(new_cfg) 193 194 195def parallel_transform( 196 cfg: dict, 197 coordinator_host: Optional[str] = None, 198 learner_host: Optional[List[str]] = None, 199 collector_host: Optional[List[str]] = None 200) -> None: 201 coordinator_host = default_host if coordinator_host is None else coordinator_host 202 collector_host = default_host if collector_host is None else collector_host 203 learner_host = default_host if learner_host is None else learner_host 204 cfg = EasyDict(cfg) 205 cfg.system = set_system_cfg(cfg) 206 cfg.system = set_host_port(cfg.system, coordinator_host, learner_host, collector_host) 207 cfg.system = set_learner_interaction_for_coordinator(cfg.system) 208 cfg.system = set_collector_interaction_for_coordinator(cfg.system) 209 return cfg 210 211 212def parallel_transform_slurm( 213 cfg: dict, 214 coordinator_host: Optional[str] = None, 215 learner_node: Optional[List[str]] = None, 216 collector_node: Optional[List[str]] = None 217) -> None: 218 cfg = EasyDict(cfg) 219 cfg.system = set_system_cfg(cfg) 220 cfg.system = set_host_port_slurm(cfg.system, coordinator_host, learner_node, collector_node) 221 cfg.system = set_learner_interaction_for_coordinator(cfg.system) 222 cfg.system = set_collector_interaction_for_coordinator(cfg.system) 223 pretty_print(cfg) 224 return cfg 225 226 227def parallel_transform_k8s( 228 cfg: dict, 229 coordinator_port: Optional[int] = None, 230 learner_port: Optional[int] = None, 231 collector_port: Optional[int] = None 232) -> None: 233 cfg = EasyDict(cfg) 234 cfg.system = set_system_cfg(cfg) 235 cfg.system = set_host_port_k8s(cfg.system, coordinator_port, learner_port, collector_port) 236 # learner/collector is created by opereator, so the following field is placeholder 237 cfg.system.coordinator.collector = {} 238 cfg.system.coordinator.learner = {} 239 pretty_print(cfg) 240 return cfg 241 242 243def save_config_formatted(config_: dict, path: str = 'formatted_total_config.py') -> None: 244 """ 245 Overview: 246 save formatted configuration to python file that can be read by serial_pipeline directly. 247 Arguments: 248 - config (:obj:`dict`): Config dict 249 - path (:obj:`str`): Path of python file 250 """ 251 with open(path, "w") as f: 252 f.write('from easydict import EasyDict\n\n') 253 f.write('main_config = dict(\n') 254 f.write(" exp_name='{}',\n".format(config_.exp_name)) 255 for k, v in config_.items(): 256 if (k == 'env'): 257 f.write(' env=dict(\n') 258 for k2, v2 in v.items(): 259 if (k2 != 'type' and k2 != 'import_names' and k2 != 'manager'): 260 if (isinstance(v2, str)): 261 f.write(" {}='{}',\n".format(k2, v2)) 262 else: 263 f.write(" {}={},\n".format(k2, v2)) 264 if (k2 == 'manager'): 265 f.write(" manager=dict(\n") 266 for k3, v3 in v2.items(): 267 if (v3 != 'cfg_type' and v3 != 'type'): 268 if (isinstance(v3, str)): 269 f.write(" {}='{}',\n".format(k3, v3)) 270 elif v3 == float('inf'): 271 f.write(" {}=float('{}'),\n".format(k3, v3)) 272 else: 273 f.write(" {}={},\n".format(k3, v3)) 274 f.write(" ),\n") 275 f.write(" ),\n") 276 if (k == 'policy'): 277 f.write(' policy=dict(\n') 278 for k2, v2 in v.items(): 279 if (k2 != 'type' and k2 != 'learn' and k2 != 'collect' and k2 != 'eval' and k2 != 'other' 280 and k2 != 'model'): 281 if (isinstance(v2, str)): 282 f.write(" {}='{}',\n".format(k2, v2)) 283 else: 284 f.write(" {}={},\n".format(k2, v2)) 285 elif (k2 == 'learn'): 286 f.write(" learn=dict(\n") 287 for k3, v3 in v2.items(): 288 if (k3 != 'learner'): 289 if (isinstance(v3, str)): 290 f.write(" {}='{}',\n".format(k3, v3)) 291 else: 292 f.write(" {}={},\n".format(k3, v3)) 293 if (k3 == 'learner'): 294 f.write(" learner=dict(\n") 295 for k4, v4 in v3.items(): 296 if (k4 != 'dataloader' and k4 != 'hook'): 297 if (isinstance(v4, str)): 298 f.write(" {}='{}',\n".format(k4, v4)) 299 else: 300 f.write(" {}={},\n".format(k4, v4)) 301 else: 302 if (k4 == 'dataloader'): 303 f.write(" dataloader=dict(\n") 304 for k5, v5 in v4.items(): 305 if (isinstance(v5, str)): 306 f.write(" {}='{}',\n".format(k5, v5)) 307 else: 308 f.write(" {}={},\n".format(k5, v5)) 309 f.write(" ),\n") 310 if (k4 == 'hook'): 311 f.write(" hook=dict(\n") 312 for k5, v5 in v4.items(): 313 if (isinstance(v5, str)): 314 f.write(" {}='{}',\n".format(k5, v5)) 315 else: 316 f.write(" {}={},\n".format(k5, v5)) 317 f.write(" ),\n") 318 f.write(" ),\n") 319 f.write(" ),\n") 320 elif (k2 == 'collect'): 321 f.write(" collect=dict(\n") 322 for k3, v3 in v2.items(): 323 if (k3 != 'collector'): 324 if (isinstance(v3, str)): 325 f.write(" {}='{}',\n".format(k3, v3)) 326 else: 327 f.write(" {}={},\n".format(k3, v3)) 328 if (k3 == 'collector'): 329 f.write(" collector=dict(\n") 330 for k4, v4 in v3.items(): 331 if (isinstance(v4, str)): 332 f.write(" {}='{}',\n".format(k4, v4)) 333 else: 334 f.write(" {}={},\n".format(k4, v4)) 335 f.write(" ),\n") 336 f.write(" ),\n") 337 elif (k2 == 'eval'): 338 f.write(" eval=dict(\n") 339 for k3, v3 in v2.items(): 340 if (k3 != 'evaluator'): 341 if (isinstance(v3, str)): 342 f.write(" {}='{}',\n".format(k3, v3)) 343 else: 344 f.write(" {}={},\n".format(k3, v3)) 345 if (k3 == 'evaluator'): 346 f.write(" evaluator=dict(\n") 347 for k4, v4 in v3.items(): 348 if (isinstance(v4, str)): 349 f.write(" {}='{}',\n".format(k4, v4)) 350 else: 351 f.write(" {}={},\n".format(k4, v4)) 352 f.write(" ),\n") 353 f.write(" ),\n") 354 elif (k2 == 'model'): 355 f.write(" model=dict(\n") 356 for k3, v3 in v2.items(): 357 if (isinstance(v3, str)): 358 f.write(" {}='{}',\n".format(k3, v3)) 359 else: 360 f.write(" {}={},\n".format(k3, v3)) 361 f.write(" ),\n") 362 elif (k2 == 'other'): 363 f.write(" other=dict(\n") 364 for k3, v3 in v2.items(): 365 if (k3 == 'replay_buffer'): 366 f.write(" replay_buffer=dict(\n") 367 for k4, v4 in v3.items(): 368 if (k4 != 'monitor' and k4 != 'thruput_controller'): 369 if (isinstance(v4, dict)): 370 f.write(" {}=dict(\n".format(k4)) 371 for k5, v5 in v4.items(): 372 if (isinstance(v5, str)): 373 f.write(" {}='{}',\n".format(k5, v5)) 374 elif v5 == float('inf'): 375 f.write(" {}=float('{}'),\n".format(k5, v5)) 376 elif (isinstance(v5, dict)): 377 f.write(" {}=dict(\n".format(k5)) 378 for k6, v6 in v5.items(): 379 if (isinstance(v6, str)): 380 f.write(" {}='{}',\n".format(k6, v6)) 381 elif v6 == float('inf'): 382 f.write( 383 " {}=float('{}'),\n".format( 384 k6, v6 385 ) 386 ) 387 elif (isinstance(v6, dict)): 388 f.write(" {}=dict(\n".format(k6)) 389 for k7, v7 in v6.items(): 390 if (isinstance(v7, str)): 391 f.write( 392 " {}='{}',\n".format( 393 k7, v7 394 ) 395 ) 396 elif v7 == float('inf'): 397 f.write( 398 " {}=float('{}'),\n". 399 format(k7, v7) 400 ) 401 else: 402 f.write( 403 " {}={},\n".format( 404 k7, v7 405 ) 406 ) 407 f.write(" ),\n") 408 else: 409 f.write(" {}={},\n".format(k6, v6)) 410 f.write(" ),\n") 411 else: 412 f.write(" {}={},\n".format(k5, v5)) 413 f.write(" ),\n") 414 else: 415 if (isinstance(v4, str)): 416 f.write(" {}='{}',\n".format(k4, v4)) 417 elif v4 == float('inf'): 418 f.write(" {}=float('{}'),\n".format(k4, v4)) 419 420 else: 421 f.write(" {}={},\n".format(k4, v4)) 422 else: 423 if (k4 == 'monitor'): 424 f.write(" monitor=dict(\n") 425 for k5, v5 in v4.items(): 426 if (k5 == 'log_path'): 427 if (isinstance(v5, str)): 428 f.write(" {}='{}',\n".format(k5, v5)) 429 else: 430 f.write(" {}={},\n".format(k5, v5)) 431 else: 432 f.write(" {}=dict(\n".format(k5)) 433 for k6, v6 in v5.items(): 434 if (isinstance(v6, str)): 435 f.write(" {}='{}',\n".format(k6, v6)) 436 else: 437 f.write(" {}={},\n".format(k6, v6)) 438 f.write(" ),\n") 439 f.write(" ),\n") 440 if (k4 == 'thruput_controller'): 441 f.write(" thruput_controller=dict(\n") 442 for k5, v5 in v4.items(): 443 if (isinstance(v5, dict)): 444 f.write(" {}=dict(\n".format(k5)) 445 for k6, v6 in v5.items(): 446 if (isinstance(v6, str)): 447 f.write(" {}='{}',\n".format(k6, v6)) 448 elif v6 == float('inf'): 449 f.write( 450 " {}=float('{}'),\n".format( 451 k6, v6 452 ) 453 ) 454 else: 455 f.write(" {}={},\n".format(k6, v6)) 456 f.write(" ),\n") 457 else: 458 if (isinstance(v5, str)): 459 f.write(" {}='{}',\n".format(k5, v5)) 460 else: 461 f.write(" {}={},\n".format(k5, v5)) 462 f.write(" ),\n") 463 f.write(" ),\n") 464 f.write(" ),\n") 465 f.write(" ),\n)\n") 466 f.write('main_config = EasyDict(main_config)\n') 467 f.write('main_config = main_config\n') 468 f.write('create_config = dict(\n') 469 for k, v in config_.items(): 470 if (k == 'env'): 471 f.write(' env=dict(\n') 472 for k2, v2 in v.items(): 473 if (k2 == 'type' or k2 == 'import_names'): 474 if isinstance(v2, str): 475 f.write(" {}='{}',\n".format(k2, v2)) 476 else: 477 f.write(" {}={},\n".format(k2, v2)) 478 f.write(" ),\n") 479 for k2, v2 in v.items(): 480 if (k2 == 'manager'): 481 f.write(' env_manager=dict(\n') 482 for k3, v3 in v2.items(): 483 if (k3 == 'cfg_type' or k3 == 'type'): 484 if (isinstance(v3, str)): 485 f.write(" {}='{}',\n".format(k3, v3)) 486 else: 487 f.write(" {}={},\n".format(k3, v3)) 488 f.write(" ),\n") 489 policy_type = config_.policy.type 490 if '_command' in policy_type: 491 f.write(" policy=dict(type='{}'),\n".format(policy_type[0:len(policy_type) - 8])) 492 else: 493 f.write(" policy=dict(type='{}'),\n".format(policy_type)) 494 f.write(")\n") 495 f.write('create_config = EasyDict(create_config)\n') 496 f.write('create_config = create_config\n') 497 498 499parallel_test_main_config = cartpole_dqn_config 500parallel_test_create_config = dict( 501 env=dict( 502 type='cartpole', 503 import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], 504 ), 505 env_manager=dict(type='subprocess'), 506 policy=dict(type='dqn_command'), 507 comm_learner=dict( 508 type='flask_fs', 509 import_names=['ding.worker.learner.comm.flask_fs_learner'], 510 ), 511 comm_collector=dict( 512 type='flask_fs', 513 import_names=['ding.worker.collector.comm.flask_fs_collector'], 514 ), 515 learner=dict( 516 type='base', 517 import_names=['ding.worker.learner.base_learner'], 518 ), 519 collector=dict( 520 type='zergling', 521 import_names=['ding.worker.collector.zergling_parallel_collector'], 522 ), 523 commander=dict( 524 type='naive', 525 import_names=['ding.worker.coordinator.base_parallel_commander'], 526 ), 527) 528parallel_test_create_config = EasyDict(parallel_test_create_config) 529parallel_test_system_config = dict( 530 coordinator=dict(), 531 path_data='.', 532 path_policy='.', 533 communication_mode='auto', 534 learner_gpu_num=1, 535) 536parallel_test_system_config = EasyDict(parallel_test_system_config)