Skip to content

ding.entry.serial_entry_r2d3

ding.entry.serial_entry_r2d3

serial_pipeline_r2d3(input_cfg, expert_cfg, seed=0, env_setting=None, model=None, expert_model=None, max_train_iter=int(10000000000.0), max_env_step=int(10000000000.0))

Overview

Serial pipeline r2d3 entry: we create this serial pipeline in order to implement r2d3 in DI-engine. For now, we support the following envs Lunarlander, Pong, Qbert. The demonstration data come from the expert model. We use a well-trained model to generate demonstration data online

Arguments: - input_cfg (:obj:Union[str, Tuple[dict, dict]]): Config in dict type. str type means config file path. Tuple[dict, dict] type means [user_config, create_cfg]. - seed (:obj:int): Random seed. - env_setting (:obj:Optional[List[Any]]): A list with 3 elements: BaseEnv subclass, collector env config, and evaluator env config. - model (:obj:Optional[torch.nn.Module]): Instance of torch.nn.Module. - expert_model (:obj:Optional[torch.nn.Module]): Instance of torch.nn.Module. The default model is DQN(**cfg.policy.model) - max_train_iter (:obj:Optional[int]): Maximum policy update iterations in training. - max_env_step (:obj:Optional[int]): Maximum collected environment interaction steps. Returns: - policy (:obj:Policy): Converged policy.

Full Source Code

../ding/entry/serial_entry_r2d3.py

1from typing import Union, Optional, List, Any, Tuple 2import os 3import torch 4import numpy as np 5from ditk import logging 6from functools import partial 7from tensorboardX import SummaryWriter 8from copy import deepcopy 9 10from ding.envs import get_vec_env_setting, create_env_manager 11from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ 12 create_serial_collector 13from ding.config import read_config, compile_config 14from ding.policy import create_policy 15from ding.utils import set_pkg_seed 16from .utils import random_collect, mark_not_expert 17 18 19def serial_pipeline_r2d3( 20 input_cfg: Union[str, Tuple[dict, dict]], 21 expert_cfg: Union[str, Tuple[dict, dict]], 22 seed: int = 0, 23 env_setting: Optional[List[Any]] = None, 24 model: Optional[torch.nn.Module] = None, 25 expert_model: Optional[torch.nn.Module] = None, 26 max_train_iter: Optional[int] = int(1e10), 27 max_env_step: Optional[int] = int(1e10), 28) -> 'Policy': # noqa 29 """ 30 Overview: 31 Serial pipeline r2d3 entry: we create this serial pipeline in order to\ 32 implement r2d3 in DI-engine. For now, we support the following envs\ 33 Lunarlander, Pong, Qbert. The demonstration\ 34 data come from the expert model. We use a well-trained model to \ 35 generate demonstration data online 36 Arguments: 37 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 38 ``str`` type means config file path. \ 39 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 40 - seed (:obj:`int`): Random seed. 41 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 42 ``BaseEnv`` subclass, collector env config, and evaluator env config. 43 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 44 - expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\ 45 The default model is DQN(**cfg.policy.model) 46 - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. 47 - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. 48 Returns: 49 - policy (:obj:`Policy`): Converged policy. 50 """ 51 if isinstance(input_cfg, str): 52 cfg, create_cfg = read_config(input_cfg) 53 expert_cfg, expert_create_cfg = read_config(expert_cfg) 54 else: 55 cfg, create_cfg = deepcopy(input_cfg) 56 expert_cfg, expert_create_cfg = expert_cfg 57 create_cfg.policy.type = create_cfg.policy.type + '_command' 58 expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command' 59 env_fn = None if env_setting is None else env_setting[0] 60 cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) 61 expert_cfg = compile_config( 62 expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True 63 ) 64 # Create main components: env, policy 65 if env_setting is None: 66 env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) 67 else: 68 env_fn, collector_env_cfg, evaluator_env_cfg = env_setting 69 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 70 expert_collector_env = create_env_manager( 71 expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg] 72 ) 73 evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) 74 expert_collector_env.seed(cfg.seed) 75 collector_env.seed(cfg.seed) 76 evaluator_env.seed(cfg.seed, dynamic_seed=False) 77 expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect', 'command']) 78 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 79 policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) 80 expert_policy.collect_mode.load_state_dict(torch.load(expert_cfg.policy.collect.model_path, map_location='cpu')) 81 # Create worker components: learner, collector, evaluator, replay buffer, commander. 82 tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) 83 learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) 84 collector = create_serial_collector( 85 cfg.policy.collect.collector, 86 env=collector_env, 87 policy=policy.collect_mode, 88 tb_logger=tb_logger, 89 exp_name=cfg.exp_name 90 ) 91 expert_collector = create_serial_collector( 92 expert_cfg.policy.collect.collector, 93 env=expert_collector_env, 94 policy=expert_policy.collect_mode, 95 tb_logger=tb_logger, 96 exp_name=expert_cfg.exp_name 97 ) 98 evaluator = InteractionSerialEvaluator( 99 cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name 100 ) 101 replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) 102 commander = BaseSerialCommander( 103 cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode 104 ) 105 expert_commander = BaseSerialCommander( 106 expert_cfg.policy.other.commander, learner, expert_collector, evaluator, replay_buffer, 107 expert_policy.command_mode 108 ) # we create this to avoid the issue of eps, this is an issue due to the sample collector part. 109 expert_collect_kwargs = expert_commander.step() 110 if 'eps' in expert_collect_kwargs: 111 expert_collect_kwargs['eps'] = -1 112 # ========== 113 # Main loop 114 # ========== 115 # Learner's before_run hook. 116 learner.call_hook('before_run') 117 if expert_cfg.policy.learn.expert_replay_buffer_size != 0: # for ablation study 118 expert_buffer = create_buffer(expert_cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) 119 expert_data = expert_collector.collect( 120 n_sample=expert_cfg.policy.learn.expert_replay_buffer_size, 121 train_iter=learner.train_iter, 122 policy_kwargs=expert_collect_kwargs 123 ) 124 125 for i in range(len(expert_data)): 126 # set is_expert flag(expert 1, agent 0) 127 # expert_data[i]['is_expert'] = 1 # for transition-based alg. 128 expert_data[i]['is_expert'] = [1] * expert_cfg.policy.collect.unroll_len # for rnn/sequence-based alg. 129 expert_buffer.push(expert_data, cur_collector_envstep=0) 130 for _ in range(cfg.policy.learn.per_train_iter_k): # pretrain 131 if evaluator.should_eval(learner.train_iter): 132 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 133 if stop: 134 break 135 # Learn policy from collected data 136 # Expert_learner will train ``update_per_collect == 1`` times in one iteration. 137 train_data = expert_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) 138 learner.train(train_data, collector.envstep) 139 if learner.policy.get_attribute('priority'): 140 expert_buffer.update(learner.priority_info) 141 learner.priority_info = {} 142 # Accumulate plenty of data at the beginning of training. 143 if cfg.policy.get('random_collect_size', 0) > 0: 144 random_collect( 145 cfg.policy, policy, collector, collector_env, commander, replay_buffer, postprocess_data_fn=mark_not_expert 146 ) 147 while True: 148 collect_kwargs = commander.step() 149 # Evaluate policy performance 150 if evaluator.should_eval(learner.train_iter): 151 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 152 if stop: 153 break 154 # Collect data by default config n_sample/n_episode 155 new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) 156 157 for i in range(len(new_data)): 158 # set is_expert flag(expert 1, agent 0) 159 new_data[i]['is_expert'] = [0] * expert_cfg.policy.collect.unroll_len 160 161 replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) 162 # Learn policy from collected data 163 for i in range(cfg.policy.learn.update_per_collect): 164 if expert_cfg.policy.learn.expert_replay_buffer_size != 0: 165 # Learner will train ``update_per_collect`` times in one iteration. 166 167 # The hyperparameter pho, the demo ratio, control the propotion of data coming\ 168 # from expert demonstrations versus from the agent's own experience. 169 expert_batch_size = int( 170 np.float32(np.random.rand(learner.policy.get_attribute('batch_size')) < cfg.policy.collect.pho 171 ).sum() 172 ) 173 agent_batch_size = (learner.policy.get_attribute('batch_size')) - expert_batch_size 174 train_data_agent = replay_buffer.sample(agent_batch_size, learner.train_iter) 175 train_data_expert = expert_buffer.sample(expert_batch_size, learner.train_iter) 176 if train_data_agent is None: 177 # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times 178 logging.warning( 179 "Replay buffer's data can only train for {} steps. ".format(i) + 180 "You can modify data collect config, e.g. increasing n_sample, n_episode." 181 ) 182 break 183 train_data = train_data_agent + train_data_expert 184 learner.train(train_data, collector.envstep) 185 if learner.policy.get_attribute('priority'): 186 # When collector, set replay_buffer_idx and replay_unique_id for each data item, priority = 1.\ 187 # When learner, assign priority for each data item according their loss 188 learner.priority_info_agent = deepcopy(learner.priority_info) 189 learner.priority_info_expert = deepcopy(learner.priority_info) 190 learner.priority_info_agent['priority'] = learner.priority_info['priority'][0:agent_batch_size] 191 learner.priority_info_agent['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][ 192 0:agent_batch_size] 193 learner.priority_info_agent['replay_unique_id'] = learner.priority_info['replay_unique_id'][ 194 0:agent_batch_size] 195 196 learner.priority_info_expert['priority'] = learner.priority_info['priority'][agent_batch_size:] 197 learner.priority_info_expert['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][ 198 agent_batch_size:] 199 learner.priority_info_expert['replay_unique_id'] = learner.priority_info['replay_unique_id'][ 200 agent_batch_size:] 201 202 # Expert data and demo data update their priority separately. 203 replay_buffer.update(learner.priority_info_agent) 204 expert_buffer.update(learner.priority_info_expert) 205 else: 206 # Learner will train ``update_per_collect`` times in one iteration. 207 train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) 208 if train_data is None: 209 # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times 210 logging.warning( 211 "Replay buffer's data can only train for {} steps. ".format(i) + 212 "You can modify data collect config, e.g. increasing n_sample, n_episode." 213 ) 214 break 215 learner.train(train_data, collector.envstep) 216 if learner.policy.get_attribute('priority'): 217 replay_buffer.update(learner.priority_info) 218 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: 219 break 220 221 # Learner's after_run hook. 222 learner.call_hook('after_run') 223 return policy