ding.example.dqn_new_env¶
ding.example.dqn_new_env
¶
Full Source Code
../ding/example/dqn_new_env.py
1import gym 2from ditk import logging 3from ding.framework.supervisor import ChildType 4from ding.model import DQN 5from ding.policy import DQNPolicy 6from ding.envs import DingEnvWrapper, EnvSupervisor 7from ding.data import DequeBuffer 8from ding.config import compile_config 9from ding.framework import task 10from ding.framework.context import OnlineRLContext 11from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ 12 eps_greedy_handler, CkptSaver 13from ding.utils import set_pkg_seed 14from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config 15 16 17def main(): 18 logging.getLogger().setLevel(logging.INFO) 19 cfg = compile_config(main_config, create_cfg=create_config, auto=True) 20 with task.start(async_mode=False, ctx=OnlineRLContext()): 21 collector_env = EnvSupervisor( 22 type_=ChildType.THREAD, 23 env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], 24 **cfg.env.manager 25 ) 26 evaluator_env = EnvSupervisor( 27 type_=ChildType.THREAD, 28 env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], 29 **cfg.env.manager 30 ) 31 32 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 33 34 model = DQN(**cfg.policy.model) 35 buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) 36 policy = DQNPolicy(cfg.policy, model=model) 37 38 task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) 39 task.use(eps_greedy_handler(cfg)) 40 task.use(StepCollector(cfg, policy.collect_mode, collector_env)) 41 task.use(data_pusher(cfg, buffer_)) 42 task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) 43 task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) 44 task.run() 45 46 47if __name__ == "__main__": 48 main()