Skip to content

ding.bonus.sql

ding.bonus.sql

SQLAgent

Overview

Class of agent for training, evaluation and deployment of Reinforcement learning algorithm Soft Q-Learning(SQL). For more information about the system design of RL agent, please refer to https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html.

Interface: __init__, train, deploy, collect_data, batch_evaluate, best

supported_env_list = list(supported_env_cfg.keys()) class-attribute instance-attribute

Overview

List of supported envs.

Examples: >>> from ding.bonus.sql import SQLAgent >>> print(SQLAgent.supported_env_list)

best property

Overview

Load the best model from the checkpoint directory, which is by default in folder exp_name/ckpt/eval.pth.tar. The return value is the agent with the best model.

Returns: - (:obj:SQLAgent): The agent with the best model. Examples: >>> agent = SQLAgent(env_id='LunarLander-v2') >>> agent.train() >>> agent = agent.best

.. note:: The best model is the model with the highest evaluation return. If this method is called, the current model will be replaced by the best model.

__init__(env_id=None, env=None, seed=0, exp_name=None, model=None, cfg=None, policy_state_dict=None)

Overview

Initialize agent for SQL algorithm.

Arguments: - env_id (:obj:str): The environment id, which is a registered environment name in gym or gymnasium. If env_id is not specified, env_id in cfg.env must be specified. If env_id is specified, env_id in cfg.env will be ignored. env_id should be one of the supported envs, which can be found in supported_env_list. - env (:obj:BaseEnv): The environment instance for training and evaluation. If env is not specified, env_id`` or ``cfg.env.env_id`` must be specified. ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored. - seed (:obj:int): The random seed, which is set before running the program. Default to 0. - exp_name (:obj:str): The name of this experiment, which will be used to create the folder to save log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``. - model (:obj:torch.nn.Module): The model of SQL algorithm, which should be an instance of class :class:ding.model.DQN. If not specified, a default model will be generated according to the configuration. - cfg (:obj:Union[EasyDict, dict]): The configuration of SQL algorithm, which is a dict. Default to None. If not specified, the default configuration will be used. The default configuration can be found in ``ding/config/example/SQL/gym_lunarlander_v2.py``. - policy_state_dict (:obj:str`): The path of policy state dict saved by PyTorch a in local file. If specified, the policy will be loaded from this file. Default to None.

.. note:: An RL Agent Instance can be initialized in two basic ways. For example, we have an environment with id LunarLander-v2 registered in gym, and we want to train an agent with SQL algorithm with default configuration. Then we can initialize the agent in the following ways: >>> agent = SQLAgent(env_id='LunarLander-v2') or, if we want can specify the env_id in the configuration: >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... } >>> agent = SQLAgent(cfg=cfg) There are also other arguments to specify the agent when initializing. For example, if we want to specify the environment instance: >>> env = CustomizedEnv('LunarLander-v2') >>> agent = SQLAgent(cfg=cfg, env=env) or, if we want to specify the model: >>> model = DQN(**cfg.policy.model) >>> agent = SQLAgent(cfg=cfg, model=model) or, if we want to reload the policy from a saved policy state dict: >>> agent = SQLAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar') Make sure that the configuration is consistent with the saved policy state dict.

train(step=int(10000000.0), collector_env_num=None, evaluator_env_num=None, n_iter_save_ckpt=1000, context=None, debug=False, wandb_sweep=False)

Overview

Train the agent with SQL algorithm for step iterations with collector_env_num collector environments and evaluator_env_num evaluator environments. Information during training will be recorded and saved by wandb.

Arguments: - step (:obj:int): The total training environment steps of all collector environments. Default to 1e7. - collector_env_num (:obj:int): The collector environment number. Default to None. If not specified, it will be set according to the configuration. - evaluator_env_num (:obj:int): The evaluator environment number. Default to None. If not specified, it will be set according to the configuration. - n_iter_save_ckpt (:obj:int): The frequency of saving checkpoint every training iteration. Default to 1000. - context (:obj:str): The multi-process context of the environment manager. Default to None. It can be specified as spawn, fork or forkserver. - debug (:obj:bool): Whether to use debug mode in the environment manager. Default to False. If set True, base environment manager will be used for easy debugging. Otherwise, subprocess environment manager will be used. - wandb_sweep (:obj:bool): Whether to use wandb sweep, which is a hyper-parameter optimization process for seeking the best configurations. Default to False. If True, the wandb sweep id will be used as the experiment name. Returns: - (:obj:TrainingReturn): The training result, of which the attributions are: - wandb_url (:obj:str): The weight & biases (wandb) project url of the trainning experiment.

deploy(enable_save_replay=False, concatenate_all_replay=False, replay_save_path=None, seed=None, debug=False)

Overview

Deploy the agent with SQL algorithm by interacting with the environment, during which the replay video can be saved if enable_save_replay is True. The evaluation result will be returned.

Arguments: - enable_save_replay (:obj:bool): Whether to save the replay video. Default to False. - concatenate_all_replay (:obj:bool): Whether to concatenate all replay videos into one video. Default to False. If enable_save_replay is False, this argument will be ignored. If enable_save_replay is True and concatenate_all_replay is False, the replay video of each episode will be saved separately. - replay_save_path (:obj:str): The path to save the replay video. Default to None. If not specified, the video will be saved in exp_name/videos. - seed (:obj:Union[int, List]): The random seed, which is set before running the program. Default to None. If not specified, self.seed will be used. If seed is an integer, the agent will be deployed once. If seed is a list of integers, the agent will be deployed once for each seed in the list. - debug (:obj:bool): Whether to use debug mode in the environment manager. Default to False. If set True, base environment manager will be used for easy debugging. Otherwise, subprocess environment manager will be used. Returns: - (:obj:EvalReturn): The evaluation result, of which the attributions are: - eval_value (:obj:np.float32): The mean of evaluation return. - eval_value_std (:obj:np.float32): The standard deviation of evaluation return.

collect_data(env_num=8, save_data_path=None, n_sample=None, n_episode=None, context=None, debug=False)

Overview

Collect data with SQL algorithm for n_episode episodes with env_num collector environments. The collected data will be saved in save_data_path if specified, otherwise it will be saved in exp_name/demo_data.

Arguments: - env_num (:obj:int): The number of collector environments. Default to 8. - save_data_path (:obj:str): The path to save the collected data. Default to None. If not specified, the data will be saved in exp_name/demo_data. - n_sample (:obj:int): The number of samples to collect. Default to None. If not specified, n_episode must be specified. - n_episode (:obj:int): The number of episodes to collect. Default to None. If not specified, n_sample must be specified. - context (:obj:str): The multi-process context of the environment manager. Default to None. It can be specified as spawn, fork or forkserver. - debug (:obj:bool): Whether to use debug mode in the environment manager. Default to False. If set True, base environment manager will be used for easy debugging. Otherwise, subprocess environment manager will be used.

batch_evaluate(env_num=4, n_evaluator_episode=4, context=None, debug=False)

Overview

Evaluate the agent with SQL algorithm for n_evaluator_episode episodes with env_num evaluator environments. The evaluation result will be returned. The difference between methods batch_evaluate and deploy is that batch_evaluate will create multiple evaluator environments to evaluate the agent to get an average performance, while deploy will only create one evaluator environment to evaluate the agent and save the replay video.

Arguments: - env_num (:obj:int): The number of evaluator environments. Default to 4. - n_evaluator_episode (:obj:int): The number of episodes to evaluate. Default to 4. - context (:obj:str): The multi-process context of the environment manager. Default to None. It can be specified as spawn, fork or forkserver. - debug (:obj:bool): Whether to use debug mode in the environment manager. Default to False. If set True, base environment manager will be used for easy debugging. Otherwise, subprocess environment manager will be used. Returns: - (:obj:EvalReturn): The evaluation result, of which the attributions are: - eval_value (:obj:np.float32): The mean of evaluation return. - eval_value_std (:obj:np.float32): The standard deviation of evaluation return.

Full Source Code

../ding/bonus/sql.py

1from typing import Optional, Union, List 2from ditk import logging 3from easydict import EasyDict 4import os 5import numpy as np 6import torch 7import treetensor.torch as ttorch 8from ding.framework import task, OnlineRLContext 9from ding.framework.middleware import CkptSaver, \ 10 wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \ 11 OffPolicyLearner, final_ctx_saver, nstep_reward_enhancer, eps_greedy_handler 12from ding.envs import BaseEnv 13from ding.envs import setup_ding_env_manager 14from ding.policy import SQLPolicy 15from ding.utils import set_pkg_seed 16from ding.utils import get_env_fps, render 17from ding.config import save_config_py, compile_config 18from ding.model import DQN 19from ding.model import model_wrap 20from ding.data import DequeBuffer 21from ding.bonus.common import TrainingReturn, EvalReturn 22from ding.config.example.SQL import supported_env_cfg 23from ding.config.example.SQL import supported_env 24 25 26class SQLAgent: 27 """ 28 Overview: 29 Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \ 30 Soft Q-Learning(SQL). 31 For more information about the system design of RL agent, please refer to \ 32 <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>. 33 Interface: 34 ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best`` 35 """ 36 supported_env_list = list(supported_env_cfg.keys()) 37 """ 38 Overview: 39 List of supported envs. 40 Examples: 41 >>> from ding.bonus.sql import SQLAgent 42 >>> print(SQLAgent.supported_env_list) 43 """ 44 45 def __init__( 46 self, 47 env_id: str = None, 48 env: BaseEnv = None, 49 seed: int = 0, 50 exp_name: str = None, 51 model: Optional[torch.nn.Module] = None, 52 cfg: Optional[Union[EasyDict, dict]] = None, 53 policy_state_dict: str = None, 54 ) -> None: 55 """ 56 Overview: 57 Initialize agent for SQL algorithm. 58 Arguments: 59 - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \ 60 If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \ 61 If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \ 62 ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``. 63 - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \ 64 If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \ 65 ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \ 66 If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored. 67 - seed (:obj:`int`): The random seed, which is set before running the program. \ 68 Default to 0. 69 - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \ 70 log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``. 71 - model (:obj:`torch.nn.Module`): The model of SQL algorithm, which should be an instance of class \ 72 :class:`ding.model.DQN`. \ 73 If not specified, a default model will be generated according to the configuration. 74 - cfg (:obj:Union[EasyDict, dict]): The configuration of SQL algorithm, which is a dict. \ 75 Default to None. If not specified, the default configuration will be used. \ 76 The default configuration can be found in ``ding/config/example/SQL/gym_lunarlander_v2.py``. 77 - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \ 78 If specified, the policy will be loaded from this file. Default to None. 79 80 .. note:: 81 An RL Agent Instance can be initialized in two basic ways. \ 82 For example, we have an environment with id ``LunarLander-v2`` registered in gym, \ 83 and we want to train an agent with SQL algorithm with default configuration. \ 84 Then we can initialize the agent in the following ways: 85 >>> agent = SQLAgent(env_id='LunarLander-v2') 86 or, if we want can specify the env_id in the configuration: 87 >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... } 88 >>> agent = SQLAgent(cfg=cfg) 89 There are also other arguments to specify the agent when initializing. 90 For example, if we want to specify the environment instance: 91 >>> env = CustomizedEnv('LunarLander-v2') 92 >>> agent = SQLAgent(cfg=cfg, env=env) 93 or, if we want to specify the model: 94 >>> model = DQN(**cfg.policy.model) 95 >>> agent = SQLAgent(cfg=cfg, model=model) 96 or, if we want to reload the policy from a saved policy state dict: 97 >>> agent = SQLAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar') 98 Make sure that the configuration is consistent with the saved policy state dict. 99 """ 100 101 assert env_id is not None or cfg is not None, "Please specify env_id or cfg." 102 103 if cfg is not None and not isinstance(cfg, EasyDict): 104 cfg = EasyDict(cfg) 105 106 if env_id is not None: 107 assert env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format( 108 SQLAgent.supported_env_list 109 ) 110 if cfg is None: 111 cfg = supported_env_cfg[env_id] 112 else: 113 assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args." 114 else: 115 assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg." 116 assert cfg.env.env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format( 117 SQLAgent.supported_env_list 118 ) 119 default_policy_config = EasyDict({"policy": SQLPolicy.default_config()}) 120 default_policy_config.update(cfg) 121 cfg = default_policy_config 122 123 if exp_name is not None: 124 cfg.exp_name = exp_name 125 self.cfg = compile_config(cfg, policy=SQLPolicy) 126 self.exp_name = self.cfg.exp_name 127 if env is None: 128 self.env = supported_env[cfg.env.env_id](cfg=cfg.env) 129 else: 130 assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type." 131 self.env = env 132 133 logging.getLogger().setLevel(logging.INFO) 134 self.seed = seed 135 set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda) 136 if not os.path.exists(self.exp_name): 137 os.makedirs(self.exp_name) 138 save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py')) 139 if model is None: 140 model = DQN(**self.cfg.policy.model) 141 self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size) 142 self.policy = SQLPolicy(self.cfg.policy, model=model) 143 if policy_state_dict is not None: 144 self.policy.learn_mode.load_state_dict(policy_state_dict) 145 self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") 146 147 def train( 148 self, 149 step: int = int(1e7), 150 collector_env_num: int = None, 151 evaluator_env_num: int = None, 152 n_iter_save_ckpt: int = 1000, 153 context: Optional[str] = None, 154 debug: bool = False, 155 wandb_sweep: bool = False, 156 ) -> TrainingReturn: 157 """ 158 Overview: 159 Train the agent with SQL algorithm for ``step`` iterations with ``collector_env_num`` collector \ 160 environments and ``evaluator_env_num`` evaluator environments. Information during training will be \ 161 recorded and saved by wandb. 162 Arguments: 163 - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7. 164 - collector_env_num (:obj:`int`): The collector environment number. Default to None. \ 165 If not specified, it will be set according to the configuration. 166 - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \ 167 If not specified, it will be set according to the configuration. 168 - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \ 169 Default to 1000. 170 - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \ 171 It can be specified as ``spawn``, ``fork`` or ``forkserver``. 172 - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \ 173 If set True, base environment manager will be used for easy debugging. Otherwise, \ 174 subprocess environment manager will be used. 175 - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \ 176 which is a hyper-parameter optimization process for seeking the best configurations. \ 177 Default to False. If True, the wandb sweep id will be used as the experiment name. 178 Returns: 179 - (:obj:`TrainingReturn`): The training result, of which the attributions are: 180 - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment. 181 """ 182 183 if debug: 184 logging.getLogger().setLevel(logging.DEBUG) 185 logging.debug(self.policy._model) 186 # define env and policy 187 collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num 188 evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num 189 collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector') 190 evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator') 191 192 with task.start(ctx=OnlineRLContext()): 193 task.use( 194 interaction_evaluator( 195 self.cfg, 196 self.policy.eval_mode, 197 evaluator_env, 198 render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False 199 ) 200 ) 201 task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt)) 202 task.use(eps_greedy_handler(self.cfg)) 203 task.use( 204 StepCollector( 205 self.cfg, 206 self.policy.collect_mode, 207 collector_env, 208 random_collect_size=self.cfg.policy.random_collect_size 209 if hasattr(self.cfg.policy, 'random_collect_size') else 0, 210 ) 211 ) 212 if "nstep" in self.cfg.policy and self.cfg.policy.nstep > 1: 213 task.use(nstep_reward_enhancer(self.cfg)) 214 task.use(data_pusher(self.cfg, self.buffer_)) 215 task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_)) 216 task.use( 217 wandb_online_logger( 218 metric_list=self.policy._monitor_vars_learn(), 219 model=self.policy._model, 220 anonymous=True, 221 project_name=self.exp_name, 222 wandb_sweep=wandb_sweep, 223 ) 224 ) 225 task.use(termination_checker(max_env_step=step)) 226 task.use(final_ctx_saver(name=self.exp_name)) 227 task.run() 228 229 return TrainingReturn(wandb_url=task.ctx.wandb_url) 230 231 def deploy( 232 self, 233 enable_save_replay: bool = False, 234 concatenate_all_replay: bool = False, 235 replay_save_path: str = None, 236 seed: Optional[Union[int, List]] = None, 237 debug: bool = False 238 ) -> EvalReturn: 239 """ 240 Overview: 241 Deploy the agent with SQL algorithm by interacting with the environment, during which the replay video \ 242 can be saved if ``enable_save_replay`` is True. The evaluation result will be returned. 243 Arguments: 244 - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False. 245 - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \ 246 Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \ 247 If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \ 248 the replay video of each episode will be saved separately. 249 - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \ 250 If not specified, the video will be saved in ``exp_name/videos``. 251 - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \ 252 Default to None. If not specified, ``self.seed`` will be used. \ 253 If ``seed`` is an integer, the agent will be deployed once. \ 254 If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list. 255 - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \ 256 If set True, base environment manager will be used for easy debugging. Otherwise, \ 257 subprocess environment manager will be used. 258 Returns: 259 - (:obj:`EvalReturn`): The evaluation result, of which the attributions are: 260 - eval_value (:obj:`np.float32`): The mean of evaluation return. 261 - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return. 262 """ 263 264 if debug: 265 logging.getLogger().setLevel(logging.DEBUG) 266 # define env and policy 267 env = self.env.clone(caller='evaluator') 268 269 if seed is not None and isinstance(seed, int): 270 seeds = [seed] 271 elif seed is not None and isinstance(seed, list): 272 seeds = seed 273 else: 274 seeds = [self.seed] 275 276 returns = [] 277 images = [] 278 if enable_save_replay: 279 replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path 280 env.enable_save_replay(replay_path=replay_save_path) 281 else: 282 logging.warning('No video would be generated during the deploy.') 283 if concatenate_all_replay: 284 logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.') 285 concatenate_all_replay = False 286 287 def single_env_forward_wrapper(forward_fn, cuda=True): 288 289 forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward 290 291 def _forward(obs): 292 # unsqueeze means add batch dim, i.e. (O, ) -> (1, O) 293 obs = ttorch.as_tensor(obs).unsqueeze(0) 294 if cuda and torch.cuda.is_available(): 295 obs = obs.cuda() 296 action = forward_fn(obs)["action"] 297 # squeeze means delete batch dim, i.e. (1, A) -> (A, ) 298 action = action.squeeze(0).detach().cpu().numpy() 299 return action 300 301 return _forward 302 303 forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda) 304 305 # reset first to make sure the env is in the initial state 306 # env will be reset again in the main loop 307 env.reset() 308 309 for seed in seeds: 310 env.seed(seed, dynamic_seed=False) 311 return_ = 0. 312 step = 0 313 obs = env.reset() 314 images.append(render(env)[None]) if concatenate_all_replay else None 315 while True: 316 action = forward_fn(obs) 317 obs, rew, done, info = env.step(action) 318 images.append(render(env)[None]) if concatenate_all_replay else None 319 return_ += rew 320 step += 1 321 if done: 322 break 323 logging.info(f'SQL deploy is finished, final episode return with {step} steps is: {return_}') 324 returns.append(return_) 325 326 env.close() 327 328 if concatenate_all_replay: 329 images = np.concatenate(images, axis=0) 330 import imageio 331 imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env)) 332 333 return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns)) 334 335 def collect_data( 336 self, 337 env_num: int = 8, 338 save_data_path: Optional[str] = None, 339 n_sample: Optional[int] = None, 340 n_episode: Optional[int] = None, 341 context: Optional[str] = None, 342 debug: bool = False 343 ) -> None: 344 """ 345 Overview: 346 Collect data with SQL algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \ 347 The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \ 348 ``exp_name/demo_data``. 349 Arguments: 350 - env_num (:obj:`int`): The number of collector environments. Default to 8. 351 - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \ 352 If not specified, the data will be saved in ``exp_name/demo_data``. 353 - n_sample (:obj:`int`): The number of samples to collect. Default to None. \ 354 If not specified, ``n_episode`` must be specified. 355 - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \ 356 If not specified, ``n_sample`` must be specified. 357 - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \ 358 It can be specified as ``spawn``, ``fork`` or ``forkserver``. 359 - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \ 360 If set True, base environment manager will be used for easy debugging. Otherwise, \ 361 subprocess environment manager will be used. 362 """ 363 364 if debug: 365 logging.getLogger().setLevel(logging.DEBUG) 366 if n_episode is not None: 367 raise NotImplementedError 368 # define env and policy 369 env_num = env_num if env_num else self.cfg.env.collector_env_num 370 env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector') 371 372 if save_data_path is None: 373 save_data_path = os.path.join(self.exp_name, 'demo_data') 374 375 # main execution task 376 with task.start(ctx=OnlineRLContext()): 377 task.use( 378 StepCollector( 379 self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size 380 ) 381 ) 382 task.use(offline_data_saver(save_data_path, data_type='hdf5')) 383 task.run(max_step=1) 384 logging.info( 385 f'SQL collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`' 386 ) 387 388 def batch_evaluate( 389 self, 390 env_num: int = 4, 391 n_evaluator_episode: int = 4, 392 context: Optional[str] = None, 393 debug: bool = False 394 ) -> EvalReturn: 395 """ 396 Overview: 397 Evaluate the agent with SQL algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \ 398 environments. The evaluation result will be returned. 399 The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \ 400 multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \ 401 will only create one evaluator environment to evaluate the agent and save the replay video. 402 Arguments: 403 - env_num (:obj:`int`): The number of evaluator environments. Default to 4. 404 - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4. 405 - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \ 406 It can be specified as ``spawn``, ``fork`` or ``forkserver``. 407 - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \ 408 If set True, base environment manager will be used for easy debugging. Otherwise, \ 409 subprocess environment manager will be used. 410 Returns: 411 - (:obj:`EvalReturn`): The evaluation result, of which the attributions are: 412 - eval_value (:obj:`np.float32`): The mean of evaluation return. 413 - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return. 414 """ 415 416 if debug: 417 logging.getLogger().setLevel(logging.DEBUG) 418 # define env and policy 419 env_num = env_num if env_num else self.cfg.env.evaluator_env_num 420 env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator') 421 422 # reset first to make sure the env is in the initial state 423 # env will be reset again in the main loop 424 env.launch() 425 env.reset() 426 427 evaluate_cfg = self.cfg 428 evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode 429 430 # main execution task 431 with task.start(ctx=OnlineRLContext()): 432 task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env)) 433 task.run(max_step=1) 434 435 return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std) 436 437 @property 438 def best(self) -> 'SQLAgent': 439 """ 440 Overview: 441 Load the best model from the checkpoint directory, \ 442 which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \ 443 The return value is the agent with the best model. 444 Returns: 445 - (:obj:`SQLAgent`): The agent with the best model. 446 Examples: 447 >>> agent = SQLAgent(env_id='LunarLander-v2') 448 >>> agent.train() 449 >>> agent = agent.best 450 451 .. note:: 452 The best model is the model with the highest evaluation return. If this method is called, the current \ 453 model will be replaced by the best model. 454 """ 455 456 best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar") 457 # Load best model if it exists 458 if os.path.exists(best_model_file_path): 459 policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu")) 460 self.policy.learn_mode.load_state_dict(policy_state_dict) 461 return self