ding.envs.env.env_implementation_check¶
ding.envs.env.env_implementation_check
¶
Full Source Code
../ding/envs/env/env_implementation_check.py
1from typing import Any, Callable, List, Tuple, Union, Dict, TYPE_CHECKING 2import numpy as np 3from collections.abc import Sequence 4from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary 5from ding.envs.env.tests import DemoEnv 6# from dizoo.atari.envs import AtariEnv 7 8if TYPE_CHECKING: 9 from ding.envs.env import BaseEnv 10 11 12def check_space_dtype(env: 'BaseEnv') -> None: 13 print("== 0. Test obs/act/rew space's dtype") 14 env.reset() 15 for name, space in zip(['obs', 'act', 'rew'], [env.observation_space, env.action_space, env.reward_space]): 16 if 'float' in repr(space.dtype): 17 assert space.dtype == np.float32, "If float, then must be np.float32, but get {} for {} space".format( 18 space.dtype, name 19 ) 20 if 'int' in repr(space.dtype): 21 assert space.dtype == np.int64, "If int, then must be np.int64, but get {} for {} space".format( 22 space.dtype, name 23 ) 24 25 26# Util function 27def check_array_space(data: Union[np.ndarray, Sequence, Dict], space: Union['Space', Dict], name: str) -> None: 28 if isinstance(data, np.ndarray): 29 # print("{}'s type should be np.ndarray".format(name)) 30 assert data.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(name, data.dtype, space.dtype) 31 assert data.shape == space.shape, "{}'s shape is {}, but requires {}".format(name, data.shape, space.shape) 32 if isinstance(space, Box): 33 assert (space.low <= data).all() and (data <= space.high).all( 34 ), "{}'s value is {}, but requires in range ({},{})".format(name, data, space.low, space.high) 35 elif isinstance(space, (Discrete, MultiDiscrete, MultiBinary)): 36 if isinstance(space, Discrete): 37 assert (data >= space.start) and (data <= space.n) 38 else: 39 assert (data >= 0).all() 40 assert all([d < n for d, n in zip(data, space.nvec)]) 41 elif isinstance(data, Sequence): 42 for i in range(len(data)): 43 try: 44 check_array_space(data[i], space[i], name) 45 except AssertionError as e: 46 print("The following error happens at {}-th index".format(i)) 47 raise e 48 elif isinstance(data, dict): 49 for k in data.keys(): 50 try: 51 check_array_space(data[k], space[k], name) 52 except AssertionError as e: 53 print("The following error happens at key {}".format(k)) 54 raise e 55 else: 56 raise TypeError( 57 "Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(data)) 58 ) 59 60 61def check_reset(env: 'BaseEnv') -> None: 62 print('== 1. Test reset method') 63 obs = env.reset() 64 check_array_space(obs, env.observation_space, 'obs') 65 66 67def check_step(env: 'BaseEnv') -> None: 68 done_times = 0 69 print('== 2. Test step method') 70 _ = env.reset() 71 if hasattr(env, "random_action"): 72 random_action = env.random_action() 73 else: 74 random_action = env.action_space.sample() 75 while True: 76 obs, rew, done, info = env.step(random_action) 77 for ndarray, space, name in zip([obs, rew], [env.observation_space, env.reward_space], ['obs', 'rew']): 78 check_array_space(ndarray, space, name) 79 if done: 80 assert 'eval_episode_return' in info, "info dict should have 'eval_episode_return' key." 81 done_times += 1 82 _ = env.reset() 83 if done_times == 3: 84 break 85 86 87# Util function 88def check_different_memory( 89 array1: Union[np.ndarray, Sequence, Dict], array2: Union[np.ndarray, Sequence, Dict], step_times: int 90) -> None: 91 assert type(array1) == type( 92 array2 93 ), "In step times {}, obs_last_frame({}) and obs_this_frame({}) are not of the same type".format( 94 step_times, type(array1), type(array2) 95 ) 96 if isinstance(array1, np.ndarray): 97 assert id(array1) != id( 98 array2 99 ), "In step times {}, obs_last_frame and obs_this_frame are the same np.ndarray".format(step_times) 100 elif isinstance(array1, Sequence): 101 assert len(array1) == len( 102 array2 103 ), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have different sequence lengths".format( 104 step_times, len(array1), len(array2) 105 ) 106 for i in range(len(array1)): 107 try: 108 check_different_memory(array1[i], array2[i], step_times) 109 except AssertionError as e: 110 print("The following error happens at {}-th index".format(i)) 111 raise e 112 elif isinstance(array1, dict): 113 assert array1.keys() == array2.keys(), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have \ 114 different dict keys".format(step_times, array1.keys(), array2.keys()) 115 for k in array1.keys(): 116 try: 117 check_different_memory(array1[k], array2[k], step_times) 118 except AssertionError as e: 119 print("The following error happens at key {}".format(k)) 120 raise e 121 else: 122 raise TypeError( 123 "Input array should be np.ndarray or list/dict of np.ndarray, but found {} and {}".format( 124 type(array1), type(array2) 125 ) 126 ) 127 128 129def check_obs_deepcopy(env: 'BaseEnv') -> None: 130 131 step_times = 0 132 print('== 3. Test observation deepcopy') 133 obs_1 = env.reset() 134 if hasattr(env, "random_action"): 135 random_action = env.random_action() 136 else: 137 random_action = env.action_space.sample() 138 while True: 139 step_times += 1 140 obs_2, _, done, _ = env.step(random_action) 141 check_different_memory(obs_1, obs_2, step_times) 142 obs_1 = obs_2 143 if done: 144 break 145 146 147def check_all(env: 'BaseEnv') -> None: 148 check_space_dtype(env) 149 check_reset(env) 150 check_step(env) 151 check_obs_deepcopy(env) 152 153 154def demonstrate_correct_procedure(env_fn: Callable[[Dict], 'BaseEnv']) -> None: 155 print('== 4. Demonstrate the correct procudures') 156 done_times = 0 157 # Init the env. 158 env = env_fn({}) 159 # Lazy init. The real env is not initialized until `reset` method is called 160 assert not hasattr(env, "_env") 161 # Must set seed before `reset` method is called. 162 env.seed(4) 163 assert env._seed == 4 164 # Reset the env. The real env is initialized here. 165 obs = env.reset() 166 while True: 167 # Using the policy to get the action from obs. But here we use `random_action` instead. 168 action = env.random_action() 169 obs, rew, done, info = env.step(action) 170 if done: 171 assert 'eval_episode_return' in info 172 done_times += 1 173 obs = env.reset() 174 # Seed will not change unless `seed` method is called again. 175 assert env._seed == 4 176 if done_times == 3: 177 break 178 179 180if __name__ == "__main__": 181 ''' 182 # Methods `check_*` are for user to check whether their implemented env obeys DI-engine's rules. 183 # You can replace `AtariEnv` with your own env. 184 atari_env = AtariEnv(EasyDict(env_id='PongNoFrameskip-v4', frame_stack=4, is_train=False)) 185 check_reset(atari_env) 186 check_step(atari_env) 187 check_obs_deepcopy(atari_env) 188 ''' 189 # Method `demonstrate_correct_procudure` is to demonstrate the correct procedure to 190 # use an env to generate trajectories. 191 # You can check whether your env's design is similar to `DemoEnv` 192 demonstrate_correct_procedure(DemoEnv)