1from typing import Optional, Tuple 2import os 3import torch 4from ditk import logging 5from functools import partial 6from tensorboardX import SummaryWriter 7from copy import deepcopy 8import numpy as np 9 10from ding.envs import get_vec_env_setting, create_env_manager 11from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ 12 create_serial_collector 13from ding.config import read_config, compile_config 14from ding.policy import create_policy 15from ding.reward_model import create_reward_model 16from ding.utils import set_pkg_seed 17from ding.entry import collect_demo_data 18from ding.utils import save_file 19from .utils import random_collect 20 21 22def save_reward_model(path, reward_model, weights_name='best'): 23 path = os.path.join(path, 'reward_model', 'ckpt') 24 if not os.path.exists(path): 25 try: 26 os.makedirs(path) 27 except FileExistsError: 28 pass 29 path = os.path.join(path, 'ckpt_{}.pth.tar'.format(weights_name)) 30 state_dict = reward_model.state_dict() 31 save_file(path, state_dict) 32 print('Saved reward model ckpt in {}'.format(path)) 33 34 35def serial_pipeline_gail( 36 input_cfg: Tuple[dict, dict], 37 expert_cfg: Tuple[dict, dict], 38 seed: int = 0, 39 model: Optional[torch.nn.Module] = None, 40 max_train_iter: Optional[int] = int(1e10), 41 max_env_step: Optional[int] = int(1e10), 42 collect_data: bool = True, 43) -> 'Policy': # noqa 44 """ 45 Overview: 46 Serial pipeline entry for GAIL reward model. 47 Arguments: 48 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 49 ``str`` type means config file path. \ 50 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 51 - expert_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Expert config in dict type. \ 52 ``str`` type means config file path. \ 53 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 54 - seed (:obj:`int`): Random seed. 55 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 56 - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. 57 - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. 58 - collect_data (:obj:`bool`): Collect expert data. 59 Returns: 60 - policy (:obj:`Policy`): Converged policy. 61 """ 62 if isinstance(input_cfg, str): 63 cfg, create_cfg = read_config(input_cfg) 64 else: 65 cfg, create_cfg = deepcopy(input_cfg) 66 if isinstance(expert_cfg, str): 67 expert_cfg, expert_create_cfg = read_config(expert_cfg) 68 else: 69 expert_cfg, expert_create_cfg = expert_cfg 70 create_cfg.policy.type = create_cfg.policy.type + '_command' 71 cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg, save_cfg=True) 72 if 'data_path' not in cfg.reward_model: 73 cfg.reward_model.data_path = cfg.exp_name 74 # Load expert data 75 if collect_data: 76 if expert_cfg.policy.get('other', None) is not None and expert_cfg.policy.other.get('eps', None) is not None: 77 expert_cfg.policy.other.eps.collect = -1 78 if expert_cfg.policy.get('load_path', None) is None: 79 expert_cfg.policy.load_path = cfg.reward_model.expert_model_path 80 collect_demo_data( 81 (expert_cfg, expert_create_cfg), 82 seed, 83 state_dict_path=expert_cfg.policy.load_path, 84 expert_data_path=cfg.reward_model.data_path + '/expert_data.pkl', 85 collect_count=cfg.reward_model.collect_count 86 ) 87 # Create main components: env, policy 88 env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) 89 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 90 evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) 91 collector_env.seed(cfg.seed) 92 evaluator_env.seed(cfg.seed, dynamic_seed=False) 93 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 94 policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) 95 96 # Create worker components: learner, collector, evaluator, replay buffer, commander. 97 tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) 98 learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) 99 collector = create_serial_collector( 100 cfg.policy.collect.collector, 101 env=collector_env, 102 policy=policy.collect_mode, 103 tb_logger=tb_logger, 104 exp_name=cfg.exp_name 105 ) 106 evaluator = InteractionSerialEvaluator( 107 cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name 108 ) 109 replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) 110 commander = BaseSerialCommander( 111 cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode 112 ) 113 reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger) 114 115 # ========== 116 # Main loop 117 # ========== 118 # Learner's before_run hook. 119 learner.call_hook('before_run') 120 121 # Accumulate plenty of data at the beginning of training. 122 if cfg.policy.get('random_collect_size', 0) > 0: 123 random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer) 124 best_reward = -np.inf 125 while True: 126 collect_kwargs = commander.step() 127 # Evaluate policy performance 128 if evaluator.should_eval(learner.train_iter): 129 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 130 reward_mean = np.array([r['eval_episode_return'] for r in reward]).mean() 131 if reward_mean >= best_reward: 132 save_reward_model(cfg.exp_name, reward_model, 'best') 133 best_reward = reward_mean 134 if stop: 135 break 136 new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1) 137 while new_data_count < target_new_data_count: 138 new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) 139 new_data_count += len(new_data) 140 # collect data for reward_model training 141 reward_model.collect_data(new_data) 142 replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) 143 # update reward_model 144 reward_model.train() 145 reward_model.clear_data() 146 # Learn policy from collected data 147 for i in range(cfg.policy.learn.update_per_collect): 148 # Learner will train ``update_per_collect`` times in one iteration. 149 train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) 150 if train_data is None: 151 # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times 152 logging.warning( 153 "Replay buffer's data can only train for {} steps. ".format(i) + 154 "You can modify data collect config, e.g. increasing n_sample, n_episode." 155 ) 156 break 157 # update train_data reward using the augmented reward 158 train_data_augmented = reward_model.estimate(train_data) 159 learner.train(train_data_augmented, collector.envstep) 160 if learner.policy.get_attribute('priority'): 161 replay_buffer.update(learner.priority_info) 162 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: 163 break 164 165 # Learner's after_run hook. 166 learner.call_hook('after_run') 167 save_reward_model(cfg.exp_name, reward_model, 'last') 168 # evaluate 169 # evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 170 return policy