1from typing import Union, Optional, List, Any, Tuple 2import os 3import pickle 4import numpy as np 5import torch 6from functools import partial 7from copy import deepcopy 8 9from ding.config import compile_config, read_config 10from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, EpisodeSerialCollector 11from ding.envs import create_env_manager, get_vec_env_setting 12from ding.policy import create_policy 13from ding.torch_utils import to_device, to_ndarray 14from ding.utils import set_pkg_seed 15from ding.utils.data import offline_data_save_type 16from ding.rl_utils import get_nstep_return_data 17from ding.utils.data import default_collate 18 19 20def eval( 21 input_cfg: Union[str, Tuple[dict, dict]], 22 seed: int = 0, 23 env_setting: Optional[List[Any]] = None, 24 model: Optional[torch.nn.Module] = None, 25 state_dict: Optional[dict] = None, 26 load_path: Optional[str] = None, 27 replay_path: Optional[str] = None, 28) -> float: 29 """ 30 Overview: 31 Pure policy evaluation entry. Evaluate mean episode return and save replay videos. 32 Arguments: 33 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 34 ``str`` type means config file path. \ 35 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 36 - seed (:obj:`int`): Random seed. 37 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 38 ``BaseEnv`` subclass, collector env config, and evaluator env config. 39 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 40 - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. 41 - load_path (:obj:`Optional[str]`): Path to load ckpt. 42 - replay_path (:obj:`Optional[str]`): Path to save replay. 43 """ 44 if isinstance(input_cfg, str): 45 cfg, create_cfg = read_config(input_cfg) 46 else: 47 cfg, create_cfg = deepcopy(input_cfg) 48 env_fn = None if env_setting is None else env_setting[0] 49 cfg = compile_config( 50 cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, save_path='eval_config.py' 51 ) 52 53 # Create components: env, policy, evaluator 54 if env_setting is None: 55 env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False) 56 else: 57 env_fn, _, evaluator_env_cfg = env_setting 58 evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) 59 evaluator_env.seed(seed, dynamic_seed=False) 60 if replay_path is None: # argument > config 61 replay_path = cfg.env.get('replay_path', None) 62 if replay_path: 63 evaluator_env.enable_save_replay(replay_path) 64 set_pkg_seed(seed, use_cuda=cfg.policy.cuda) 65 policy = create_policy(cfg.policy, model=model, enable_field=['eval']) 66 if state_dict is None: 67 if load_path is None: 68 load_path = cfg.policy.learn.learner.load_path 69 state_dict = torch.load(load_path, map_location='cpu') 70 policy.eval_mode.load_state_dict(state_dict) 71 evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode) 72 73 # Evaluate 74 _, episode_info = evaluator.eval() 75 episode_return = np.mean(episode_info['eval_episode_return']) 76 print('Eval is over! The performance of your RL policy is {}'.format(episode_return)) 77 return episode_return 78 79 80def collect_demo_data( 81 input_cfg: Union[str, dict], 82 seed: int, 83 collect_count: int, 84 expert_data_path: Optional[str] = None, 85 env_setting: Optional[List[Any]] = None, 86 model: Optional[torch.nn.Module] = None, 87 state_dict: Optional[dict] = None, 88 state_dict_path: Optional[str] = None, 89) -> None: 90 r""" 91 Overview: 92 Collect demonstration data by the trained policy. 93 Arguments: 94 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 95 ``str`` type means config file path. \ 96 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 97 - seed (:obj:`int`): Random seed. 98 - collect_count (:obj:`int`): The count of collected data. 99 - expert_data_path (:obj:`str`): File path of the expert demo data will be written to. 100 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 101 ``BaseEnv`` subclass, collector env config, and evaluator env config. 102 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 103 - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. 104 - state_dict_path (:obj:`Optional[str]`): The path of the state_dict of policy or model. 105 """ 106 if isinstance(input_cfg, str): 107 cfg, create_cfg = read_config(input_cfg) 108 else: 109 cfg, create_cfg = deepcopy(input_cfg) 110 env_fn = None if env_setting is None else env_setting[0] 111 cfg = compile_config( 112 cfg, 113 seed=seed, 114 env=env_fn, 115 auto=True, 116 create_cfg=create_cfg, 117 save_cfg=True, 118 save_path='collect_demo_data_config.py' 119 ) 120 if expert_data_path is None: 121 expert_data_path = cfg.policy.collect.save_path 122 123 # Create components: env, policy, collector 124 if env_setting is None: 125 env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env, eval_=False) 126 else: 127 env_fn, collector_env_cfg, _ = env_setting 128 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 129 collector_env.seed(seed) 130 set_pkg_seed(seed, use_cuda=cfg.policy.cuda) 131 policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval']) 132 # for policies like DQN (in collect_mode has eps-greedy) 133 # collect_demo_policy = policy.collect_function( 134 # policy._forward_eval, 135 # policy._process_transition, 136 # policy._get_train_sample, 137 # policy._reset_eval, 138 # policy._get_attribute, 139 # policy._set_attribute, 140 # policy._state_dict_collect, 141 # policy._load_state_dict_collect, 142 # ) 143 collect_demo_policy = policy.collect_mode 144 if state_dict is None: 145 assert state_dict_path is not None 146 state_dict = torch.load(state_dict_path, map_location='cpu') 147 policy.collect_mode.load_state_dict(state_dict) 148 collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy) 149 150 if hasattr(cfg.policy.other, 'eps'): 151 policy_kwargs = {'eps': 0.} 152 else: 153 policy_kwargs = None 154 155 # Let's collect some expert demonstrations 156 exp_data = collector.collect(n_sample=collect_count, policy_kwargs=policy_kwargs) 157 if cfg.policy.cuda: 158 exp_data = to_device(exp_data, 'cpu') 159 # Save data transitions. 160 offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) 161 print('Collect demo data successfully') 162 163 164def collect_episodic_demo_data( 165 input_cfg: Union[str, dict], 166 seed: int, 167 collect_count: int, 168 expert_data_path: str, 169 env_setting: Optional[List[Any]] = None, 170 model: Optional[torch.nn.Module] = None, 171 state_dict: Optional[dict] = None, 172 state_dict_path: Optional[str] = None, 173) -> None: 174 r""" 175 Overview: 176 Collect episodic demonstration data by the trained policy. 177 Arguments: 178 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 179 ``str`` type means config file path. \ 180 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 181 - seed (:obj:`int`): Random seed. 182 - collect_count (:obj:`int`): The count of collected data. 183 - expert_data_path (:obj:`str`): File path of the expert demo data will be written to. 184 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 185 ``BaseEnv`` subclass, collector env config, and evaluator env config. 186 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 187 - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. 188 - state_dict_path (:obj:'str') the abs path of the state dict 189 """ 190 if isinstance(input_cfg, str): 191 cfg, create_cfg = read_config(input_cfg) 192 else: 193 cfg, create_cfg = deepcopy(input_cfg) 194 env_fn = None if env_setting is None else env_setting[0] 195 cfg = compile_config( 196 cfg, 197 collector=EpisodeSerialCollector, 198 seed=seed, 199 env=env_fn, 200 auto=True, 201 create_cfg=create_cfg, 202 save_cfg=True, 203 save_path='collect_demo_data_config.py' 204 ) 205 206 # Create components: env, policy, collector 207 if env_setting is None: 208 env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env, eval_=False) 209 else: 210 env_fn, collector_env_cfg, _ = env_setting 211 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 212 collector_env.seed(seed) 213 set_pkg_seed(seed, use_cuda=cfg.policy.cuda) 214 policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval']) 215 collect_demo_policy = policy.collect_mode 216 if state_dict is None: 217 assert state_dict_path is not None 218 state_dict = torch.load(state_dict_path, map_location='cpu') 219 policy.collect_mode.load_state_dict(state_dict) 220 collector = EpisodeSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy) 221 222 if hasattr(cfg.policy.other, 'eps'): 223 policy_kwargs = {'eps': 0.} 224 else: 225 policy_kwargs = None 226 227 # Let's collect some expert demonstrations 228 exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs) 229 if cfg.policy.cuda: 230 exp_data = to_device(exp_data, 'cpu') 231 # Save data transitions. 232 offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) 233 print('Collect episodic demo data successfully') 234 235 236def episode_to_transitions(data_path: str, expert_data_path: str, nstep: int) -> None: 237 r""" 238 Overview: 239 Transfer episodic data into nstep transitions. 240 Arguments: 241 - data_path (:obj:str): data path that stores the pkl file 242 - expert_data_path (:obj:`str`): File path of the expert demo data will be written to. 243 - nstep (:obj:`int`): {s_{t}, a_{t}, s_{t+n}}. 244 245 """ 246 with open(data_path, 'rb') as f: 247 _dict = pickle.load(f) # class is list; length is cfg.reward_model.collect_count 248 post_process_data = [] 249 for i in range(len(_dict)): 250 data = get_nstep_return_data(_dict[i], nstep) 251 post_process_data.extend(data) 252 offline_data_save_type( 253 post_process_data, 254 expert_data_path, 255 ) 256 257 258def episode_to_transitions_filter(data_path: str, expert_data_path: str, nstep: int, min_episode_return: int) -> None: 259 r""" 260 Overview: 261 Transfer episodic data into n-step transitions and only take the episode data whose return is larger than 262 min_episode_return. 263 Arguments: 264 - data_path (:obj:str): data path that stores the pkl file 265 - expert_data_path (:obj:`str`): File path of the expert demo data will be written to. 266 - nstep (:obj:`int`): {s_{t}, a_{t}, s_{t+n}}. 267 268 """ 269 with open(data_path, 'rb') as f: 270 _dict = pickle.load(f) # class is list; length is cfg.reward_model.collect_count 271 post_process_data = [] 272 for i in range(len(_dict)): 273 episode_returns = torch.stack([_dict[i][j]['reward'] for j in range(_dict[i].__len__())], axis=0) 274 if episode_returns.sum() < min_episode_return: 275 continue 276 data = get_nstep_return_data(_dict[i], nstep) 277 post_process_data.extend(data) 278 offline_data_save_type( 279 post_process_data, 280 expert_data_path, 281 )