ding.envs.env_manager.envpool_env_manager¶
ding.envs.env_manager.envpool_env_manager
¶
PoolEnvManager
¶
Overview
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. Here we list some commonly used env_ids as follows. For more examples, you can refer to https://envpool.readthedocs.io/en/latest/api/atari.html.
- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5"
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1"
Full Source Code
../ding/envs/env_manager/envpool_env_manager.py
1import gym 2from easydict import EasyDict 3from copy import deepcopy 4import numpy as np 5from collections import namedtuple 6from typing import Any, Union, List, Tuple, Dict, Callable, Optional 7from ditk import logging 8try: 9 import envpool 10except ImportError: 11 import sys 12 logging.warning("Please install envpool first, use 'pip install envpool'") 13 envpool = None 14 15from ding.envs import BaseEnvTimestep 16from ding.utils import ENV_MANAGER_REGISTRY, deep_merge_dicts 17from ding.torch_utils import to_ndarray 18 19 20@ENV_MANAGER_REGISTRY.register('env_pool') 21class PoolEnvManager: 22 ''' 23 Overview: 24 Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. 25 Here we list some commonly used env_ids as follows. 26 For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>. 27 28 - Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5" 29 - Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1" 30 ''' 31 32 @classmethod 33 def default_config(cls) -> EasyDict: 34 return EasyDict(deepcopy(cls.config)) 35 36 config = dict( 37 type='envpool', 38 # Sync mode: batch_size == env_num 39 # Async mode: batch_size < env_num 40 env_num=8, 41 batch_size=8, 42 ) 43 44 def __init__(self, cfg: EasyDict) -> None: 45 self._cfg = cfg 46 self._env_num = cfg.env_num 47 self._batch_size = cfg.batch_size 48 self._ready_obs = {} 49 self._closed = True 50 self._seed = None 51 52 def launch(self) -> None: 53 assert self._closed, "Please first close the env manager" 54 if self._seed is None: 55 seed = 0 56 else: 57 seed = self._seed 58 self._envs = envpool.make( 59 task_id=self._cfg.env_id, 60 env_type="gym", 61 num_envs=self._env_num, 62 batch_size=self._batch_size, 63 seed=seed, 64 episodic_life=self._cfg.episodic_life, 65 reward_clip=self._cfg.reward_clip, 66 stack_num=self._cfg.stack_num, 67 gray_scale=self._cfg.gray_scale, 68 frame_skip=self._cfg.frame_skip 69 ) 70 self._closed = False 71 self.reset() 72 73 def reset(self) -> None: 74 self._ready_obs = {} 75 self._envs.async_reset() 76 while True: 77 obs, _, _, info = self._envs.recv() 78 env_id = info['env_id'] 79 obs = obs.astype(np.float32) 80 self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs) 81 if len(self._ready_obs) == self._env_num: 82 break 83 self._eval_episode_return = [0. for _ in range(self._env_num)] 84 85 def step(self, action: dict) -> Dict[int, namedtuple]: 86 env_id = np.array(list(action.keys())) 87 action = np.array(list(action.values())) 88 if len(action.shape) == 2: 89 action = action.squeeze(1) 90 self._envs.send(action, env_id) 91 92 obs, rew, done, info = self._envs.recv() 93 obs = obs.astype(np.float32) 94 rew = rew.astype(np.float32) 95 env_id = info['env_id'] 96 timesteps = {} 97 self._ready_obs = {} 98 for i in range(len(env_id)): 99 d = bool(done[i]) 100 r = to_ndarray([rew[i]]) 101 self._eval_episode_return[env_id[i]] += r 102 timesteps[env_id[i]] = BaseEnvTimestep(obs[i], r, d, info={'env_id': i}) 103 if d: 104 timesteps[env_id[i]].info['eval_episode_return'] = self._eval_episode_return[env_id[i]] 105 self._eval_episode_return[env_id[i]] = 0. 106 self._ready_obs[env_id[i]] = obs[i] 107 return timesteps 108 109 def close(self) -> None: 110 if self._closed: 111 return 112 # Envpool has no `close` API 113 self._closed = True 114 115 def seed(self, seed: int, dynamic_seed=False) -> None: 116 # The i-th environment seed in Envpool will be set with i+seed, so we don't do extra transformation here 117 self._seed = seed 118 logging.warning("envpool doesn't support dynamic_seed in different episode") 119 120 @property 121 def env_num(self) -> int: 122 return self._env_num 123 124 @property 125 def ready_obs(self) -> Dict[int, Any]: 126 return self._ready_obs