Skip to content

ding.framework.middleware

ding.framework.middleware

multistep_trainer(policy, log_freq=100)

Overview

The middleware that executes training for a target num of steps.

Arguments: - policy (:obj:Policy): The policy specialized for multi-step training. - log_freq (:obj:int): The frequency (iteration) of showing log.

offpolicy_data_fetcher(cfg, buffer_, data_shortage_warning=False)

Overview

The return function is a generator which meanly fetch a batch of data from a buffer, a list of buffers, or a dict of buffers.

Arguments: - cfg (:obj:EasyDict): Config which should contain the following keys: cfg.policy.learn.batch_size. - buffer (:obj:Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]]): The buffer where the data is fetched from. Buffer type means a buffer. List[Tuple[Buffer, float]] type means a list of tuple. In each tuple there is a buffer and a float. The float defines, how many batch_size is the size of the data which is sampled from the corresponding buffer. Dict[str, Buffer] type means a dict in which the value of each element is a buffer. For each key-value pair of dict, batch_size of data will be sampled from the corresponding buffer and assigned to the same key of ctx.train_data. - data_shortage_warning (:obj:bool): Whether to output warning when data shortage occurs in fetching.

data_pusher(cfg, buffer_, group_by_env=None)

Overview

Push episodes or trajectories into the buffer.

Arguments: - cfg (:obj:EasyDict): Config. - buffer (:obj:Buffer): Buffer to push the data in.

offline_data_fetcher(cfg, dataset, collate_fn=lambda x: x)

Overview

The outer function transforms a Pytorch Dataset to DataLoader. The return function is a generator which each time fetches a batch of data from the previous DataLoader. Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html and https://pytorch.org/docs/stable/data.html for more details.

Arguments: - cfg (:obj:EasyDict): Config which should contain the following keys: cfg.policy.learn.batch_size. - dataset (:obj:Dataset): The dataset of type torch.utils.data.Dataset which stores the data.

offline_data_saver(data_path, data_type='hdf5')

Overview

Save the expert data of offline RL in a directory.

Arguments: - data_path (:obj:str): File path where the expert data will be written into, which is usually ./expert.pkl'. - data_type (:obj:str): Define the type of the saved data. The type of saved data is pkl if data_type == 'naive'. The type of saved data is hdf5 if data_type == 'hdf5'.

sqil_data_pusher(cfg, buffer_, expert)

Overview

Push trajectories into the buffer in sqil learning pipeline.

Arguments: - cfg (:obj:EasyDict): Config. - buffer (:obj:Buffer): Buffer to push the data in. - expert (:obj:bool): Whether the pushed data is expert data or not. In each element of the pushed data, the reward will be set to 1 if this attribute is True, otherwise 0.

buffer_saver(cfg, buffer_, every_envstep=1000, replace=False)

Overview

Save current buffer data.

Arguments: - cfg (:obj:EasyDict): Config. - buffer (:obj:Buffer): Buffer to push the data in. - every_envstep (:obj:int): save at every env step. - replace (:obj:bool): Whether replace the last file.

inferencer(seed, policy, env)

Overview

The middleware that executes the inference process.

Arguments: - seed (:obj:int): Random seed. - policy (:obj:Policy): The policy to be inferred. - env (:obj:BaseEnvManager): The env where the inference process is performed. The env.ready_obs (:obj:tnp.array) will be used as model input.

rolloutor(policy, env, transitions, collect_print_freq=100)

Overview

The middleware that executes the transition process in the env.

Arguments: - policy (:obj:Policy): The policy to be used during transition. - env (:obj:BaseEnvManager): The env for the collection, the BaseEnvManager object or its derivatives are supported. - transitions (:obj:TransitionList): The transition information which will be filled in this process, including obs, next_obs, action, logit, value, reward and done.

interaction_evaluator(cfg, policy, env, render=False, **kwargs)

Overview

The middleware that executes the evaluation.

Arguments: - cfg (:obj:EasyDict): Config. - policy (:obj:Policy): The policy to be evaluated. - env (:obj:BaseEnvManager): The env for the evaluation. - render (:obj:bool): Whether to render env images and policy logits. - kwargs: (:obj:Any): Other arguments for specific evaluation.

interaction_evaluator_ttorch(seed, policy, env, n_evaluator_episode=None, stop_value=np.inf, eval_freq=1000, render=False)

Overview

The middleware that executes the evaluation with ttorch data.

Arguments: - policy (:obj:Policy): The policy to be evaluated. - env (:obj:BaseEnvManager): The env for the evaluation. - render (:obj:bool): Whether to render env images and policy logits.

online_logger(record_train_iter=False, train_show_freq=100)

Overview

Create an online RL tensorboard logger for recording training and evaluation metrics.

Arguments: - record_train_iter (:obj:bool): Whether to record training iteration. Default is False. - train_show_freq (:obj:int): Frequency of showing training logs. Default is 100. Returns: - _logger (:obj:Callable): A logger function that takes an OnlineRLContext object as input. Raises: - RuntimeError: If writer is None. - NotImplementedError: If the key of train_output is not supported, such as "scalars".

Examples:

>>> task.use(online_logger(record_train_iter=False, train_show_freq=1000))

offline_logger(train_show_freq=100)

Overview

Create an offline RL tensorboard logger for recording training and evaluation metrics.

Arguments: - train_show_freq (:obj:int): Frequency of showing training logs. Defaults to 100. Returns: - _logger (:obj:Callable): A logger function that takes an OfflineRLContext object as input. Raises: - RuntimeError: If writer is None. - NotImplementedError: If the key of train_output is not supported, such as "scalars".

Examples:

>>> task.use(offline_logger(train_show_freq=1000))

wandb_online_logger(record_path=None, cfg=None, exp_config=None, metric_list=None, env=None, model=None, anonymous=False, project_name='default-project', run_name=None, wandb_sweep=False)

Overview

Wandb visualizer to track the experiment.

Arguments: - record_path (:obj:str): The path to save the replay of simulation. - cfg (:obj:Union[dict, EasyDict]): Config, a dict of following settings: - gradient_logger: boolean. Whether to track the gradient. - plot_logger: boolean. Whether to track the metrics like reward and loss. - video_logger: boolean. Whether to upload the rendering video replay. - action_logger: boolean. q_value or action probability. - return_logger: boolean. Whether to track the return value. - metric_list (:obj:Optional[List[str]]): Logged metric list, specialized by different policies. - env (:obj:BaseEnvManagerV2): Evaluator environment. - model (:obj:nn.Module): Policy neural network model. - anonymous (:obj:bool): Open the anonymous mode of wandb or not. The anonymous mode allows visualization of data without wandb count. - project_name (:obj:str): The name of wandb project. - run_name (:obj:str): The name of wandb run. - wandb_sweep (:obj:bool): Whether to use wandb sweep. ''' Returns: - _plot (:obj:Callable): A logger function that takes an OnlineRLContext object as input.

wandb_offline_logger(record_path=None, cfg=None, exp_config=None, metric_list=None, env=None, model=None, anonymous=False, project_name='default-project', run_name=None, wandb_sweep=False)

Overview

Wandb visualizer to track the experiment.

Arguments: - record_path (:obj:str): The path to save the replay of simulation. - cfg (:obj:Union[dict, EasyDict]): Config, a dict of following settings: - gradient_logger: boolean. Whether to track the gradient. - plot_logger: boolean. Whether to track the metrics like reward and loss. - video_logger: boolean. Whether to upload the rendering video replay. - action_logger: boolean. q_value or action probability. - return_logger: boolean. Whether to track the return value. - vis_dataset: boolean. Whether to visualize the dataset. - metric_list (:obj:Optional[List[str]]): Logged metric list, specialized by different policies. - env (:obj:BaseEnvManagerV2): Evaluator environment. - model (:obj:nn.Module): Policy neural network model. - anonymous (:obj:bool): Open the anonymous mode of wandb or not. The anonymous mode allows visualization of data without wandb count. - project_name (:obj:str): The name of wandb project. - run_name (:obj:str): The name of wandb run. - wandb_sweep (:obj:bool): Whether to use wandb sweep. ''' Returns: - _plot (:obj:Callable): A logger function that takes an OfflineRLContext object as input.

eps_greedy_handler(cfg)

Overview

The middleware that computes epsilon value according to the env_step.

Arguments: - cfg (:obj:EasyDict): Config.

eps_greedy_masker()

Overview

The middleware that returns masked epsilon value and stop generating actions by the e_greedy method.

gae_estimator(cfg, policy, buffer_=None)

Overview

Calculate value using observation of input data, then call function gae to get advantage. The processed data will be pushed into buffer_ if buffer_ is not None, otherwise it will be assigned to ctx.train_data.

Arguments: - cfg (:obj:EasyDict): Config which should contain the following keys: cfg.policy.collect.discount_factor, cfg.policy.collect.gae_lambda. - policy (:obj:Policy): Policy in policy.collect_mode, used to get model to calculate value. - buffer_ (:obj:Optional[Buffer]): The buffer_ to push the processed data in if buffer_ is not None.

reward_estimator(cfg, reward_model)

Overview

Estimate the reward of train_data using reward_model.

Arguments: - cfg (:obj:EasyDict): Config. - reward_model (:obj:BaseRewardModel): Reward model.

her_data_enhancer(cfg, buffer_, her_reward_model)

Overview

Fetch a batch of data/episode from buffer_, then use her_reward_model to get HER processed episodes from original episodes.

Arguments: - cfg (:obj:EasyDict): Config which should contain the following keys if her_reward_model.episode_size is None: cfg.policy.learn.batch_size. - buffer_ (:obj:Buffer): Buffer to sample data from. - her_reward_model (:obj:HerRewardModel): Hindsight Experience Replay (HER) model which is used to process episodes.

priority_calculator(priority_calculation_fn)

Overview

The middleware that calculates the priority of the collected data.

Arguments: - priority_calculation_fn (:obj:Callable): The function that calculates the priority of the collected data.

epoch_timer(print_per=1, smooth_window=10)

Overview

Print time cost of each epoch.

Arguments: - print_per (:obj:int): Print each N epoch. - smooth_window (:obj:int): The window size to smooth the mean.

Full Source Code

../ding/framework/middleware/__init__.py

1from .functional import * 2from .collector import StepCollector, EpisodeCollector, PPOFStepCollector 3from .learner import OffPolicyLearner, HERLearner 4from .ckpt_handler import CkptSaver 5from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger 6from .barrier import Barrier, BarrierRuntime 7from .data_fetcher import OfflineMemoryDataFetcher