Skip to content

ding.entry.serial_entry_dqfd

ding.entry.serial_entry_dqfd

serial_pipeline_dqfd(input_cfg, expert_cfg, seed=0, env_setting=None, model=None, expert_model=None, max_train_iter=int(10000000000.0), max_env_step=int(10000000000.0))

Overview

Serial pipeline dqfd entry: we create this serial pipeline in order to implement dqfd in DI-engine. For now, we support the following envs Cartpole, Lunarlander, Pong, Spaceinvader. The demonstration data come from the expert model. We use a well-trained model to generate demonstration data online

Arguments: - input_cfg (:obj:Union[str, Tuple[dict, dict]]): Config in dict type. str type means config file path. Tuple[dict, dict] type means [user_config, create_cfg]. - seed (:obj:int): Random seed. - env_setting (:obj:Optional[List[Any]]): A list with 3 elements: BaseEnv subclass, collector env config, and evaluator env config. - model (:obj:Optional[torch.nn.Module]): Instance of torch.nn.Module. - expert_model (:obj:Optional[torch.nn.Module]): Instance of torch.nn.Module. The default model is DQN(**cfg.policy.model) - max_train_iter (:obj:Optional[int]): Maximum policy update iterations in training. - max_env_step (:obj:Optional[int]): Maximum collected environment interaction steps. Returns: - policy (:obj:Policy): Converged policy.

Full Source Code

../ding/entry/serial_entry_dqfd.py

1from typing import Union, Optional, List, Any, Tuple 2import os 3import torch 4import numpy as np 5from ditk import logging 6from functools import partial 7from tensorboardX import SummaryWriter 8from copy import deepcopy 9 10from ding.envs import get_vec_env_setting, create_env_manager 11from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ 12 create_serial_collector 13from ding.config import read_config, compile_config 14from ding.policy import create_policy 15from ding.utils import set_pkg_seed 16from .utils import random_collect, mark_not_expert 17 18 19def serial_pipeline_dqfd( 20 input_cfg: Union[str, Tuple[dict, dict]], 21 expert_cfg: Union[str, Tuple[dict, dict]], 22 seed: int = 0, 23 env_setting: Optional[List[Any]] = None, 24 model: Optional[torch.nn.Module] = None, 25 expert_model: Optional[torch.nn.Module] = None, 26 max_train_iter: Optional[int] = int(1e10), 27 max_env_step: Optional[int] = int(1e10), 28) -> 'Policy': # noqa 29 """ 30 Overview: 31 Serial pipeline dqfd entry: we create this serial pipeline in order to\ 32 implement dqfd in DI-engine. For now, we support the following envs\ 33 Cartpole, Lunarlander, Pong, Spaceinvader. The demonstration\ 34 data come from the expert model. We use a well-trained model to \ 35 generate demonstration data online 36 Arguments: 37 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 38 ``str`` type means config file path. \ 39 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 40 - seed (:obj:`int`): Random seed. 41 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 42 ``BaseEnv`` subclass, collector env config, and evaluator env config. 43 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 44 - expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\ 45 The default model is DQN(**cfg.policy.model) 46 - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. 47 - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. 48 Returns: 49 - policy (:obj:`Policy`): Converged policy. 50 """ 51 if isinstance(input_cfg, str): 52 cfg, create_cfg = read_config(input_cfg) 53 expert_cfg, expert_create_cfg = read_config(expert_cfg) 54 else: 55 cfg, create_cfg = deepcopy(input_cfg) 56 expert_cfg, expert_create_cfg = expert_cfg 57 create_cfg.policy.type = create_cfg.policy.type + '_command' 58 expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command' 59 env_fn = None if env_setting is None else env_setting[0] 60 cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) 61 expert_cfg = compile_config( 62 expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True 63 ) 64 # Create main components: env, policy 65 if env_setting is None: 66 env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) 67 else: 68 env_fn, collector_env_cfg, evaluator_env_cfg = env_setting 69 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 70 expert_collector_env = create_env_manager( 71 expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg] 72 ) 73 evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) 74 expert_collector_env.seed(cfg.seed) 75 collector_env.seed(cfg.seed) 76 evaluator_env.seed(cfg.seed, dynamic_seed=False) 77 expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect', 'command']) 78 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 79 policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) 80 expert_policy.collect_mode.load_state_dict(torch.load(cfg.policy.collect.model_path, map_location='cpu')) 81 # Create worker components: learner, collector, evaluator, replay buffer, commander. 82 tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) 83 learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) 84 collector = create_serial_collector( 85 cfg.policy.collect.collector, 86 env=collector_env, 87 policy=policy.collect_mode, 88 tb_logger=tb_logger, 89 exp_name=cfg.exp_name 90 ) 91 expert_collector = create_serial_collector( 92 expert_cfg.policy.collect.collector, 93 env=expert_collector_env, 94 policy=expert_policy.collect_mode, 95 tb_logger=tb_logger, 96 exp_name=expert_cfg.exp_name 97 ) 98 evaluator = InteractionSerialEvaluator( 99 cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name 100 ) 101 replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) 102 commander = BaseSerialCommander( 103 cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode 104 ) 105 expert_commander = BaseSerialCommander( 106 expert_cfg.policy.other.commander, learner, expert_collector, evaluator, replay_buffer, 107 expert_policy.command_mode 108 ) # we create this to avoid the issue of eps, this is an issue due to the sample collector part. 109 expert_collect_kwargs = expert_commander.step() 110 if 'eps' in expert_collect_kwargs: 111 expert_collect_kwargs['eps'] = -1 112 # ========== 113 # Main loop 114 # ========== 115 # Learner's before_run hook. 116 learner.call_hook('before_run') 117 if cfg.policy.learn.expert_replay_buffer_size != 0: # for ablation study 118 dummy_variable = deepcopy(cfg.policy.other.replay_buffer) 119 dummy_variable['replay_buffer_size'] = cfg.policy.learn.expert_replay_buffer_size 120 expert_buffer = create_buffer(dummy_variable, tb_logger=tb_logger, exp_name=cfg.exp_name) 121 expert_data = expert_collector.collect( 122 n_sample=cfg.policy.learn.expert_replay_buffer_size, policy_kwargs=expert_collect_kwargs 123 ) 124 for i in range(len(expert_data)): 125 expert_data[i]['is_expert'] = 1 # set is_expert flag(expert 1, agent 0) 126 expert_buffer.push(expert_data, cur_collector_envstep=0) 127 for _ in range(cfg.policy.learn.per_train_iter_k): # pretrain 128 if evaluator.should_eval(learner.train_iter): 129 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 130 if stop: 131 break 132 # Learn policy from collected data 133 # Expert_learner will train ``update_per_collect == 1`` times in one iteration. 134 train_data = expert_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) 135 learner.train(train_data, collector.envstep) 136 if learner.policy.get_attribute('priority'): 137 expert_buffer.update(learner.priority_info) 138 learner.priority_info = {} 139 # Accumulate plenty of data at the beginning of training. 140 if cfg.policy.get('random_collect_size', 0) > 0: 141 random_collect( 142 cfg.policy, policy, collector, collector_env, commander, replay_buffer, postprocess_data_fn=mark_not_expert 143 ) 144 while True: 145 collect_kwargs = commander.step() 146 # Evaluate policy performance 147 if evaluator.should_eval(learner.train_iter): 148 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 149 if stop: 150 break 151 # Collect data by default config n_sample/n_episode 152 new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) 153 for i in range(len(new_data)): 154 new_data[i]['is_expert'] = 0 # set is_expert flag(expert 1, agent 0) 155 replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) 156 # Learn policy from collected data 157 for i in range(cfg.policy.learn.update_per_collect): 158 if cfg.policy.learn.expert_replay_buffer_size != 0: 159 # Learner will train ``update_per_collect`` times in one iteration. 160 # The hyperparameter pho, the demo ratio, control the propotion of data coming\ 161 # from expert demonstrations versus from the agent's own experience. 162 stats = np.random.choice( 163 (learner.policy.get_attribute('batch_size')), size=(learner.policy.get_attribute('batch_size')) 164 ) < ( 165 learner.policy.get_attribute('batch_size') 166 ) * cfg.policy.collect.pho # torch.rand((learner.policy.get_attribute('batch_size')))\ 167 # < cfg.policy.collect.pho 168 expert_batch_size = stats[stats].shape[0] 169 demo_batch_size = (learner.policy.get_attribute('batch_size')) - expert_batch_size 170 train_data = replay_buffer.sample(demo_batch_size, learner.train_iter) 171 train_data_demonstration = expert_buffer.sample(expert_batch_size, learner.train_iter) 172 if train_data is None: 173 # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times 174 logging.warning( 175 "Replay buffer's data can only train for {} steps. ".format(i) + 176 "You can modify data collect config, e.g. increasing n_sample, n_episode." 177 ) 178 break 179 train_data = train_data + train_data_demonstration 180 learner.train(train_data, collector.envstep) 181 if learner.policy.get_attribute('priority'): 182 # When collector, set replay_buffer_idx and replay_unique_id for each data item, priority = 1.\ 183 # When learner, assign priority for each data item according their loss 184 learner.priority_info_agent = deepcopy(learner.priority_info) 185 learner.priority_info_expert = deepcopy(learner.priority_info) 186 learner.priority_info_agent['priority'] = learner.priority_info['priority'][0:demo_batch_size] 187 learner.priority_info_agent['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][ 188 0:demo_batch_size] 189 learner.priority_info_agent['replay_unique_id'] = learner.priority_info['replay_unique_id'][ 190 0:demo_batch_size] 191 learner.priority_info_expert['priority'] = learner.priority_info['priority'][demo_batch_size:] 192 learner.priority_info_expert['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][ 193 demo_batch_size:] 194 learner.priority_info_expert['replay_unique_id'] = learner.priority_info['replay_unique_id'][ 195 demo_batch_size:] 196 # Expert data and demo data update their priority separately. 197 replay_buffer.update(learner.priority_info_agent) 198 expert_buffer.update(learner.priority_info_expert) 199 else: 200 # Learner will train ``update_per_collect`` times in one iteration. 201 train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) 202 if train_data is None: 203 # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times 204 logging.warning( 205 "Replay buffer's data can only train for {} steps. ".format(i) + 206 "You can modify data collect config, e.g. increasing n_sample, n_episode." 207 ) 208 break 209 learner.train(train_data, collector.envstep) 210 if learner.policy.get_attribute('priority'): 211 replay_buffer.update(learner.priority_info) 212 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: 213 break 214 215 # Learner's after_run hook. 216 learner.call_hook('after_run') 217 return policy