Skip to content

ding.example.dqn_per

ding.example.dqn_per

Full Source Code

../ding/example/dqn_per.py

1import gym 2from ditk import logging 3from ding.model import DQN 4from ding.policy import DQNPolicy 5from ding.envs import DingEnvWrapper, BaseEnvManagerV2 6from ding.data import DequeBuffer 7from ding.data.buffer.middleware import PriorityExperienceReplay 8from ding.config import compile_config 9from ding.framework import task 10from ding.framework.context import OnlineRLContext 11from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ 12 eps_greedy_handler, CkptSaver 13from ding.utils import set_pkg_seed 14from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config 15 16 17def main(): 18 logging.getLogger().setLevel(logging.INFO) 19 main_config.exp_name = 'cartpole_dqn_per' 20 main_config.policy.priority = True 21 main_config.policy.priority_IS_weight = True 22 cfg = compile_config(main_config, create_cfg=create_config, auto=True) 23 with task.start(async_mode=False, ctx=OnlineRLContext()): 24 collector_env = BaseEnvManagerV2( 25 env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], 26 cfg=cfg.env.manager 27 ) 28 evaluator_env = BaseEnvManagerV2( 29 env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], 30 cfg=cfg.env.manager 31 ) 32 33 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 34 35 model = DQN(**cfg.policy.model) 36 buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) 37 buffer_.use(PriorityExperienceReplay(buffer_, IS_weight=True)) 38 policy = DQNPolicy(cfg.policy, model=model) 39 40 task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) 41 task.use(eps_greedy_handler(cfg)) 42 task.use(StepCollector(cfg, policy.collect_mode, collector_env)) 43 task.use(data_pusher(cfg, buffer_)) 44 task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) 45 task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) 46 task.run() 47 48 49if __name__ == "__main__": 50 main()