Skip to content

ding.framework.middleware.functional.enhancer

ding.framework.middleware.functional.enhancer

reward_estimator(cfg, reward_model)

Overview

Estimate the reward of train_data using reward_model.

Arguments: - cfg (:obj:EasyDict): Config. - reward_model (:obj:BaseRewardModel): Reward model.

her_data_enhancer(cfg, buffer_, her_reward_model)

Overview

Fetch a batch of data/episode from buffer_, then use her_reward_model to get HER processed episodes from original episodes.

Arguments: - cfg (:obj:EasyDict): Config which should contain the following keys if her_reward_model.episode_size is None: cfg.policy.learn.batch_size. - buffer_ (:obj:Buffer): Buffer to sample data from. - her_reward_model (:obj:HerRewardModel): Hindsight Experience Replay (HER) model which is used to process episodes.

Full Source Code

../ding/framework/middleware/functional/enhancer.py

1from typing import TYPE_CHECKING, Callable 2from easydict import EasyDict 3from ditk import logging 4import torch 5from ding.framework import task 6if TYPE_CHECKING: 7 from ding.framework import OnlineRLContext 8 from ding.reward_model import BaseRewardModel, HerRewardModel 9 from ding.data import Buffer 10 11 12def reward_estimator(cfg: EasyDict, reward_model: "BaseRewardModel") -> Callable: 13 """ 14 Overview: 15 Estimate the reward of `train_data` using `reward_model`. 16 Arguments: 17 - cfg (:obj:`EasyDict`): Config. 18 - reward_model (:obj:`BaseRewardModel`): Reward model. 19 """ 20 if task.router.is_active and not task.has_role(task.role.LEARNER): 21 return task.void() 22 23 def _enhance(ctx: "OnlineRLContext"): 24 """ 25 Input of ctx: 26 - train_data (:obj:`List`): The list of data used for estimation. 27 """ 28 reward_model.estimate(ctx.train_data) # inplace modification 29 30 return _enhance 31 32 33def her_data_enhancer(cfg: EasyDict, buffer_: "Buffer", her_reward_model: "HerRewardModel") -> Callable: 34 """ 35 Overview: 36 Fetch a batch of data/episode from `buffer_`, \ 37 then use `her_reward_model` to get HER processed episodes from original episodes. 38 Arguments: 39 - cfg (:obj:`EasyDict`): Config which should contain the following keys \ 40 if her_reward_model.episode_size is None: `cfg.policy.learn.batch_size`. 41 - buffer\_ (:obj:`Buffer`): Buffer to sample data from. 42 - her_reward_model (:obj:`HerRewardModel`): Hindsight Experience Replay (HER) model \ 43 which is used to process episodes. 44 """ 45 if task.router.is_active and not task.has_role(task.role.LEARNER): 46 return task.void() 47 48 def _fetch_and_enhance(ctx: "OnlineRLContext"): 49 """ 50 Output of ctx: 51 - train_data (:obj:`List[treetensor.torch.Tensor]`): The HER processed episodes. 52 """ 53 if her_reward_model.episode_size is None: 54 size = cfg.policy.learn.batch_size 55 else: 56 size = her_reward_model.episode_size 57 try: 58 buffered_episode = buffer_.sample(size) 59 train_episode = [d.data for d in buffered_episode] 60 except (ValueError, AssertionError): 61 # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode. 62 logging.warning( 63 "Replay buffer's data is not enough to support training, so skip this training for waiting more data." 64 ) 65 ctx.train_data = None 66 return 67 68 her_episode = sum([her_reward_model.estimate(e) for e in train_episode], []) 69 ctx.train_data = sum(her_episode, []) 70 71 return _fetch_and_enhance 72 73 74def nstep_reward_enhancer(cfg: EasyDict) -> Callable: 75 76 if task.router.is_active and (not task.has_role(task.role.LEARNER) and not task.has_role(task.role.COLLECTOR)): 77 return task.void() 78 79 def _enhance(ctx: "OnlineRLContext"): 80 nstep = cfg.policy.nstep 81 gamma = cfg.policy.discount_factor 82 L = len(ctx.trajectories) 83 reward_template = ctx.trajectories[0].reward 84 nstep_rewards = [] 85 value_gamma = [] 86 for i in range(L): 87 valid = min(nstep, L - i) 88 for j in range(1, valid): 89 if ctx.trajectories[j + i].done: 90 valid = j 91 break 92 value_gamma.append(torch.FloatTensor([gamma ** valid])) 93 nstep_reward = [ctx.trajectories[j].reward for j in range(i, i + valid)] 94 if nstep > valid: 95 nstep_reward.extend([torch.zeros_like(reward_template) for j in range(nstep - valid)]) 96 nstep_reward = torch.cat(nstep_reward) # (nstep, ) 97 nstep_rewards.append(nstep_reward) 98 for i in range(L): 99 ctx.trajectories[i].reward = nstep_rewards[i] 100 ctx.trajectories[i].value_gamma = value_gamma[i] 101 102 return _enhance 103 104 105# TODO MBPO 106# TODO SIL 107# TODO TD3 VAE