Skip to content

ding.entry.serial_entry_sqil

ding.entry.serial_entry_sqil

serial_pipeline_sqil(input_cfg, expert_cfg, seed=0, env_setting=None, model=None, expert_model=None, max_train_iter=int(10000000000.0), max_env_step=int(10000000000.0))

Overview

Serial pipeline sqil entry: we create this serial pipeline in order to implement SQIL in DI-engine. For now, we support the following envs Cartpole, Lunarlander, Pong, Spaceinvader, Qbert. The demonstration data come from the expert model. We use a well-trained model to generate demonstration data online

Arguments: - input_cfg (:obj:Union[str, Tuple[dict, dict]]): Config in dict type. str type means config file path. Tuple[dict, dict] type means [user_config, create_cfg]. - seed (:obj:int): Random seed. - env_setting (:obj:Optional[List[Any]]): A list with 3 elements: BaseEnv subclass, collector env config, and evaluator env config. - model (:obj:Optional[torch.nn.Module]): Instance of torch.nn.Module. - expert_model (:obj:Optional[torch.nn.Module]): Instance of torch.nn.Module. The default model is DQN(**cfg.policy.model) - max_train_iter (:obj:Optional[int]): Maximum policy update iterations in training. - max_env_step (:obj:Optional[int]): Maximum collected environment interaction steps. Returns: - policy (:obj:Policy): Converged policy.

Full Source Code

../ding/entry/serial_entry_sqil.py

1from typing import Union, Optional, List, Any, Tuple 2import os 3import torch 4from ditk import logging 5from functools import partial 6from tensorboardX import SummaryWriter 7from copy import deepcopy 8 9from ding.envs import get_vec_env_setting, create_env_manager 10from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ 11 create_serial_collector 12from ding.config import read_config, compile_config 13from ding.policy import create_policy 14from ding.utils import set_pkg_seed 15from .utils import random_collect 16 17 18def serial_pipeline_sqil( 19 input_cfg: Union[str, Tuple[dict, dict]], 20 expert_cfg: Union[str, Tuple[dict, dict]], 21 seed: int = 0, 22 env_setting: Optional[List[Any]] = None, 23 model: Optional[torch.nn.Module] = None, 24 expert_model: Optional[torch.nn.Module] = None, 25 max_train_iter: Optional[int] = int(1e10), 26 max_env_step: Optional[int] = int(1e10), 27) -> 'Policy': # noqa 28 """ 29 Overview: 30 Serial pipeline sqil entry: we create this serial pipeline in order to\ 31 implement SQIL in DI-engine. For now, we support the following envs\ 32 Cartpole, Lunarlander, Pong, Spaceinvader, Qbert. The demonstration\ 33 data come from the expert model. We use a well-trained model to \ 34 generate demonstration data online 35 Arguments: 36 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 37 ``str`` type means config file path. \ 38 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 39 - seed (:obj:`int`): Random seed. 40 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 41 ``BaseEnv`` subclass, collector env config, and evaluator env config. 42 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 43 - expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\ 44 The default model is DQN(**cfg.policy.model) 45 - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. 46 - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. 47 Returns: 48 - policy (:obj:`Policy`): Converged policy. 49 """ 50 if isinstance(input_cfg, str): 51 cfg, create_cfg = read_config(input_cfg) 52 expert_cfg, expert_create_cfg = read_config(expert_cfg) 53 else: 54 cfg, create_cfg = deepcopy(input_cfg) 55 expert_cfg, expert_create_cfg = expert_cfg 56 create_cfg.policy.type = create_cfg.policy.type + '_command' 57 expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command' 58 env_fn = None if env_setting is None else env_setting[0] 59 cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) 60 expert_cfg = compile_config( 61 expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True 62 ) 63 # expert config must have the same `n_sample`. The line below ensure we do not need to modify the expert configs 64 expert_cfg.policy.collect.n_sample = cfg.policy.collect.n_sample 65 # Create main components: env, policy 66 if env_setting is None: 67 env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) 68 else: 69 env_fn, collector_env_cfg, evaluator_env_cfg = env_setting 70 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 71 expert_collector_env = create_env_manager( 72 expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg] 73 ) 74 evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) 75 expert_collector_env.seed(cfg.seed) 76 collector_env.seed(cfg.seed) 77 evaluator_env.seed(cfg.seed, dynamic_seed=False) 78 expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect', 'command']) 79 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 80 policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) 81 expert_policy.collect_mode.load_state_dict(torch.load(cfg.policy.collect.model_path, map_location='cpu')) 82 # Create worker components: learner, collector, evaluator, replay buffer, commander. 83 tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) 84 learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) 85 collector = create_serial_collector( 86 cfg.policy.collect.collector, 87 env=collector_env, 88 policy=policy.collect_mode, 89 tb_logger=tb_logger, 90 exp_name=cfg.exp_name 91 ) 92 expert_collector = create_serial_collector( 93 expert_cfg.policy.collect.collector, 94 env=expert_collector_env, 95 policy=expert_policy.collect_mode, 96 tb_logger=tb_logger, 97 exp_name=expert_cfg.exp_name 98 ) 99 evaluator = InteractionSerialEvaluator( 100 cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name 101 ) 102 replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) 103 expert_buffer = create_buffer(expert_cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) 104 commander = BaseSerialCommander( 105 cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode 106 ) 107 expert_commander = BaseSerialCommander( 108 expert_cfg.policy.other.commander, learner, expert_collector, evaluator, replay_buffer, 109 expert_policy.command_mode 110 ) # we create this to avoid the issue of eps, this is an issue due to the sample collector part. 111 expert_collect_kwargs = expert_commander.step() 112 if 'eps' in expert_collect_kwargs: 113 expert_collect_kwargs['eps'] = -1 114 # ========== 115 # Main loop 116 # ========== 117 # Learner's before_run hook. 118 learner.call_hook('before_run') 119 120 # Accumulate plenty of data at the beginning of training. 121 if cfg.policy.get('random_collect_size', 0) > 0: 122 random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer) 123 if cfg.policy.get('expert_random_collect_size', 0) > 0: 124 random_collect( 125 expert_cfg.policy, expert_policy, expert_collector, expert_collector_env, expert_commander, expert_buffer 126 ) 127 while True: 128 collect_kwargs = commander.step() 129 # Evaluate policy performance 130 if evaluator.should_eval(learner.train_iter): 131 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 132 if stop: 133 break 134 # Collect data by default config n_sample/n_episode 135 new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) 136 expert_data = expert_collector.collect( 137 train_iter=learner.train_iter, policy_kwargs=expert_collect_kwargs 138 ) # policy_kwargs={'eps': -1} 139 for i in range(len(new_data)): 140 device_1 = new_data[i]['obs'].device 141 device_2 = expert_data[i]['obs'].device 142 new_data[i]['reward'] = torch.zeros(cfg.policy.nstep).to(device_1) 143 expert_data[i]['reward'] = torch.ones(cfg.policy.nstep).to(device_2) 144 replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) 145 expert_buffer.push(expert_data, cur_collector_envstep=collector.envstep) 146 # Learn policy from collected data 147 for i in range(cfg.policy.learn.update_per_collect): 148 # Learner will train ``update_per_collect`` times in one iteration. 149 train_data = replay_buffer.sample((learner.policy.get_attribute('batch_size')) // 2, learner.train_iter) 150 train_data_demonstration = expert_buffer.sample( 151 (learner.policy.get_attribute('batch_size')) // 2, learner.train_iter 152 ) 153 if train_data is None: 154 # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times 155 logging.warning( 156 "Replay buffer's data can only train for {} steps. ".format(i) + 157 "You can modify data collect config, e.g. increasing n_sample, n_episode." 158 ) 159 break 160 train_data = train_data + train_data_demonstration 161 learner.train(train_data, collector.envstep) 162 if learner.policy.get_attribute('priority'): 163 replay_buffer.update(learner.priority_info) 164 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: 165 break 166 167 # Learner's after_run hook. 168 learner.call_hook('after_run') 169 return policy