ding.entry.cli_ditask¶
ding.entry.cli_ditask
¶
Full Source Code
../ding/entry/cli_ditask.py
1import click 2import os 3import sys 4import importlib 5import importlib.util 6import json 7from click.core import Context, Option 8 9from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ 10from ding.framework import Parallel 11from ding.entry.cli_parsers import PLATFORM_PARSERS 12 13 14def print_version(ctx: Context, param: Option, value: bool) -> None: 15 if not value or ctx.resilient_parsing: 16 return 17 click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__)) 18 click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__)) 19 ctx.exit() 20 21 22CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) 23 24 25@click.command(context_settings=CONTEXT_SETTINGS) 26@click.option( 27 '-v', 28 '--version', 29 is_flag=True, 30 callback=print_version, 31 expose_value=False, 32 is_eager=True, 33 help="Show package's version information." 34) 35@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.") 36@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1") 37@click.option( 38 '--protocol', 39 type=click.Choice(["tcp", "ipc"]), 40 default="tcp", 41 help="Network protocol in parallel mode, default: tcp" 42) 43@click.option( 44 "--ports", 45 type=str, 46 help="The port addresses that the tasks listen to, e.g. 50515,50516, default: k8s, local: 50515, slurm: 15151" 47) 48@click.option("--attach-to", type=str, help="The addresses to connect to.") 49@click.option("--address", type=str, help="The address to listen to (without port).") 50@click.option("--labels", type=str, help="Labels.") 51@click.option("--node-ids", type=str, help="Candidate node ids.") 52@click.option( 53 "--topology", 54 type=click.Choice(["alone", "mesh", "star"]), 55 default="alone", 56 help="Network topology, default: alone." 57) 58@click.option("--platform-spec", type=str, help="Platform specific configure.") 59@click.option("--platform", type=str, help="Platform type: slurm, k8s.") 60@click.option("--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis.") 61@click.option("--redis-host", type=str, help="Redis host.") 62@click.option("--redis-port", type=int, help="Redis port.") 63@click.option("-m", "--main", type=str, help="Main function of entry module.") 64@click.option("--startup-interval", type=int, default=1, help="Start up interval between each task.") 65@click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP") 66def cli_ditask(*args, **kwargs): 67 return _cli_ditask(*args, **kwargs) 68 69 70def _parse_platform_args(platform: str, platform_spec: str, all_args: dict): 71 if platform_spec: 72 try: 73 if os.path.splitext(platform_spec) == "json": 74 with open(platform_spec) as f: 75 platform_spec = json.load(f) 76 else: 77 platform_spec = json.loads(platform_spec) 78 except: 79 click.echo("platform_spec is not a valid json!") 80 exit(1) 81 if platform not in PLATFORM_PARSERS: 82 click.echo("platform type is invalid! type: {}".format(platform)) 83 exit(1) 84 all_args.pop("platform") 85 all_args.pop("platform_spec") 86 try: 87 parsed_args = PLATFORM_PARSERS[platform](platform_spec, **all_args) 88 except Exception as e: 89 click.echo("error when parse platform spec configure: {}".format(e)) 90 raise e 91 92 return parsed_args 93 94 95def _cli_ditask( 96 package: str, 97 main: str, 98 parallel_workers: int, 99 protocol: str, 100 ports: str, 101 attach_to: str, 102 address: str, 103 labels: str, 104 node_ids: str, 105 topology: str, 106 mq_type: str, 107 redis_host: str, 108 redis_port: int, 109 startup_interval: int, 110 local_rank: int = 0, 111 platform: str = None, 112 platform_spec: str = None, 113): 114 # Parse entry point 115 all_args = locals() 116 if platform: 117 parsed_args = _parse_platform_args(platform, platform_spec, all_args) 118 return _cli_ditask(**parsed_args) 119 120 if not package: 121 package = os.getcwd() 122 sys.path.append(package) 123 if main is None: 124 mod_name = os.path.basename(package) 125 mod_name, _ = os.path.splitext(mod_name) 126 func_name = "main" 127 else: 128 mod_name, func_name = main.rsplit(".", 1) 129 root_mod_name = mod_name.split(".", 1)[0] 130 sys.path.append(os.path.join(package, root_mod_name)) 131 mod = importlib.import_module(mod_name) 132 main_func = getattr(mod, func_name) 133 # Parse arguments 134 ports = ports or 50515 135 if not isinstance(ports, int): 136 ports = ports.split(",") 137 ports = list(map(lambda i: int(i), ports)) 138 ports = ports[0] if len(ports) == 1 else ports 139 if attach_to: 140 attach_to = attach_to.split(",") 141 attach_to = list(map(lambda s: s.strip(), attach_to)) 142 if labels: 143 labels = labels.split(",") 144 labels = set(map(lambda s: s.strip(), labels)) 145 if node_ids and not isinstance(node_ids, int): 146 node_ids = node_ids.split(",") 147 node_ids = list(map(lambda i: int(i), node_ids)) 148 Parallel.runner( 149 n_parallel_workers=parallel_workers, 150 ports=ports, 151 protocol=protocol, 152 topology=topology, 153 attach_to=attach_to, 154 address=address, 155 labels=labels, 156 node_ids=node_ids, 157 mq_type=mq_type, 158 redis_host=redis_host, 159 redis_port=redis_port, 160 startup_interval=startup_interval 161 )(main_func)