Skip to content

ding.framework.middleware.learner

ding.framework.middleware.learner

OffPolicyLearner

Overview

The class of the off-policy learner, including data fetching and model training. Use the __call__ method to execute the whole learning process.

__init__(cfg, policy, buffer_, reward_model=None, log_freq=100)

Parameters:

Name Type Description Default
- cfg (

obj:EasyDict): Config.

required
- policy (

obj:Policy): The policy to be trained.

required
- buffer (

obj:Buffer): The replay buffer to store the data for training.

required
- reward_model (

obj:BaseRewardModel): Additional reward estimator likes RND, ICM, etc. default to None.

required
- log_freq (

obj:int): The frequency (iteration) of showing log.

required

__call__(ctx)

Output of ctx
  • train_output (:obj:Deque): The training output in deque.

HERLearner

Overview

The class of the learner with the Hindsight Experience Replay (HER). Use the __call__ method to execute the data featching and training process.

__init__(cfg, policy, buffer_, her_reward_model)

Parameters:

Name Type Description Default
- cfg (

obj:EasyDict): Config.

required
- policy (

obj:Policy): The policy to be trained.

required
- buffer\_ (

obj:Buffer): The replay buffer to store the data for training.

required
- her_reward_model (

obj:HerRewardModel): HER reward model.

required

__call__(ctx)

Output of ctx
  • train_output (:obj:Deque): The deque of training output.

Full Source Code

../ding/framework/middleware/learner.py

1from typing import TYPE_CHECKING, Callable, List, Tuple, Union, Dict, Optional 2from easydict import EasyDict 3from collections import deque 4 5from ding.framework import task 6from ding.data import Buffer 7from .functional import trainer, offpolicy_data_fetcher, reward_estimator, her_data_enhancer 8 9if TYPE_CHECKING: 10 from ding.framework import Context, OnlineRLContext 11 from ding.policy import Policy 12 from ding.reward_model import BaseRewardModel 13 14 15class OffPolicyLearner: 16 """ 17 Overview: 18 The class of the off-policy learner, including data fetching and model training. Use \ 19 the `__call__` method to execute the whole learning process. 20 """ 21 22 def __new__(cls, *args, **kwargs): 23 if task.router.is_active and not task.has_role(task.role.LEARNER): 24 return task.void() 25 return super(OffPolicyLearner, cls).__new__(cls) 26 27 def __init__( 28 self, 29 cfg: EasyDict, 30 policy: 'Policy', 31 buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], 32 reward_model: Optional['BaseRewardModel'] = None, 33 log_freq: int = 100, 34 ) -> None: 35 """ 36 Arguments: 37 - cfg (:obj:`EasyDict`): Config. 38 - policy (:obj:`Policy`): The policy to be trained. 39 - buffer (:obj:`Buffer`): The replay buffer to store the data for training. 40 - reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \ 41 default to None. 42 - log_freq (:obj:`int`): The frequency (iteration) of showing log. 43 """ 44 self.cfg = cfg 45 self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_)) 46 self._trainer = task.wrap(trainer(cfg, policy, log_freq=log_freq)) 47 if reward_model is not None: 48 self._reward_estimator = task.wrap(reward_estimator(cfg, reward_model)) 49 else: 50 self._reward_estimator = None 51 52 def __call__(self, ctx: "OnlineRLContext") -> None: 53 """ 54 Output of ctx: 55 - train_output (:obj:`Deque`): The training output in deque. 56 """ 57 train_output_queue = [] 58 for _ in range(self.cfg.policy.learn.update_per_collect): 59 self._fetcher(ctx) 60 if ctx.train_data is None: 61 break 62 if self._reward_estimator: 63 self._reward_estimator(ctx) 64 self._trainer(ctx) 65 train_output_queue.append(ctx.train_output) 66 ctx.train_output = train_output_queue 67 68 69class HERLearner: 70 """ 71 Overview: 72 The class of the learner with the Hindsight Experience Replay (HER). \ 73 Use the `__call__` method to execute the data featching and training \ 74 process. 75 """ 76 77 def __init__( 78 self, 79 cfg: EasyDict, 80 policy, 81 buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], 82 her_reward_model, 83 ) -> None: 84 """ 85 Arguments: 86 - cfg (:obj:`EasyDict`): Config. 87 - policy (:obj:`Policy`): The policy to be trained. 88 - buffer\_ (:obj:`Buffer`): The replay buffer to store the data for training. 89 - her_reward_model (:obj:`HerRewardModel`): HER reward model. 90 """ 91 self.cfg = cfg 92 self._fetcher = task.wrap(her_data_enhancer(cfg, buffer_, her_reward_model)) 93 self._trainer = task.wrap(trainer(cfg, policy)) 94 95 def __call__(self, ctx: "OnlineRLContext") -> None: 96 """ 97 Output of ctx: 98 - train_output (:obj:`Deque`): The deque of training output. 99 """ 100 train_output_queue = [] 101 for _ in range(self.cfg.policy.learn.update_per_collect): 102 self._fetcher(ctx) 103 if ctx.train_data is None: 104 break 105 self._trainer(ctx) 106 train_output_queue.append(ctx.train_output) 107 ctx.train_output = train_output_queue