ding.entry.serial_entry_offline¶
ding.entry.serial_entry_offline
¶
serial_pipeline_offline(input_cfg, seed=0, env_setting=None, model=None, max_train_iter=int(10000000000.0))
¶
Overview
Serial pipeline entry.
Arguments:
- input_cfg (:obj:Union[str, Tuple[dict, dict]]): Config in dict type. str type means config file path. Tuple[dict, dict] type means [user_config, create_cfg].
- seed (:obj:int): Random seed.
- env_setting (:obj:Optional[List[Any]]): A list with 3 elements: BaseEnv subclass, collector env config, and evaluator env config.
- model (:obj:Optional[torch.nn.Module]): Instance of torch.nn.Module.
- max_train_iter (:obj:Optional[int]): Maximum policy update iterations in training.
Returns:
- policy (:obj:Policy): Converged policy.
Full Source Code
../ding/entry/serial_entry_offline.py
1from typing import Union, Optional, List, Any, Tuple 2import os 3import torch 4from functools import partial 5from tensorboardX import SummaryWriter 6from copy import deepcopy 7from torch.utils.data import DataLoader 8from torch.utils.data.distributed import DistributedSampler 9 10from ding.envs import get_vec_env_setting, create_env_manager 11from ding.worker import BaseLearner, InteractionSerialEvaluator 12from ding.config import read_config, compile_config 13from ding.policy import create_policy 14from ding.utils import set_pkg_seed, get_world_size, get_rank 15from ding.utils.data import create_dataset 16 17 18def serial_pipeline_offline( 19 input_cfg: Union[str, Tuple[dict, dict]], 20 seed: int = 0, 21 env_setting: Optional[List[Any]] = None, 22 model: Optional[torch.nn.Module] = None, 23 max_train_iter: Optional[int] = int(1e10), 24) -> 'Policy': # noqa 25 """ 26 Overview: 27 Serial pipeline entry. 28 Arguments: 29 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 30 ``str`` type means config file path. \ 31 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 32 - seed (:obj:`int`): Random seed. 33 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 34 ``BaseEnv`` subclass, collector env config, and evaluator env config. 35 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 36 - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. 37 Returns: 38 - policy (:obj:`Policy`): Converged policy. 39 """ 40 if isinstance(input_cfg, str): 41 cfg, create_cfg = read_config(input_cfg) 42 else: 43 cfg, create_cfg = deepcopy(input_cfg) 44 create_cfg.policy.type = create_cfg.policy.type + '_command' 45 cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) 46 47 # Dataset 48 dataset = create_dataset(cfg) 49 sampler, shuffle = None, True 50 if get_world_size() > 1: 51 sampler, shuffle = DistributedSampler(dataset), False 52 dataloader = DataLoader( 53 dataset, 54 # Dividing by get_world_size() here simply to make multigpu 55 # settings mathmatically equivalent to the singlegpu setting. 56 # If the training efficiency is the bottleneck, feel free to 57 # use the original batch size per gpu and increase learning rate 58 # correspondingly. 59 cfg.policy.learn.batch_size // get_world_size(), 60 # cfg.policy.learn.batch_size 61 shuffle=shuffle, 62 sampler=sampler, 63 collate_fn=lambda x: x, 64 pin_memory=cfg.policy.cuda, 65 ) 66 # Env, Policy 67 try: 68 if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats: 69 cfg.env.norm_obs.offline_stats.update({'mean': dataset.mean, 'std': dataset.std}) 70 except (KeyError, AttributeError): 71 pass 72 env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False) 73 evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) 74 # Random seed 75 evaluator_env.seed(cfg.seed, dynamic_seed=False) 76 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 77 policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval']) 78 if cfg.policy.collect.data_type == 'diffuser_traj': 79 policy.init_data_normalizer(dataset.normalizer) 80 81 if hasattr(policy, 'set_statistic'): 82 # useful for setting action bounds for ibc 83 policy.set_statistic(dataset.statistics) 84 85 # Otherwise, directory may conflicts in the multigpu settings. 86 if get_rank() == 0: 87 tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) 88 else: 89 tb_logger = None 90 learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) 91 evaluator = InteractionSerialEvaluator( 92 cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name 93 ) 94 # ========== 95 # Main loop 96 # ========== 97 # Learner's before_run hook. 98 learner.call_hook('before_run') 99 stop = False 100 101 for epoch in range(cfg.policy.learn.train_epoch): 102 if get_world_size() > 1: 103 dataloader.sampler.set_epoch(epoch) 104 for train_data in dataloader: 105 learner.train(train_data) 106 107 # Evaluate policy at most once per epoch. 108 if evaluator.should_eval(learner.train_iter): 109 stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) 110 111 if stop or learner.train_iter >= max_train_iter: 112 stop = True 113 break 114 115 learner.call_hook('after_run') 116 print('final reward is: {}'.format(reward)) 117 return policy, stop