Skip to content

ding.example.qgpo

ding.example.qgpo

QGPOD4RLDataset

Bases: Dataset

Overview

Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, which needs true action and fake action. The true action is sampled from the dataset, and the fake action is sampled from the action support generated by the behavior policy.

Interface: __init__, __getitem__, __len__.

__init__(cfg, device='cpu')

Overview

Initialization method of QGPOD4RLDataset class

Arguments: - cfg (:obj:EasyDict): Config dict - device (:obj:str): Device name

__getitem__(index)

Overview

Get data by index

Arguments: - index (:obj:int): Index of data Returns: - data (:obj:dict): Data dict

.. note:: The data dict contains the following keys: - s (:obj:torch.Tensor): State - a (:obj:torch.Tensor): Action - r (:obj:torch.Tensor): Reward - s_ (:obj:torch.Tensor): Next state - d (:obj:torch.Tensor): Is finished - fake_a (:obj:torch.Tensor): Fake action for contrastive energy prediction and qgpo training (fake action is sampled from the action support generated by the behavior policy) - fake_a_ (:obj:torch.Tensor): Fake next action for contrastive energy prediction and qgpo training (fake action is sampled from the action support generated by the behavior policy)

Full Source Code

../ding/example/qgpo.py

1import torch 2import gym 3import d4rl 4from easydict import EasyDict 5from ditk import logging 6from ding.model import QGPO 7from ding.policy import QGPOPolicy 8from ding.envs import DingEnvWrapper, BaseEnvManagerV2 9from ding.config import compile_config 10from ding.framework import task, ding_init 11from ding.framework.context import OfflineRLContext 12from ding.framework.middleware import trainer, CkptSaver, offline_logger, wandb_offline_logger, termination_checker 13from ding.framework.middleware.functional.evaluator import interaction_evaluator 14from ding.framework.middleware.functional.data_processor import qgpo_support_data_generator, qgpo_offline_data_fetcher 15from ding.utils import set_pkg_seed 16 17from dizoo.d4rl.config.halfcheetah_medium_expert_qgpo_config import main_config, create_config 18 19 20class QGPOD4RLDataset(torch.utils.data.Dataset): 21 """ 22 Overview: 23 Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \ 24 which needs true action and fake action. The true action is sampled from the dataset, and the fake action \ 25 is sampled from the action support generated by the behavior policy. 26 Interface: 27 ``__init__``, ``__getitem__``, ``__len__``. 28 """ 29 30 def __init__(self, cfg, device="cpu"): 31 """ 32 Overview: 33 Initialization method of QGPOD4RLDataset class 34 Arguments: 35 - cfg (:obj:`EasyDict`): Config dict 36 - device (:obj:`str`): Device name 37 """ 38 39 self.cfg = cfg 40 data = d4rl.qlearning_dataset(gym.make(cfg.env_id)) 41 self.device = device 42 self.states = torch.from_numpy(data['observations']).float().to(self.device) 43 self.actions = torch.from_numpy(data['actions']).float().to(self.device) 44 self.next_states = torch.from_numpy(data['next_observations']).float().to(self.device) 45 reward = torch.from_numpy(data['rewards']).view(-1, 1).float().to(self.device) 46 self.is_finished = torch.from_numpy(data['terminals']).view(-1, 1).float().to(self.device) 47 48 reward_tune = "iql_antmaze" if "antmaze" in cfg.env_id else "iql_locomotion" 49 if reward_tune == 'normalize': 50 reward = (reward - reward.mean()) / reward.std() 51 elif reward_tune == 'iql_antmaze': 52 reward = reward - 1.0 53 elif reward_tune == 'iql_locomotion': 54 min_ret, max_ret = QGPOD4RLDataset.return_range(data, 1000) 55 reward /= (max_ret - min_ret) 56 reward *= 1000 57 elif reward_tune == 'cql_antmaze': 58 reward = (reward - 0.5) * 4.0 59 elif reward_tune == 'antmaze': 60 reward = (reward - 0.25) * 2.0 61 self.rewards = reward 62 self.len = self.states.shape[0] 63 logging.info(f"{self.len} data loaded in QGPOD4RLDataset") 64 65 def __getitem__(self, index): 66 """ 67 Overview: 68 Get data by index 69 Arguments: 70 - index (:obj:`int`): Index of data 71 Returns: 72 - data (:obj:`dict`): Data dict 73 74 .. note:: 75 The data dict contains the following keys: 76 - s (:obj:`torch.Tensor`): State 77 - a (:obj:`torch.Tensor`): Action 78 - r (:obj:`torch.Tensor`): Reward 79 - s_ (:obj:`torch.Tensor`): Next state 80 - d (:obj:`torch.Tensor`): Is finished 81 - fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \ 82 (fake action is sampled from the action support generated by the behavior policy) 83 - fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \ 84 (fake action is sampled from the action support generated by the behavior policy) 85 """ 86 87 data = { 88 's': self.states[index % self.len], 89 'a': self.actions[index % self.len], 90 'r': self.rewards[index % self.len], 91 's_': self.next_states[index % self.len], 92 'd': self.is_finished[index % self.len], 93 'fake_a': self.fake_actions[index % self.len] 94 if hasattr(self, "fake_actions") else 0.0, # self.fake_actions <D, 16, A> 95 'fake_a_': self.fake_next_actions[index % self.len] 96 if hasattr(self, "fake_next_actions") else 0.0, # self.fake_next_actions <D, 16, A> 97 } 98 return data 99 100 def __len__(self): 101 return self.len 102 103 def return_range(dataset, max_episode_steps): 104 returns, lengths = [], [] 105 ep_ret, ep_len = 0., 0 106 for r, d in zip(dataset['rewards'], dataset['terminals']): 107 ep_ret += float(r) 108 ep_len += 1 109 if d or ep_len == max_episode_steps: 110 returns.append(ep_ret) 111 lengths.append(ep_len) 112 ep_ret, ep_len = 0., 0 113 # returns.append(ep_ret) # incomplete trajectory 114 lengths.append(ep_len) # but still keep track of number of steps 115 assert sum(lengths) == len(dataset['rewards']) 116 return min(returns), max(returns) 117 118 119def main(): 120 # If you don't have offline data, you need to prepare if first and set the data_path in config 121 # For demostration, we also can train a RL policy (e.g. SAC) and collect some data 122 logging.getLogger().setLevel(logging.INFO) 123 cfg = compile_config(main_config, policy=QGPOPolicy) 124 ding_init(cfg) 125 with task.start(async_mode=False, ctx=OfflineRLContext()): 126 set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) 127 model = QGPO(cfg=cfg.policy.model) 128 policy = QGPOPolicy(cfg.policy, model=model) 129 dataset = QGPOD4RLDataset(cfg=cfg.dataset, device=policy._device) 130 if hasattr(cfg.policy, "load_path") and cfg.policy.load_path is not None: 131 policy_state_dict = torch.load(cfg.policy.load_path, map_location=torch.device("cpu")) 132 policy.learn_mode.load_state_dict(policy_state_dict) 133 134 task.use(qgpo_support_data_generator(cfg, dataset, policy)) 135 task.use(qgpo_offline_data_fetcher(cfg, dataset, collate_fn=None)) 136 task.use(trainer(cfg, policy.learn_mode)) 137 for guidance_scale in cfg.policy.eval.guidance_scale: 138 evaluator_env = BaseEnvManagerV2( 139 env_fn=[ 140 lambda: DingEnvWrapper(env=gym.make(cfg.env.env_id), cfg=cfg.env, caller="evaluator") 141 for _ in range(cfg.env.evaluator_env_num) 142 ], 143 cfg=cfg.env.manager 144 ) 145 task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env, guidance_scale=guidance_scale)) 146 task.use( 147 wandb_offline_logger( 148 cfg=EasyDict( 149 dict( 150 gradient_logger=False, 151 plot_logger=True, 152 video_logger=False, 153 action_logger=False, 154 return_logger=False, 155 vis_dataset=False, 156 ) 157 ), 158 exp_config=cfg, 159 metric_list=policy._monitor_vars_learn(), 160 project_name=cfg.exp_name 161 ) 162 ) 163 task.use(CkptSaver(policy, cfg.exp_name, train_freq=100000)) 164 task.use(offline_logger()) 165 task.use(termination_checker(max_train_iter=500000 + cfg.policy.learn.q_value_stop_training_iter)) 166 task.run() 167 168 169if __name__ == "__main__": 170 main()