ding.entry.serial_entry_bco¶
ding.entry.serial_entry_bco
¶
Full Source Code
../ding/entry/serial_entry_bco.py
1import os 2import pickle 3import torch 4from functools import partial 5from tensorboardX import SummaryWriter 6from torch.utils.data import DataLoader 7from typing import Union, Optional, List, Any, Tuple, Dict 8 9from ding.worker import BaseLearner, BaseSerialCommander, InteractionSerialEvaluator, create_serial_collector 10from ding.config import read_config, compile_config 11from ding.utils import set_pkg_seed 12from ding.envs import get_vec_env_setting, create_env_manager 13from ding.policy.common_utils import default_preprocess_learn 14from ding.policy import create_policy 15from ding.utils.data.dataset import BCODataset 16from ding.world_model.idm import InverseDynamicsModel 17 18 19def load_expertdata(data: Dict[str, torch.Tensor]) -> BCODataset: 20 """ 21 loading from demonstration data, which only have obs and next_obs 22 action need to be inferred from Inverse Dynamics Model 23 """ 24 post_data = list() 25 for episode in range(len(data)): 26 for transition in data[episode]: 27 transition['episode_id'] = episode 28 post_data.append(transition) 29 post_data = default_preprocess_learn(post_data) 30 return BCODataset( 31 { 32 'obs': torch.cat((post_data['obs'], post_data['next_obs']), 1), 33 'episode_id': post_data['episode_id'], 34 'action': post_data['action'] 35 } 36 ) 37 38 39def load_agentdata(data) -> BCODataset: 40 """ 41 loading from policy data, which only have obs and next_obs as features and action as label 42 """ 43 post_data = list() 44 for episode in range(len(data)): 45 for transition in data[episode]: 46 transition['episode_id'] = episode 47 post_data.append(transition) 48 post_data = default_preprocess_learn(post_data) 49 return BCODataset( 50 { 51 'obs': torch.cat((post_data['obs'], post_data['next_obs']), 1), 52 'action': post_data['action'], 53 'episode_id': post_data['episode_id'] 54 } 55 ) 56 57 58def serial_pipeline_bco( 59 input_cfg: Union[str, Tuple[dict, dict]], 60 expert_cfg: Union[str, Tuple[dict, dict]], 61 seed: int = 0, 62 env_setting: Optional[List[Any]] = None, 63 model: Optional[torch.nn.Module] = None, 64 expert_model: Optional[torch.nn.Module] = None, 65 # model: Optional[torch.nn.Module] = None, 66 max_train_iter: Optional[int] = int(1e10), 67 max_env_step: Optional[int] = int(1e10), 68) -> None: 69 70 if isinstance(input_cfg, str): 71 cfg, create_cfg = read_config(input_cfg) 72 expert_cfg, expert_create_cfg = read_config(expert_cfg) 73 else: 74 cfg, create_cfg = input_cfg 75 expert_cfg, expert_create_cfg = expert_cfg 76 create_cfg.policy.type = create_cfg.policy.type + '_command' 77 expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command' 78 env_fn = None if env_setting is None else env_setting[0] 79 cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) 80 expert_cfg = compile_config( 81 expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True 82 ) 83 # Random seed 84 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 85 # Create main components: env, policy 86 if env_setting is None: 87 env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) 88 else: 89 env_fn, collector_env_cfg, evaluator_env_cfg = env_setting 90 91 # Generate Expert Data 92 if cfg.policy.collect.model_path is None: 93 with open(cfg.policy.collect.data_path, 'rb') as f: 94 data = pickle.load(f) 95 expert_learn_dataset = load_expertdata(data) 96 else: 97 expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect']) 98 expert_collector_env = create_env_manager( 99 expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg] 100 ) 101 expert_collector_env.seed(expert_cfg.seed) 102 expert_policy.collect_mode.load_state_dict(torch.load(cfg.policy.collect.model_path, map_location='cpu')) 103 104 expert_collector = create_serial_collector( 105 cfg.policy.collect.collector, # for episode collector 106 env=expert_collector_env, 107 policy=expert_policy.collect_mode, 108 exp_name=expert_cfg.exp_name 109 ) 110 # if expert policy is sac, eps kwargs is unexpected 111 if cfg.policy.continuous: 112 expert_data = expert_collector.collect(n_episode=100) 113 else: 114 policy_kwargs = {'eps': 0} 115 expert_data = expert_collector.collect(n_episode=100, policy_kwargs=policy_kwargs) 116 expert_learn_dataset = load_expertdata(expert_data) 117 expert_collector.reset_policy(expert_policy.collect_mode) 118 119 # Main components 120 tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) 121 policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) 122 learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) 123 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 124 evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) 125 collector_env.seed(cfg.seed) 126 evaluator_env.seed(cfg.seed, dynamic_seed=False) 127 collector = create_serial_collector( 128 cfg.policy.collect.collector, 129 env=collector_env, 130 policy=policy.collect_mode, 131 tb_logger=tb_logger, 132 exp_name=cfg.exp_name 133 ) 134 evaluator = InteractionSerialEvaluator( 135 cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name 136 ) 137 commander = BaseSerialCommander( 138 cfg.policy.other.commander, learner, collector, evaluator, None, policy=policy.command_mode 139 ) 140 learned_model = InverseDynamicsModel( 141 cfg.policy.model.obs_shape, cfg.policy.model.action_shape, cfg.bco.model.idm_encoder_hidden_size_list, 142 cfg.bco.model.action_space 143 ) 144 # ========== 145 # Main loop 146 # ========== 147 learner.call_hook('before_run') 148 collect_episode = int(cfg.policy.collect.n_episode * cfg.bco.alpha) 149 init_episode = True 150 while True: 151 collect_kwargs = commander.step() 152 # Evaluate policy performance 153 if evaluator.should_eval(learner.train_iter): 154 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 155 if stop: 156 break 157 158 if init_episode: 159 new_data = collector.collect( 160 n_episode=cfg.policy.collect.n_episode, train_iter=learner.train_iter, policy_kwargs=collect_kwargs 161 ) 162 init_episode = False 163 else: 164 new_data = collector.collect( 165 n_episode=collect_episode, train_iter=learner.train_iter, policy_kwargs=collect_kwargs 166 ) 167 learn_dataset = load_agentdata(new_data) 168 learn_dataloader = DataLoader(learn_dataset, cfg.bco.learn.idm_batch_size) 169 for i, train_data in enumerate(learn_dataloader): 170 idm_loss = learned_model.train( 171 train_data, 172 cfg.bco.learn.idm_train_epoch, 173 cfg.bco.learn.idm_learning_rate, 174 cfg.bco.learn.idm_weight_decay, 175 ) 176 # tb_logger.add_scalar("learner_iter/idm_loss", idm_loss, learner.train_iter) 177 # tb_logger.add_scalar("learner_step/idm_loss", idm_loss, collector.envstep) 178 # Generate state transitions from demonstrated state trajectories by IDM 179 expert_action_data = learned_model.predict_action(expert_learn_dataset.obs)['action'] 180 post_expert_dataset = BCODataset( 181 { 182 # next_obs are deleted 183 'obs': expert_learn_dataset.obs[:, 0:int(expert_learn_dataset.obs.shape[1] // 2)], 184 'action': expert_action_data, 185 'expert_action': expert_learn_dataset.action 186 } 187 ) # post_expert_dataset: Only obs and action are reserved for BC. next_obs are deleted 188 expert_learn_dataloader = DataLoader(post_expert_dataset, cfg.policy.learn.batch_size) 189 # Improve policy using BC 190 for epoch in range(cfg.policy.learn.train_epoch): 191 for i, train_data in enumerate(expert_learn_dataloader): 192 learner.train(train_data, collector.envstep) 193 if cfg.policy.learn.lr_decay: 194 learner.policy.get_attribute('lr_scheduler').step() 195 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: 196 break 197 198 # Learner's after_run hook. 199 learner.call_hook('after_run')