ding.framework.middleware¶
ding.framework.middleware
¶
multistep_trainer(policy, log_freq=100)
¶
Overview
The middleware that executes training for a target num of steps.
Arguments:
- policy (:obj:Policy): The policy specialized for multi-step training.
- log_freq (:obj:int): The frequency (iteration) of showing log.
offpolicy_data_fetcher(cfg, buffer_, data_shortage_warning=False)
¶
Overview
The return function is a generator which meanly fetch a batch of data from a buffer, a list of buffers, or a dict of buffers.
Arguments:
- cfg (:obj:EasyDict): Config which should contain the following keys: cfg.policy.learn.batch_size.
- buffer (:obj:Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]]): The buffer where the data is fetched from. Buffer type means a buffer. List[Tuple[Buffer, float]] type means a list of tuple. In each tuple there is a buffer and a float. The float defines, how many batch_size is the size of the data which is sampled from the corresponding buffer. Dict[str, Buffer] type means a dict in which the value of each element is a buffer. For each key-value pair of dict, batch_size of data will be sampled from the corresponding buffer and assigned to the same key of ctx.train_data.
- data_shortage_warning (:obj:bool): Whether to output warning when data shortage occurs in fetching.
data_pusher(cfg, buffer_, group_by_env=None)
¶
Overview
Push episodes or trajectories into the buffer.
Arguments:
- cfg (:obj:EasyDict): Config.
- buffer (:obj:Buffer): Buffer to push the data in.
offline_data_fetcher(cfg, dataset, collate_fn=lambda x: x)
¶
Overview
The outer function transforms a Pytorch Dataset to DataLoader. The return function is a generator which each time fetches a batch of data from the previous DataLoader. Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html and https://pytorch.org/docs/stable/data.html for more details.
Arguments:
- cfg (:obj:EasyDict): Config which should contain the following keys: cfg.policy.learn.batch_size.
- dataset (:obj:Dataset): The dataset of type torch.utils.data.Dataset which stores the data.
offline_data_saver(data_path, data_type='hdf5')
¶
Overview
Save the expert data of offline RL in a directory.
Arguments:
- data_path (:obj:str): File path where the expert data will be written into, which is usually ./expert.pkl'.
- data_type (:obj:str): Define the type of the saved data. The type of saved data is pkl if data_type == 'naive'. The type of saved data is hdf5 if data_type == 'hdf5'.
sqil_data_pusher(cfg, buffer_, expert)
¶
Overview
Push trajectories into the buffer in sqil learning pipeline.
Arguments:
- cfg (:obj:EasyDict): Config.
- buffer (:obj:Buffer): Buffer to push the data in.
- expert (:obj:bool): Whether the pushed data is expert data or not. In each element of the pushed data, the reward will be set to 1 if this attribute is True, otherwise 0.
buffer_saver(cfg, buffer_, every_envstep=1000, replace=False)
¶
Overview
Save current buffer data.
Arguments:
- cfg (:obj:EasyDict): Config.
- buffer (:obj:Buffer): Buffer to push the data in.
- every_envstep (:obj:int): save at every env step.
- replace (:obj:bool): Whether replace the last file.
inferencer(seed, policy, env)
¶
Overview
The middleware that executes the inference process.
Arguments:
- seed (:obj:int): Random seed.
- policy (:obj:Policy): The policy to be inferred.
- env (:obj:BaseEnvManager): The env where the inference process is performed. The env.ready_obs (:obj:tnp.array) will be used as model input.
rolloutor(policy, env, transitions, collect_print_freq=100)
¶
Overview
The middleware that executes the transition process in the env.
Arguments:
- policy (:obj:Policy): The policy to be used during transition.
- env (:obj:BaseEnvManager): The env for the collection, the BaseEnvManager object or its derivatives are supported.
- transitions (:obj:TransitionList): The transition information which will be filled in this process, including obs, next_obs, action, logit, value, reward and done.
interaction_evaluator(cfg, policy, env, render=False, **kwargs)
¶
Overview
The middleware that executes the evaluation.
Arguments:
- cfg (:obj:EasyDict): Config.
- policy (:obj:Policy): The policy to be evaluated.
- env (:obj:BaseEnvManager): The env for the evaluation.
- render (:obj:bool): Whether to render env images and policy logits.
- kwargs: (:obj:Any): Other arguments for specific evaluation.
interaction_evaluator_ttorch(seed, policy, env, n_evaluator_episode=None, stop_value=np.inf, eval_freq=1000, render=False)
¶
Overview
The middleware that executes the evaluation with ttorch data.
Arguments:
- policy (:obj:Policy): The policy to be evaluated.
- env (:obj:BaseEnvManager): The env for the evaluation.
- render (:obj:bool): Whether to render env images and policy logits.
online_logger(record_train_iter=False, train_show_freq=100)
¶
Overview
Create an online RL tensorboard logger for recording training and evaluation metrics.
Arguments:
- record_train_iter (:obj:bool): Whether to record training iteration. Default is False.
- train_show_freq (:obj:int): Frequency of showing training logs. Default is 100.
Returns:
- _logger (:obj:Callable): A logger function that takes an OnlineRLContext object as input.
Raises:
- RuntimeError: If writer is None.
- NotImplementedError: If the key of train_output is not supported, such as "scalars".
Examples:
>>> task.use(online_logger(record_train_iter=False, train_show_freq=1000))
offline_logger(train_show_freq=100)
¶
Overview
Create an offline RL tensorboard logger for recording training and evaluation metrics.
Arguments:
- train_show_freq (:obj:int): Frequency of showing training logs. Defaults to 100.
Returns:
- _logger (:obj:Callable): A logger function that takes an OfflineRLContext object as input.
Raises:
- RuntimeError: If writer is None.
- NotImplementedError: If the key of train_output is not supported, such as "scalars".
Examples:
>>> task.use(offline_logger(train_show_freq=1000))
wandb_online_logger(record_path=None, cfg=None, exp_config=None, metric_list=None, env=None, model=None, anonymous=False, project_name='default-project', run_name=None, wandb_sweep=False)
¶
Overview
Wandb visualizer to track the experiment.
Arguments:
- record_path (:obj:str): The path to save the replay of simulation.
- cfg (:obj:Union[dict, EasyDict]): Config, a dict of following settings:
- gradient_logger: boolean. Whether to track the gradient.
- plot_logger: boolean. Whether to track the metrics like reward and loss.
- video_logger: boolean. Whether to upload the rendering video replay.
- action_logger: boolean. q_value or action probability.
- return_logger: boolean. Whether to track the return value.
- metric_list (:obj:Optional[List[str]]): Logged metric list, specialized by different policies.
- env (:obj:BaseEnvManagerV2): Evaluator environment.
- model (:obj:nn.Module): Policy neural network model.
- anonymous (:obj:bool): Open the anonymous mode of wandb or not. The anonymous mode allows visualization of data without wandb count.
- project_name (:obj:str): The name of wandb project.
- run_name (:obj:str): The name of wandb run.
- wandb_sweep (:obj:bool): Whether to use wandb sweep.
'''
Returns:
- _plot (:obj:Callable): A logger function that takes an OnlineRLContext object as input.
wandb_offline_logger(record_path=None, cfg=None, exp_config=None, metric_list=None, env=None, model=None, anonymous=False, project_name='default-project', run_name=None, wandb_sweep=False)
¶
Overview
Wandb visualizer to track the experiment.
Arguments:
- record_path (:obj:str): The path to save the replay of simulation.
- cfg (:obj:Union[dict, EasyDict]): Config, a dict of following settings:
- gradient_logger: boolean. Whether to track the gradient.
- plot_logger: boolean. Whether to track the metrics like reward and loss.
- video_logger: boolean. Whether to upload the rendering video replay.
- action_logger: boolean. q_value or action probability.
- return_logger: boolean. Whether to track the return value.
- vis_dataset: boolean. Whether to visualize the dataset.
- metric_list (:obj:Optional[List[str]]): Logged metric list, specialized by different policies.
- env (:obj:BaseEnvManagerV2): Evaluator environment.
- model (:obj:nn.Module): Policy neural network model.
- anonymous (:obj:bool): Open the anonymous mode of wandb or not. The anonymous mode allows visualization of data without wandb count.
- project_name (:obj:str): The name of wandb project.
- run_name (:obj:str): The name of wandb run.
- wandb_sweep (:obj:bool): Whether to use wandb sweep.
'''
Returns:
- _plot (:obj:Callable): A logger function that takes an OfflineRLContext object as input.
eps_greedy_handler(cfg)
¶
Overview
The middleware that computes epsilon value according to the env_step.
Arguments:
- cfg (:obj:EasyDict): Config.
eps_greedy_masker()
¶
Overview
The middleware that returns masked epsilon value and stop generating actions by the e_greedy method.
gae_estimator(cfg, policy, buffer_=None)
¶
Overview
Calculate value using observation of input data, then call function gae to get advantage. The processed data will be pushed into buffer_ if buffer_ is not None, otherwise it will be assigned to ctx.train_data.
Arguments:
- cfg (:obj:EasyDict): Config which should contain the following keys: cfg.policy.collect.discount_factor, cfg.policy.collect.gae_lambda.
- policy (:obj:Policy): Policy in policy.collect_mode, used to get model to calculate value.
- buffer_ (:obj:Optional[Buffer]): The buffer_ to push the processed data in if buffer_ is not None.
reward_estimator(cfg, reward_model)
¶
Overview
Estimate the reward of train_data using reward_model.
Arguments:
- cfg (:obj:EasyDict): Config.
- reward_model (:obj:BaseRewardModel): Reward model.
her_data_enhancer(cfg, buffer_, her_reward_model)
¶
Overview
Fetch a batch of data/episode from buffer_, then use her_reward_model to get HER processed episodes from original episodes.
Arguments:
- cfg (:obj:EasyDict): Config which should contain the following keys if her_reward_model.episode_size is None: cfg.policy.learn.batch_size.
- buffer_ (:obj:Buffer): Buffer to sample data from.
- her_reward_model (:obj:HerRewardModel): Hindsight Experience Replay (HER) model which is used to process episodes.
priority_calculator(priority_calculation_fn)
¶
Overview
The middleware that calculates the priority of the collected data.
Arguments:
- priority_calculation_fn (:obj:Callable): The function that calculates the priority of the collected data.
epoch_timer(print_per=1, smooth_window=10)
¶
Overview
Print time cost of each epoch.
Arguments:
- print_per (:obj:int): Print each N epoch.
- smooth_window (:obj:int): The window size to smooth the mean.
Full Source Code
../ding/framework/middleware/__init__.py