Skip to content

ding.envs.env_manager.subprocess_env_manager

ding.envs.env_manager.subprocess_env_manager

AsyncSubprocessEnvManager

Bases: BaseEnvManager

Overview

Create an AsyncSubprocessEnvManager to manage multiple environments. Each Environment is run by a respective subprocess.

Interfaces: seed, launch, ready_obs, step, reset, active_env

ready_obs property

Overview

Get the next observations.

Return: A dictionary with observations and their environment IDs. Note: The observations are returned in np.ndarray. Example: >>> obs_dict = env_manager.ready_obs >>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())}

ready_imgs property

Overview

Get the next renderd frames.

Return: A dictionary with rendered frames and their environment IDs. Note: The rendered frames are returned in np.ndarray.

__init__(env_fn, cfg=EasyDict({}))

Overview

Initialize the AsyncSubprocessEnvManager.

Arguments: - env_fn (:obj:List[Callable]): The function to create environment - cfg (:obj:EasyDict): Config

.. note::

- wait_num: for each time the minimum number of env return to gather
- step_wait_timeout: for each time the minimum number of env return to gather

launch(reset_param=None)

Overview

Set up the environments and their parameters.

Arguments: - reset_param (:obj:Optional[Dict]): Dict of reset parameters for each environment, key is the env_id, value is the cooresponding reset parameters.

reset(reset_param=None)

Overview

Reset the environments their parameters.

Arguments: - reset_param (:obj:List): Dict of reset parameters for each environment, key is the env_id, value is the cooresponding reset parameters.

step(actions)

Overview

Step all environments. Reset an env if done.

Arguments: - actions (:obj:Dict[int, Any]): {env_id: action} Returns: - timesteps (:obj:Dict[int, namedtuple]): {env_id: timestep}. Timestep is a BaseEnvTimestep tuple with observation, reward, done, env_info. Example: >>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())} >>> timesteps = env_manager.step(actions_dict): >>> for env_id, timestep in timesteps.items(): >>> pass

.. note:

- The env_id that appears in ``actions`` will also be returned in ``timesteps``.
- Each environment is run by a subprocess separately. Once an environment is done, it is reset immediately.
- Async subprocess env manager use ``connection.wait`` to poll.

worker_fn(p, c, env_fn_wrapper, obs_buffer, method_name_list, reset_inplace=False) staticmethod

Overview

Subprocess's target function to run.

worker_fn_robust(parent, child, env_fn_wrapper, obs_buffer, method_name_list, reset_timeout=None, step_timeout=None, reset_inplace=False) staticmethod

Overview

A more robust version of subprocess's target function to run. Used by default.

enable_save_replay(replay_path)

Overview

Set each env's replay save path.

Arguments: - replay_path (:obj:Union[List[str], str]): List of paths for each environment; Or one path for all environments.

close()

Overview

CLose the env manager and release all related resources.

wait(rest_conn, wait_num, timeout=None) staticmethod

Overview

Wait at least enough(len(ready_conn) >= wait_num) connections within timeout constraint. If timeout is None and wait_num == len(ready_conn), means sync mode; If timeout is not None, will return when len(ready_conn) >= wait_num and this method takes more than timeout seconds.

SyncSubprocessEnvManager

Bases: AsyncSubprocessEnvManager

step(actions)

Overview

Step all environments. Reset an env if done.

Arguments: - actions (:obj:Dict[int, Any]): {env_id: action} Returns: - timesteps (:obj:Dict[int, namedtuple]): {env_id: timestep}. Timestep is a BaseEnvTimestep tuple with observation, reward, done, env_info. Example: >>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())} >>> timesteps = env_manager.step(actions_dict): >>> for env_id, timestep in timesteps.items(): >>> pass

.. note::

- The env_id that appears in ``actions`` will also be returned in ``timesteps``.
- Each environment is run by a subprocess separately. Once an environment is done, it is reset immediately.

SubprocessEnvManagerV2

Bases: SyncSubprocessEnvManager

Overview

SyncSubprocessEnvManager for new task pipeline and interfaces coupled with treetensor.

ready_obs property

Overview

Get the ready (next) observation in tnp.array type, which is uniform for both async/sync scenarios.

Return: - ready_obs (:obj:tnp.array): A stacked treenumpy-type observation data. Example: >>> obs = env_manager.ready_obs >>> action = model(obs) # model input np obs and output np action >>> timesteps = env_manager.step(action)

step(actions)

Overview

Execute env step according to input actions. And reset an env if done.

Arguments: - actions (:obj:Union[List[tnp.ndarray], tnp.ndarray]): actions came from outer caller like policy. Returns: - timesteps (:obj:List[tnp.ndarray]): Each timestep is a tnp.array with observation, reward, done, info, env_id.

Full Source Code

../ding/envs/env_manager/subprocess_env_manager.py

1from typing import Any, Union, List, Tuple, Dict, Callable, Optional 2from multiprocessing import connection, get_context 3from collections import namedtuple 4from ditk import logging 5import platform 6import time 7import copy 8import errno 9import gymnasium 10import gym 11import traceback 12import torch 13import pickle 14import numpy as np 15import treetensor.numpy as tnp 16from easydict import EasyDict 17from types import MethodType 18from ding.data import ShmBufferContainer, ShmBuffer 19 20from ding.envs.env import BaseEnvTimestep 21from ding.utils import PropagatingThread, LockContextType, LockContext, ENV_MANAGER_REGISTRY, make_key_as_identifier, \ 22 remove_illegal_item, CloudPickleWrapper 23from .base_env_manager import BaseEnvManager, EnvState, timeout_wrapper 24 25 26def is_abnormal_timestep(timestep: namedtuple) -> bool: 27 if isinstance(timestep.info, dict): 28 return timestep.info.get('abnormal', False) 29 elif isinstance(timestep.info, list) or isinstance(timestep.info, tuple): 30 return timestep.info[0].get('abnormal', False) or timestep.info[1].get('abnormal', False) 31 else: 32 raise TypeError("invalid env timestep type: {}".format(type(timestep.info))) 33 34 35@ENV_MANAGER_REGISTRY.register('async_subprocess') 36class AsyncSubprocessEnvManager(BaseEnvManager): 37 """ 38 Overview: 39 Create an AsyncSubprocessEnvManager to manage multiple environments. 40 Each Environment is run by a respective subprocess. 41 Interfaces: 42 seed, launch, ready_obs, step, reset, active_env 43 """ 44 45 config = dict( 46 episode_num=float("inf"), 47 max_retry=1, 48 step_timeout=None, 49 auto_reset=True, 50 retry_type='reset', 51 reset_timeout=None, 52 retry_waiting_time=0.1, 53 # subprocess specified args 54 shared_memory=True, 55 copy_on_get=True, 56 context='spawn' if platform.system().lower() == 'windows' else 'fork', 57 wait_num=2, 58 step_wait_timeout=0.01, 59 connect_timeout=60, 60 reset_inplace=False, 61 ) 62 63 def __init__( 64 self, 65 env_fn: List[Callable], 66 cfg: EasyDict = EasyDict({}), 67 ) -> None: 68 """ 69 Overview: 70 Initialize the AsyncSubprocessEnvManager. 71 Arguments: 72 - env_fn (:obj:`List[Callable]`): The function to create environment 73 - cfg (:obj:`EasyDict`): Config 74 75 .. note:: 76 77 - wait_num: for each time the minimum number of env return to gather 78 - step_wait_timeout: for each time the minimum number of env return to gather 79 """ 80 super().__init__(env_fn, cfg) 81 self._shared_memory = self._cfg.shared_memory 82 self._copy_on_get = self._cfg.copy_on_get 83 self._context = self._cfg.context 84 self._wait_num = self._cfg.wait_num 85 self._step_wait_timeout = self._cfg.step_wait_timeout 86 87 self._lock = LockContext(LockContextType.THREAD_LOCK) 88 self._connect_timeout = self._cfg.connect_timeout 89 self._async_args = { 90 'step': { 91 'wait_num': min(self._wait_num, self._env_num), 92 'timeout': self._step_wait_timeout 93 } 94 } 95 self._reset_inplace = self._cfg.reset_inplace 96 if not self._auto_reset: 97 assert not self._reset_inplace, "reset_inplace is unavailable when auto_reset=False." 98 99 def _create_state(self) -> None: 100 r""" 101 Overview: 102 Fork/spawn sub-processes(Call ``_create_env_subprocess``) and create pipes to transfer the data. 103 """ 104 self._env_episode_count = {env_id: 0 for env_id in range(self.env_num)} 105 self._ready_obs = {env_id: None for env_id in range(self.env_num)} 106 self._reset_param = {i: {} for i in range(self.env_num)} 107 if self._shared_memory: 108 obs_space = self._observation_space 109 if isinstance(obs_space, (gym.spaces.Dict, gymnasium.spaces.Dict)): 110 # For multi_agent case, such as multiagent_mujoco and petting_zoo mpe. 111 # Now only for the case that each agent in the team have the same obs structure 112 # and corresponding shape. 113 shape = {k: v.shape for k, v in obs_space.spaces.items()} 114 dtype = {k: v.dtype for k, v in obs_space.spaces.items()} 115 else: 116 shape = obs_space.shape 117 dtype = obs_space.dtype 118 self._obs_buffers = { 119 env_id: ShmBufferContainer(dtype, shape, copy_on_get=self._copy_on_get) 120 for env_id in range(self.env_num) 121 } 122 else: 123 self._obs_buffers = {env_id: None for env_id in range(self.env_num)} 124 self._pipe_parents, self._pipe_children = {}, {} 125 self._subprocesses = {} 126 for env_id in range(self.env_num): 127 self._create_env_subprocess(env_id) 128 self._waiting_env = {'step': set()} 129 self._closed = False 130 131 def _create_env_subprocess(self, env_id): 132 # start a new one 133 ctx = get_context(self._context) 134 self._pipe_parents[env_id], self._pipe_children[env_id] = ctx.Pipe() 135 self._subprocesses[env_id] = ctx.Process( 136 # target=self.worker_fn, 137 target=self.worker_fn_robust, 138 args=( 139 self._pipe_parents[env_id], 140 self._pipe_children[env_id], 141 CloudPickleWrapper(self._env_fn[env_id]), 142 self._obs_buffers[env_id], 143 self.method_name_list, 144 self._reset_timeout, 145 self._step_timeout, 146 self._reset_inplace, 147 ), 148 daemon=True, 149 name='subprocess_env_manager{}_{}'.format(env_id, time.time()) 150 ) 151 self._subprocesses[env_id].start() 152 self._pipe_children[env_id].close() 153 self._env_states[env_id] = EnvState.INIT 154 155 if self._env_replay_path is not None: 156 self._pipe_parents[env_id].send(['enable_save_replay', [self._env_replay_path[env_id]], {}]) 157 self._pipe_parents[env_id].recv() 158 159 @property 160 def ready_env(self) -> List[int]: 161 active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] 162 return [i for i in active_env if i not in self._waiting_env['step']] 163 164 @property 165 def ready_obs(self) -> Dict[int, Any]: 166 """ 167 Overview: 168 Get the next observations. 169 Return: 170 A dictionary with observations and their environment IDs. 171 Note: 172 The observations are returned in np.ndarray. 173 Example: 174 >>> obs_dict = env_manager.ready_obs 175 >>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())} 176 """ 177 no_done_env_idx = [i for i, s in self._env_states.items() if s != EnvState.DONE] 178 sleep_count = 0 179 while not any([self._env_states[i] == EnvState.RUN for i in no_done_env_idx]): 180 if sleep_count != 0 and sleep_count % 10000 == 0: 181 logging.warning( 182 'VEC_ENV_MANAGER: all the not done envs are resetting, sleep {} times'.format(sleep_count) 183 ) 184 time.sleep(0.001) 185 sleep_count += 1 186 return {i: self._ready_obs[i] for i in self.ready_env} 187 188 @property 189 def ready_imgs(self, render_mode: Optional[str] = 'rgb_array') -> Dict[int, Any]: 190 """ 191 Overview: 192 Get the next renderd frames. 193 Return: 194 A dictionary with rendered frames and their environment IDs. 195 Note: 196 The rendered frames are returned in np.ndarray. 197 """ 198 for i in self.ready_env: 199 self._pipe_parents[i].send(['render', None, {'render_mode': render_mode}]) 200 data = {i: self._pipe_parents[i].recv() for i in self.ready_env} 201 self._check_data(data) 202 return data 203 204 def launch(self, reset_param: Optional[Dict] = None) -> None: 205 """ 206 Overview: 207 Set up the environments and their parameters. 208 Arguments: 209 - reset_param (:obj:`Optional[Dict]`): Dict of reset parameters for each environment, key is the env_id, \ 210 value is the cooresponding reset parameters. 211 """ 212 assert self._closed, "please first close the env manager" 213 if reset_param is not None: 214 assert len(reset_param) == len(self._env_fn) 215 self._create_state() 216 self.reset(reset_param) 217 218 def reset(self, reset_param: Optional[Dict] = None) -> None: 219 """ 220 Overview: 221 Reset the environments their parameters. 222 Arguments: 223 - reset_param (:obj:`List`): Dict of reset parameters for each environment, key is the env_id, \ 224 value is the cooresponding reset parameters. 225 """ 226 self._check_closed() 227 228 if reset_param is None: 229 reset_env_list = [env_id for env_id in range(self._env_num)] 230 else: 231 reset_env_list = reset_param.keys() 232 for env_id in reset_param: 233 self._reset_param[env_id] = reset_param[env_id] 234 235 # clear previous info 236 for env_id in reset_env_list: 237 if env_id in self._waiting_env['step']: 238 self._pipe_parents[env_id].recv() 239 self._waiting_env['step'].remove(env_id) 240 241 sleep_count = 0 242 while any([self._env_states[i] == EnvState.RESET for i in reset_env_list]): 243 if sleep_count != 0 and sleep_count % 10000 == 0: 244 logging.warning( 245 'VEC_ENV_MANAGER: not all the envs finish resetting, sleep {} times'.format(sleep_count) 246 ) 247 time.sleep(0.001) 248 sleep_count += 1 249 250 # reset env 251 reset_thread_list = [] 252 for i, env_id in enumerate(reset_env_list): 253 # set seed 254 if self._env_seed[env_id] is not None: 255 try: 256 if self._env_dynamic_seed is not None: 257 self._pipe_parents[env_id].send(['seed', [self._env_seed[env_id], self._env_dynamic_seed], {}]) 258 else: 259 self._pipe_parents[env_id].send(['seed', [self._env_seed[env_id]], {}]) 260 ret = self._pipe_parents[env_id].recv() 261 self._check_data({env_id: ret}) 262 self._env_seed[env_id] = None # seed only use once 263 except BaseException as e: 264 logging.warning( 265 "subprocess reset set seed failed, ignore and continue... \n subprocess exception traceback: \n" 266 + traceback.format_exc() 267 ) 268 self._env_states[env_id] = EnvState.RESET 269 reset_thread = PropagatingThread(target=self._reset, args=(env_id, )) 270 reset_thread.daemon = True 271 reset_thread_list.append(reset_thread) 272 273 for t in reset_thread_list: 274 t.start() 275 for t in reset_thread_list: 276 t.join() 277 278 def _reset(self, env_id: int) -> None: 279 280 def reset_fn(): 281 if self._pipe_parents[env_id].poll(): 282 recv_data = self._pipe_parents[env_id].recv() 283 raise RuntimeError("unread data left before sending to the pipe: {}".format(repr(recv_data))) 284 # if self._reset_param[env_id] is None, just reset specific env, not pass reset param 285 if self._reset_param[env_id] is not None: 286 assert isinstance(self._reset_param[env_id], dict), type(self._reset_param[env_id]) 287 self._pipe_parents[env_id].send(['reset', [], self._reset_param[env_id]]) 288 else: 289 self._pipe_parents[env_id].send(['reset', [], None]) 290 291 if not self._pipe_parents[env_id].poll(self._connect_timeout): 292 raise ConnectionError("env reset connection timeout") # Leave it to try again 293 294 obs = self._pipe_parents[env_id].recv() 295 self._check_data({env_id: obs}, close=False) 296 if self._shared_memory: 297 obs = self._obs_buffers[env_id].get() 298 # it is necessary to add lock for the updates of env_state 299 with self._lock: 300 self._env_states[env_id] = EnvState.RUN 301 self._ready_obs[env_id] = obs 302 303 exceptions = [] 304 for _ in range(self._max_retry): 305 try: 306 reset_fn() 307 return 308 except BaseException as e: 309 # During teardown, reset threads may race with ``close`` and hit closed pipes. 310 if self._closed and isinstance(e, (OSError, EOFError, BrokenPipeError, ConnectionResetError)): 311 return 312 if isinstance(e, OSError) and getattr(e, 'errno', None) == errno.EBADF and self._closed: 313 return 314 logging.info("subprocess exception traceback: \n" + traceback.format_exc()) 315 if self._retry_type == 'renew' or isinstance(e, pickle.UnpicklingError): 316 self._pipe_parents[env_id].close() 317 if self._subprocesses[env_id].is_alive(): 318 self._subprocesses[env_id].terminate() 319 self._create_env_subprocess(env_id) 320 exceptions.append(e) 321 time.sleep(self._retry_waiting_time) 322 323 logging.error("Env {} reset has exceeded max retries({})".format(env_id, self._max_retry)) 324 runtime_error = RuntimeError( 325 "Env {} reset has exceeded max retries({}), and the latest exception is: {}".format( 326 env_id, self._max_retry, str(exceptions[-1]) 327 ) 328 ) 329 runtime_error.__traceback__ = exceptions[-1].__traceback__ 330 if self._closed: # exception cased by main thread closing parent_remote 331 return 332 else: 333 self.close() 334 raise runtime_error 335 336 def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]: 337 """ 338 Overview: 339 Step all environments. Reset an env if done. 340 Arguments: 341 - actions (:obj:`Dict[int, Any]`): {env_id: action} 342 Returns: 343 - timesteps (:obj:`Dict[int, namedtuple]`): {env_id: timestep}. Timestep is a \ 344 ``BaseEnvTimestep`` tuple with observation, reward, done, env_info. 345 Example: 346 >>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())} 347 >>> timesteps = env_manager.step(actions_dict): 348 >>> for env_id, timestep in timesteps.items(): 349 >>> pass 350 351 .. note: 352 353 - The env_id that appears in ``actions`` will also be returned in ``timesteps``. 354 - Each environment is run by a subprocess separately. Once an environment is done, it is reset immediately. 355 - Async subprocess env manager use ``connection.wait`` to poll. 356 """ 357 self._check_closed() 358 env_ids = list(actions.keys()) 359 assert all([self._env_states[env_id] == EnvState.RUN for env_id in env_ids] 360 ), 'current env state are: {}, please check whether the requested env is in reset or done'.format( 361 {env_id: self._env_states[env_id] 362 for env_id in env_ids} 363 ) 364 365 for env_id, act in actions.items(): 366 self._pipe_parents[env_id].send(['step', [act], None]) 367 368 timesteps = {} 369 step_args = self._async_args['step'] 370 wait_num, timeout = min(step_args['wait_num'], len(env_ids)), step_args['timeout'] 371 rest_env_ids = list(set(env_ids).union(self._waiting_env['step'])) 372 ready_env_ids = [] 373 cur_rest_env_ids = copy.deepcopy(rest_env_ids) 374 while True: 375 rest_conn = [self._pipe_parents[env_id] for env_id in cur_rest_env_ids] 376 ready_conn, ready_ids = AsyncSubprocessEnvManager.wait(rest_conn, min(wait_num, len(rest_conn)), timeout) 377 cur_ready_env_ids = [cur_rest_env_ids[env_id] for env_id in ready_ids] 378 assert len(cur_ready_env_ids) == len(ready_conn) 379 # timesteps.update({env_id: p.recv() for env_id, p in zip(cur_ready_env_ids, ready_conn)}) 380 for env_id, p in zip(cur_ready_env_ids, ready_conn): 381 try: 382 timesteps.update({env_id: p.recv()}) 383 except pickle.UnpicklingError as e: 384 timestep = BaseEnvTimestep(None, None, None, {'abnormal': True}) 385 timesteps.update({env_id: timestep}) 386 self._pipe_parents[env_id].close() 387 if self._subprocesses[env_id].is_alive(): 388 self._subprocesses[env_id].terminate() 389 self._create_env_subprocess(env_id) 390 self._check_data(timesteps) 391 ready_env_ids += cur_ready_env_ids 392 cur_rest_env_ids = list(set(cur_rest_env_ids).difference(set(cur_ready_env_ids))) 393 # At least one not done env timestep, or all envs' steps are finished 394 if any([not t.done for t in timesteps.values()]) or len(ready_conn) == len(rest_conn): 395 break 396 self._waiting_env['step']: set 397 for env_id in rest_env_ids: 398 if env_id in ready_env_ids: 399 if env_id in self._waiting_env['step']: 400 self._waiting_env['step'].remove(env_id) 401 else: 402 self._waiting_env['step'].add(env_id) 403 404 if self._shared_memory: 405 for i, (env_id, timestep) in enumerate(timesteps.items()): 406 timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get()) 407 408 for env_id, timestep in timesteps.items(): 409 if is_abnormal_timestep(timestep): 410 self._env_states[env_id] = EnvState.ERROR 411 continue 412 if timestep.done: 413 self._env_episode_count[env_id] += 1 414 if self._env_episode_count[env_id] < self._episode_num: 415 if self._auto_reset: 416 if self._reset_inplace: # reset in subprocess at once 417 self._env_states[env_id] = EnvState.RUN 418 self._ready_obs[env_id] = timestep.obs 419 else: 420 # in this case, ready_obs is updated in ``self._reset`` 421 self._env_states[env_id] = EnvState.RESET 422 reset_thread = PropagatingThread(target=self._reset, args=(env_id, ), name='regular_reset') 423 reset_thread.daemon = True 424 reset_thread.start() 425 else: 426 # in the case that auto_reset=False, caller should call ``env_manager.reset`` manually 427 self._env_states[env_id] = EnvState.NEED_RESET 428 else: 429 self._env_states[env_id] = EnvState.DONE 430 else: 431 self._ready_obs[env_id] = timestep.obs 432 return timesteps 433 434 # This method must be staticmethod, otherwise there will be some resource conflicts(e.g. port or file) 435 # Env must be created in worker, which is a trick of avoiding env pickle errors. 436 # A more robust version is used by default. But this one is also preserved. 437 @staticmethod 438 def worker_fn( 439 p: connection.Connection, 440 c: connection.Connection, 441 env_fn_wrapper: 'CloudPickleWrapper', 442 obs_buffer: ShmBuffer, 443 method_name_list: list, 444 reset_inplace: bool = False, 445 ) -> None: # noqa 446 """ 447 Overview: 448 Subprocess's target function to run. 449 """ 450 torch.set_num_threads(1) 451 env_fn = env_fn_wrapper.data 452 env = env_fn() 453 p.close() 454 try: 455 while True: 456 try: 457 cmd, args, kwargs = c.recv() 458 except EOFError: # for the case when the pipe has been closed 459 c.close() 460 break 461 try: 462 if cmd == 'getattr': 463 ret = getattr(env, args[0]) 464 elif cmd in method_name_list: 465 if cmd == 'step': 466 timestep = env.step(*args, **kwargs) 467 if is_abnormal_timestep(timestep): 468 ret = timestep 469 else: 470 if reset_inplace and timestep.done: 471 obs = env.reset() 472 timestep = timestep._replace(obs=obs) 473 if obs_buffer is not None: 474 obs_buffer.fill(timestep.obs) 475 timestep = timestep._replace(obs=None) 476 ret = timestep 477 elif cmd == 'reset': 478 ret = env.reset(*args, **kwargs) # obs 479 if obs_buffer is not None: 480 obs_buffer.fill(ret) 481 ret = None 482 elif args is None and kwargs is None: 483 ret = getattr(env, cmd)() 484 else: 485 ret = getattr(env, cmd)(*args, **kwargs) 486 else: 487 raise KeyError("not support env cmd: {}".format(cmd)) 488 c.send(ret) 489 except Exception as e: 490 # when there are some errors in env, worker_fn will send the errors to env manager 491 # directly send error to another process will lose the stack trace, so we create a new Exception 492 logging.warning("subprocess exception traceback: \n" + traceback.format_exc()) 493 c.send( 494 e.__class__( 495 '\nEnv Process Exception:\n' + ''.join(traceback.format_tb(e.__traceback__)) + repr(e) 496 ) 497 ) 498 if cmd == 'close': 499 c.close() 500 break 501 except KeyboardInterrupt: 502 c.close() 503 504 @staticmethod 505 def worker_fn_robust( 506 parent, 507 child, 508 env_fn_wrapper, 509 obs_buffer, 510 method_name_list, 511 reset_timeout=None, 512 step_timeout=None, 513 reset_inplace=False, 514 ) -> None: 515 """ 516 Overview: 517 A more robust version of subprocess's target function to run. Used by default. 518 """ 519 torch.set_num_threads(1) 520 env_fn = env_fn_wrapper.data 521 env = env_fn() 522 parent.close() 523 524 @timeout_wrapper(timeout=step_timeout) 525 def step_fn(*args, **kwargs): 526 timestep = env.step(*args, **kwargs) 527 if is_abnormal_timestep(timestep): 528 ret = timestep 529 else: 530 if reset_inplace and timestep.done: 531 obs = env.reset() 532 timestep = timestep._replace(obs=obs) 533 if obs_buffer is not None: 534 obs_buffer.fill(timestep.obs) 535 timestep = timestep._replace(obs=None) 536 ret = timestep 537 return ret 538 539 @timeout_wrapper(timeout=reset_timeout) 540 def reset_fn(*args, **kwargs): 541 try: 542 ret = env.reset(*args, **kwargs) 543 if obs_buffer is not None: 544 obs_buffer.fill(ret) 545 ret = None 546 return ret 547 except BaseException as e: 548 logging.warning("subprocess exception traceback: \n" + traceback.format_exc()) 549 env.close() 550 raise e 551 552 while True: 553 try: 554 cmd, args, kwargs = child.recv() 555 except EOFError: # for the case when the pipe has been closed 556 child.close() 557 break 558 try: 559 if cmd == 'getattr': 560 ret = getattr(env, args[0]) 561 elif cmd in method_name_list: 562 if cmd == 'step': 563 ret = step_fn(*args) 564 elif cmd == 'reset': 565 if kwargs is None: 566 kwargs = {} 567 ret = reset_fn(*args, **kwargs) 568 elif cmd == 'render': 569 from ding.utils import render 570 ret = render(env, **kwargs) 571 elif args is None and kwargs is None: 572 ret = getattr(env, cmd)() 573 else: 574 ret = getattr(env, cmd)(*args, **kwargs) 575 else: 576 raise KeyError("not support env cmd: {}".format(cmd)) 577 child.send(ret) 578 except BaseException as e: 579 logging.debug("Sub env '{}' error when executing {}".format(str(env), cmd)) 580 # when there are some errors in env, worker_fn will send the errors to env manager 581 # directly send error to another process will lose the stack trace, so we create a new Exception 582 logging.warning("subprocess exception traceback: \n" + traceback.format_exc()) 583 child.send( 584 e.__class__('\nEnv Process Exception:\n' + ''.join(traceback.format_tb(e.__traceback__)) + repr(e)) 585 ) 586 if cmd == 'close': 587 child.close() 588 break 589 590 def _check_data(self, data: Dict, close: bool = True) -> None: 591 exceptions = [] 592 for i, d in data.items(): 593 if isinstance(d, BaseException): 594 self._env_states[i] = EnvState.ERROR 595 exceptions.append(d) 596 # when receiving env Exception, env manager will safely close and raise this Exception to caller 597 if len(exceptions) > 0: 598 if close: 599 self.close() 600 raise exceptions[0] 601 602 # override 603 def __getattr__(self, key: str) -> Any: 604 self._check_closed() 605 # we suppose that all the envs has the same attributes, if you need different envs, please 606 # create different env managers. 607 if not hasattr(self._env_ref, key): 608 raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key)) 609 if isinstance(getattr(self._env_ref, key), MethodType) and key not in self.method_name_list: 610 raise RuntimeError("env getattr doesn't supports method({}), please override method_name_list".format(key)) 611 for _, p in self._pipe_parents.items(): 612 p.send(['getattr', [key], {}]) 613 data = {i: p.recv() for i, p in self._pipe_parents.items()} 614 self._check_data(data) 615 ret = [data[i] for i in self._pipe_parents.keys()] 616 return ret 617 618 # override 619 def enable_save_replay(self, replay_path: Union[List[str], str]) -> None: 620 """ 621 Overview: 622 Set each env's replay save path. 623 Arguments: 624 - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \ 625 Or one path for all environments. 626 """ 627 if isinstance(replay_path, str): 628 replay_path = [replay_path] * self.env_num 629 self._env_replay_path = replay_path 630 631 # override 632 def close(self) -> None: 633 """ 634 Overview: 635 CLose the env manager and release all related resources. 636 """ 637 if self._closed: 638 return 639 self._closed = True 640 for _, p in self._pipe_parents.items(): 641 try: 642 p.send(['close', None, None]) 643 except (OSError, EOFError, BrokenPipeError): 644 continue 645 for env_id, p in self._pipe_parents.items(): 646 try: 647 if not p.poll(5): 648 continue 649 p.recv() 650 except (OSError, EOFError, BrokenPipeError): 651 continue 652 for i in range(self._env_num): 653 self._env_states[i] = EnvState.VOID 654 # disable process join for avoiding hang 655 # for p in self._subprocesses: 656 # p.join() 657 for _, p in self._subprocesses.items(): 658 p.terminate() 659 for _, p in self._pipe_parents.items(): 660 try: 661 p.close() 662 except OSError: 663 continue 664 665 @staticmethod 666 def wait(rest_conn: list, wait_num: int, timeout: Optional[float] = None) -> Tuple[list, list]: 667 """ 668 Overview: 669 Wait at least enough(len(ready_conn) >= wait_num) connections within timeout constraint. 670 If timeout is None and wait_num == len(ready_conn), means sync mode; 671 If timeout is not None, will return when len(ready_conn) >= wait_num and 672 this method takes more than timeout seconds. 673 """ 674 assert 1 <= wait_num <= len(rest_conn 675 ), 'please indicate proper wait_num: <wait_num: {}, rest_conn_num: {}>'.format( 676 wait_num, len(rest_conn) 677 ) 678 rest_conn_set = set(rest_conn) 679 ready_conn = set() 680 start_time = time.time() 681 while len(rest_conn_set) > 0: 682 if len(ready_conn) >= wait_num and timeout: 683 if (time.time() - start_time) >= timeout: 684 break 685 finish_conn = set(connection.wait(rest_conn_set, timeout=timeout)) 686 ready_conn = ready_conn.union(finish_conn) 687 rest_conn_set = rest_conn_set.difference(finish_conn) 688 ready_ids = [rest_conn.index(c) for c in ready_conn] 689 return list(ready_conn), ready_ids 690 691 692@ENV_MANAGER_REGISTRY.register('subprocess') 693class SyncSubprocessEnvManager(AsyncSubprocessEnvManager): 694 config = dict( 695 episode_num=float("inf"), 696 max_retry=1, 697 step_timeout=None, 698 auto_reset=True, 699 reset_timeout=None, 700 retry_type='reset', 701 retry_waiting_time=0.1, 702 # subprocess specified args 703 shared_memory=True, 704 copy_on_get=True, 705 context='spawn' if platform.system().lower() == 'windows' else 'fork', 706 wait_num=float("inf"), # inf mean all the environments 707 step_wait_timeout=None, 708 connect_timeout=60, 709 reset_inplace=False, # if reset_inplace=True in SyncSubprocessEnvManager, the interaction can be reproducible. 710 ) 711 712 def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]: 713 """ 714 Overview: 715 Step all environments. Reset an env if done. 716 Arguments: 717 - actions (:obj:`Dict[int, Any]`): {env_id: action} 718 Returns: 719 - timesteps (:obj:`Dict[int, namedtuple]`): {env_id: timestep}. Timestep is a \ 720 ``BaseEnvTimestep`` tuple with observation, reward, done, env_info. 721 Example: 722 >>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())} 723 >>> timesteps = env_manager.step(actions_dict): 724 >>> for env_id, timestep in timesteps.items(): 725 >>> pass 726 727 .. note:: 728 729 - The env_id that appears in ``actions`` will also be returned in ``timesteps``. 730 - Each environment is run by a subprocess separately. Once an environment is done, it is reset immediately. 731 """ 732 self._check_closed() 733 env_ids = list(actions.keys()) 734 assert all([self._env_states[env_id] == EnvState.RUN for env_id in env_ids] 735 ), 'current env state are: {}, please check whether the requested env is in reset or done'.format( 736 {env_id: self._env_states[env_id] 737 for env_id in env_ids} 738 ) 739 for env_id, act in actions.items(): 740 # it is necessary to set kwargs as None for saving cost of serialization in some env like cartpole, 741 # and step method never uses kwargs in known envs. 742 self._pipe_parents[env_id].send(['step', [act], None]) 743 744 # === This part is different from async one. === 745 # === Because operate in this way is more efficient. === 746 timesteps = {} 747 ready_conn = [self._pipe_parents[env_id] for env_id in env_ids] 748 # timesteps.update({env_id: p.recv() for env_id, p in zip(env_ids, ready_conn)}) 749 for env_id, p in zip(env_ids, ready_conn): 750 try: 751 timesteps.update({env_id: p.recv()}) 752 except pickle.UnpicklingError as e: 753 timestep = BaseEnvTimestep(None, None, None, {'abnormal': True}) 754 timesteps.update({env_id: timestep}) 755 self._pipe_parents[env_id].close() 756 if self._subprocesses[env_id].is_alive(): 757 self._subprocesses[env_id].terminate() 758 self._create_env_subprocess(env_id) 759 self._check_data(timesteps) 760 # ====================================================== 761 762 if self._shared_memory: 763 # TODO(nyz) optimize sync shm 764 for i, (env_id, timestep) in enumerate(timesteps.items()): 765 timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get()) 766 for env_id, timestep in timesteps.items(): 767 if is_abnormal_timestep(timestep): 768 self._env_states[env_id] = EnvState.ERROR 769 continue 770 if timestep.done: 771 self._env_episode_count[env_id] += 1 772 if self._env_episode_count[env_id] < self._episode_num: 773 if self._auto_reset: 774 if self._reset_inplace: # reset in subprocess at once 775 self._env_states[env_id] = EnvState.RUN 776 self._ready_obs[env_id] = timestep.obs 777 else: 778 # in this case, ready_obs is updated in ``self._reset`` 779 self._env_states[env_id] = EnvState.RESET 780 reset_thread = PropagatingThread(target=self._reset, args=(env_id, ), name='regular_reset') 781 reset_thread.daemon = True 782 reset_thread.start() 783 else: 784 # in the case that auto_reset=False, caller should call ``env_manager.reset`` manually 785 self._env_states[env_id] = EnvState.NEED_RESET 786 else: 787 self._env_states[env_id] = EnvState.DONE 788 else: 789 self._ready_obs[env_id] = timestep.obs 790 return timesteps 791 792 793@ENV_MANAGER_REGISTRY.register('subprocess_v2') 794class SubprocessEnvManagerV2(SyncSubprocessEnvManager): 795 """ 796 Overview: 797 SyncSubprocessEnvManager for new task pipeline and interfaces coupled with treetensor. 798 """ 799 800 @property 801 def ready_obs(self) -> tnp.array: 802 """ 803 Overview: 804 Get the ready (next) observation in ``tnp.array`` type, which is uniform for both async/sync scenarios. 805 Return: 806 - ready_obs (:obj:`tnp.array`): A stacked treenumpy-type observation data. 807 Example: 808 >>> obs = env_manager.ready_obs 809 >>> action = model(obs) # model input np obs and output np action 810 >>> timesteps = env_manager.step(action) 811 """ 812 no_done_env_idx = [i for i, s in self._env_states.items() if s != EnvState.DONE] 813 sleep_count = 0 814 while not any([self._env_states[i] == EnvState.RUN for i in no_done_env_idx]): 815 if sleep_count != 0 and sleep_count % 10000 == 0: 816 logging.warning( 817 'VEC_ENV_MANAGER: all the not done envs are resetting, sleep {} times'.format(sleep_count) 818 ) 819 time.sleep(0.001) 820 sleep_count += 1 821 return tnp.stack([tnp.array(self._ready_obs[i]) for i in self.ready_env]) 822 823 def step(self, actions: Union[List[tnp.ndarray], tnp.ndarray]) -> List[tnp.ndarray]: 824 """ 825 Overview: 826 Execute env step according to input actions. And reset an env if done. 827 Arguments: 828 - actions (:obj:`Union[List[tnp.ndarray], tnp.ndarray]`): actions came from outer caller like policy. 829 Returns: 830 - timesteps (:obj:`List[tnp.ndarray]`): Each timestep is a tnp.array with observation, reward, done, \ 831 info, env_id. 832 """ 833 if isinstance(actions, tnp.ndarray): 834 # zip operation will lead to wrong behaviour if not split data 835 split_action = tnp.split(actions, actions.shape[0]) 836 split_action = [s.squeeze(0) for s in split_action] 837 else: 838 split_action = actions 839 actions = {env_id: a for env_id, a in zip(self.ready_obs_id, split_action)} 840 timesteps = super().step(actions) 841 new_data = [] 842 for env_id, timestep in timesteps.items(): 843 obs, reward, done, info = timestep 844 # make the type and content of key as similar as identifier, 845 # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info 846 info = make_key_as_identifier(info) 847 info = remove_illegal_item(info) 848 new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id})) 849 return new_data