1from collections import defaultdict 2import math 3import queue 4from time import sleep, time 5import gym 6from ding.framework import Supervisor 7from typing import TYPE_CHECKING, Any, List, Union, Dict, Optional, Callable 8from ding.framework.supervisor import ChildType, RecvPayload, SendPayload 9from ding.utils import make_key_as_identifier 10from ditk import logging 11from ding.data import ShmBufferContainer 12import enum 13import treetensor.numpy as tnp 14import numbers 15if TYPE_CHECKING: 16 from gym.spaces import Space 17 18 19class EnvState(enum.IntEnum): 20 """ 21 VOID -> RUN -> DONE 22 """ 23 VOID = 0 24 INIT = 1 25 RUN = 2 26 RESET = 3 27 DONE = 4 28 ERROR = 5 29 NEED_RESET = 6 30 31 32class EnvRetryType(str, enum.Enum): 33 RESET = "reset" 34 RENEW = "renew" 35 36 37class EnvSupervisor(Supervisor): 38 """ 39 Manage multiple envs with supervisor. 40 41 New features (compared to env manager): 42 - Consistent interface in multi-process and multi-threaded mode. 43 - Add asynchronous features and recommend using asynchronous methods. 44 - Reset is performed after an error is encountered in the step method. 45 46 Breaking changes (compared to env manager): 47 - Without some states. 48 """ 49 50 def __init__( 51 self, 52 type_: ChildType = ChildType.PROCESS, 53 env_fn: List[Callable] = None, 54 retry_type: EnvRetryType = EnvRetryType.RESET, 55 max_try: Optional[int] = None, 56 max_retry: Optional[int] = None, 57 auto_reset: bool = True, 58 reset_timeout: Optional[int] = None, 59 step_timeout: Optional[int] = None, 60 retry_waiting_time: Optional[int] = None, 61 episode_num: int = float("inf"), 62 shared_memory: bool = True, 63 copy_on_get: bool = True, 64 **kwargs 65 ) -> None: 66 """ 67 Overview: 68 Supervisor that manage a group of envs. 69 Arguments: 70 - type_ (:obj:`ChildType`): Type of child process. 71 - env_fn (:obj:`List[Callable]`): The function to create environment 72 - retry_type (:obj:`EnvRetryType`): Retry reset or renew env. 73 - max_try (:obj:`EasyDict`): Max try times for reset or step action. 74 - max_retry (:obj:`Optional[int]`): Alias of max_try. 75 - auto_reset (:obj:`bool`): Auto reset env if reach done. 76 - reset_timeout (:obj:`Optional[int]`): Timeout in seconds for reset. 77 - step_timeout (:obj:`Optional[int]`): Timeout in seconds for step. 78 - retry_waiting_time (:obj:`Optional[float]`): Wait time on each retry. 79 - shared_memory (:obj:`bool`): Use shared memory in multiprocessing. 80 - copy_on_get (:obj:`bool`): Use copy on get in multiprocessing. 81 """ 82 if kwargs: 83 logging.warning("Unknown parameters on env supervisor: {}".format(kwargs)) 84 super().__init__(type_=type_) 85 if type_ is not ChildType.PROCESS and (shared_memory or copy_on_get): 86 logging.warning("shared_memory and copy_on_get only works in process mode.") 87 self._shared_memory = type_ is ChildType.PROCESS and shared_memory 88 self._copy_on_get = type_ is ChildType.PROCESS and copy_on_get 89 self._env_fn = env_fn 90 self._create_env_ref() 91 self._obs_buffers = None 92 if env_fn: 93 if self._shared_memory: 94 obs_space = self._observation_space 95 if isinstance(obs_space, gym.spaces.Dict): 96 # For multi_agent case, such as multiagent_mujoco and petting_zoo mpe. 97 # Now only for the case that each agent in the team have the same obs structure 98 # and corresponding shape. 99 shape = {k: v.shape for k, v in obs_space.spaces.items()} 100 dtype = {k: v.dtype for k, v in obs_space.spaces.items()} 101 else: 102 shape = obs_space.shape 103 dtype = obs_space.dtype 104 self._obs_buffers = { 105 env_id: ShmBufferContainer(dtype, shape, copy_on_get=self._copy_on_get) 106 for env_id in range(len(self._env_fn)) 107 } 108 for env_init in env_fn: 109 self.register(env_init, shm_buffer=self._obs_buffers, shm_callback=self._shm_callback) 110 else: 111 for env_init in env_fn: 112 self.register(env_init) 113 self._retry_type = retry_type 114 self._auto_reset = auto_reset 115 if max_retry: 116 logging.warning("The `max_retry` is going to be deprecated, use `max_try` instead!") 117 self._max_try = max_try or max_retry or 1 118 self._reset_timeout = reset_timeout 119 self._step_timeout = step_timeout 120 self._retry_waiting_time = retry_waiting_time 121 self._env_replay_path = None 122 self._episode_num = episode_num 123 self._init_states() 124 125 def _init_states(self): 126 self._env_seed = {} 127 self._env_dynamic_seed = None 128 self._env_replay_path = None 129 self._env_states = {} 130 self._reset_param = {} 131 self._ready_obs = {} 132 self._env_episode_count = {i: 0 for i in range(self.env_num)} 133 self._retry_times = defaultdict(lambda: 0) 134 self._last_called = defaultdict(lambda: {"step": math.inf, "reset": math.inf}) 135 136 def _shm_callback(self, payload: RecvPayload, obs_buffers: Any): 137 """ 138 Overview: 139 This method will be called in child worker, so we can put large data into shared memory 140 and replace the original payload data to none, then reduce the serialization/deserialization cost. 141 """ 142 if payload.method == "reset" and payload.data is not None: 143 obs_buffers[payload.proc_id].fill(payload.data) 144 payload.data = None 145 elif payload.method == "step" and payload.data is not None: 146 obs_buffers[payload.proc_id].fill(payload.data.obs) 147 payload.data._replace(obs=None) 148 149 def _create_env_ref(self): 150 # env_ref is used to acquire some common attributes of env, like obs_shape and act_shape 151 self._env_ref = self._env_fn[0]() 152 self._env_ref.reset() 153 self._observation_space = self._env_ref.observation_space 154 self._action_space = self._env_ref.action_space 155 self._reward_space = self._env_ref.reward_space 156 self._env_ref.close() 157 158 def step(self, actions: Union[Dict[int, List[Any]], List[Any]], block: bool = True) -> Optional[List[tnp.ndarray]]: 159 """ 160 Overview: 161 Execute env step according to input actions. And reset an env if done. 162 Arguments: 163 - actions (:obj:`List[tnp.ndarray]`): Actions came from outer caller like policy, \ 164 in structure of {env_id: actions}. 165 - block (:obj:`bool`): If block, return timesteps, else return none. 166 Returns: 167 - timesteps (:obj:`List[tnp.ndarray]`): Each timestep is a tnp.array with observation, reward, done, \ 168 info, env_id. 169 """ 170 assert not self.closed, "Env supervisor has closed." 171 if isinstance(actions, List): 172 actions = {i: p for i, p in enumerate(actions)} 173 assert actions, "Action is empty!" 174 175 send_payloads = [] 176 177 for env_id, act in actions.items(): 178 payload = SendPayload(proc_id=env_id, method="step", args=[act]) 179 send_payloads.append(payload) 180 self.send(payload) 181 182 if not block: 183 # Retrieve the data for these steps from the recv method 184 return 185 186 # Wait for all steps returns 187 recv_payloads = self.recv_all( 188 send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._step_timeout 189 ) 190 return [payload.data for payload in recv_payloads] 191 192 def recv(self, ignore_err: bool = False) -> RecvPayload: 193 """ 194 Overview: 195 Wait for recv payload, this function will block the thread. 196 Arguments: 197 - ignore_err (:obj:`bool`): If ignore_err is true, payload with error object will be discarded.\ 198 This option will not catch the exception. 199 Returns: 200 - recv_payload (:obj:`RecvPayload`): Recv payload. 201 """ 202 self._detect_timeout() 203 try: 204 payload = super().recv(ignore_err=True, timeout=0.1) 205 payload = self._recv_callback(payload=payload) 206 if payload.err: 207 return self.recv(ignore_err=ignore_err) 208 else: 209 return payload 210 except queue.Empty: 211 return self.recv(ignore_err=ignore_err) 212 213 def _detect_timeout(self): 214 """ 215 Overview: 216 Try to restart all timeout environments if detected timeout. 217 """ 218 for env_id in self._last_called: 219 if self._step_timeout and time() - self._last_called[env_id]["step"] > self._step_timeout: 220 payload = RecvPayload( 221 proc_id=env_id, method="step", err=TimeoutError("Step timeout on env {}".format(env_id)) 222 ) 223 self._recv_queue.put(payload) 224 continue 225 if self._reset_timeout and time() - self._last_called[env_id]["reset"] > self._reset_timeout: 226 payload = RecvPayload( 227 proc_id=env_id, method="reset", err=TimeoutError("Step timeout on env {}".format(env_id)) 228 ) 229 self._recv_queue.put(payload) 230 continue 231 232 @property 233 def env_num(self) -> int: 234 return len(self._children) 235 236 @property 237 def observation_space(self) -> 'Space': 238 return self._observation_space 239 240 @property 241 def action_space(self) -> 'Space': 242 return self._action_space 243 244 @property 245 def reward_space(self) -> 'Space': 246 return self._reward_space 247 248 @property 249 def ready_obs(self) -> tnp.array: 250 """ 251 Overview: 252 Get the ready (next) observation in ``tnp.array`` type, which is uniform for both async/sync scenarios. 253 Return: 254 - ready_obs (:obj:`tnp.array`): A stacked treenumpy-type observation data. 255 Example: 256 >>> obs = env_manager.ready_obs 257 >>> action = model(obs) # model input np obs and output np action 258 >>> timesteps = env_manager.step(action) 259 """ 260 active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] 261 active_env.sort() 262 obs = [self._ready_obs.get(i) for i in active_env] 263 if len(obs) == 0: 264 return tnp.array([]) 265 return tnp.stack(obs) 266 267 @property 268 def ready_obs_id(self) -> List[int]: 269 return [i for i, s in self.env_states.items() if s == EnvState.RUN] 270 271 @property 272 def done(self) -> bool: 273 return all([s == EnvState.DONE for s in self.env_states.values()]) 274 275 @property 276 def method_name_list(self) -> List[str]: 277 return ['reset', 'step', 'seed', 'close', 'enable_save_replay'] 278 279 @property 280 def env_states(self) -> Dict[int, EnvState]: 281 return {env_id: self._env_states.get(env_id) or EnvState.VOID for env_id in range(self.env_num)} 282 283 def env_state_done(self, env_id: int) -> bool: 284 return self.env_states[env_id] == EnvState.DONE 285 286 def launch(self, reset_param: Optional[Dict] = None, block: bool = True) -> None: 287 """ 288 Overview: 289 Set up the environments and their parameters. 290 Arguments: 291 - reset_param (:obj:`Optional[Dict]`): Dict of reset parameters for each environment, key is the env_id, \ 292 value is the cooresponding reset parameters. 293 - block (:obj:`block`): Whether will block the process and wait for reset states. 294 """ 295 assert self.closed, "Please first close the env supervisor before launch it" 296 if reset_param is not None: 297 assert len(reset_param) == self.env_num 298 self.start_link() 299 self._send_seed(self._env_seed, self._env_dynamic_seed, block=block) 300 self.reset(reset_param, block=block) 301 self._enable_env_replay() 302 303 def reset(self, reset_param: Optional[Dict[int, List[Any]]] = None, block: bool = True) -> None: 304 """ 305 Overview: 306 Reset an environment. 307 Arguments: 308 - reset_param (:obj:`Optional[Dict[int, List[Any]]]`): Dict of reset parameters for each environment, \ 309 key is the env_id, value is the cooresponding reset parameters. 310 - block (:obj:`block`): Whether will block the process and wait for reset states. 311 """ 312 if not reset_param: 313 reset_param = {i: {} for i in range(self.env_num)} 314 elif isinstance(reset_param, List): 315 reset_param = {i: p for i, p in enumerate(reset_param)} 316 317 send_payloads = [] 318 319 for env_id, kw_param in reset_param.items(): 320 self._reset_param[env_id] = kw_param # For auto reset 321 send_payloads += self._reset(env_id, kw_param=kw_param) 322 323 if not block: 324 return 325 326 self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout) 327 328 def _recv_callback( 329 self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None 330 ) -> RecvPayload: 331 """ 332 Overview: 333 The callback function for each received payload, within this method will modify the state of \ 334 each environment, replace objects in shared memory, and determine if a retry is needed due to an error. 335 Arguments: 336 - payload (:obj:`RecvPayload`): The received payload. 337 - remain_payloads (:obj:`Optional[Dict[str, SendPayload]]`): The callback may be called many times \ 338 until remain_payloads be cleared, you can append new payload into remain_payloads to call this \ 339 callback recursively. 340 """ 341 self._set_shared_obs(payload=payload) 342 self.change_state(payload=payload) 343 if payload.method == "reset": 344 return self._recv_reset_callback(payload=payload, remain_payloads=remain_payloads) 345 elif payload.method == "step": 346 return self._recv_step_callback(payload=payload, remain_payloads=remain_payloads) 347 return payload 348 349 def _set_shared_obs(self, payload: RecvPayload): 350 if self._obs_buffers is None: 351 return 352 if payload.method == "reset" and payload.err is None: 353 payload.data = self._obs_buffers[payload.proc_id].get() 354 elif payload.method == "step" and payload.err is None: 355 payload.data._replace(obs=self._obs_buffers[payload.proc_id].get()) 356 357 def _recv_reset_callback( 358 self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None 359 ) -> RecvPayload: 360 assert payload.method == "reset", "Recv error callback({}) in reset callback!".format(payload.method) 361 if remain_payloads is None: 362 remain_payloads = {} 363 env_id = payload.proc_id 364 if payload.err: 365 self._retry_times[env_id] += 1 366 if self._retry_times[env_id] > self._max_try - 1: 367 self.shutdown(5) 368 raise RuntimeError( 369 "Env {} reset has exceeded max_try({}), and the latest exception is: {}".format( 370 env_id, self._max_try, payload.err 371 ) 372 ) 373 if self._retry_waiting_time: 374 sleep(self._retry_waiting_time) 375 if self._retry_type == EnvRetryType.RENEW: 376 self._children[env_id].restart() 377 send_payloads = self._reset(env_id) 378 for p in send_payloads: 379 remain_payloads[p.req_id] = p 380 else: 381 self._retry_times[env_id] = 0 382 self._ready_obs[env_id] = payload.data 383 return payload 384 385 def _recv_step_callback( 386 self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None 387 ) -> RecvPayload: 388 assert payload.method == "step", "Recv error callback({}) in step callback!".format(payload.method) 389 if remain_payloads is None: 390 remain_payloads = {} 391 if payload.err: 392 send_payloads = self._reset(payload.proc_id) 393 for p in send_payloads: 394 remain_payloads[p.req_id] = p 395 info = {"abnormal": True, "err": payload.err} 396 payload.data = tnp.array( 397 { 398 'obs': None, 399 'reward': None, 400 'done': None, 401 'info': info, 402 'env_id': payload.proc_id 403 } 404 ) 405 else: 406 obs, reward, done, info, *_ = payload.data 407 if done: 408 self._env_episode_count[payload.proc_id] += 1 409 if self._env_episode_count[payload.proc_id] < self._episode_num and self._auto_reset: 410 send_payloads = self._reset(payload.proc_id) 411 for p in send_payloads: 412 remain_payloads[p.req_id] = p 413 # make the type and content of key as similar as identifier, 414 # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info 415 info = make_key_as_identifier(info) 416 payload.data = tnp.array( 417 { 418 'obs': obs, 419 'reward': reward, 420 'done': done, 421 'info': info, 422 'env_id': payload.proc_id 423 } 424 ) 425 self._ready_obs[payload.proc_id] = obs 426 return payload 427 428 def _reset(self, env_id: int, kw_param: Optional[Dict[str, Any]] = None) -> List[SendPayload]: 429 """ 430 Overview: 431 Reset an environment. This method does not wait for the result to be returned. 432 Arguments: 433 - env_id (:obj:`int`): Environment id. 434 - kw_param (:obj:`Optional[Dict[str, Any]]`): Reset parameters for the environment. 435 Returns: 436 - send_payloads (:obj:`List[SendPayload]`): The request payloads for seed and reset actions. 437 """ 438 assert not self.closed, "Env supervisor has closed." 439 send_payloads = [] 440 kw_param = kw_param or self._reset_param[env_id] 441 442 if self._env_replay_path is not None and self.env_states[env_id] == EnvState.RUN: 443 logging.warning("Please don't reset an unfinished env when you enable save replay, we just skip it") 444 return send_payloads 445 446 # Reset env 447 payload = SendPayload(proc_id=env_id, method="reset", kwargs=kw_param) 448 send_payloads.append(payload) 449 self.send(payload) 450 451 return send_payloads 452 453 def _send_seed(self, env_seed: Dict[int, int], env_dynamic_seed: Optional[bool] = None, block: bool = True) -> None: 454 send_payloads = [] 455 for env_id, seed in env_seed.items(): 456 if seed is None: 457 continue 458 args = [seed] 459 if env_dynamic_seed is not None: 460 args.append(env_dynamic_seed) 461 payload = SendPayload(proc_id=env_id, method="seed", args=args) 462 send_payloads.append(payload) 463 self.send(payload) 464 if not block or not send_payloads: 465 return 466 self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout) 467 468 def change_state(self, payload: RecvPayload): 469 self._last_called[payload.proc_id][payload.method] = math.inf # Have recevied 470 if payload.err: 471 self._env_states[payload.proc_id] = EnvState.ERROR 472 elif payload.method == "reset": 473 self._env_states[payload.proc_id] = EnvState.RUN 474 elif payload.method == "step": 475 if payload.data[2]: 476 self._env_states[payload.proc_id] = EnvState.DONE 477 478 def send(self, payload: SendPayload) -> None: 479 self._last_called[payload.proc_id][payload.method] = time() 480 return super().send(payload) 481 482 def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: Optional[bool] = None) -> None: 483 """ 484 Overview: 485 Set the seed for each environment. The seed function will not be called until supervisor.launch \ 486 was called. 487 Arguments: 488 - seed (:obj:`Union[Dict[int, int], List[int], int]`): List of seeds for each environment; \ 489 Or one seed for the first environment and other seeds are generated automatically. \ 490 Note that in threading mode, no matter how many seeds are given, only the last one will take effect. \ 491 Because the execution in the thread is asynchronous, the results of each experiment \ 492 are different even if a fixed seed is used. 493 - dynamic_seed (:obj:`Optional[bool]`): Dynamic seed is used in the training environment, \ 494 trying to make the random seed of each episode different, they are all generated in the reset \ 495 method by a random generator 100 * np.random.randint(1 , 1000) (but the seed of this random \ 496 number generator is fixed by the environmental seed method, guranteeing the reproducibility \ 497 of the experiment). You need not pass the dynamic_seed parameter in the seed method, or pass \ 498 the parameter as True. 499 """ 500 self._env_seed = {} 501 if isinstance(seed, numbers.Integral): 502 self._env_seed = {i: seed + i for i in range(self.env_num)} 503 elif isinstance(seed, list): 504 assert len(seed) == self.env_num, "len(seed) {:d} != env_num {:d}".format(len(seed), self.env_num) 505 self._env_seed = {i: _seed for i, _seed in enumerate(seed)} 506 elif isinstance(seed, dict): 507 self._env_seed = {env_id: s for env_id, s in seed.items()} 508 else: 509 raise TypeError("Invalid seed arguments type: {}".format(type(seed))) 510 self._env_dynamic_seed = dynamic_seed 511 512 def enable_save_replay(self, replay_path: Union[List[str], str]) -> None: 513 """ 514 Overview: 515 Set each env's replay save path. 516 Arguments: 517 - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \ 518 Or one path for all environments. 519 """ 520 if isinstance(replay_path, str): 521 replay_path = [replay_path] * self.env_num 522 self._env_replay_path = replay_path 523 524 def _enable_env_replay(self): 525 if self._env_replay_path is None: 526 return 527 send_payloads = [] 528 for env_id, s in enumerate(self._env_replay_path): 529 payload = SendPayload(proc_id=env_id, method="enable_save_replay", args=[s]) 530 send_payloads.append(payload) 531 self.send(payload) 532 self.recv_all(send_payloads=send_payloads) 533 534 def __getattr__(self, key: str) -> List[Any]: 535 if not hasattr(self._env_ref, key): 536 raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key)) 537 return super().__getattr__(key) 538 539 def close(self, timeout: Optional[float] = None) -> None: 540 """ 541 In order to be compatible with BaseEnvManager, the new version can use `shutdown` directly. 542 """ 543 self.shutdown(timeout=timeout) 544 545 def shutdown(self, timeout: Optional[float] = None) -> None: 546 if self._running: 547 send_payloads = [] 548 for env_id in range(self.env_num): 549 payload = SendPayload(proc_id=env_id, method="close") 550 send_payloads.append(payload) 551 self.send(payload) 552 self.recv_all(send_payloads=send_payloads, ignore_err=True, timeout=timeout) 553 super().shutdown(timeout=timeout) 554 self._init_states() 555 556 @property 557 def closed(self) -> bool: 558 return not self._running