1from typing import List, Optional, Union, Dict 2from easydict import EasyDict 3import gym 4import gymnasium 5import copy 6import numpy as np 7import treetensor.numpy as tnp 8 9from ding.envs.common.common_function import affine_transform 10from ding.envs.env_wrappers import create_env_wrapper, GymToGymnasiumWrapper 11from ding.torch_utils import to_ndarray 12from ding.utils import CloudPickleWrapper 13from .base_env import BaseEnv, BaseEnvTimestep 14from .default_wrapper import get_default_wrappers 15 16 17class DingEnvWrapper(BaseEnv): 18 """ 19 Overview: 20 This is a wrapper for the BaseEnv class, used to provide a consistent environment interface. 21 Interfaces: 22 __init__, reset, step, close, seed, random_action, _wrap_env, __repr__, create_collector_env_cfg, 23 create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone 24 """ 25 26 def __init__( 27 self, 28 env: Union[gym.Env, gymnasium.Env] = None, 29 cfg: dict = None, 30 seed_api: bool = True, 31 caller: str = 'collector', 32 is_gymnasium: bool = False 33 ) -> None: 34 """ 35 Overview: 36 Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \ 37 instance should be passed in. For the former, i.e., an environment instance: The `env` parameter must not \ 38 be `None`, but should be the instance. It does not support subprocess environment manager. Thus, it is \ 39 usually used in simple environments. For the latter, i.e., a config to create an environment instance: \ 40 The `cfg` parameter must contain `env_id`. 41 Arguments: 42 - env (:obj:`Union[gym.Env, gymnasium.Env]`): An environment instance to be wrapped. 43 - cfg (:obj:`dict`): The configuration dictionary to create an environment instance. 44 - seed_api (:obj:`bool`): Whether to use seed API. Defaults to True. 45 - caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \ 46 ``evaluator``. Different caller may need different wrappers. Default is 'collector'. 47 - is_gymnasium (:obj:`bool`): Whether the environment is a gymnasium environment. Defaults to False, i.e., \ 48 the environment is a gym environment. 49 """ 50 self._env = None 51 self._raw_env = env 52 self._cfg = cfg 53 self._seed_api = seed_api # some env may disable `env.seed` api 54 self._caller = caller 55 56 if self._cfg is None: 57 self._cfg = {} 58 self._cfg = EasyDict(self._cfg) 59 if 'act_scale' not in self._cfg: 60 self._cfg.act_scale = False 61 if 'rew_clip' not in self._cfg: 62 self._cfg.rew_clip = False 63 if 'env_wrapper' not in self._cfg: 64 self._cfg.env_wrapper = 'default' 65 if 'env_id' not in self._cfg: 66 self._cfg.env_id = None 67 if env is not None: 68 self._is_gymnasium = isinstance(env, gymnasium.Env) 69 self._env = env 70 self._wrap_env(caller) 71 self._observation_space = self._env.observation_space 72 self._action_space = self._env.action_space 73 self._action_space.seed(0) # default seed 74 try: 75 low, high = self._env.reward_range 76 except: # for compatibility with gymnasium high-version API 77 low, high = -1, 1 78 self._reward_space = gym.spaces.Box(low=low, high=high, shape=(1, ), dtype=np.float32) 79 self._init_flag = True 80 else: 81 assert 'env_id' in self._cfg 82 self._is_gymnasium = is_gymnasium 83 self._init_flag = False 84 self._observation_space = None 85 self._action_space = None 86 self._reward_space = None 87 # Only if user specifies the replay_path, will the video be saved. So its inital value is None. 88 self._replay_path = None 89 90 # override 91 def reset(self) -> np.ndarray: 92 """ 93 Overview: 94 Resets the state of the environment. If the environment is not initialized, it will be created first. 95 Returns: 96 - obs (:obj:`Dict`): The new observation after reset. 97 """ 98 if not self._init_flag: 99 gym_proxy = gymnasium if self._is_gymnasium else gym 100 self._env = gym_proxy.make(self._cfg.env_id) 101 self._wrap_env(self._caller) 102 self._observation_space = self._env.observation_space 103 self._action_space = self._env.action_space 104 self._reward_space = gym.spaces.Box( 105 low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 106 ) 107 self._init_flag = True 108 if self._replay_path is not None: 109 self._env = gym.wrappers.RecordVideo( 110 self._env, 111 video_folder=self._replay_path, 112 episode_trigger=lambda episode_id: True, 113 name_prefix='rl-video-{}'.format(id(self)) 114 ) 115 self._replay_path = None 116 if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: 117 np_seed = 100 * np.random.randint(1, 1000) 118 if self._seed_api: 119 self._env.seed(self._seed + np_seed) 120 self._action_space.seed(self._seed + np_seed) 121 elif hasattr(self, '_seed'): 122 if self._seed_api: 123 self._env.seed(self._seed) 124 self._action_space.seed(self._seed) 125 obs = self._env.reset() 126 if self.observation_space.dtype == np.float32: 127 obs = to_ndarray(obs, dtype=np.float32) 128 else: 129 obs = to_ndarray(obs) 130 return obs 131 132 # override 133 def close(self) -> None: 134 """ 135 Overview: 136 Clean up the environment by closing and deleting it. 137 This method should be called when the environment is no longer needed. 138 Failing to call this method can lead to memory leaks. 139 """ 140 try: 141 self._env.close() 142 del self._env 143 except: # noqa 144 pass 145 146 # override 147 def seed(self, seed: int, dynamic_seed: bool = True) -> None: 148 """ 149 Overview: 150 Set the seed for the environment. 151 Arguments: 152 - seed (:obj:`int`): The seed to set. 153 - dynamic_seed (:obj:`bool`): Whether to use dynamic seed, default is True. 154 """ 155 self._seed = seed 156 self._dynamic_seed = dynamic_seed 157 np.random.seed(self._seed) 158 159 # override 160 def step(self, action: Union[np.int64, np.ndarray]) -> BaseEnvTimestep: 161 """ 162 Overview: 163 Execute the given action in the environment, and return the timestep (observation, reward, done, info). 164 Arguments: 165 - action (:obj:`Union[np.int64, np.ndarray]`): The action to execute in the environment. 166 Returns: 167 - timestep (:obj:`BaseEnvTimestep`): The timestep after the action execution. 168 """ 169 action = self._judge_action_type(action) 170 if self._cfg.act_scale: 171 action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high) 172 obs, rew, done, info = self._env.step(action) 173 if self._cfg.rew_clip: 174 rew = max(-10, rew) 175 rew = np.float32(rew) 176 if self.observation_space.dtype == np.float32: 177 obs = to_ndarray(obs, dtype=np.float32) 178 else: 179 obs = to_ndarray(obs) 180 rew = to_ndarray([rew], np.float32) 181 return BaseEnvTimestep(obs, rew, done, info) 182 183 def _judge_action_type(self, action: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]: 184 """ 185 Overview: 186 Ensure the action taken by the agent is of the correct type. 187 This method is used to standardize different action types to a common format. 188 Arguments: 189 - action (Union[np.ndarray, dict]): The action taken by the agent. 190 Returns: 191 - action (Union[np.ndarray, dict]): The formatted action. 192 """ 193 if isinstance(action, int): 194 return action 195 elif isinstance(action, np.int64): 196 return int(action) 197 elif isinstance(action, np.ndarray): 198 if action.shape == (): 199 action = action.item() 200 elif action.shape == (1, ) and action.dtype == np.int64: 201 action = action.item() 202 return action 203 elif isinstance(action, dict): 204 for k, v in action.items(): 205 action[k] = self._judge_action_type(v) 206 return action 207 elif isinstance(action, tnp.ndarray): 208 return self._judge_action_type(action.json()) 209 else: 210 raise TypeError( 211 '`action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( 212 type(action), action 213 ) 214 ) 215 216 def random_action(self) -> np.ndarray: 217 """ 218 Overview: 219 Return a random action from the action space of the environment. 220 Returns: 221 - action (:obj:`np.ndarray`): The random action. 222 """ 223 random_action = self.action_space.sample() 224 if isinstance(random_action, np.ndarray): 225 pass 226 elif isinstance(random_action, (int, np.int64)): 227 random_action = to_ndarray([random_action], dtype=np.int64) 228 elif isinstance(random_action, dict): 229 random_action = to_ndarray(random_action) 230 else: 231 raise TypeError( 232 '`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( 233 type(random_action), random_action 234 ) 235 ) 236 return random_action 237 238 def _wrap_env(self, caller: str = 'collector') -> None: 239 """ 240 Overview: 241 Wrap the environment according to the configuration. 242 Arguments: 243 - caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \ 244 Different caller may need different wrappers. Default is 'collector'. 245 """ 246 if self._is_gymnasium: 247 self._env = GymToGymnasiumWrapper(self._env) 248 # wrapper_cfgs: Union[str, List] 249 wrapper_cfgs = self._cfg.env_wrapper 250 if isinstance(wrapper_cfgs, str): 251 wrapper_cfgs = get_default_wrappers(wrapper_cfgs, self._cfg.env_id, caller) 252 # self._wrapper_cfgs: List[Union[Callable, Dict]] 253 self._wrapper_cfgs = wrapper_cfgs 254 for wrapper in self._wrapper_cfgs: 255 # wrapper: Union[Callable, Dict] 256 if isinstance(wrapper, Dict): 257 self._env = create_env_wrapper(self._env, wrapper) 258 else: # Callable, such as lambda anonymous function 259 self._env = wrapper(self._env) 260 261 def __repr__(self) -> str: 262 """ 263 Overview: 264 Return the string representation of the instance. 265 Returns: 266 - str (:obj:`str`): The string representation of the instance. 267 """ 268 return "DI-engine Env({}), generated by DingEnvWrapper".format(self._cfg.env_id) 269 270 @staticmethod 271 def create_collector_env_cfg(cfg: dict) -> List[dict]: 272 """ 273 Overview: 274 Create a list of environment configuration for collectors based on the input configuration. 275 Arguments: 276 - cfg (:obj:`dict`): The input configuration dictionary. 277 Returns: 278 - env_cfgs (:obj:`List[dict]`): The list of environment configurations for collectors. 279 """ 280 actor_env_num = cfg.pop('collector_env_num') 281 cfg = copy.deepcopy(cfg) 282 cfg.is_train = True 283 return [cfg for _ in range(actor_env_num)] 284 285 @staticmethod 286 def create_evaluator_env_cfg(cfg: dict) -> List[dict]: 287 """ 288 Overview: 289 Create a list of environment configuration for evaluators based on the input configuration. 290 Arguments: 291 - cfg (:obj:`dict`): The input configuration dictionary. 292 Returns: 293 - env_cfgs (:obj:`List[dict]`): The list of environment configurations for evaluators. 294 """ 295 evaluator_env_num = cfg.pop('evaluator_env_num') 296 cfg = copy.deepcopy(cfg) 297 cfg.is_train = False 298 return [cfg for _ in range(evaluator_env_num)] 299 300 def enable_save_replay(self, replay_path: Optional[str] = None) -> None: 301 """ 302 Overview: 303 Enable the save replay functionality. The replay will be saved at the specified path. 304 Arguments: 305 - replay_path (:obj:`Optional[str]`): The path to save the replay, default is None. 306 """ 307 if replay_path is None: 308 replay_path = './video' 309 self._replay_path = replay_path 310 311 @property 312 def observation_space(self) -> gym.spaces.Space: 313 """ 314 Overview: 315 Return the observation space of the wrapped environment. 316 The observation space represents the range and shape of possible observations 317 that the environment can provide to the agent. 318 Note: 319 If the data type of the observation space is float64, it's converted to float32 320 for better compatibility with most machine learning libraries. 321 Returns: 322 - observation_space (gym.spaces.Space): The observation space of the environment. 323 """ 324 if self._observation_space.dtype == np.float64: 325 self._observation_space.dtype = np.float32 326 return self._observation_space 327 328 @property 329 def action_space(self) -> gym.spaces.Space: 330 """ 331 Overview: 332 Return the action space of the wrapped environment. 333 The action space represents the range and shape of possible actions 334 that the agent can take in the environment. 335 Returns: 336 - action_space (gym.spaces.Space): The action space of the environment. 337 """ 338 return self._action_space 339 340 @property 341 def reward_space(self) -> gym.spaces.Space: 342 """ 343 Overview: 344 Return the reward space of the wrapped environment. 345 The reward space represents the range and shape of possible rewards 346 that the agent can receive as a result of its actions. 347 Returns: 348 - reward_space (gym.spaces.Space): The reward space of the environment. 349 """ 350 return self._reward_space 351 352 def clone(self, caller: str = 'collector') -> BaseEnv: 353 """ 354 Overview: 355 Clone the current environment wrapper, creating a new environment with the same settings. 356 Arguments: 357 - caller (str): A string representing the caller of this method, including ``collector`` or ``evaluator``. \ 358 Different caller may need different wrappers. Default is 'collector'. 359 Returns: 360 - DingEnvWrapper: A new instance of the environment with the same settings. 361 """ 362 try: 363 spec = copy.deepcopy(self._raw_env.spec) 364 raw_env = CloudPickleWrapper(self._raw_env) 365 raw_env = copy.deepcopy(raw_env).data 366 raw_env.__setattr__('spec', spec) 367 except Exception: 368 raw_env = self._raw_env 369 return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller, self._is_gymnasium)