1from typing import Any, Optional, Callable, Tuple 2from abc import ABC, abstractmethod 3from collections import namedtuple, deque 4from easydict import EasyDict 5import copy 6import numpy as np 7import torch 8 9from ding.utils import SERIAL_EVALUATOR_REGISTRY, import_module, lists_to_dicts 10from ding.torch_utils import to_tensor, to_ndarray, tensor_to_list 11 12 13class ISerialEvaluator(ABC): 14 """ 15 Overview: 16 Basic interface class for serial evaluator. 17 Interfaces: 18 reset, reset_policy, reset_env, close, should_eval, eval 19 Property: 20 env, policy 21 """ 22 23 @classmethod 24 def default_config(cls: type) -> EasyDict: 25 """ 26 Overview: 27 Get evaluator's default config. We merge evaluator's default config with other default configs\ 28 and user's config to get the final config. 29 Return: 30 cfg: (:obj:`EasyDict`): evaluator's default config 31 """ 32 cfg = EasyDict(copy.deepcopy(cls.config)) 33 cfg.cfg_type = cls.__name__ + 'Dict' 34 return cfg 35 36 @abstractmethod 37 def reset_env(self, _env: Optional[Any] = None) -> None: 38 raise NotImplementedError 39 40 @abstractmethod 41 def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: 42 raise NotImplementedError 43 44 @abstractmethod 45 def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[Any] = None) -> None: 46 raise NotImplementedError 47 48 @abstractmethod 49 def close(self) -> None: 50 raise NotImplementedError 51 52 @abstractmethod 53 def should_eval(self, train_iter: int) -> bool: 54 raise NotImplementedError 55 56 @abstractmethod 57 def eval( 58 self, 59 save_ckpt_fn: Callable = None, 60 train_iter: int = -1, 61 envstep: int = -1, 62 n_episode: Optional[int] = None 63 ) -> Any: 64 raise NotImplementedError 65 66 67def create_serial_evaluator(cfg: EasyDict, **kwargs) -> ISerialEvaluator: 68 """ 69 Overview: 70 Create a specific evaluator instance based on the config. 71 """ 72 import_module(cfg.get('import_names', [])) 73 if 'type' not in cfg: 74 cfg.type = 'interaction' 75 return SERIAL_EVALUATOR_REGISTRY.build(cfg.type, cfg=cfg, **kwargs) 76 77 78class VectorEvalMonitor(object): 79 """ 80 Overview: 81 In some cases, different environment in evaluator may collect different length episode. For example, \ 82 suppose we want to collect 12 episodes in evaluator but only have 5 environments, if we didn’t do \ 83 any thing, it is likely that we will get more short episodes than long episodes. As a result, \ 84 our average reward will have a bias and may not be accurate. we use VectorEvalMonitor to solve the problem. 85 Interfaces: 86 __init__, is_finished, update_info, update_reward, get_episode_return, get_latest_reward, get_current_episode,\ 87 get_episode_info 88 """ 89 90 def __init__(self, env_num: int, n_episode: int) -> None: 91 """ 92 Overview: 93 Init method. According to the number of episodes and the number of environments, determine how many \ 94 episodes need to be opened for each environment, and initialize the reward, info and other \ 95 information 96 Arguments: 97 - env_num (:obj:`int`): the number of episodes need to be open 98 - n_episode (:obj:`int`): the number of environments 99 """ 100 assert n_episode >= env_num, "n_episode < env_num, please decrease the number of eval env" 101 self._env_num = env_num 102 self._n_episode = n_episode 103 each_env_episode = [n_episode // env_num for _ in range(env_num)] 104 for i in range(n_episode % env_num): 105 each_env_episode[i] += 1 106 self._video = { 107 env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen) 108 for env_id, maxlen in enumerate(each_env_episode) 109 } 110 self._reward = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)} 111 self._info = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)} 112 113 def is_finished(self) -> bool: 114 """ 115 Overview: 116 Determine whether the evaluator has completed the work. 117 Return: 118 - result: (:obj:`bool`): whether the evaluator has completed the work 119 """ 120 return all([len(v) == v.maxlen for v in self._reward.values()]) 121 122 def update_info(self, env_id: int, info: Any) -> None: 123 """ 124 Overview: 125 Update the information of the environment indicated by env_id. 126 Arguments: 127 - env_id: (:obj:`int`): the id of the environment we need to update information 128 - info: (:obj:`Any`): the information we need to update 129 """ 130 info = tensor_to_list(info) 131 self._info[env_id].append(info) 132 133 def update_reward(self, env_id: int, reward: Any) -> None: 134 """ 135 Overview: 136 Update the reward indicated by env_id. 137 Arguments: 138 - env_id: (:obj:`int`): the id of the environment we need to update the reward 139 - reward: (:obj:`Any`): the reward we need to update 140 """ 141 if isinstance(reward, torch.Tensor): 142 reward = reward.item() 143 self._reward[env_id].append(reward) 144 145 def update_video(self, imgs): 146 for env_id, img in imgs.items(): 147 if len(self._reward[env_id]) == self._reward[env_id].maxlen: 148 continue 149 self._video[env_id][len(self._reward[env_id])].append(img) 150 151 def get_video(self): 152 """ 153 Overview: 154 Convert list of videos into [N, T, C, H, W] tensor, containing 155 worst, median, best evaluation trajectories for video logging. 156 """ 157 videos = sum([list(v) for v in self._video.values()], []) 158 videos = [np.transpose(np.stack(video, 0), [0, 3, 1, 2]) for video in videos] 159 sortarg = np.argsort(self.get_episode_return()) 160 # worst, median(s), best 161 if len(sortarg) == 1: 162 idxs = [sortarg[0]] 163 elif len(sortarg) == 2: 164 idxs = [sortarg[0], sortarg[-1]] 165 elif len(sortarg) == 3: 166 idxs = [sortarg[0], sortarg[len(sortarg) // 2], sortarg[-1]] 167 else: 168 # TensorboardX pad the number of videos to even numbers with black frames, 169 # therefore providing even number of videos prevents black frames being rendered. 170 idxs = [sortarg[0], sortarg[len(sortarg) // 2 - 1], sortarg[len(sortarg) // 2], sortarg[-1]] 171 videos = [videos[idx] for idx in idxs] 172 # pad videos to the same length with last frames 173 max_length = max(video.shape[0] for video in videos) 174 for i in range(len(videos)): 175 if videos[i].shape[0] < max_length: 176 padding = np.tile([videos[i][-1]], (max_length - videos[i].shape[0], 1, 1, 1)) 177 videos[i] = np.concatenate([videos[i], padding], 0) 178 videos = np.stack(videos, 0) 179 assert len(videos.shape) == 5, 'Need [N, T, C, H, W] input tensor for video logging!' 180 return videos 181 182 def get_episode_return(self) -> list: 183 """ 184 Overview: 185 Sum up all reward and get the total return of one episode. 186 """ 187 return sum([list(v) for v in self._reward.values()], []) # sum(iterable, start) 188 189 def get_latest_reward(self, env_id: int) -> int: 190 """ 191 Overview: 192 Get the latest reward of a certain environment. 193 Arguments: 194 - env_id: (:obj:`int`): the id of the environment we need to get reward. 195 """ 196 return self._reward[env_id][-1] 197 198 def get_current_episode(self) -> int: 199 """ 200 Overview: 201 Get the current episode. We can know which episode our evaluator is executing now. 202 """ 203 return sum([len(v) for v in self._reward.values()]) 204 205 def get_episode_info(self) -> dict: 206 """ 207 Overview: 208 Get all episode information, such as total return of one episode. 209 """ 210 if len(self._info[0]) == 0: 211 return None 212 else: 213 total_info = sum([list(v) for v in self._info.values()], []) 214 total_info = lists_to_dicts(total_info) 215 new_dict = {} 216 for k in total_info.keys(): 217 if np.isscalar(total_info[k][0]): 218 new_dict[k + '_mean'] = np.mean(total_info[k]) 219 total_info.update(new_dict) 220 return total_info