ding.worker¶
ding.worker
¶
ISerialCollector
¶
Bases: ABC
Overview
Abstract baseclass for serial collector.
Interfaces: default_config, reset_env, reset_policy, reset, collect Property: envstep
default_config()
classmethod
¶
Overview
Get collector's default config. We merge collector's default config with other default configs and user's config to get the final config.
Return:
cfg: (:obj:EasyDict): collector's default config
reset_env(_env=None)
abstractmethod
¶
Overview
Reset collector's environment. In some case, we need collector use the same policy to collect data in different environments. We can use reset_env to reset the environment.
reset_policy(_policy=None)
abstractmethod
¶
Overview
Reset collector's policy. In some case, we need collector work in this same environment but use different policy to collect data. We can use reset_policy to reset the policy.
reset(_policy=None, _env=None)
abstractmethod
¶
Overview
Reset collector's policy and environment. Use new policy and environment to collect data.
collect(per_collect_target)
abstractmethod
¶
Overview
Collect the corresponding data according to the specified target and return. There are different definitions in episode and sample mode.
envstep()
¶
Overview
Get the total envstep num.
SampleSerialCollector
¶
Bases: ISerialCollector
Overview
Sample collector(n_sample), a sample is one training sample for updating model,
it is usually like (one transition)
while is a trajectory with many transitions, which is often used in RNN-model.
Interfaces: init, reset, reset_env, reset_policy, collect, close Property: envstep
envstep
property
writable
¶
Overview
Print the total envstep count.
Return:
- envstep (:obj:int): The total envstep count.
__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='collector')
¶
Overview
Initialization method.
Arguments:
- cfg (:obj:EasyDict): Config dict
- env (:obj:BaseEnvManager): the subclass of vectorized env_manager(BaseEnvManager)
- policy (:obj:namedtuple): the api namedtuple of collect_mode policy
- tb_logger (:obj:SummaryWriter): tensorboard handle
reset_env(_env=None)
¶
Overview
Reset the environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch.
Arguments:
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
reset_policy(_policy=None)
¶
Overview
Reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:Optional[namedtuple]): the api namedtuple of collect_mode policy
reset(_policy=None, _env=None)
¶
Overview
Reset the environment and policy. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:Optional[namedtuple]): the api namedtuple of collect_mode policy
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
close()
¶
Overview
Close the collector. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.
__del__()
¶
Overview
Execute the close command and close the collector. del is automatically called to destroy the collector instance when the collector finishes its work
collect(n_sample=None, train_iter=0, drop_extra=True, random_collect=False, record_random_collect=True, policy_kwargs=None, level_seeds=None)
¶
Overview
Collect n_sample data with policy_kwargs, which is already trained train_iter iterations.
Arguments:
- n_sample (:obj:int): The number of collecting data sample.
- train_iter (:obj:int): The number of training iteration when calling collect method.
- drop_extra (:obj:bool): Whether to drop extra return_data more than n_sample.
- record_random_collect (:obj:bool) :Whether to output logs of random collect.
- policy_kwargs (:obj:dict): The keyword args for policy forward.
- level_seeds (:obj:dict): Used in PLR, represents the seed of the environment that generate the data
Returns:
- return_data (:obj:List): A list containing training samples.
EpisodeSerialCollector
¶
Bases: ISerialCollector
Overview
Episode collector(n_episode)
Interfaces: init, reset, reset_env, reset_policy, collect, close Property: envstep
envstep
property
writable
¶
Overview
Print the total envstep count.
Return:
- envstep (:obj:int): The total envstep count.
__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='collector')
¶
Overview
Initialization method.
Arguments:
- cfg (:obj:EasyDict): Config dict
- env (:obj:BaseEnvManager): the subclass of vectorized env_manager(BaseEnvManager)
- policy (:obj:namedtuple): the api namedtuple of collect_mode policy
- tb_logger (:obj:SummaryWriter): tensorboard handle
reset_env(_env=None)
¶
Overview
Reset the environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch.
Arguments:
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
reset_policy(_policy=None)
¶
Overview
Reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:Optional[namedtuple]): the api namedtuple of collect_mode policy
reset(_policy=None, _env=None)
¶
Overview
Reset the environment and policy. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:Optional[namedtuple]): the api namedtuple of collect_mode policy
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
close()
¶
Overview
Close the collector. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.
__del__()
¶
Overview
Execute the close command and close the collector. del is automatically called to destroy the collector instance when the collector finishes its work
collect(n_episode=None, train_iter=0, policy_kwargs=None)
¶
Overview
Collect n_episode data with policy_kwargs, which is already trained train_iter iterations
Arguments:
- n_episode (:obj:int): the number of collecting data episode
- train_iter (:obj:int): the number of training iteration
- policy_kwargs (:obj:dict): the keyword args for policy forward
Returns:
- return_data (:obj:List): A list containing collected episodes if not get_train_sample, otherwise, return train_samples split by unroll_len.
BattleEpisodeSerialCollector
¶
Bases: ISerialCollector
Overview
Episode collector(n_episode) with two policy battle
Interfaces: init, reset, reset_env, reset_policy, collect, close Property: envstep
envstep
property
writable
¶
Overview
Print the total envstep count.
Return:
- envstep (:obj:int): The total envstep count.
__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='collector')
¶
Overview
Initialization method.
Arguments:
- cfg (:obj:EasyDict): Config dict
- env (:obj:BaseEnvManager): the subclass of vectorized env_manager(BaseEnvManager)
- policy (:obj:List[namedtuple]): the api namedtuple of collect_mode policy
- tb_logger (:obj:SummaryWriter): tensorboard handle
reset_env(_env=None)
¶
Overview
Reset the environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch.
Arguments:
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
reset_policy(_policy=None)
¶
Overview
Reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:Optional[List[namedtuple]]): the api namedtuple of collect_mode policy
reset(_policy=None, _env=None)
¶
Overview
Reset the environment and policy. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:Optional[List[namedtuple]]): the api namedtuple of collect_mode policy
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
close()
¶
Overview
Close the collector. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.
__del__()
¶
Overview
Execute the close command and close the collector. del is automatically called to destroy the collector instance when the collector finishes its work
collect(n_episode=None, train_iter=0, policy_kwargs=None)
¶
Overview
Collect n_episode data with policy_kwargs, which is already trained train_iter iterations
Arguments:
- n_episode (:obj:int): the number of collecting data episode
- train_iter (:obj:int): the number of training iteration
- policy_kwargs (:obj:dict): the keyword args for policy forward
Returns:
- return_data (:obj:Tuple[List, List]): A tuple with training sample(data) and episode info, the former is a list containing collected episodes if not get_train_sample, otherwise, return train_samples split by unroll_len.
BattleSampleSerialCollector
¶
Bases: ISerialCollector
Overview
Sample collector(n_sample) with multiple(n VS n) policy battle
Interfaces: init, reset, reset_env, reset_policy, collect, close Property: envstep
envstep
property
writable
¶
Overview
Print the total envstep count.
Return:
- envstep (:obj:int): The total envstep count.
__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='collector')
¶
Overview
Initialization method.
Arguments:
- cfg (:obj:EasyDict): Config dict
- env (:obj:BaseEnvManager): the subclass of vectorized env_manager(BaseEnvManager)
- policy (:obj:List[namedtuple]): the api namedtuple of collect_mode policy
- tb_logger (:obj:SummaryWriter): tensorboard handle
reset_env(_env=None)
¶
Overview
Reset the environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch.
Arguments:
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
reset_policy(_policy=None)
¶
Overview
Reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:Optional[List[namedtuple]]): the api namedtuple of collect_mode policy
reset(_policy=None, _env=None)
¶
Overview
Reset the environment and policy. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed in environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:Optional[List[namedtuple]]): the api namedtuple of collect_mode policy
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
close()
¶
Overview
Close the collector. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.
__del__()
¶
Overview
Execute the close command and close the collector. del is automatically called to destroy the collector instance when the collector finishes its work
collect(n_sample=None, train_iter=0, drop_extra=True, policy_kwargs=None)
¶
Overview
Collect n_sample data with policy_kwargs, which is already trained train_iter iterations.
Arguments:
- n_sample (:obj:int): The number of collecting data sample.
- train_iter (:obj:int): The number of training iteration when calling collect method.
- drop_extra (:obj:bool): Whether to drop extra return_data more than n_sample.
- policy_kwargs (:obj:dict): The keyword args for policy forward.
Returns:
- return_data (:obj:List): A list containing training samples.
ISerialEvaluator
¶
Bases: ABC
Overview
Basic interface class for serial evaluator.
Interfaces: reset, reset_policy, reset_env, close, should_eval, eval Property: env, policy
default_config()
classmethod
¶
Overview
Get evaluator's default config. We merge evaluator's default config with other default configs and user's config to get the final config.
Return:
cfg: (:obj:EasyDict): evaluator's default config
VectorEvalMonitor
¶
Bases: object
Overview
In some cases, different environment in evaluator may collect different length episode. For example, suppose we want to collect 12 episodes in evaluator but only have 5 environments, if we didn’t do any thing, it is likely that we will get more short episodes than long episodes. As a result, our average reward will have a bias and may not be accurate. we use VectorEvalMonitor to solve the problem.
Interfaces: init, is_finished, update_info, update_reward, get_episode_return, get_latest_reward, get_current_episode, get_episode_info
__init__(env_num, n_episode)
¶
Overview
Init method. According to the number of episodes and the number of environments, determine how many episodes need to be opened for each environment, and initialize the reward, info and other information
Arguments:
- env_num (:obj:int): the number of episodes need to be open
- n_episode (:obj:int): the number of environments
is_finished()
¶
Overview
Determine whether the evaluator has completed the work.
Return:
- result: (:obj:bool): whether the evaluator has completed the work
update_info(env_id, info)
¶
Overview
Update the information of the environment indicated by env_id.
Arguments:
- env_id: (:obj:int): the id of the environment we need to update information
- info: (:obj:Any): the information we need to update
update_reward(env_id, reward)
¶
Overview
Update the reward indicated by env_id.
Arguments:
- env_id: (:obj:int): the id of the environment we need to update the reward
- reward: (:obj:Any): the reward we need to update
get_video()
¶
Overview
Convert list of videos into [N, T, C, H, W] tensor, containing worst, median, best evaluation trajectories for video logging.
get_episode_return()
¶
Overview
Sum up all reward and get the total return of one episode.
get_latest_reward(env_id)
¶
Overview
Get the latest reward of a certain environment.
Arguments:
- env_id: (:obj:int): the id of the environment we need to get reward.
get_current_episode()
¶
Overview
Get the current episode. We can know which episode our evaluator is executing now.
get_episode_info()
¶
Overview
Get all episode information, such as total return of one episode.
InteractionSerialEvaluator
¶
Bases: ISerialEvaluator
Overview
Interaction serial evaluator class, policy interacts with env.
Interfaces: init, reset, reset_policy, reset_env, close, should_eval, eval Property: env, policy
__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='evaluator')
¶
Overview
Init method. Load config and use self._cfg setting to build common serial evaluator components, e.g. logger helper, timer.
Arguments:
- cfg (:obj:EasyDict): Configuration EasyDict.
reset_env(_env=None)
¶
Overview
Reset evaluator's environment. In some case, we need evaluator use the same policy in different environments. We can use reset_env to reset the environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the evaluator with the new passed in environment and launch.
Arguments:
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
reset_policy(_policy=None)
¶
Overview
Reset evaluator's policy. In some case, we need evaluator work in this same environment but use different policy. We can use reset_policy to reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:Optional[namedtuple]): the api namedtuple of eval_mode policy
reset(_policy=None, _env=None)
¶
Overview
Reset evaluator's policy and environment. Use new policy and environment to collect data. If _env is None, reset the old environment. If _env is not None, replace the old environment in the evaluator with the new passed in environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:Optional[namedtuple]): the api namedtuple of eval_mode policy
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
close()
¶
Overview
Close the evaluator. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.
__del__()
¶
Overview
Execute the close command and close the evaluator. del is automatically called to destroy the evaluator instance when the evaluator finishes its work
should_eval(train_iter)
¶
Overview
Determine whether you need to start the evaluation mode, if the number of training has reached the maximum number of times to start the evaluator, return True
eval(save_ckpt_fn=None, train_iter=-1, envstep=-1, n_episode=None, force_render=False, policy_kwargs={})
¶
Overview
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Arguments:
- save_ckpt_fn (:obj:Callable): Saving ckpt function, which will be triggered by getting the best reward.
- train_iter (:obj:int): Current training iteration.
- envstep (:obj:int): Current env interaction step.
- n_episode (:obj:int): Number of evaluation episodes.
Returns:
- stop_flag (:obj:bool): Whether this training program can be ended.
- episode_info (:obj:Dict[str, List]): Current evaluation episode information.
BattleInteractionSerialEvaluator
¶
Bases: ISerialEvaluator
Overview
Multiple player battle evaluator class.
Interfaces: init, reset, reset_policy, reset_env, close, should_eval, eval Property: env, policy
default_config()
classmethod
¶
Overview
Get evaluator's default config. We merge evaluator's default config with other default configs and user's config to get the final config.
Return:
cfg: (:obj:EasyDict): evaluator's default config
__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='evaluator')
¶
Overview
Init method. Load config and use self._cfg setting to build common serial evaluator components,
e.g. logger helper, timer.
Policy is not initialized here, but set afterwards through policy setter.
Arguments:
- cfg (:obj:EasyDict)
reset_env(_env=None)
¶
Overview
Reset evaluator's environment. In some case, we need evaluator use the same policy in different environments. We can use reset_env to reset the environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the evaluator with the new passed in environment and launch.
Arguments:
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
reset_policy(_policy=None)
¶
Overview
Reset evaluator's policy. In some case, we need evaluator work in this same environment but use different policy. We can use reset_policy to reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:Optional[List[namedtuple]]): the api namedtuple of eval_mode policy
reset(_policy=None, _env=None)
¶
Overview
Reset evaluator's policy and environment. Use new policy and environment to collect data. If _env is None, reset the old environment. If _env is not None, replace the old environment in the evaluator with the new passed in environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:Optional[List[namedtuple]]): the api namedtuple of eval_mode policy
- env (:obj:Optional[BaseEnvManager]): instance of the subclass of vectorized env_manager(BaseEnvManager)
close()
¶
Overview
Close the evaluator. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.
__del__()
¶
Overview
Execute the close command and close the evaluator. del is automatically called to destroy the evaluator instance when the evaluator finishes its work
should_eval(train_iter)
¶
Overview
Determine whether you need to start the evaluation mode, if the number of training has reached the maximum number of times to start the evaluator, return True
eval(save_ckpt_fn=None, train_iter=-1, envstep=-1, n_episode=None)
¶
Overview
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Arguments:
- save_ckpt_fn (:obj:Callable): Saving ckpt function, which will be triggered by getting the best reward.
- train_iter (:obj:int): Current training iteration.
- envstep (:obj:int): Current env interaction step.
- n_episode (:obj:int): Number of evaluation episodes.
Returns:
- stop_flag (:obj:bool): Whether this training program can be ended.
- return_info (:obj:list): Environment information of each finished episode.
MetricSerialEvaluator
¶
Bases: ISerialEvaluator
Overview
Metric serial evaluator class, policy is evaluated by objective metric(env).
Interfaces: init, reset, reset_policy, reset_env, close, should_eval, eval Property: env, policy
__init__(cfg, env=None, policy=None, tb_logger=None, exp_name='default_experiment', instance_name='evaluator')
¶
Overview
Init method. Load config and use self._cfg setting to build common serial evaluator components,
e.g. logger helper, timer.
Arguments:
- cfg (:obj:EasyDict): Configuration EasyDict.
reset_env(_env=None)
¶
Overview
Reset evaluator's environment. In some case, we need evaluator use the same policy in different environments. We can use reset_env to reset the environment. If _env is not None, replace the old environment in the evaluator with the new one
Arguments:
- env (:obj:Optional[Tuple[DataLoader, IMetric]]): Instance of the DataLoader and Metric
reset_policy(_policy=None)
¶
Overview
Reset evaluator's policy. In some case, we need evaluator work in this same environment but use different policy. We can use reset_policy to reset the policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:Optional[namedtuple]): the api namedtuple of eval_mode policy
reset(_policy=None, _env=None)
¶
Overview
Reset evaluator's policy and environment. Use new policy and environment to collect data. If _env is not None, replace the old environment in the evaluator with the new one If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:Optional[namedtuple]): the api namedtuple of eval_mode policy
- env (:obj:Optional[Tuple[DataLoader, IMetric]]): Instance of the DataLoader and Metric
close()
¶
Overview
Close the evaluator. If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.
__del__()
¶
Overview
Execute the close command and close the evaluator. del is automatically called to destroy the evaluator instance when the evaluator finishes its work
should_eval(train_iter)
¶
Overview
Determine whether you need to start the evaluation mode, if the number of training has reached the maximum number of times to start the evaluator, return True
eval(save_ckpt_fn=None, train_iter=-1, envstep=-1)
¶
Overview
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Arguments:
- save_ckpt_fn (:obj:Callable): Saving ckpt function, which will be triggered by getting the best reward.
- train_iter (:obj:int): Current training iteration.
- envstep (:obj:int): Current env interaction step.
Returns:
- stop_flag (:obj:bool): Whether this training program can be ended.
- eval_metric (:obj:float): Current evaluation metric result.
IMetric
¶
Bases: ABC
gt(metric1, metric2)
abstractmethod
¶
Overview
Whether metric1 is greater than metric2 (>=)
.. note:: If metric2 is None, return True
BaseParallelCollector
¶
Bases: ABC
Overview
Abstract baseclass for collector.
Interfaces: init, info, error, debug, get_finish_info, start, close, _setup_timer, _setup_logger, _iter_after_hook, _policy_inference, _env_step, _process_timestep, _finish_task, _update_policy, _start_thread, _join_thread Property: policy
__init__(cfg)
¶
Overview
Initialization method.
Arguments:
- cfg (:obj:EasyDict): Config dict
ZerglingParallelCollector
¶
Bases: BaseParallelCollector
Feature
- one policy, many envs
- async envs(step + reset)
- batch network eval
- different episode length env
- periodic policy update
- metadata + stepdata
MarineParallelCollector
¶
Bases: BaseParallelCollector
Feature
- one policy or two policies, many envs
- async envs(step + reset)
- batch network eval
- different episode length env
- periodic policy update
- metadata + stepdata
BaseCommCollector
¶
Bases: ABC
Overview
Abstract baseclass for common collector.
Interfaces: init, get_policy_update_info, send_metadata, send_stepdata start, close, _create_collector Property: collector_uid
__init__(cfg)
¶
Overview
Initialization method.
Arguments:
- cfg (:obj:EasyDict): Config dict
get_policy_update_info(path)
abstractmethod
¶
Overview
Get policy information in corresponding path. Will be registered in base collector.
Arguments:
- path (:obj:str): path to policy update information.
send_metadata(metadata)
abstractmethod
¶
Overview
Store meta data in queue, which will be retrieved by callback function "deal_with_collector_data" in collector slave, then will be sent to coordinator. Will be registered in base collector.
Arguments:
- metadata (:obj:Any): meta data.
send_stepdata(stepdata)
abstractmethod
¶
Overview
Save step data in corresponding path. Will be registered in base collector.
Arguments:
- stepdata (:obj:Any): step data.
start()
¶
Overview
Start comm collector.
close()
¶
Overview
Close comm collector.
FlaskFileSystemCollector
¶
Bases: BaseCommCollector
Overview
An implementation of CommLearner, using flask and the file system.
Interfaces: init, deal_with_resource, deal_with_collector_start, deal_with_collector_data, deal_with_collector_close, get_policy_update_info, send_stepdata, send_metadata, start, close
__init__(cfg)
¶
Overview
Initialization method.
Arguments:
- cfg (:obj:EasyDict): Config dict
deal_with_resource()
¶
Overview
Callback function in CollectorSlave. Return how many resources are needed to start current collector.
Returns:
- resource (:obj:dict): Resource info dict, including ['gpu', 'cpu'].
deal_with_collector_start(task_info)
¶
Overview
Callback function in CollectorSlave.
Create a collector and start a collector thread of the created one.
Arguments:
- task_info (:obj:dict): Task info dict.
Note:
In _create_collector method in base class BaseCommCollector, 4 methods
'send_metadata', 'send_stepdata', 'get_policy_update_info', and policy are set.
You can refer to it for details.
deal_with_collector_data()
¶
Overview
Callback function in CollectorSlave. Get data sample dict from _metadata_queue,
which will be sent to coordinator afterwards.
Returns:
- data (:obj:Any): Data sample dict.
get_policy_update_info(path)
¶
Overview
Get policy information in corresponding path.
Arguments:
- path (:obj:str): path to policy update information.
send_stepdata(path, stepdata)
¶
Overview
Save collector's step data in corresponding path.
Arguments:
- path (:obj:str): Path to save data.
- stepdata (:obj:Any): Data of one step.
send_metadata(metadata)
¶
Overview
Store learn info dict in queue, which will be retrieved by callback function "deal_with_collector_learn" in collector slave, then will be sent to coordinator.
Arguments:
- metadata (:obj:Any): meta data.
start()
¶
Overview
Start comm collector itself and the collector slave.
close()
¶
Overview
Close comm collector itself and the collector slave.
NaiveCollector
¶
Bases: Slave
Overview
A slave, whose master is coordinator. Used to pass message between comm collector and coordinator.
Interfaces: _process_task, _get_timestep
BaseLearner
¶
Bases: object
Overview
Base class for policy learning.
Interface: train, call_hook, register_hook, save_checkpoint, start, setup_dataloader, close Property: learn_info, priority_info, last_iter, train_iter, rank, world_size, policy monitor, log_buffer, logger, tb_logger, ckpt_name, exp_name, instance_name
collector_envstep
property
writable
¶
Overview
Get current collector envstep.
Returns:
- collector_envstep (:obj:int): Current collector envstep.
learn_info
property
¶
Overview
Get current info dict, which will be sent to commander, e.g. replay buffer priority update, current iteration, hyper-parameter adjustment, whether task is finished, etc.
Returns:
- info (:obj:dict): Current learner info dict.
__init__(cfg, policy=None, tb_logger=None, dist_info=None, exp_name='default_experiment', instance_name='learner')
¶
Overview
Initialization method, build common learner components according to cfg, such as hook, wrapper and so on.
Arguments:
- cfg (:obj:EasyDict): Learner config, you can refer cls.config for details. It should include is_multitask_pipeline to indicate if the pipeline is multitask, default is False, and only_monitor_rank0 to control whether only rank 0 needs monitor and tb_logger, default is True.
- policy (:obj:namedtuple): A collection of policy function of learn mode. And policy can also be initialized when runtime.
- tb_logger (:obj:SummaryWriter): Tensorboard summary writer.
- dist_info (:obj:Tuple[int, int]): Multi-GPU distributed training information.
- exp_name (:obj:str): Experiment name, which is used to indicate output directory.
- instance_name (:obj:str): Instance name, which should be unique among different learners.
Notes:
If you want to debug in sync CUDA mode, please add the following code at the beginning of __init__.
.. code:: python
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debug async CUDA
register_hook(hook)
¶
Overview
Add a new learner hook.
Arguments:
- hook (:obj:LearnerHook): The hook to be addedr.
train(data, envstep=-1, policy_kwargs=None)
¶
Overview
Given training data, implement network update for one iteration and update related variables.
Learner's API for serial entry.
Also called in start for each iteration's training.
Arguments:
- data (:obj:dict): Training data which is retrieved from repaly buffer.
.. note::
``_policy`` must be set before calling this method.
``_policy.forward`` method contains: forward, backward, grad sync(if in multi-gpu mode) and
parameter update.
``before_iter`` and ``after_iter`` hooks are called at the beginning and ending.
start()
¶
Overview
[Only Used In Parallel Mode] Learner's API for parallel entry.
For each iteration, learner will get data through _next_data and call train to train.
.. note::
``before_run`` and ``after_run`` hooks are called at the beginning and ending.
setup_dataloader()
¶
Overview
[Only Used In Parallel Mode] Setup learner's dataloader.
.. note::
Only in parallel mode will we use attributes ``get_data`` and ``_dataloader`` to get data from file system;
Instead, in serial version, we can fetch data from memory directly.
In parallel mode, ``get_data`` is set by ``LearnerCommHelper``, and should be callable.
Users don't need to know the related details if not necessary.
close()
¶
Overview
[Only Used In Parallel Mode] Close the related resources, e.g. dataloader, tensorboard logger, etc.
call_hook(name)
¶
Overview
Call the corresponding hook plugins according to position name.
Arguments:
- name (:obj:str): Hooks in which position to call, should be in ['before_run', 'after_run', 'before_iter', 'after_iter'].
info(s)
¶
Overview
Log string info by self._logger.info.
Arguments:
- s (:obj:str): The message to add into the logger.
save_checkpoint(ckpt_name=None)
¶
Overview
Directly call save_ckpt_after_run hook to save checkpoint.
Note: Must guarantee that "save_ckpt_after_run" is registered in "after_run" hook. This method is called in:
- ``auto_checkpoint`` (``torch_utils/checkpoint_helper.py``), which is designed for saving checkpoint whenever an exception raises.
- ``serial_pipeline`` (``entry/serial_entry.py``). Used to save checkpoint when reaching new highest episode return.
BaseCommLearner
¶
Bases: ABC
Overview
Abstract baseclass for CommLearner.
Interfaces: init, send_policy, get_data, send_learn_info, start, close Property: hooks4call
__init__(cfg)
¶
Overview
Initialization method.
Arguments:
- cfg (:obj:EasyDict): Config dict
send_policy(state_dict)
abstractmethod
¶
Overview
Save learner's policy in corresponding path. Will be registered in base learner.
Arguments:
- state_dict (:obj:dict): State dict of the runtime policy.
get_data(batch_size)
abstractmethod
¶
Overview
Get batched meta data from coordinator. Will be registered in base learner.
Arguments:
- batch_size (:obj:int): Batch size.
Returns:
- stepdata (:obj:list): A list of training data, each element is one trajectory.
send_learn_info(learn_info)
abstractmethod
¶
Overview
Send learn info to coordinator. Will be registered in base learner.
Arguments:
- learn_info (:obj:dict): Learn info in dict type.
start()
¶
Overview
Start comm learner.
close()
¶
Overview
Close comm learner.
hooks4call()
¶
Returns:
| Type | Description |
|---|---|
list
|
|
FlaskFileSystemLearner
¶
Bases: BaseCommLearner
Overview
An implementation of CommLearner, using flask and the file system.
Interfaces: init, send_policy, get_data, send_learn_info, start, close Property: hooks4call
hooks4call
property
¶
Overview
Return the hooks that are related to message passing with coordinator.
Returns:
- hooks (:obj:list): The hooks which comm learner has. Will be registered in learner as well.
__init__(cfg)
¶
Overview
Init method.
Arguments:
- cfg (:obj:EasyDict): Config dict.
start()
¶
Overview
Start comm learner itself and the learner slave.
close()
¶
Overview
Join learner thread and close learner if still running. Then close learner slave and comm learner itself.
__del__()
¶
Overview
Call close for deletion.
deal_with_resource()
¶
Overview
Callback function. Return how many resources are needed to start current learner.
Returns:
- resource (:obj:dict): Resource info dict, including ["gpu"].
deal_with_learner_start(task_info)
¶
Overview
Callback function. Create a learner and help register its hooks. Start a learner thread of the created one.
Arguments:
- task_info (:obj:dict): Task info dict.
.. note::
In _create_learner method in base class BaseCommLearner, 3 methods
('get_data', 'send_policy', 'send_learn_info'), dataloader and policy are set.
You can refer to it for details.
deal_with_get_data()
¶
Overview
Callback function. Get data demand info dict from _data_demand_queue,
which will be sent to coordinator afterwards.
Returns:
- data_demand (:obj:Any): Data demand info dict.
deal_with_learner_learn(data)
¶
Overview
Callback function. Put training data info dict (i.e. meta data), which is received from coordinator, into
_data_result_queue, and wait for get_data to retrieve. Wait for learner training and
get learn info dict from _learn_info_queue. If task is finished, join the learner thread and
close the learner.
Returns:
- learn_info (:obj:Any): Learn info dict.
send_policy(state_dict)
¶
Overview
Save learner's policy in corresponding path, called by SendPolicyHook.
Arguments:
- state_dict (:obj:dict): State dict of the policy.
load_data_fn(path, meta, decompressor)
staticmethod
¶
Overview
The function that is used to load data file.
Arguments:
- meta (:obj:Dict[str, Any]): Meta data info dict.
- decompressor (:obj:Callable): Decompress function.
Returns:
- s (:obj:Any): Data which is read from file.
get_data(batch_size)
¶
Overview
Get a list of data loading function, which can be implemented by dataloader to read data from files.
Arguments:
- batch_size (:obj:int): Batch size.
Returns:
- data (:obj:List[Callable]): A list of callable data loading function.
send_learn_info(learn_info)
¶
Overview
Store learn info dict in queue, which will be retrieved by callback function "deal_with_learner_learn" in learner slave, then will be sent to coordinator.
Arguments:
- learn_info (:obj:dict): Learn info in dict type. Keys are like 'learner_step', 'priority_info' 'finished_task', etc. You can refer to learn_info(worker/learner/base_learner.py) for details.
LearnerHook
¶
Bases: Hook
Overview
Abstract class for hooks used in Learner.
Interfaces: init Property: name, priority, position
.. note::
Subclass should implement ``self.__call__``.
__init__(*args, position, **kwargs)
¶
Overview
Init LearnerHook.
Arguments:
- position (:obj:str): The position to call hook in learner. Must be in ['before_run', 'after_run', 'before_iter', 'after_iter'].
IBuffer
¶
Bases: ABC
Overview
Buffer interface
Interfaces: default_config, push, update, sample, clear, count, state_dict, load_state_dict
default_config()
classmethod
¶
Overview
Default config of this buffer class.
Returns:
- default_config (:obj:EasyDict)
push(data, cur_collector_envstep)
abstractmethod
¶
Overview
Push a data into buffer.
Arguments:
- data (:obj:Union[List[Any], Any]): The data which will be pushed into buffer. Can be one \
(in Any type), or many(int List[Any] type).
- cur_collector_envstep (:obj:int): Collector's current env step.
update(info)
abstractmethod
¶
Overview
Update data info, e.g. priority.
Arguments:
- info (:obj:Dict[str, list]): Info dict. Keys depends on the specific buffer type.
sample(batch_size, cur_learner_iter)
abstractmethod
¶
Overview
Sample data with length batch_size.
Arguments:
- size (:obj:int): The number of the data that will be sampled.
- cur_learner_iter (:obj:int): Learner's current iteration.
Returns:
- sampled_data (:obj:list): A list of data with length batch_size.
clear()
abstractmethod
¶
Overview
Clear all the data and reset the related variables.
count()
abstractmethod
¶
Overview
Count how many valid datas there are in the buffer.
Returns:
- count (:obj:int): Number of valid data.
save_data(file_name)
abstractmethod
¶
Overview
Save buffer data into a file.
Arguments:
- file_name (:obj:str): file name of buffer data
load_data(file_name)
abstractmethod
¶
Overview
Load buffer data from a file.
Arguments:
- file_name (:obj:str): file name of buffer data
state_dict()
abstractmethod
¶
Overview
Provide a state dict to keep a record of current buffer.
Returns:
- state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer. With the dict, one can easily reproduce the buffer.
load_state_dict(_state_dict)
abstractmethod
¶
Overview
Load state dict to reproduce the buffer.
Returns:
- state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer.
NaiveReplayBuffer
¶
Bases: IBuffer
Overview
Naive replay buffer, can store and sample data.
An naive implementation of replay buffer with no priority or any other advanced features.
This buffer refers to multi-thread/multi-process and guarantees thread-safe, which means that methods like
sample, push, clear are all mutual to each other.
Interface: start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config Property: replay_buffer_size, push_count
__init__(cfg, tb_logger=None, exp_name='default_experiment', instance_name='buffer')
¶
Overview
Initialize the buffer
Arguments:
- cfg (:obj:dict): Config dict.
- tb_logger (:obj:Optional['SummaryWriter']): Outer tb logger. Usually get this argument in serial mode.
- exp_name (:obj:Optional[str]): Name of this experiment.
- instance_name (:obj:Optional[str]): Name of this instance.
start()
¶
Overview
Start the buffer's used_data_remover thread if enables track_used_data.
close()
¶
Overview
Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data.
push(data, cur_collector_envstep)
¶
Overview
Push a data into buffer.
Arguments:
- data (:obj:Union[List[Any], Any]): The data which will be pushed into buffer. Can be one \
(in Any type), or many(int List[Any] type).
- cur_collector_envstep (:obj:int): Collector's current env step. \
Not used in naive buffer, but preserved for compatibility.
sample(size, cur_learner_iter, sample_range=None, replace=False)
¶
Overview
Sample data with length size.
Arguments:
- size (:obj:int): The number of the data that will be sampled.
- cur_learner_iter (:obj:int): Learner's current iteration. Not used in naive buffer, but preserved for compatibility.
- sample_range (:obj:slice): Buffer slice for sampling, such as slice(-10, None), which means only sample among the last 10 data
- replace (:obj:bool): Whether sample with replacement
Returns:
- sample_data (:obj:list): A list of data with length size.
update(info)
¶
Overview
Naive Buffer does not need to update any info, but this method is preserved for compatibility.
clear()
¶
Overview
Clear all the data and reset the related variables.
__del__()
¶
Overview
Call close to delete the object.
count()
¶
Overview
Count how many valid datas there are in the buffer.
Returns:
- count (:obj:int): Number of valid data.
state_dict()
¶
Overview
Provide a state dict to keep a record of current buffer.
Returns:
- state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer. With the dict, one can easily reproduce the buffer.
load_state_dict(_state_dict)
¶
Overview
Load state dict to reproduce the buffer.
Returns:
- state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer.
SequenceReplayBuffer
¶
Bases: NaiveReplayBuffer
Overview: Interface: start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config Property: replay_buffer_size, push_count
sample(batch, sequence, cur_learner_iter, sample_range=None, replace=False)
¶
Overview
Sample data with length size.
Arguments:
- size (:obj:int): The number of the data that will be sampled.
- sequence (:obj:int): The length of the sequence of a data that will be sampled.
- cur_learner_iter (:obj:int): Learner's current iteration. Not used in naive buffer, but preserved for compatibility.
- sample_range (:obj:slice): Buffer slice for sampling, such as slice(-10, None), which means only sample among the last 10 data
- replace (:obj:bool): Whether sample with replacement
Returns:
- sample_data (:obj:list): A list of data with length size.
AdvancedReplayBuffer
¶
Bases: IBuffer
Overview
Prioritized replay buffer derived from NaiveReplayBuffer.
This replay buffer adds:
1) Prioritized experience replay implemented by segment tree.
2) Data quality monitor. Monitor use count and staleness of each data.
3) Throughput monitor and control.
4) Logger. Log 2) and 3) in tensorboard or text.
Interface: start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config Property: beta, replay_buffer_size, push_count
__init__(cfg, tb_logger=None, exp_name='default_experiment', instance_name='buffer')
¶
Overview
Initialize the buffer
Arguments:
- cfg (:obj:dict): Config dict.
- tb_logger (:obj:Optional['SummaryWriter']): Outer tb logger. Usually get this argument in serial mode.
- exp_name (:obj:Optional[str]): Name of this experiment.
- instance_name (:obj:Optional[str]): Name of this instance.
start()
¶
Overview
Start the buffer's used_data_remover thread if enables track_used_data.
close()
¶
Overview
Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data. Join periodic throughtput monitor, flush tensorboard logger.
sample(size, cur_learner_iter, sample_range=None)
¶
Overview
Sample data with length size.
Arguments:
- size (:obj:int): The number of the data that will be sampled.
- cur_learner_iter (:obj:int): Learner's current iteration, used to calculate staleness.
- sample_range (:obj:slice): Buffer slice for sampling, such as slice(-10, None), which means only sample among the last 10 data
Returns:
- sample_data (:obj:list): A list of data with length size
ReturnsKeys:
- necessary: original keys(e.g. obs, action, next_obs, reward, info), replay_unique_id, replay_buffer_idx
- optional(if use priority): IS, priority
push(data, cur_collector_envstep)
¶
Overview
Push a data into buffer.
Arguments:
- data (:obj:Union[List[Any], Any]): The data which will be pushed into buffer. Can be one \
(in Any type), or many(int List[Any] type).
- cur_collector_envstep (:obj:int): Collector's current env step.
update(info)
¶
Overview
Update a data's priority. Use repaly_buffer_idx to locate, and use replay_unique_id to verify.
Arguments:
- info (:obj:dict): Info dict containing all necessary keys for priority update.
ArgumentsKeys:
- necessary: replay_unique_id, replay_buffer_idx, priority. All values are lists with the same length.
clear()
¶
Overview
Clear all the data and reset the related variables.
__del__()
¶
Overview
Call close to delete the object.
count()
¶
Overview
Count how many valid datas there are in the buffer.
Returns:
- count (:obj:int): Number of valid data.
state_dict()
¶
Overview
Provide a state dict to keep a record of current buffer.
Returns:
- state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer. With the dict, one can easily reproduce the buffer.
load_state_dict(_state_dict, deepcopy=False)
¶
Overview
Load state dict to reproduce the buffer.
Returns:
- state_dict (:obj:Dict[str, Any]): A dict containing all important values in the buffer.
EpisodeReplayBuffer
¶
Bases: NaiveReplayBuffer
Overview
Episode replay buffer is a buffer to store complete episodes, i.e. Each element in episode buffer is an episode.
Some algorithms do not want to sample batch_size complete episodes, however, they want some transitions with
some fixed length. As a result, sample should be overwritten for those requirements.
Interface: start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config
BaseSerialCommander
¶
Bases: object
Overview
Base serial commander class.
Interface: init, step Property: policy
__init__(cfg, learner, collector, evaluator, replay_buffer, policy=None)
¶
Overview
Init the BaseSerialCommander
Arguments:
- cfg (:obj:dict): the config of commander
- learner (:obj:BaseLearner): the learner
- collector (:obj:BaseSerialCollector): the collector
- evaluator (:obj:InteractionSerialEvaluator): the evaluator
- replay_buffer (:obj:IBuffer): the buffer
step()
¶
Overview
Step the commander
Coordinator
¶
Bases: object
Overview
the coordinator will manage parallel tasks and data
Interface: init, start, close, del, state_dict, load_state_dict, deal_with_collector_send_data, deal_with_collector_finish_task, deal_with_learner_get_data, deal_with_learner_send_info, deal_with_learner_finish_task Property: system_shutdown_flag
system_shutdown_flag
property
¶
Overview
Return whether the system is shutdown
Returns:
- system_shutdown_flag (:obj:bool): whether the system is shutdown
__init__(cfg)
¶
Overview
init method of the coordinator
Arguments:
- cfg (:obj:dict): the config file to init the coordinator
state_dict()
¶
Overview
Return empty state_dict.
load_state_dict(state_dict)
¶
Overview
Pass when load state_dict.
start()
¶
Overview
Start the coordinator, including lunching the interaction thread and the collector learner threads.
close()
¶
Overview
Close the coordinator, including closing the interaction thread, the collector learner threads and the \ buffers.
__del__()
¶
Overview
del method will close the coordinator.
deal_with_collector_send_data(task_id, buffer_id, data_id, data)
¶
Overview
deal with the data send from collector
Arguments:
- task_id (:obj:str): the collector task_id
- buffer_id (:obj:str): the buffer_id
- data_id (:obj:str): the data_id
- data (:obj:str): the data to dealt with
deal_with_collector_finish_task(task_id, finished_task)
¶
Overview
finish the collector task
Arguments:
- task_id (:obj:str): the collector task_id
- finished_task (:obj:dict): the finished_task
deal_with_learner_get_data(task_id, buffer_id, batch_size, cur_learner_iter)
¶
Overview
learner get the data from buffer
Arguments:
- task_id (:obj:str): the learner task_id
- buffer_id (:obj:str): the buffer_id
- batch_size (:obj:int): the batch_size to sample
- cur_learn_iter (:obj:int): the current learner iter num
deal_with_learner_send_info(task_id, buffer_id, info)
¶
Overview
the learner send the info and update the priority in buffer
Arguments:
- task_id (:obj:str): the learner task id
- buffer_id (:obj:str): the buffer_id of buffer to add info to
- info (:obj:dict): the info to add
deal_with_learner_finish_task(task_id, finished_task)
¶
Overview
finish the learner task, close the corresponding buffer
Arguments:
- task_id (:obj:str): the learner task_id
- finished_task (:obj:dict): the dict of task to finish
deal_with_increase_collector()
¶
" Overview: Increase task space when a new collector has added dynamically.
deal_with_decrease_collector()
¶
" Overview: Decrease task space when a new collector has removed dynamically.
info(s)
¶
Overview
Return the info
Arguments:
- s (:obj:str): the string to print in info
error(s)
¶
Overview
Return the error
Arguments:
- s (:obj:str): the error info to print
LearnerAggregator
¶
Bases: object
Overview
Aggregate multiple learners.
Interfaces: init, start, close, merge_info
create_serial_collector(cfg, **kwargs)
¶
Overview
Create a specific collector instance based on the config.
get_serial_collector_cls(cfg)
¶
Overview
Get the specific collector class according to the config.
to_tensor_transitions(data, shallow_copy_next_obs=True)
¶
Overview
Transform ths original transition return from env to tensor format.
Argument:
- data (:obj:List[Dict[str, Any]]): The data that will be transformed to tensor.
- shallow_copy_next_obs (:obj:bool): Whether to shallow copy next_obs. Default: True.
Return:
- data (:obj:List[Dict[str, Any]]): The transformed tensor-like data.
.. tip::
In order to save memory, If there are next_obs in the passed data, we do special treatment on next_obs so that the next_obs of each state in the data fragment is the next state's obs and the next_obs of the last state is its own next_obsself. Besides, we set transform_scalar to False to avoid the extra .item() operation.
create_serial_evaluator(cfg, **kwargs)
¶
Overview
Create a specific evaluator instance based on the config.
create_comm_collector(cfg)
¶
Overview
Given the key(comm_collector_name), create a new comm collector instance if in comm_map's values,
or raise an KeyError. In other words, a derived comm collector must first register,
then can call create_comm_collector to get the instance.
Arguments:
- cfg (:obj:EasyDict): Collector config. Necessary keys: [import_names, comm_collector_type].
Returns:
- collector (:obj:BaseCommCollector): The created new comm collector, should be an instance of one of comm_map's values.
create_learner(cfg, **kwargs)
¶
Overview
Given the key(learner_name), create a new learner instance if in learner_mapping's values,
or raise an KeyError. In other words, a derived learner must first register, then can call create_learner
to get the instance.
Arguments:
- cfg (:obj:EasyDict): Learner config. Necessary keys: [learner.import_module, learner.learner_type].
Returns:
- learner (:obj:BaseLearner): The created new learner, should be an instance of one of learner_mapping's values.
create_comm_learner(cfg)
¶
Overview
Given the key(comm_learner_name), create a new comm learner instance if in comm_map's values,
or raise an KeyError. In other words, a derived comm learner must first register,
then can call create_comm_learner to get the instance.
Arguments:
- cfg (:obj:dict): Learner config. Necessary keys: [import_names, comm_learner_type].
Returns:
- learner (:obj:BaseCommLearner): The created new comm learner, should be an instance of one of comm_map's values.
register_learner_hook(name, hook_type)
¶
Overview
Add a new LearnerHook class to hook_mapping, so you can build one instance with build_learner_hook_by_cfg.
Arguments:
- name (:obj:str): name of the register hook
- hook_type (:obj:type): the register hook_type you implemented that realize LearnerHook
Examples:
>>> class HookToRegister(LearnerHook):
>>> def init(args, kargs):
>>> ...
>>> ...
>>> def call(args, **kargs):
>>> ...
>>> ...
>>> ...
>>> register_learner_hook('name_of_hook', HookToRegister)
>>> ...
>>> hooks = build_learner_hook_by_cfg(cfg)
add_learner_hook(hooks, hook)
¶
Overview
Add a learner hook(:obj:LearnerHook) to hooks(:obj:Dict[str, List[Hook])
Arguments:
- hooks (:obj:Dict[str, List[Hook]): You can refer to build_learner_hook_by_cfg's return hooks.
- hook (:obj:LearnerHook): The LearnerHook which will be added to hooks.
merge_hooks(hooks1, hooks2)
¶
Overview
Merge two hooks dict, which have the same keys, and each value is sorted by hook priority with stable method.
Arguments:
- hooks1 (:obj:Dict[str, List[Hook]): hooks1 to be merged.
- hooks2 (:obj:Dict[str, List[Hook]): hooks2 to be merged.
Returns:
- new_hooks (:obj:Dict[str, List[Hook]): New merged hooks dict.
Note:
This merge function uses stable sort method without disturbing the same priority hook.
build_learner_hook_by_cfg(cfg)
¶
Overview
Build the learner hooks in hook_mapping by config.
This function is often used to initialize hooks according to cfg,
while add_learner_hook() is often used to add an existing LearnerHook to hooks.
Arguments:
- cfg (:obj:EasyDict): Config dict. Should be like {'hook': xxx}.
Returns:
- hooks (:obj:Dict[str, List[Hook]): Keys should be in ['before_run', 'after_run', 'before_iter', 'after_iter'], each value should be a list containing all hooks in this position.
Note:
Lower value means higher priority.
create_buffer(cfg, *args, **kwargs)
¶
Overview
Create a buffer according to cfg and other arguments.
Arguments:
- cfg (:obj:EasyDict): Buffer config.
ArgumentsKeys:
- necessary: type
get_buffer_cls(cfg)
¶
Overview
Get a buffer class according to cfg.
Arguments:
- cfg (:obj:EasyDict): Buffer config.
ArgumentsKeys:
- necessary: type
create_parallel_commander(cfg)
¶
Overview
create the commander according to cfg
Arguments:
- cfg (:obj:dict): the commander cfg to create, should include import_names and parallel_commander_type
Full Source Code
../ding/worker/__init__.py