1from typing import Union, Optional, List, Any, Tuple 2import os 3import numpy as np 4import torch 5from ditk import logging 6from functools import partial 7from tensorboardX import SummaryWriter 8from copy import deepcopy 9 10from ding.envs import get_vec_env_setting, create_env_manager 11from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ 12 create_serial_collector 13from ding.config import read_config, compile_config 14from ding.policy import create_policy 15from ding.reward_model import create_reward_model 16from ding.utils import set_pkg_seed 17from .utils import random_collect 18 19 20def serial_pipeline_reward_model_onpolicy( 21 input_cfg: Union[str, Tuple[dict, dict]], 22 seed: int = 0, 23 env_setting: Optional[List[Any]] = None, 24 model: Optional[torch.nn.Module] = None, 25 max_train_iter: Optional[int] = int(1e10), 26 max_env_step: Optional[int] = int(1e10), 27) -> 'Policy': # noqa 28 """ 29 Overview: 30 Serial pipeline entry for on-policy RL with reward model. 31 Arguments: 32 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 33 ``str`` type means config file path. \ 34 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 35 - seed (:obj:`int`): Random seed. 36 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 37 ``BaseEnv`` subclass, collector env config, and evaluator env config. 38 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 39 - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. 40 - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. 41 Returns: 42 - policy (:obj:`Policy`): Converged policy. 43 """ 44 if isinstance(input_cfg, str): 45 cfg, create_cfg = read_config(input_cfg) 46 else: 47 cfg, create_cfg = deepcopy(input_cfg) 48 create_cfg.policy.type = create_cfg.policy.type + '_command' 49 env_fn = None if env_setting is None else env_setting[0] 50 cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) 51 # Create main components: env, policy 52 if env_setting is None: 53 env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) 54 else: 55 env_fn, collector_env_cfg, evaluator_env_cfg = env_setting 56 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 57 evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) 58 collector_env.seed(cfg.seed) 59 evaluator_env.seed(cfg.seed, dynamic_seed=False) 60 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 61 policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) 62 63 # Create worker components: learner, collector, evaluator, replay buffer, commander. 64 tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) 65 learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) 66 collector = create_serial_collector( 67 cfg.policy.collect.collector, 68 env=collector_env, 69 policy=policy.collect_mode, 70 tb_logger=tb_logger, 71 exp_name=cfg.exp_name 72 ) 73 evaluator = InteractionSerialEvaluator( 74 cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name 75 ) 76 replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) 77 commander = BaseSerialCommander( 78 cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode 79 ) 80 reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger) 81 82 # ========== 83 # Main loop 84 # ========== 85 # Learner's before_run hook. 86 learner.call_hook('before_run') 87 88 # Accumulate plenty of data at the beginning of training. 89 if cfg.policy.get('random_collect_size', 0) > 0: 90 random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer) 91 count = 0 92 best_return = -np.inf 93 while True: 94 collect_kwargs = commander.step() 95 # Evaluate policy performance 96 if evaluator.should_eval(learner.train_iter): 97 stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) 98 eval_return_mean = np.mean(eval_info['eval_episode_return']) 99 if eval_return_mean >= best_return: 100 reward_model.save(path=cfg.exp_name, name='best') 101 best_return = eval_return_mean 102 if stop: 103 break 104 new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1) 105 while new_data_count < target_new_data_count: 106 new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) 107 new_data_count += len(new_data) 108 # collect data for reward_model training 109 reward_model.collect_data(new_data) 110 # update reward_model 111 reward_model.train() 112 if count % cfg.reward_model.clear_buffer_per_iters == 0: 113 reward_model.clear_data() 114 # Learn policy from collected data 115 for i in range(cfg.policy.learn.update_per_collect): 116 # Learner will train ``update_per_collect`` times in one iteration. 117 train_data = new_data 118 if train_data is None: 119 # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times 120 logging.warning( 121 "Replay buffer's data can only train for {} steps. ".format(i) + 122 "You can modify data collect config, e.g. increasing n_sample, n_episode." 123 ) 124 break 125 # update train_data reward using the augmented reward 126 train_data_augmented = reward_model.estimate(train_data) 127 learner.train(train_data_augmented, collector.envstep) 128 if learner.policy.get_attribute('priority'): 129 replay_buffer.update(learner.priority_info) 130 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: 131 break 132 count += 1 133 134 # Learner's after_run hook. 135 learner.call_hook('after_run') 136 reward_model.save(path=cfg.exp_name, name='last') 137 return policy