ding.entry.cli¶
ding.entry.cli
¶
Full Source Code
../ding/entry/cli.py
1from typing import List, Union 2import os 3import copy 4import click 5from click.core import Context, Option 6import numpy as np 7 8from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ 9from ding.config import read_config 10from .predefined_config import get_predefined_config 11 12 13def print_version(ctx: Context, param: Option, value: bool) -> None: 14 if not value or ctx.resilient_parsing: 15 return 16 click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__)) 17 click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__)) 18 ctx.exit() 19 20 21def print_registry(ctx: Context, param: Option, value: str): 22 if value is None: 23 return 24 from ding.utils import registries # noqa 25 if value not in registries: 26 click.echo('[ERROR]: not support registry name: {}'.format(value)) 27 else: 28 registered_info = registries[value].query_details() 29 click.echo('Available {}: [{}]'.format(value, '|'.join(registered_info.keys()))) 30 for alias, info in registered_info.items(): 31 click.echo('\t{}: registered at {}#{}'.format(alias, info[0], info[1])) 32 ctx.exit() 33 34 35CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) 36 37 38@click.command(context_settings=CONTEXT_SETTINGS) 39@click.option( 40 '-v', 41 '--version', 42 is_flag=True, 43 callback=print_version, 44 expose_value=False, 45 is_eager=True, 46 help="Show package's version information." 47) 48@click.option( 49 '-q', 50 '--query-registry', 51 type=str, 52 callback=print_registry, 53 expose_value=False, 54 is_eager=True, 55 help='query registered module or function, show name and path' 56) 57@click.option( 58 '-m', 59 '--mode', 60 type=click.Choice( 61 [ 62 'serial', 63 'serial_onpolicy', 64 'serial_sqil', 65 'serial_dqfd', 66 'serial_trex', 67 'serial_trex_onpolicy', 68 'parallel', 69 'dist', 70 'eval', 71 'serial_reward_model', 72 'serial_gail', 73 'serial_offline', 74 'serial_ngu', 75 ] 76 ), 77 help='serial-train or parallel-train or dist-train or eval' 78) 79@click.option('-c', '--config', type=str, help='Path to DRL experiment config') 80@click.option( 81 '-s', 82 '--seed', 83 type=int, 84 default=[0], 85 multiple=True, 86 help='random generator seed(for all the possible package: random, numpy, torch and user env)' 87) 88@click.option('-e', '--env', type=str, help='RL env name') 89@click.option('-p', '--policy', type=str, help='DRL policy name') 90@click.option('--exp-name', type=str, help='experiment directory name') 91@click.option('--train-iter', type=str, default='1e8', help='Maximum policy update iterations in training') 92@click.option('--env-step', type=str, default='1e8', help='Maximum collected environment steps for training') 93@click.option('--load-path', type=str, default=None, help='Path to load ckpt') 94@click.option('--replay-path', type=str, default=None, help='Path to save replay') 95# the following arguments are only applied to dist mode 96@click.option('--enable-total-log', type=bool, help='whether enable the total DI-engine system log', default=False) 97@click.option('--disable-flask-log', type=bool, help='whether disable flask log', default=True) 98@click.option( 99 '-P', '--platform', type=click.Choice(['local', 'slurm', 'k8s']), help='local or slurm or k8s', default='local' 100) 101@click.option( 102 '-M', 103 '--module', 104 type=click.Choice(['config', 'collector', 'learner', 'coordinator', 'learner_aggregator', 'spawn_learner']), 105 help='dist module type' 106) 107@click.option('--module-name', type=str, help='dist module name') 108@click.option('-cdh', '--coordinator-host', type=str, help='coordinator host', default='0.0.0.0') 109@click.option('-cdp', '--coordinator-port', type=int, help='coordinator port') 110@click.option('-lh', '--learner-host', type=str, help='learner host', default='0.0.0.0') 111@click.option('-lp', '--learner-port', type=int, help='learner port') 112@click.option('-clh', '--collector-host', type=str, help='collector host', default='0.0.0.0') 113@click.option('-clp', '--collector-port', type=int, help='collector port') 114@click.option('-agh', '--aggregator-host', type=str, help='aggregator slave host', default='0.0.0.0') 115@click.option('-agp', '--aggregator-port', type=int, help='aggregator slave port') 116@click.option('--add', type=click.Choice(['collector', 'learner']), help='add replicas type') 117@click.option('--delete', type=click.Choice(['collector', 'learner']), help='delete replicas type') 118@click.option('--restart', type=click.Choice(['collector', 'learner']), help='restart replicas type') 119@click.option('--kubeconfig', type=str, default=None, help='the path of Kubernetes configuration file') 120@click.option('-cdn', '--coordinator-name', type=str, default=None, help='coordinator name') 121@click.option('-ns', '--namespace', type=str, default=None, help='job namespace') 122@click.option('-rs', '--replicas', type=int, default=1, help='number of replicas to add/delete/restart') 123@click.option('-rpn', '--restart-pod-name', type=str, default=None, help='restart pod name') 124@click.option('--cpus', type=int, default=0, help='The requested CPU, read the value from DIJob yaml by default') 125@click.option('--gpus', type=int, default=0, help='The requested GPU, read the value from DIJob yaml by default') 126@click.option( 127 '--memory', type=str, default=None, help='The requested Memory, read the value from DIJob yaml by default' 128) 129@click.option( 130 '--profile', 131 type=str, 132 default=None, 133 help='profile Time cost by cProfile, and save the files into the specified folder path' 134) 135def cli( 136 # serial/eval 137 mode: str, 138 config: str, 139 seed: Union[int, List], 140 exp_name: str, 141 env: str, 142 policy: str, 143 train_iter: str, # transform into int 144 env_step: str, # transform into int 145 load_path: str, 146 replay_path: str, 147 # parallel/dist 148 platform: str, 149 coordinator_host: str, 150 coordinator_port: int, 151 learner_host: str, 152 learner_port: int, 153 collector_host: str, 154 collector_port: int, 155 aggregator_host: str, 156 aggregator_port: int, 157 enable_total_log: bool, 158 disable_flask_log: bool, 159 module: str, 160 module_name: str, 161 # add/delete/restart 162 add: str, 163 delete: str, 164 restart: str, 165 kubeconfig: str, 166 coordinator_name: str, 167 namespace: str, 168 replicas: int, 169 cpus: int, 170 gpus: int, 171 memory: str, 172 restart_pod_name: str, 173 profile: str, 174): 175 if profile is not None: 176 from ..utils.profiler_helper import Profiler 177 profiler = Profiler() 178 profiler.profile(profile) 179 180 train_iter = int(float(train_iter)) 181 env_step = int(float(env_step)) 182 183 def run_single_pipeline(seed, config): 184 if config is None: 185 config = get_predefined_config(env, policy) 186 else: 187 config = read_config(config) 188 if exp_name is not None: 189 config[0].exp_name = exp_name 190 191 if mode == 'serial': 192 from .serial_entry import serial_pipeline 193 serial_pipeline(config, seed, max_train_iter=train_iter, max_env_step=env_step) 194 elif mode == 'serial_onpolicy': 195 from .serial_entry_onpolicy import serial_pipeline_onpolicy 196 serial_pipeline_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step) 197 elif mode == 'serial_sqil': 198 from .serial_entry_sqil import serial_pipeline_sqil 199 expert_config = input("Enter the name of the config you used to generate your expert model: ") 200 serial_pipeline_sqil(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step) 201 elif mode == 'serial_reward_model': 202 from .serial_entry_reward_model_offpolicy import serial_pipeline_reward_model_offpolicy 203 serial_pipeline_reward_model_offpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step) 204 elif mode == 'serial_gail': 205 from .serial_entry_gail import serial_pipeline_gail 206 expert_config = input("Enter the name of the config you used to generate your expert model: ") 207 serial_pipeline_gail( 208 config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step, collect_data=True 209 ) 210 elif mode == 'serial_dqfd': 211 from .serial_entry_dqfd import serial_pipeline_dqfd 212 expert_config = input("Enter the name of the config you used to generate your expert model: ") 213 assert (expert_config == config[:config.find('_dqfd')] + '_dqfd_config.py'), "DQFD only supports "\ 214 + "the models used in q learning now; However, one should still type the DQFD config in this "\ 215 + "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py') 216 serial_pipeline_dqfd(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step) 217 elif mode == 'serial_trex': 218 from .serial_entry_trex import serial_pipeline_trex 219 serial_pipeline_trex(config, seed, max_train_iter=train_iter, max_env_step=env_step) 220 elif mode == 'serial_trex_onpolicy': 221 from .serial_entry_trex_onpolicy import serial_pipeline_trex_onpolicy 222 serial_pipeline_trex_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step) 223 elif mode == 'serial_offline': 224 from .serial_entry_offline import serial_pipeline_offline 225 serial_pipeline_offline(config, seed, max_train_iter=train_iter) 226 elif mode == 'serial_ngu': 227 from .serial_entry_ngu import serial_pipeline_ngu 228 serial_pipeline_ngu(config, seed, max_train_iter=train_iter) 229 elif mode == 'parallel': 230 from .parallel_entry import parallel_pipeline 231 parallel_pipeline(config, seed, enable_total_log, disable_flask_log) 232 elif mode == 'dist': 233 from .dist_entry import dist_launch_coordinator, dist_launch_collector, dist_launch_learner, \ 234 dist_prepare_config, dist_launch_learner_aggregator, dist_launch_spawn_learner, \ 235 dist_add_replicas, dist_delete_replicas, dist_restart_replicas 236 if module == 'config': 237 dist_prepare_config( 238 config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port, 239 learner_port, collector_port 240 ) 241 elif module == 'coordinator': 242 dist_launch_coordinator(config, seed, coordinator_port, disable_flask_log) 243 elif module == 'learner_aggregator': 244 dist_launch_learner_aggregator( 245 config, seed, aggregator_host, aggregator_port, module_name, disable_flask_log 246 ) 247 248 elif module == 'collector': 249 dist_launch_collector(config, seed, collector_port, module_name, disable_flask_log) 250 elif module == 'learner': 251 dist_launch_learner(config, seed, learner_port, module_name, disable_flask_log) 252 elif module == 'spawn_learner': 253 dist_launch_spawn_learner(config, seed, learner_port, module_name, disable_flask_log) 254 elif add in ['collector', 'learner']: 255 dist_add_replicas(add, kubeconfig, replicas, coordinator_name, namespace, cpus, gpus, memory) 256 elif delete in ['collector', 'learner']: 257 dist_delete_replicas(delete, kubeconfig, replicas, coordinator_name, namespace) 258 elif restart in ['collector', 'learner']: 259 dist_restart_replicas(restart, kubeconfig, coordinator_name, namespace, restart_pod_name) 260 else: 261 raise Exception 262 elif mode == 'eval': 263 from .application_entry import eval 264 eval(config, seed, load_path=load_path, replay_path=replay_path) 265 266 if mode is None: 267 raise RuntimeError("Please indicate at least one argument.") 268 269 if isinstance(seed, (list, tuple)): 270 assert len(seed) > 0, "Please input at least 1 seed" 271 if len(seed) == 1: # necessary 272 run_single_pipeline(seed[0], config) 273 else: 274 if exp_name is None: 275 multi_exp_root = os.path.basename(config).split('.')[0] + '_result' 276 else: 277 multi_exp_root = exp_name 278 if not os.path.exists(multi_exp_root): 279 os.makedirs(multi_exp_root) 280 abs_config_path = os.path.abspath(config) 281 origin_root = os.getcwd() 282 for s in seed: 283 seed_exp_root = os.path.join(multi_exp_root, 'seed{}'.format(s)) 284 if not os.path.exists(seed_exp_root): 285 os.makedirs(seed_exp_root) 286 os.chdir(seed_exp_root) 287 run_single_pipeline(s, abs_config_path) 288 os.chdir(origin_root) 289 else: 290 raise TypeError("invalid seed type: {}".format(type(seed)))