ding.framework.middleware.functional.explorer¶
ding.framework.middleware.functional.explorer
¶
Full Source Code
../ding/framework/middleware/functional/explorer.py
1from typing import TYPE_CHECKING, Callable 2from easydict import EasyDict 3from ding.rl_utils import get_epsilon_greedy_fn 4from ding.framework import task 5 6if TYPE_CHECKING: 7 from ding.framework import OnlineRLContext 8 9 10def eps_greedy_handler(cfg: EasyDict) -> Callable: 11 """ 12 Overview: 13 The middleware that computes epsilon value according to the env_step. 14 Arguments: 15 - cfg (:obj:`EasyDict`): Config. 16 """ 17 if task.router.is_active and not task.has_role(task.role.COLLECTOR): 18 return task.void() 19 20 eps_cfg = cfg.policy.other.eps 21 handle = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) 22 23 def _eps_greedy(ctx: "OnlineRLContext"): 24 """ 25 Input of ctx: 26 - env_step (:obj:`int`): The env steps count. 27 Output of ctx: 28 - collect_kwargs['eps'] (:obj:`float`): The eps conditioned on env_step and cfg. 29 """ 30 31 ctx.collect_kwargs['eps'] = handle(ctx.env_step) 32 yield 33 try: 34 ctx.collect_kwargs.pop('eps') 35 except: # noqa 36 pass 37 38 return _eps_greedy 39 40 41def eps_greedy_masker(): 42 """ 43 Overview: 44 The middleware that returns masked epsilon value and stop generating \ 45 actions by the e_greedy method. 46 """ 47 48 def _masker(ctx: "OnlineRLContext"): 49 """ 50 Output of ctx: 51 - collect_kwargs['eps'] (:obj:`float`): The masked eps value, default to -1. 52 """ 53 54 ctx.collect_kwargs['eps'] = -1 55 56 return _masker