Skip to content

ding.bonus

ding.bonus

env_supported(algo=None)

return list of the envs that supported by di-engine.

algo_supported(env_id=None)

return list of the algos that supported by di-engine.

is_supported(env_id=None, algo=None)

Check if the env-algo pair is supported by di-engine.

Full Source Code

../ding/bonus/__init__.py

1import ding.config 2from .a2c import A2CAgent 3from .c51 import C51Agent 4from .ddpg import DDPGAgent 5from .dqn import DQNAgent 6from .pg import PGAgent 7from .ppof import PPOF 8from .ppo_offpolicy import PPOOffPolicyAgent 9from .sac import SACAgent 10from .sql import SQLAgent 11from .td3 import TD3Agent 12 13supported_algo = dict( 14 A2C=A2CAgent, 15 C51=C51Agent, 16 DDPG=DDPGAgent, 17 DQN=DQNAgent, 18 PG=PGAgent, 19 PPOF=PPOF, 20 PPOOffPolicy=PPOOffPolicyAgent, 21 SAC=SACAgent, 22 SQL=SQLAgent, 23 TD3=TD3Agent, 24) 25 26supported_algo_list = list(supported_algo.keys()) 27 28 29def env_supported(algo: str = None) -> list: 30 """ 31 return list of the envs that supported by di-engine. 32 """ 33 34 if algo is not None: 35 if algo.upper() == "A2C": 36 return list(ding.config.example.A2C.supported_env.keys()) 37 elif algo.upper() == "C51": 38 return list(ding.config.example.C51.supported_env.keys()) 39 elif algo.upper() == "DDPG": 40 return list(ding.config.example.DDPG.supported_env.keys()) 41 elif algo.upper() == "DQN": 42 return list(ding.config.example.DQN.supported_env.keys()) 43 elif algo.upper() == "PG": 44 return list(ding.config.example.PG.supported_env.keys()) 45 elif algo.upper() == "PPOF": 46 return list(ding.config.example.PPOF.supported_env.keys()) 47 elif algo.upper() == "PPOOFFPOLICY": 48 return list(ding.config.example.PPOOffPolicy.supported_env.keys()) 49 elif algo.upper() == "SAC": 50 return list(ding.config.example.SAC.supported_env.keys()) 51 elif algo.upper() == "SQL": 52 return list(ding.config.example.SQL.supported_env.keys()) 53 elif algo.upper() == "TD3": 54 return list(ding.config.example.TD3.supported_env.keys()) 55 else: 56 raise ValueError("The algo {} is not supported by di-engine.".format(algo)) 57 else: 58 supported_env = set() 59 supported_env.update(ding.config.example.A2C.supported_env.keys()) 60 supported_env.update(ding.config.example.C51.supported_env.keys()) 61 supported_env.update(ding.config.example.DDPG.supported_env.keys()) 62 supported_env.update(ding.config.example.DQN.supported_env.keys()) 63 supported_env.update(ding.config.example.PG.supported_env.keys()) 64 supported_env.update(ding.config.example.PPOF.supported_env.keys()) 65 supported_env.update(ding.config.example.PPOOffPolicy.supported_env.keys()) 66 supported_env.update(ding.config.example.SAC.supported_env.keys()) 67 supported_env.update(ding.config.example.SQL.supported_env.keys()) 68 supported_env.update(ding.config.example.TD3.supported_env.keys()) 69 # return the list of the envs 70 return list(supported_env) 71 72 73supported_env = env_supported() 74 75 76def algo_supported(env_id: str = None) -> list: 77 """ 78 return list of the algos that supported by di-engine. 79 """ 80 if env_id is not None: 81 algo = [] 82 if env_id.upper() in [item.upper() for item in ding.config.example.A2C.supported_env.keys()]: 83 algo.append("A2C") 84 if env_id.upper() in [item.upper() for item in ding.config.example.C51.supported_env.keys()]: 85 algo.append("C51") 86 if env_id.upper() in [item.upper() for item in ding.config.example.DDPG.supported_env.keys()]: 87 algo.append("DDPG") 88 if env_id.upper() in [item.upper() for item in ding.config.example.DQN.supported_env.keys()]: 89 algo.append("DQN") 90 if env_id.upper() in [item.upper() for item in ding.config.example.PG.supported_env.keys()]: 91 algo.append("PG") 92 if env_id.upper() in [item.upper() for item in ding.config.example.PPOF.supported_env.keys()]: 93 algo.append("PPOF") 94 if env_id.upper() in [item.upper() for item in ding.config.example.PPOOffPolicy.supported_env.keys()]: 95 algo.append("PPOOffPolicy") 96 if env_id.upper() in [item.upper() for item in ding.config.example.SAC.supported_env.keys()]: 97 algo.append("SAC") 98 if env_id.upper() in [item.upper() for item in ding.config.example.SQL.supported_env.keys()]: 99 algo.append("SQL") 100 if env_id.upper() in [item.upper() for item in ding.config.example.TD3.supported_env.keys()]: 101 algo.append("TD3") 102 103 if len(algo) == 0: 104 raise ValueError("The env {} is not supported by di-engine.".format(env_id)) 105 return algo 106 else: 107 return supported_algo_list 108 109 110def is_supported(env_id: str = None, algo: str = None) -> bool: 111 """ 112 Check if the env-algo pair is supported by di-engine. 113 """ 114 if env_id is not None and env_id.upper() in [item.upper() for item in supported_env.keys()]: 115 if algo is not None and algo.upper() in supported_algo_list: 116 if env_id.upper() in env_supported(algo): 117 return True 118 else: 119 return False 120 elif algo is None: 121 return True 122 else: 123 return False 124 elif env_id is None: 125 if algo is not None and algo.upper() in supported_algo_list: 126 return True 127 elif algo is None: 128 raise ValueError("Please specify the env or algo.") 129 else: 130 return False 131 else: 132 return False