ding.world_model.utils¶
ding.world_model.utils
¶
get_rollout_length_scheduler(cfg)
¶
Overview
Get the rollout length scheduler that adapts rollout length based on the current environment steps.
Returns:
- scheduler (:obj:Callble): The function that takes envstep and return the current rollout length.
Full Source Code
../ding/world_model/utils.py
1from easydict import EasyDict 2from typing import Callable 3 4 5def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]: 6 """ 7 Overview: 8 Get the rollout length scheduler that adapts rollout length based\ 9 on the current environment steps. 10 Returns: 11 - scheduler (:obj:`Callble`): The function that takes envstep and\ 12 return the current rollout length. 13 """ 14 if cfg.type == 'linear': 15 x0 = cfg.rollout_start_step 16 x1 = cfg.rollout_end_step 17 y0 = cfg.rollout_length_min 18 y1 = cfg.rollout_length_max 19 w = (y1 - y0) / (x1 - x0) 20 b = y0 21 return lambda x: int(min(max(w * (x - x0) + b, y0), y1)) 22 elif cfg.type == 'constant': 23 return lambda x: cfg.rollout_length 24 else: 25 raise KeyError("not implemented key: {}".format(cfg.type))