Skip to content

ding.entry.application_entry_trex_collect_data

ding.entry.application_entry_trex_collect_data

collect_episodic_demo_data_for_trex(input_cfg, seed, collect_count, rank, env_setting=None, model=None, state_dict=None, state_dict_path=None)

Overview

Collect episodic demonstration data by the trained policy for trex specifically.

Arguments: - input_cfg (:obj:Union[str, Tuple[dict, dict]]): Config in dict type. str type means config file path. Tuple[dict, dict] type means [user_config, create_cfg]. - seed (:obj:int): Random seed. - collect_count (:obj:int): The count of collected data. - rank (:obj:int): The episode ranking. - env_setting (:obj:Optional[List[Any]]): A list with 3 elements: BaseEnv subclass, collector env config, and evaluator env config. - model (:obj:Optional[torch.nn.Module]): Instance of torch.nn.Module. - state_dict (:obj:Optional[dict]): The state_dict of policy or model. - state_dict_path (:obj:'str') The abs path of the state dict.

Full Source Code

../ding/entry/application_entry_trex_collect_data.py

1import argparse 2import torch 3import os 4from typing import Union, Optional, List, Any 5from functools import partial 6from copy import deepcopy 7 8from ding.config import compile_config, read_config 9from ding.worker import EpisodeSerialCollector 10from ding.envs import create_env_manager, get_vec_env_setting 11from ding.policy import create_policy 12from ding.torch_utils import to_device 13from ding.utils import set_pkg_seed 14from ding.utils.data import offline_data_save_type 15from ding.utils.data import default_collate 16 17 18def collect_episodic_demo_data_for_trex( 19 input_cfg: Union[str, dict], 20 seed: int, 21 collect_count: int, 22 rank: int, 23 env_setting: Optional[List[Any]] = None, 24 model: Optional[torch.nn.Module] = None, 25 state_dict: Optional[dict] = None, 26 state_dict_path: Optional[str] = None, 27): 28 """ 29 Overview: 30 Collect episodic demonstration data by the trained policy for trex specifically. 31 Arguments: 32 - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ 33 ``str`` type means config file path. \ 34 ``Tuple[dict, dict]`` type means [user_config, create_cfg]. 35 - seed (:obj:`int`): Random seed. 36 - collect_count (:obj:`int`): The count of collected data. 37 - rank (:obj:`int`): The episode ranking. 38 - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ 39 ``BaseEnv`` subclass, collector env config, and evaluator env config. 40 - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. 41 - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. 42 - state_dict_path (:obj:'str') The abs path of the state dict. 43 """ 44 if isinstance(input_cfg, str): 45 cfg, create_cfg = read_config(input_cfg) 46 else: 47 cfg, create_cfg = deepcopy(input_cfg) 48 create_cfg.policy.type += '_command' 49 env_fn = None if env_setting is None else env_setting[0] 50 cfg.env.collector_env_num = 1 51 cfg = compile_config( 52 cfg, 53 collector=EpisodeSerialCollector, 54 seed=seed, 55 env=env_fn, 56 auto=True, 57 create_cfg=create_cfg, 58 save_cfg=True, 59 save_path='collect_demo_data_config.py' 60 ) 61 62 # Create components: env, policy, collector 63 if env_setting is None: 64 env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env) 65 else: 66 env_fn, collector_env_cfg, _ = env_setting 67 collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) 68 collector_env.seed(seed) 69 set_pkg_seed(seed, use_cuda=cfg.policy.cuda) 70 policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval']) 71 collect_demo_policy = policy.collect_mode 72 if state_dict is None: 73 assert state_dict_path is not None 74 state_dict = torch.load(state_dict_path, map_location='cpu') 75 policy.collect_mode.load_state_dict(state_dict) 76 collector = EpisodeSerialCollector( 77 cfg.policy.collect.collector, collector_env, collect_demo_policy, exp_name=cfg.exp_name 78 ) 79 80 policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \ 81 else {'eps': cfg.policy.other.eps.get('collect', 0.2)} 82 83 # Let's collect some sub-optimal demostrations 84 exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs) 85 if cfg.policy.cuda: 86 exp_data = to_device(exp_data, 'cpu') 87 # Save data transitions. 88 print('Collect {}th episodic demo data successfully'.format(rank)) 89 return exp_data 90 91 92def trex_get_args(): 93 parser = argparse.ArgumentParser() 94 parser.add_argument('--cfg', type=str, default='abs path for a config') 95 parser.add_argument('--seed', type=int, default=0) 96 parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 97 args = parser.parse_known_args()[0] 98 return args 99 100 101def trex_collecting_data(args=None): 102 if args is None: 103 args = trex_get_args() # TODO(nyz) use sub-command in cli 104 if isinstance(args.cfg, str): 105 cfg, create_cfg = read_config(args.cfg) 106 else: 107 cfg, create_cfg = deepcopy(args.cfg) 108 data_path = cfg.exp_name 109 expert_model_path = cfg.reward_model.expert_model_path # directory path 110 checkpoint_min = cfg.reward_model.checkpoint_min 111 checkpoint_max = cfg.reward_model.checkpoint_max 112 checkpoint_step = cfg.reward_model.checkpoint_step 113 checkpoints = [] 114 for i in range(checkpoint_min, checkpoint_max + checkpoint_step, checkpoint_step): 115 checkpoints.append(str(i)) 116 data_for_save = {} 117 learning_returns = [] 118 learning_rewards = [] 119 episodes_data = [] 120 for checkpoint in checkpoints: 121 num_per_ckpt = 1 122 model_path = expert_model_path + \ 123 '/ckpt/iteration_' + checkpoint + '.pth.tar' 124 seed = args.seed + (int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) 125 exp_data = collect_episodic_demo_data_for_trex( 126 deepcopy(args.cfg), 127 seed, 128 state_dict_path=model_path, 129 collect_count=num_per_ckpt, 130 rank=(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) + 1 131 ) 132 data_for_save[(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)] = exp_data 133 obs = [list(default_collate(exp_data[i])['obs'].numpy()) for i in range(len(exp_data))] 134 rewards = [default_collate(exp_data[i])['reward'].tolist() for i in range(len(exp_data))] 135 sum_rewards = [torch.sum(default_collate(exp_data[i])['reward']).item() for i in range(len(exp_data))] 136 137 learning_rewards.append(rewards) 138 learning_returns.append(sum_rewards) 139 episodes_data.append(obs) 140 offline_data_save_type( 141 data_for_save, data_path + '/suboptimal_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') 142 ) 143 # if not compiled_cfg.reward_model.auto: more feature 144 offline_data_save_type( 145 episodes_data, data_path + '/episodes_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') 146 ) 147 offline_data_save_type( 148 learning_returns, data_path + '/learning_returns.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') 149 ) 150 offline_data_save_type( 151 learning_rewards, data_path + '/learning_rewards.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') 152 ) 153 offline_data_save_type( 154 checkpoints, data_path + '/checkpoints.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') 155 ) 156 return checkpoints, episodes_data, learning_returns, learning_rewards 157 158 159if __name__ == '__main__': 160 trex_collecting_data()