Skip to content

ding.framework.middleware.functional.collector

ding.framework.middleware.functional.collector

inferencer(seed, policy, env)

Overview

The middleware that executes the inference process.

Arguments: - seed (:obj:int): Random seed. - policy (:obj:Policy): The policy to be inferred. - env (:obj:BaseEnvManager): The env where the inference process is performed. The env.ready_obs (:obj:tnp.array) will be used as model input.

rolloutor(policy, env, transitions, collect_print_freq=100)

Overview

The middleware that executes the transition process in the env.

Arguments: - policy (:obj:Policy): The policy to be used during transition. - env (:obj:BaseEnvManager): The env for the collection, the BaseEnvManager object or its derivatives are supported. - transitions (:obj:TransitionList): The transition information which will be filled in this process, including obs, next_obs, action, logit, value, reward and done.

output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count)

Overview

Print the output log information. You can refer to the docs of Best Practice to understand the training generated logs and tensorboards.

Arguments: - train_iter (:obj:int): the number of training iteration.

Full Source Code

../ding/framework/middleware/functional/collector.py

1from typing import TYPE_CHECKING, Callable, List, Tuple, Any 2from functools import reduce 3import treetensor.torch as ttorch 4import numpy as np 5from ditk import logging 6from ding.utils import EasyTimer 7from ding.envs import BaseEnvManager 8from ding.policy import Policy 9from ding.torch_utils import to_ndarray, get_shape0 10 11if TYPE_CHECKING: 12 from ding.framework import OnlineRLContext 13 14 15class TransitionList: 16 17 def __init__(self, env_num: int) -> None: 18 self.env_num = env_num 19 self._transitions = [[] for _ in range(env_num)] 20 self._done_idx = [[] for _ in range(env_num)] 21 22 def append(self, env_id: int, transition: Any) -> None: 23 self._transitions[env_id].append(transition) 24 if transition.done: 25 self._done_idx[env_id].append(len(self._transitions[env_id])) 26 27 def to_trajectories(self) -> Tuple[List[Any], List[int]]: 28 trajectories = sum(self._transitions, []) 29 lengths = [len(t) for t in self._transitions] 30 trajectory_end_idx = [reduce(lambda x, y: x + y, lengths[:i + 1]) for i in range(len(lengths))] 31 trajectory_end_idx = [t - 1 for t in trajectory_end_idx] 32 return trajectories, trajectory_end_idx 33 34 def to_episodes(self) -> List[List[Any]]: 35 episodes = [] 36 for env_id in range(self.env_num): 37 last_idx = 0 38 for done_idx in self._done_idx[env_id]: 39 episodes.append(self._transitions[env_id][last_idx:done_idx]) 40 last_idx = done_idx 41 return episodes 42 43 def clear(self): 44 for item in self._transitions: 45 item.clear() 46 for item in self._done_idx: 47 item.clear() 48 49 50def inferencer(seed: int, policy: Policy, env: BaseEnvManager) -> Callable: 51 """ 52 Overview: 53 The middleware that executes the inference process. 54 Arguments: 55 - seed (:obj:`int`): Random seed. 56 - policy (:obj:`Policy`): The policy to be inferred. 57 - env (:obj:`BaseEnvManager`): The env where the inference process is performed. \ 58 The env.ready_obs (:obj:`tnp.array`) will be used as model input. 59 """ 60 61 env.seed(seed) 62 63 def _inference(ctx: "OnlineRLContext"): 64 """ 65 Output of ctx: 66 - obs (:obj:`Union[torch.Tensor, Dict[torch.Tensor]]`): The input observations collected \ 67 from all collector environments. 68 - action: (:obj:`List[np.ndarray]`): The inferred actions listed by env_id. 69 - inference_output (:obj:`Dict[int, Dict]`): The dict of which the key is env_id (int), \ 70 and the value is inference result (Dict). 71 """ 72 73 if env.closed: 74 env.launch() 75 76 obs = ttorch.as_tensor(env.ready_obs) 77 ctx.obs = obs 78 obs = obs.to(dtype=ttorch.float32) 79 # TODO mask necessary rollout 80 81 obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD 82 inference_output = policy.forward(obs, **ctx.collect_kwargs) 83 ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD 84 ctx.inference_output = inference_output 85 86 return _inference 87 88 89def rolloutor( 90 policy: Policy, 91 env: BaseEnvManager, 92 transitions: TransitionList, 93 collect_print_freq=100, 94) -> Callable: 95 """ 96 Overview: 97 The middleware that executes the transition process in the env. 98 Arguments: 99 - policy (:obj:`Policy`): The policy to be used during transition. 100 - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ 101 its derivatives are supported. 102 - transitions (:obj:`TransitionList`): The transition information which will be filled \ 103 in this process, including `obs`, `next_obs`, `action`, `logit`, `value`, `reward` \ 104 and `done`. 105 """ 106 107 env_episode_id = [_ for _ in range(env.env_num)] 108 current_id = env.env_num 109 timer = EasyTimer() 110 last_train_iter = 0 111 total_envstep_count = 0 112 total_episode_count = 0 113 total_train_sample_count = 0 114 env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)} 115 episode_info = [] 116 117 def _rollout(ctx: "OnlineRLContext"): 118 """ 119 Input of ctx: 120 - action: (:obj:`List[np.ndarray]`): The inferred actions from previous inference process. 121 - obs (:obj:`Dict[Tensor]`): The states fed into the transition dict. 122 - inference_output (:obj:`Dict[int, Dict]`): The inference results to be fed into the \ 123 transition dict. 124 - train_iter (:obj:`int`): The train iteration count to be fed into the transition dict. 125 - env_step (:obj:`int`): The count of env step, which will increase by 1 for a single \ 126 transition call. 127 - env_episode (:obj:`int`): The count of env episode, which will increase by 1 if the \ 128 trajectory stops. 129 """ 130 131 nonlocal current_id, env_info, episode_info, timer, \ 132 total_episode_count, total_envstep_count, total_train_sample_count, last_train_iter 133 timesteps = env.step(ctx.action) 134 ctx.env_step += len(timesteps) 135 timesteps = [t.tensor() for t in timesteps] 136 137 collected_sample = 0 138 collected_step = 0 139 collected_episode = 0 140 interaction_duration = timer.value / len(timesteps) 141 for i, timestep in enumerate(timesteps): 142 with timer: 143 transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep) 144 transition = ttorch.as_tensor(transition) 145 transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) 146 transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]]) 147 transitions.append(timestep.env_id, transition) 148 149 collected_step += 1 150 collected_sample += len(transition.obs) 151 env_info[timestep.env_id.item()]['step'] += 1 152 env_info[timestep.env_id.item()]['train_sample'] += len(transition.obs) 153 154 env_info[timestep.env_id.item()]['time'] += timer.value + interaction_duration 155 if timestep.done: 156 info = { 157 'reward': timestep.info['eval_episode_return'], 158 'time': env_info[timestep.env_id.item()]['time'], 159 'step': env_info[timestep.env_id.item()]['step'], 160 'train_sample': env_info[timestep.env_id.item()]['train_sample'], 161 } 162 # reset corresponding env info 163 env_info[timestep.env_id.item()] = {'time': 0., 'step': 0, 'train_sample': 0} 164 165 episode_info.append(info) 166 policy.reset([timestep.env_id.item()]) 167 env_episode_id[timestep.env_id.item()] = current_id 168 collected_episode += 1 169 current_id += 1 170 ctx.env_episode += 1 171 172 total_envstep_count += collected_step 173 total_episode_count += collected_episode 174 total_train_sample_count += collected_sample 175 176 if (ctx.train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0: 177 output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) 178 last_train_iter = ctx.train_iter 179 180 return _rollout 181 182 183def output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) -> None: 184 """ 185 Overview: 186 Print the output log information. You can refer to the docs of `Best Practice` to understand \ 187 the training generated logs and tensorboards. 188 Arguments: 189 - train_iter (:obj:`int`): the number of training iteration. 190 """ 191 episode_count = len(episode_info) 192 envstep_count = sum([d['step'] for d in episode_info]) 193 train_sample_count = sum([d['train_sample'] for d in episode_info]) 194 duration = sum([d['time'] for d in episode_info]) 195 episode_return = [d['reward'].item() for d in episode_info] 196 info = { 197 'episode_count': episode_count, 198 'envstep_count': envstep_count, 199 'train_sample_count': train_sample_count, 200 'avg_envstep_per_episode': envstep_count / episode_count, 201 'avg_sample_per_episode': train_sample_count / episode_count, 202 'avg_envstep_per_sec': envstep_count / duration, 203 'avg_train_sample_per_sec': train_sample_count / duration, 204 'avg_episode_per_sec': episode_count / duration, 205 'reward_mean': np.mean(episode_return), 206 'reward_std': np.std(episode_return), 207 'reward_max': np.max(episode_return), 208 'reward_min': np.min(episode_return), 209 'total_envstep_count': total_envstep_count, 210 'total_train_sample_count': total_train_sample_count, 211 'total_episode_count': total_episode_count, 212 # 'each_reward': episode_return, 213 } 214 episode_info.clear() 215 logging.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))