Skip to content

ding.framework.middleware.collector

ding.framework.middleware.collector

StepCollector

Overview

The class of the collector running by steps, including model inference and transition process. Use the __call__ method to execute the whole collection process.

__init__(cfg, policy, env, random_collect_size=0)

Parameters:

Name Type Description Default
- cfg (

obj:EasyDict): Config.

required
- policy (

obj:Policy): The policy to be collected.

required
- env (

obj:BaseEnvManager): The env for the collection, the BaseEnvManager object or its derivatives are supported.

required
- random_collect_size (

obj:int): The count of samples that will be collected randomly, typically used in initial runs.

required

__call__(ctx)

Overview

An encapsulation of inference and rollout middleware. Stop when completing the target number of steps.

Input of ctx: - env_step (:obj:int): The env steps which will increase during collection.

PPOFStepCollector

Overview

The class of the collector running by steps, including model inference and transition process. Use the __call__ method to execute the whole collection process.

__init__(seed, policy, env, n_sample, unroll_len=1)

Parameters:

Name Type Description Default
- seed (

obj:int): Random seed.

required
- policy (

obj:Policy): The policy to be collected.

required
- env (

obj:BaseEnvManager): The env for the collection, the BaseEnvManager object or its derivatives are supported.

required

__call__(ctx)

Overview

An encapsulation of inference and rollout middleware. Stop when completing the target number of steps.

Input of ctx: - env_step (:obj:int): The env steps which will increase during collection.

EpisodeCollector

Overview

The class of the collector running by episodes, including model inference and transition process. Use the __call__ method to execute the whole collection process.

__init__(cfg, policy, env, random_collect_size=0)

Parameters:

Name Type Description Default
- cfg (

obj:EasyDict): Config.

required
- policy (

obj:Policy): The policy to be collected.

required
- env (

obj:BaseEnvManager): The env for the collection, the BaseEnvManager object or its derivatives are supported.

required
- random_collect_size (

obj:int): The count of samples that will be collected randomly, typically used in initial runs.

required

__call__(ctx)

Overview

An encapsulation of inference and rollout middleware. Stop when completing the target number of episodes.

Input of ctx: - env_episode (:obj:int): The env env_episode which will increase during collection.

Full Source Code

../ding/framework/middleware/collector.py

1from typing import TYPE_CHECKING 2from easydict import EasyDict 3import treetensor.torch as ttorch 4 5from ding.policy import get_random_policy 6from ding.envs import BaseEnvManager 7from ding.framework import task 8from .functional import inferencer, rolloutor, TransitionList 9 10if TYPE_CHECKING: 11 from ding.framework import OnlineRLContext 12 13 14class StepCollector: 15 """ 16 Overview: 17 The class of the collector running by steps, including model inference and transition \ 18 process. Use the `__call__` method to execute the whole collection process. 19 """ 20 21 def __new__(cls, *args, **kwargs): 22 if task.router.is_active and not task.has_role(task.role.COLLECTOR): 23 return task.void() 24 return super(StepCollector, cls).__new__(cls) 25 26 def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: 27 """ 28 Arguments: 29 - cfg (:obj:`EasyDict`): Config. 30 - policy (:obj:`Policy`): The policy to be collected. 31 - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ 32 its derivatives are supported. 33 - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \ 34 typically used in initial runs. 35 """ 36 self.cfg = cfg 37 self.env = env 38 self.policy = policy 39 self.random_collect_size = random_collect_size 40 self._transitions = TransitionList(self.env.env_num) 41 self._inferencer = task.wrap(inferencer(cfg.seed, policy, env)) 42 self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions)) 43 44 def __call__(self, ctx: "OnlineRLContext") -> None: 45 """ 46 Overview: 47 An encapsulation of inference and rollout middleware. Stop when completing \ 48 the target number of steps. 49 Input of ctx: 50 - env_step (:obj:`int`): The env steps which will increase during collection. 51 """ 52 old = ctx.env_step 53 if self.random_collect_size > 0 and old < self.random_collect_size: 54 target_size = self.random_collect_size - old 55 random_policy = get_random_policy(self.cfg, self.policy, self.env) 56 current_inferencer = task.wrap(inferencer(self.cfg.seed, random_policy, self.env)) 57 else: 58 # compatible with old config, a train sample = unroll_len step 59 target_size = self.cfg.policy.collect.n_sample * self.cfg.policy.collect.unroll_len 60 current_inferencer = self._inferencer 61 62 while True: 63 current_inferencer(ctx) 64 self._rolloutor(ctx) 65 if ctx.env_step - old >= target_size: 66 ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories() 67 self._transitions.clear() 68 break 69 70 71class PPOFStepCollector: 72 """ 73 Overview: 74 The class of the collector running by steps, including model inference and transition \ 75 process. Use the `__call__` method to execute the whole collection process. 76 """ 77 78 def __new__(cls, *args, **kwargs): 79 if task.router.is_active and not task.has_role(task.role.COLLECTOR): 80 return task.void() 81 return super(PPOFStepCollector, cls).__new__(cls) 82 83 def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None: 84 """ 85 Arguments: 86 - seed (:obj:`int`): Random seed. 87 - policy (:obj:`Policy`): The policy to be collected. 88 - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ 89 its derivatives are supported. 90 """ 91 self.env = env 92 self.env.seed(seed) 93 self.policy = policy 94 self.n_sample = n_sample 95 self.unroll_len = unroll_len 96 self._transitions = TransitionList(self.env.env_num) 97 self._env_episode_id = [_ for _ in range(env.env_num)] 98 self._current_id = env.env_num 99 100 def __call__(self, ctx: "OnlineRLContext") -> None: 101 """ 102 Overview: 103 An encapsulation of inference and rollout middleware. Stop when completing \ 104 the target number of steps. 105 Input of ctx: 106 - env_step (:obj:`int`): The env steps which will increase during collection. 107 """ 108 device = self.policy._device 109 old = ctx.env_step 110 target_size = self.n_sample * self.unroll_len 111 112 if self.env.closed: 113 self.env.launch() 114 115 while True: 116 obs = ttorch.as_tensor(self.env.ready_obs).to(dtype=ttorch.float32) 117 obs = obs.to(device) 118 inference_output = self.policy.collect(obs, **ctx.collect_kwargs) 119 inference_output = inference_output.cpu() 120 action = inference_output.action.numpy() 121 timesteps = self.env.step(action) 122 ctx.env_step += len(timesteps) 123 124 obs = obs.cpu() 125 for i, timestep in enumerate(timesteps): 126 transition = self.policy.process_transition(obs[i], inference_output[i], timestep) 127 transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) 128 transition.env_data_id = ttorch.as_tensor([self._env_episode_id[timestep.env_id]]) 129 self._transitions.append(timestep.env_id, transition) 130 if timestep.done: 131 self.policy.reset([timestep.env_id]) 132 self._env_episode_id[timestep.env_id] = self._current_id 133 self._current_id += 1 134 ctx.env_episode += 1 135 136 if ctx.env_step - old >= target_size: 137 ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories() 138 self._transitions.clear() 139 break 140 141 142class EpisodeCollector: 143 """ 144 Overview: 145 The class of the collector running by episodes, including model inference and transition \ 146 process. Use the `__call__` method to execute the whole collection process. 147 """ 148 149 def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: 150 """ 151 Arguments: 152 - cfg (:obj:`EasyDict`): Config. 153 - policy (:obj:`Policy`): The policy to be collected. 154 - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ 155 its derivatives are supported. 156 - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \ 157 typically used in initial runs. 158 """ 159 self.cfg = cfg 160 self.env = env 161 self.policy = policy 162 self.random_collect_size = random_collect_size 163 self._transitions = TransitionList(self.env.env_num) 164 self._inferencer = task.wrap(inferencer(cfg.seed, policy, env)) 165 self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions)) 166 167 def __call__(self, ctx: "OnlineRLContext") -> None: 168 """ 169 Overview: 170 An encapsulation of inference and rollout middleware. Stop when completing the \ 171 target number of episodes. 172 Input of ctx: 173 - env_episode (:obj:`int`): The env env_episode which will increase during collection. 174 """ 175 old = ctx.env_episode 176 if self.random_collect_size > 0 and old < self.random_collect_size: 177 target_size = self.random_collect_size - old 178 random_policy = get_random_policy(self.cfg, self.policy, self.env) 179 current_inferencer = task.wrap(inferencer(self.cfg, random_policy, self.env)) 180 else: 181 target_size = self.cfg.policy.collect.n_episode 182 current_inferencer = self._inferencer 183 184 while True: 185 current_inferencer(ctx) 186 self._rolloutor(ctx) 187 if ctx.env_episode - old >= target_size: 188 ctx.episodes = self._transitions.to_episodes() 189 self._transitions.clear() 190 break 191 192 193# TODO battle collector