Skip to content

ding.utils.data.dataset

ding.utils.data.dataset

DatasetStatistics dataclass

Overview

Dataset statistics.

NaiveRLDataset

Bases: Dataset

Overview

Naive RL dataset, which is used for offline RL algorithms.

Interfaces: __init__, __len__, __getitem__

__init__(cfg)

Overview

Initialization method.

Arguments: - cfg (:obj:dict): Config dict.

__len__()

Overview

Get the length of the dataset.

__getitem__(idx)

Overview

Get the item of the dataset.

D4RLDataset

Bases: Dataset

Overview

D4RL dataset, which is used for offline RL algorithms.

Interfaces: __init__, __len__, __getitem__ Properties: - mean (:obj:np.ndarray): Mean of the dataset. - std (:obj:np.ndarray): Std of the dataset. - action_bounds (:obj:np.ndarray): Action bounds of the dataset. - statistics (:obj:dict): Statistics of the dataset.

mean property

Overview

Get the mean of the dataset.

std property

Overview

Get the std of the dataset.

action_bounds property

Overview

Get the action bounds of the dataset.

statistics property

Overview

Get the statistics of the dataset.

__init__(cfg)

Overview

Initialization method.

Arguments: - cfg (:obj:dict): Config dict.

__len__()

Overview

Get the length of the dataset.

__getitem__(idx)

Overview

Get the item of the dataset.

HDF5Dataset

Bases: Dataset

Overview

HDF5 dataset is saved in hdf5 format, which is used for offline RL algorithms. The hdf5 format is a common format for storing large numerical arrays in Python. For more details, please refer to https://support.hdfgroup.org/HDF5/.

Interfaces: __init__, __len__, __getitem__ Properties: - mean (:obj:np.ndarray): Mean of the dataset. - std (:obj:np.ndarray): Std of the dataset. - action_bounds (:obj:np.ndarray): Action bounds of the dataset. - statistics (:obj:dict): Statistics of the dataset.

mean property

Overview

Get the mean of the dataset.

std property

Overview

Get the std of the dataset.

action_bounds property

Overview

Get the action bounds of the dataset.

statistics property

Overview

Get the statistics of the dataset.

__init__(cfg)

Overview

Initialization method.

Arguments: - cfg (:obj:dict): Config dict.

__len__()

Overview

Get the length of the dataset.

__getitem__(idx)

Overview

Get the item of the dataset.

Arguments: - idx (:obj:int): The index of the dataset.

D4RLTrajectoryDataset

Bases: Dataset

Overview

D4RL trajectory dataset, which is used for offline RL algorithms.

Interfaces: __init__, __len__, __getitem__

__init__(cfg)

Overview

Initialization method.

Arguments: - cfg (:obj:dict): Config dict.

get_max_timestep()

Overview

Get the max timestep of the dataset.

get_state_stats()

Overview

Get the state mean and std of the dataset.

get_d4rl_dataset_stats(env_d4rl_name)

Overview

Get the d4rl dataset stats.

Arguments: - env_d4rl_name (:obj:str): The d4rl env name.

__len__()

Overview

Get the length of the dataset.

__getitem__(idx)

Overview

Get the item of the dataset.

Arguments: - idx (:obj:int): The index of the dataset.

D4RLDiffuserDataset

Bases: Dataset

Overview

D4RL diffuser dataset, which is used for offline RL algorithms.

Interfaces: __init__, __len__, __getitem__

__init__(dataset_path, context_len, rtg_scale)

Overview

Initialization method of D4RLDiffuserDataset.

Arguments: - dataset_path (:obj:str): The dataset path. - context_len (:obj:int): The length of the context. - rtg_scale (:obj:float): The scale of the returns to go.

FixedReplayBuffer

Bases: object

Overview

Object composed of a list of OutofGraphReplayBuffers.

Interfaces: __init__, get_transition_elements, sample_transition_batch

__init__(data_dir, replay_suffix, *args, **kwargs)

Overview

Initialize the FixedReplayBuffer class.

Arguments: - data_dir (:obj:str): Log directory from which to load the replay buffer. - replay_suffix (:obj:int): If not None, then only load the replay buffer corresponding to the specific suffix in data directory. - args (:obj:list): Arbitrary extra arguments. - kwargs (:obj:dict): Arbitrary keyword arguments.

load_single_buffer(suffix)

Overview

Load a single replay buffer.

Arguments: - suffix (:obj:int): The suffix of the replay buffer.

get_transition_elements()

Overview

Returns the transition elements.

sample_transition_batch(batch_size=None, indices=None)

Overview

Returns a batch of transitions (including any extra contents).

Arguments: - batch_size (:obj:int): The batch size. - indices (:obj:list): The indices of the batch.

PCDataset

Bases: Dataset

Overview

Dataset for Procedure Cloning.

Interfaces: __init__, __len__, __getitem__

__init__(all_data)

Overview

Initialization method of PCDataset.

Arguments: - all_data (:obj:tuple): The tuple of all data.

__getitem__(item)

Overview

Get the item of the dataset.

Arguments: - item (:obj:int): The index of the dataset.

__len__()

Overview

Get the length of the dataset.

BCODataset

Bases: Dataset

Overview

Dataset for Behavioral Cloning from Observation.

Interfaces: __init__, __len__, __getitem__ Properties: - obs (:obj:np.ndarray): The observation array. - action (:obj:np.ndarray): The action array.

obs property

Overview

Get the observation array.

action property

Overview

Get the action array.

__init__(data=None)

Overview

Initialization method of BCODataset.

Arguments: - data (:obj:dict): The data dict.

__len__()

Overview

Get the length of the dataset.

__getitem__(idx)

Overview

Get the item of the dataset.

Arguments: - idx (:obj:int): The index of the dataset.

SequenceDataset

Bases: Dataset

Overview

Dataset for diffuser.

Interfaces: __init__, __len__, __getitem__

__init__(cfg)

Overview

Initialization method of SequenceDataset.

Arguments: - cfg (:obj:dict): The config dict.

sequence_dataset(env, dataset=None)

Overview

Sequence the dataset.

Arguments: - env (:obj:gym.Env): The gym env.

maze2d_set_terminals(env, dataset)

Overview

Set the terminals for maze2d.

Arguments: - env (:obj:gym.Env): The gym env. - dataset (:obj:dict): The dataset dict.

process_maze2d_episode(episode)

Overview

Process the maze2d episode, adds in next_observations field to episode.

Arguments: - episode (:obj:dict): The episode dict.

normalize(keys=['observations', 'actions'])

Overview

Normalize the dataset, normalize fields that will be predicted by the diffusion model

Arguments: - keys (:obj:list): The list of keys.

make_indices(path_lengths, horizon)

Overview

Make indices for sampling from dataset. Each index maps to a datapoint.

Arguments: - path_lengths (:obj:np.ndarray): The path length array. - horizon (:obj:int): The horizon.

get_conditions(observations)

Overview

Get the conditions on current observation for planning.

Arguments: - observations (:obj:np.ndarray): The observation array.

__len__()

Overview

Get the length of the dataset.

normalize_value(value)

Overview

Normalize the value.

Arguments: - value (:obj:np.ndarray): The value array.

__getitem__(idx, eps=0.0001)

Overview

Get the item of the dataset.

Arguments: - idx (:obj:int): The index of the dataset. - eps (:obj:float): The epsilon.

load_bfs_datasets(train_seeds=1, test_seeds=5)

Overview

Load BFS datasets.

Arguments: - train_seeds (:obj:int): The number of train seeds. - test_seeds (:obj:int): The number of test seeds.

hdf5_save(exp_data, expert_data_path)

Overview

Save the data to hdf5.

naive_save(exp_data, expert_data_path)

Overview

Save the data to pickle.

offline_data_save_type(exp_data, expert_data_path, data_type='naive')

Overview

Save the offline data.

create_dataset(cfg, **kwargs)

Overview

Create dataset.

Full Source Code

../ding/utils/data/dataset.py

1from typing import List, Dict, Tuple 2from ditk import logging 3from copy import deepcopy 4from easydict import EasyDict 5from torch.utils.data import Dataset 6from dataclasses import dataclass 7 8import pickle 9import easydict 10import torch 11import numpy as np 12 13from ding.utils.bfs_helper import get_vi_sequence 14from ding.utils import DATASET_REGISTRY, import_module, DatasetNormalizer 15from ding.rl_utils import discount_cumsum 16 17 18@dataclass 19class DatasetStatistics: 20 """ 21 Overview: 22 Dataset statistics. 23 """ 24 mean: np.ndarray # obs 25 std: np.ndarray # obs 26 action_bounds: np.ndarray 27 28 29@DATASET_REGISTRY.register('naive') 30class NaiveRLDataset(Dataset): 31 """ 32 Overview: 33 Naive RL dataset, which is used for offline RL algorithms. 34 Interfaces: 35 ``__init__``, ``__len__``, ``__getitem__`` 36 """ 37 38 def __init__(self, cfg) -> None: 39 """ 40 Overview: 41 Initialization method. 42 Arguments: 43 - cfg (:obj:`dict`): Config dict. 44 """ 45 46 assert type(cfg) in [str, EasyDict], "invalid cfg type: {}".format(type(cfg)) 47 if isinstance(cfg, EasyDict): 48 self._data_path = cfg.policy.collect.data_path 49 elif isinstance(cfg, str): 50 self._data_path = cfg 51 with open(self._data_path, 'rb') as f: 52 self._data: List[Dict[str, torch.Tensor]] = pickle.load(f) 53 54 def __len__(self) -> int: 55 """ 56 Overview: 57 Get the length of the dataset. 58 """ 59 60 return len(self._data) 61 62 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 63 """ 64 Overview: 65 Get the item of the dataset. 66 """ 67 68 return self._data[idx] 69 70 71@DATASET_REGISTRY.register('d4rl') 72class D4RLDataset(Dataset): 73 """ 74 Overview: 75 D4RL dataset, which is used for offline RL algorithms. 76 Interfaces: 77 ``__init__``, ``__len__``, ``__getitem__`` 78 Properties: 79 - mean (:obj:`np.ndarray`): Mean of the dataset. 80 - std (:obj:`np.ndarray`): Std of the dataset. 81 - action_bounds (:obj:`np.ndarray`): Action bounds of the dataset. 82 - statistics (:obj:`dict`): Statistics of the dataset. 83 """ 84 85 def __init__(self, cfg: dict) -> None: 86 """ 87 Overview: 88 Initialization method. 89 Arguments: 90 - cfg (:obj:`dict`): Config dict. 91 """ 92 93 import gym 94 try: 95 import d4rl # register d4rl enviroments with open ai gym 96 except ImportError: 97 import sys 98 logging.warning("not found d4rl env, please install it, refer to https://github.com/rail-berkeley/d4rl") 99 sys.exit(1) 100 101 # Init parameters 102 data_path = cfg.policy.collect.get('data_path', None) 103 env_id = cfg.env.env_id 104 105 # Create the environment 106 if data_path: 107 d4rl.set_dataset_path(data_path) 108 env = gym.make(env_id) 109 dataset = d4rl.qlearning_dataset(env) 110 self._cal_statistics(dataset, env) 111 try: 112 if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats: 113 dataset = self._normalize_states(dataset) 114 except (KeyError, AttributeError): 115 # do not normalize 116 pass 117 if hasattr(cfg.env, "reward_norm"): 118 if cfg.env.reward_norm == "normalize": 119 dataset['rewards'] = (dataset['rewards'] - dataset['rewards'].mean()) / dataset['rewards'].std() 120 elif cfg.env.reward_norm == "iql_antmaze": 121 dataset['rewards'] = dataset['rewards'] - 1.0 122 elif cfg.env.reward_norm == "iql_locomotion": 123 124 def return_range(dataset, max_episode_steps): 125 returns, lengths = [], [] 126 ep_ret, ep_len = 0.0, 0 127 for r, d in zip(dataset["rewards"], dataset["terminals"]): 128 ep_ret += float(r) 129 ep_len += 1 130 if d or ep_len == max_episode_steps: 131 returns.append(ep_ret) 132 lengths.append(ep_len) 133 ep_ret, ep_len = 0.0, 0 134 # returns.append(ep_ret) # incomplete trajectory 135 lengths.append(ep_len) # but still keep track of number of steps 136 assert sum(lengths) == len(dataset["rewards"]) 137 return min(returns), max(returns) 138 139 min_ret, max_ret = return_range(dataset, 1000) 140 dataset['rewards'] /= max_ret - min_ret 141 dataset['rewards'] *= 1000 142 elif cfg.env.reward_norm == "cql_antmaze": 143 dataset['rewards'] = (dataset['rewards'] - 0.5) * 4.0 144 elif cfg.env.reward_norm == "antmaze": 145 dataset['rewards'] = (dataset['rewards'] - 0.25) * 2.0 146 else: 147 raise NotImplementedError 148 149 self._data = [] 150 self._load_d4rl(dataset) 151 152 @property 153 def data(self) -> List: 154 return self._data 155 156 def __len__(self) -> int: 157 """ 158 Overview: 159 Get the length of the dataset. 160 """ 161 162 return len(self._data) 163 164 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 165 """ 166 Overview: 167 Get the item of the dataset. 168 """ 169 170 return self._data[idx] 171 172 def _load_d4rl(self, dataset: Dict[str, np.ndarray]) -> None: 173 """ 174 Overview: 175 Load the d4rl dataset. 176 Arguments: 177 - dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. 178 """ 179 180 for i in range(len(dataset['observations'])): 181 trans_data = {} 182 trans_data['obs'] = torch.from_numpy(dataset['observations'][i]) 183 trans_data['next_obs'] = torch.from_numpy(dataset['next_observations'][i]) 184 trans_data['action'] = torch.from_numpy(dataset['actions'][i]) 185 trans_data['reward'] = torch.tensor(dataset['rewards'][i]) 186 trans_data['done'] = dataset['terminals'][i] 187 self._data.append(trans_data) 188 189 def _cal_statistics(self, dataset, env, eps=1e-3, add_action_buffer=True): 190 """ 191 Overview: 192 Calculate the statistics of the dataset. 193 Arguments: 194 - dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. 195 - env (:obj:`gym.Env`): The environment. 196 - eps (:obj:`float`): Epsilon. 197 """ 198 199 self._mean = dataset['observations'].mean(0) 200 self._std = dataset['observations'].std(0) + eps 201 action_max = dataset['actions'].max(0) 202 action_min = dataset['actions'].min(0) 203 if add_action_buffer: 204 action_buffer = 0.05 * (action_max - action_min) 205 action_max = (action_max + action_buffer).clip(max=env.action_space.high) 206 action_min = (action_min - action_buffer).clip(min=env.action_space.low) 207 self._action_bounds = np.stack([action_min, action_max], axis=0) 208 209 def _normalize_states(self, dataset): 210 """ 211 Overview: 212 Normalize the states. 213 Arguments: 214 - dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. 215 """ 216 217 dataset['observations'] = (dataset['observations'] - self._mean) / self._std 218 dataset['next_observations'] = (dataset['next_observations'] - self._mean) / self._std 219 return dataset 220 221 @property 222 def mean(self): 223 """ 224 Overview: 225 Get the mean of the dataset. 226 """ 227 228 return self._mean 229 230 @property 231 def std(self): 232 """ 233 Overview: 234 Get the std of the dataset. 235 """ 236 237 return self._std 238 239 @property 240 def action_bounds(self) -> np.ndarray: 241 """ 242 Overview: 243 Get the action bounds of the dataset. 244 """ 245 246 return self._action_bounds 247 248 @property 249 def statistics(self) -> dict: 250 """ 251 Overview: 252 Get the statistics of the dataset. 253 """ 254 255 return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds) 256 257 258@DATASET_REGISTRY.register('hdf5') 259class HDF5Dataset(Dataset): 260 """ 261 Overview: 262 HDF5 dataset is saved in hdf5 format, which is used for offline RL algorithms. 263 The hdf5 format is a common format for storing large numerical arrays in Python. 264 For more details, please refer to https://support.hdfgroup.org/HDF5/. 265 Interfaces: 266 ``__init__``, ``__len__``, ``__getitem__`` 267 Properties: 268 - mean (:obj:`np.ndarray`): Mean of the dataset. 269 - std (:obj:`np.ndarray`): Std of the dataset. 270 - action_bounds (:obj:`np.ndarray`): Action bounds of the dataset. 271 - statistics (:obj:`dict`): Statistics of the dataset. 272 """ 273 274 def __init__(self, cfg: dict) -> None: 275 """ 276 Overview: 277 Initialization method. 278 Arguments: 279 - cfg (:obj:`dict`): Config dict. 280 """ 281 282 try: 283 import h5py 284 except ImportError: 285 import sys 286 logging.warning("not found h5py package, please install it trough `pip install h5py ") 287 sys.exit(1) 288 data_path = cfg.policy.collect.get('data_path', None) 289 if 'dataset' in cfg: 290 self.context_len = cfg.dataset.context_len 291 else: 292 self.context_len = 0 293 data = h5py.File(data_path, 'r') 294 self._load_data(data) 295 self._cal_statistics() 296 try: 297 if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats: 298 self._normalize_states() 299 except (KeyError, AttributeError): 300 # do not normalize 301 pass 302 303 def __len__(self) -> int: 304 """ 305 Overview: 306 Get the length of the dataset. 307 """ 308 309 return len(self._data['obs']) - self.context_len 310 311 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 312 """ 313 Overview: 314 Get the item of the dataset. 315 Arguments: 316 - idx (:obj:`int`): The index of the dataset. 317 """ 318 319 if self.context_len == 0: # for other offline RL algorithms 320 return {k: self._data[k][idx] for k in self._data.keys()} 321 else: # for decision transformer 322 block_size = self.context_len 323 done_idx = idx + block_size 324 idx = done_idx - block_size 325 states = torch.as_tensor( 326 np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32 327 ).view(block_size, -1) 328 actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long) 329 rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32) 330 timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64) 331 traj_mask = torch.ones(self.context_len, dtype=torch.long) 332 return timesteps, states, actions, rtgs, traj_mask 333 334 def _load_data(self, dataset: Dict[str, np.ndarray]) -> None: 335 """ 336 Overview: 337 Load the dataset. 338 Arguments: 339 - dataset (:obj:`Dict[str, np.ndarray]`): The dataset. 340 """ 341 342 self._data = {} 343 for k in dataset.keys(): 344 logging.info(f'Load {k} data.') 345 self._data[k] = dataset[k][:] 346 347 def _cal_statistics(self, eps: float = 1e-3): 348 """ 349 Overview: 350 Calculate the statistics of the dataset. 351 Arguments: 352 - eps (:obj:`float`): Epsilon. 353 """ 354 355 self._mean = self._data['obs'].mean(0) 356 self._std = self._data['obs'].std(0) + eps 357 action_max = self._data['action'].max(0) 358 action_min = self._data['action'].min(0) 359 buffer = 0.05 * (action_max - action_min) 360 action_max = action_max.astype(float) + buffer 361 action_min = action_max.astype(float) - buffer 362 self._action_bounds = np.stack([action_min, action_max], axis=0) 363 364 def _normalize_states(self): 365 """ 366 Overview: 367 Normalize the states. 368 """ 369 370 self._data['obs'] = (self._data['obs'] - self._mean) / self._std 371 self._data['next_obs'] = (self._data['next_obs'] - self._mean) / self._std 372 373 @property 374 def mean(self): 375 """ 376 Overview: 377 Get the mean of the dataset. 378 """ 379 380 return self._mean 381 382 @property 383 def std(self): 384 """ 385 Overview: 386 Get the std of the dataset. 387 """ 388 389 return self._std 390 391 @property 392 def action_bounds(self) -> np.ndarray: 393 """ 394 Overview: 395 Get the action bounds of the dataset. 396 """ 397 398 return self._action_bounds 399 400 @property 401 def statistics(self) -> dict: 402 """ 403 Overview: 404 Get the statistics of the dataset. 405 """ 406 407 return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds) 408 409 410@DATASET_REGISTRY.register('d4rl_trajectory') 411class D4RLTrajectoryDataset(Dataset): 412 """ 413 Overview: 414 D4RL trajectory dataset, which is used for offline RL algorithms. 415 Interfaces: 416 ``__init__``, ``__len__``, ``__getitem__`` 417 """ 418 419 # from infos.py from official d4rl github repo 420 REF_MIN_SCORE = { 421 'halfcheetah': -280.178953, 422 'walker2d': 1.629008, 423 'hopper': -20.272305, 424 } 425 426 REF_MAX_SCORE = { 427 'halfcheetah': 12135.0, 428 'walker2d': 4592.3, 429 'hopper': 3234.3, 430 } 431 432 # calculated from d4rl datasets 433 D4RL_DATASET_STATS = { 434 'halfcheetah-medium-v2': { 435 'state_mean': [ 436 -0.06845773756504059, 0.016414547339081764, -0.18354906141757965, -0.2762460708618164, 437 -0.34061527252197266, -0.09339715540409088, -0.21321271359920502, -0.0877423882484436, 438 5.173007488250732, -0.04275195300579071, -0.036108363419771194, 0.14053793251514435, 439 0.060498327016830444, 0.09550975263118744, 0.06739100068807602, 0.005627387668937445, 440 0.013382787816226482 441 ], 442 'state_std': [ 443 0.07472999393939972, 0.3023499846458435, 0.30207309126853943, 0.34417077898979187, 0.17619241774082184, 444 0.507205605506897, 0.2567007839679718, 0.3294812738895416, 1.2574149370193481, 0.7600541710853577, 445 1.9800915718078613, 6.565362453460693, 7.466367721557617, 4.472222805023193, 10.566964149475098, 446 5.671932697296143, 7.4982590675354 447 ] 448 }, 449 'halfcheetah-medium-replay-v2': { 450 'state_mean': [ 451 -0.12880703806877136, 0.3738119602203369, -0.14995987713336945, -0.23479078710079193, 452 -0.2841278612613678, -0.13096535205841064, -0.20157982409000397, -0.06517726927995682, 453 3.4768247604370117, -0.02785065770149231, -0.015035249292850494, 0.07697279006242752, 454 0.01266712136566639, 0.027325302362442017, 0.02316424623131752, 0.010438721626996994, 455 -0.015839405357837677 456 ], 457 'state_std': [ 458 0.17019015550613403, 1.284424901008606, 0.33442774415016174, 0.3672759234905243, 0.26092398166656494, 459 0.4784106910228729, 0.3181420564651489, 0.33552637696266174, 2.0931615829467773, 0.8037433624267578, 460 1.9044333696365356, 6.573209762573242, 7.572863578796387, 5.069749355316162, 9.10555362701416, 461 6.085654258728027, 7.25300407409668 462 ] 463 }, 464 'halfcheetah-medium-expert-v2': { 465 'state_mean': [ 466 -0.05667462572455406, 0.024369969964027405, -0.061670560389757156, -0.22351515293121338, 467 -0.2675151228904724, -0.07545716315507889, -0.05809682980179787, -0.027675075456500053, 468 8.110626220703125, -0.06136331334710121, -0.17986927926540375, 0.25175222754478455, 0.24186332523822784, 469 0.2519369423389435, 0.5879552960395813, -0.24090635776519775, -0.030184272676706314 470 ], 471 'state_std': [ 472 0.06103534251451492, 0.36054104566574097, 0.45544400811195374, 0.38476887345314026, 0.2218363732099533, 473 0.5667523741722107, 0.3196682929992676, 0.2852923572063446, 3.443821907043457, 0.6728139519691467, 474 1.8616976737976074, 9.575807571411133, 10.029894828796387, 5.903450012207031, 12.128185272216797, 475 6.4811787605285645, 6.378620147705078 476 ] 477 }, 478 'walker2d-medium-v2': { 479 'state_mean': [ 480 1.218966007232666, 0.14163373410701752, -0.03704913705587387, -0.13814310729503632, 0.5138224363327026, 481 -0.04719110205769539, -0.47288352251052856, 0.042254164814949036, 2.3948874473571777, 482 -0.03143199160695076, 0.04466355964541435, -0.023907244205474854, -0.1013401448726654, 483 0.09090937674045563, -0.004192637279629707, -0.12120571732521057, -0.5497063994407654 484 ], 485 'state_std': [ 486 0.12311358004808426, 0.3241879940032959, 0.11456084251403809, 0.2623065710067749, 0.5640279054641724, 487 0.2271878570318222, 0.3837319612503052, 0.7373676896095276, 1.2387926578521729, 0.798020601272583, 488 1.5664079189300537, 1.8092705011367798, 3.025604248046875, 4.062486171722412, 1.4586567878723145, 489 3.7445690631866455, 5.5851287841796875 490 ] 491 }, 492 'walker2d-medium-replay-v2': { 493 'state_mean': [ 494 1.209364652633667, 0.13264022767543793, -0.14371201395988464, -0.2046516090631485, 0.5577612519264221, 495 -0.03231537342071533, -0.2784661054611206, 0.19130706787109375, 1.4701707363128662, 496 -0.12504704296588898, 0.0564953051507473, -0.09991033375263214, -0.340340256690979, 0.03546293452382088, 497 -0.08934258669614792, -0.2992438077926636, -0.5984178185462952 498 ], 499 'state_std': [ 500 0.11929835379123688, 0.3562574088573456, 0.25852200388908386, 0.42075422406196594, 0.5202291011810303, 501 0.15685082972049713, 0.36770978569984436, 0.7161387801170349, 1.3763766288757324, 0.8632221817970276, 502 2.6364643573760986, 3.0134117603302, 3.720684051513672, 4.867283821105957, 2.6681625843048096, 503 3.845186948776245, 5.4768385887146 504 ] 505 }, 506 'walker2d-medium-expert-v2': { 507 'state_mean': [ 508 1.2294334173202515, 0.16869689524173737, -0.07089081406593323, -0.16197483241558075, 509 0.37101927399635315, -0.012209027074277401, -0.42461398243904114, 0.18986578285694122, 510 3.162475109100342, -0.018092676997184753, 0.03496946766972542, -0.013921679928898811, 511 -0.05937029421329498, -0.19549426436424255, -0.0019200450042262673, -0.062483321875333786, 512 -0.27366524934768677 513 ], 514 'state_std': [ 515 0.09932824969291687, 0.25981399416923523, 0.15062759816646576, 0.24249176681041718, 0.6758718490600586, 516 0.1650741547346115, 0.38140663504600525, 0.6962361335754395, 1.3501490354537964, 0.7641991376876831, 517 1.534574270248413, 2.1785972118377686, 3.276582717895508, 4.766193866729736, 1.1716983318328857, 518 4.039782524108887, 5.891613960266113 519 ] 520 }, 521 'hopper-medium-v2': { 522 'state_mean': [ 523 1.311279058456421, -0.08469521254301071, -0.5382719039916992, -0.07201576232910156, 0.04932365566492081, 524 2.1066856384277344, -0.15017354488372803, 0.008783451281487942, -0.2848185896873474, 525 -0.18540096282958984, -0.28461286425590515 526 ], 527 'state_std': [ 528 0.17790751159191132, 0.05444620922207832, 0.21297138929367065, 0.14530418813228607, 0.6124444007873535, 529 0.8517446517944336, 1.4515252113342285, 0.6751695871353149, 1.5362390279769897, 1.616074562072754, 530 5.607253551483154 531 ] 532 }, 533 'hopper-medium-replay-v2': { 534 'state_mean': [ 535 1.2305138111114502, -0.04371410980820656, -0.44542956352233887, -0.09370097517967224, 536 0.09094487875699997, 1.3694725036621094, -0.19992674887180328, -0.022861352190375328, 537 -0.5287045240402222, -0.14465883374214172, -0.19652697443962097 538 ], 539 'state_std': [ 540 0.1756512075662613, 0.0636928603053093, 0.3438323438167572, 0.19566889107227325, 0.5547984838485718, 541 1.051029920578003, 1.158307671546936, 0.7963128685951233, 1.4802359342575073, 1.6540331840515137, 542 5.108601093292236 543 ] 544 }, 545 'hopper-medium-expert-v2': { 546 'state_mean': [ 547 1.3293815851211548, -0.09836531430482864, -0.5444297790527344, -0.10201650857925415, 548 0.02277466468513012, 2.3577215671539307, -0.06349576264619827, -0.00374026270583272, 549 -0.1766270101070404, -0.11862941086292267, -0.12097819894552231 550 ], 551 'state_std': [ 552 0.17012375593185425, 0.05159067362546921, 0.18141433596611023, 0.16430604457855225, 0.6023368239402771, 553 0.7737284898757935, 1.4986555576324463, 0.7483318448066711, 1.7953159809112549, 2.0530025959014893, 554 5.725032806396484 555 ] 556 }, 557 } 558 559 def __init__(self, cfg: dict) -> None: 560 """ 561 Overview: 562 Initialization method. 563 Arguments: 564 - cfg (:obj:`dict`): Config dict. 565 """ 566 567 dataset_path = cfg.dataset.data_dir_prefix 568 rtg_scale = cfg.dataset.rtg_scale 569 self.context_len = cfg.dataset.context_len 570 self.env_type = cfg.dataset.env_type 571 572 if 'hdf5' in dataset_path: # for mujoco env 573 try: 574 import h5py 575 import collections 576 except ImportError: 577 import sys 578 logging.warning("not found h5py package, please install it trough `pip install h5py ") 579 sys.exit(1) 580 dataset = h5py.File(dataset_path, 'r') 581 582 N = dataset['rewards'].shape[0] 583 data_ = collections.defaultdict(list) 584 585 use_timeouts = False 586 if 'timeouts' in dataset: 587 use_timeouts = True 588 589 episode_step = 0 590 paths = [] 591 for i in range(N): 592 done_bool = bool(dataset['terminals'][i]) 593 if use_timeouts: 594 final_timestep = dataset['timeouts'][i] 595 else: 596 final_timestep = (episode_step == 1000 - 1) 597 for k in ['observations', 'actions', 'rewards', 'terminals']: 598 data_[k].append(dataset[k][i]) 599 if done_bool or final_timestep: 600 episode_step = 0 601 episode_data = {} 602 for k in data_: 603 episode_data[k] = np.array(data_[k]) 604 paths.append(episode_data) 605 data_ = collections.defaultdict(list) 606 episode_step += 1 607 608 self.trajectories = paths 609 610 # calculate state mean and variance and returns_to_go for all traj 611 states = [] 612 for traj in self.trajectories: 613 traj_len = traj['observations'].shape[0] 614 states.append(traj['observations']) 615 # calculate returns to go and rescale them 616 traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale 617 618 # used for input normalization 619 states = np.concatenate(states, axis=0) 620 self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 621 622 # normalize states 623 for traj in self.trajectories: 624 traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std 625 626 elif 'pkl' in dataset_path: 627 if 'dqn' in dataset_path: 628 # load dataset 629 with open(dataset_path, 'rb') as f: 630 self.trajectories = pickle.load(f) 631 632 if isinstance(self.trajectories[0], list): 633 # for our collected dataset, e.g. cartpole/lunarlander case 634 trajectories_tmp = [] 635 636 original_keys = ['obs', 'next_obs', 'action', 'reward'] 637 keys = ['observations', 'next_observations', 'actions', 'rewards'] 638 trajectories_tmp = [ 639 { 640 key: np.stack( 641 [ 642 self.trajectories[eps_index][transition_index][o_key] 643 for transition_index in range(len(self.trajectories[eps_index])) 644 ], 645 axis=0 646 ) 647 for key, o_key in zip(keys, original_keys) 648 } for eps_index in range(len(self.trajectories)) 649 ] 650 self.trajectories = trajectories_tmp 651 652 states = [] 653 for traj in self.trajectories: 654 # traj_len = traj['observations'].shape[0] 655 states.append(traj['observations']) 656 # calculate returns to go and rescale them 657 traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale 658 659 # used for input normalization 660 states = np.concatenate(states, axis=0) 661 self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 662 663 # normalize states 664 for traj in self.trajectories: 665 traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std 666 else: 667 # load dataset 668 with open(dataset_path, 'rb') as f: 669 self.trajectories = pickle.load(f) 670 671 states = [] 672 for traj in self.trajectories: 673 states.append(traj['observations']) 674 # calculate returns to go and rescale them 675 traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale 676 677 # used for input normalization 678 states = np.concatenate(states, axis=0) 679 self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 680 681 # normalize states 682 for traj in self.trajectories: 683 traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std 684 else: 685 # -- load data from memory (make more efficient) 686 obss = [] 687 actions = [] 688 returns = [0] 689 done_idxs = [] 690 stepwise_returns = [] 691 692 transitions_per_buffer = np.zeros(50, dtype=int) 693 num_trajectories = 0 694 while len(obss) < cfg.dataset.num_steps: 695 buffer_num = np.random.choice(np.arange(50 - cfg.dataset.num_buffers, 50), 1)[0] 696 i = transitions_per_buffer[buffer_num] 697 frb = FixedReplayBuffer( 698 data_dir=cfg.dataset.data_dir_prefix + '/1/replay_logs', 699 replay_suffix=buffer_num, 700 observation_shape=(84, 84), 701 stack_size=4, 702 update_horizon=1, 703 gamma=0.99, 704 observation_dtype=np.uint8, 705 batch_size=32, 706 replay_capacity=100000 707 ) 708 if frb._loaded_buffers: 709 done = False 710 curr_num_transitions = len(obss) 711 trajectories_to_load = cfg.dataset.trajectories_per_buffer 712 while not done: 713 states, ac, ret, next_states, next_action, next_reward, terminal, indices = \ 714 frb.sample_transition_batch(batch_size=1, indices=[i]) 715 states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) 716 obss.append(states) 717 actions.append(ac[0]) 718 stepwise_returns.append(ret[0]) 719 if terminal[0]: 720 done_idxs.append(len(obss)) 721 returns.append(0) 722 if trajectories_to_load == 0: 723 done = True 724 else: 725 trajectories_to_load -= 1 726 returns[-1] += ret[0] 727 i += 1 728 if i >= 100000: 729 obss = obss[:curr_num_transitions] 730 actions = actions[:curr_num_transitions] 731 stepwise_returns = stepwise_returns[:curr_num_transitions] 732 returns[-1] = 0 733 i = transitions_per_buffer[buffer_num] 734 done = True 735 num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load) 736 transitions_per_buffer[buffer_num] = i 737 738 actions = np.array(actions) 739 returns = np.array(returns) 740 stepwise_returns = np.array(stepwise_returns) 741 done_idxs = np.array(done_idxs) 742 743 # -- create reward-to-go dataset 744 start_index = 0 745 rtg = np.zeros_like(stepwise_returns) 746 for i in done_idxs: 747 i = int(i) 748 curr_traj_returns = stepwise_returns[start_index:i] 749 for j in range(i - 1, start_index - 1, -1): # start from i-1 750 rtg_j = curr_traj_returns[j - start_index:i - start_index] 751 rtg[j] = sum(rtg_j) 752 start_index = i 753 754 # -- create timestep dataset 755 start_index = 0 756 timesteps = np.zeros(len(actions) + 1, dtype=int) 757 for i in done_idxs: 758 i = int(i) 759 timesteps[start_index:i + 1] = np.arange(i + 1 - start_index) 760 start_index = i + 1 761 762 self.obss = obss 763 self.actions = actions 764 self.done_idxs = done_idxs 765 self.rtgs = rtg 766 self.timesteps = timesteps 767 # return obss, actions, returns, done_idxs, rtg, timesteps 768 769 def get_max_timestep(self) -> int: 770 """ 771 Overview: 772 Get the max timestep of the dataset. 773 """ 774 775 return max(self.timesteps) 776 777 def get_state_stats(self) -> Tuple[np.ndarray, np.ndarray]: 778 """ 779 Overview: 780 Get the state mean and std of the dataset. 781 """ 782 783 return deepcopy(self.state_mean), deepcopy(self.state_std) 784 785 def get_d4rl_dataset_stats(self, env_d4rl_name: str) -> Dict[str, list]: 786 """ 787 Overview: 788 Get the d4rl dataset stats. 789 Arguments: 790 - env_d4rl_name (:obj:`str`): The d4rl env name. 791 """ 792 793 return self.D4RL_DATASET_STATS[env_d4rl_name] 794 795 def __len__(self) -> int: 796 """ 797 Overview: 798 Get the length of the dataset. 799 """ 800 801 if self.env_type != 'atari': 802 return len(self.trajectories) 803 else: 804 return len(self.obss) - self.context_len 805 806 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 807 """ 808 Overview: 809 Get the item of the dataset. 810 Arguments: 811 - idx (:obj:`int`): The index of the dataset. 812 """ 813 814 if self.env_type != 'atari': 815 traj = self.trajectories[idx] 816 traj_len = traj['observations'].shape[0] 817 818 if traj_len > self.context_len: 819 # sample random index to slice trajectory 820 si = np.random.randint(0, traj_len - self.context_len) 821 822 states = torch.from_numpy(traj['observations'][si:si + self.context_len]) 823 actions = torch.from_numpy(traj['actions'][si:si + self.context_len]) 824 returns_to_go = torch.from_numpy(traj['returns_to_go'][si:si + self.context_len]) 825 timesteps = torch.arange(start=si, end=si + self.context_len, step=1) 826 827 # all ones since no padding 828 traj_mask = torch.ones(self.context_len, dtype=torch.long) 829 830 else: 831 padding_len = self.context_len - traj_len 832 833 # padding with zeros 834 states = torch.from_numpy(traj['observations']) 835 states = torch.cat( 836 [states, torch.zeros(([padding_len] + list(states.shape[1:])), dtype=states.dtype)], dim=0 837 ) 838 839 actions = torch.from_numpy(traj['actions']) 840 actions = torch.cat( 841 [actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0 842 ) 843 844 returns_to_go = torch.from_numpy(traj['returns_to_go']) 845 returns_to_go = torch.cat( 846 [ 847 returns_to_go, 848 torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype) 849 ], 850 dim=0 851 ) 852 853 timesteps = torch.arange(start=0, end=self.context_len, step=1) 854 855 traj_mask = torch.cat( 856 [torch.ones(traj_len, dtype=torch.long), 857 torch.zeros(padding_len, dtype=torch.long)], dim=0 858 ) 859 return timesteps, states, actions, returns_to_go, traj_mask 860 else: # mean cost less than 0.001s 861 block_size = self.context_len 862 done_idx = idx + block_size 863 for i in self.done_idxs: 864 if i > idx: # first done_idx greater than idx 865 done_idx = min(int(i), done_idx) 866 break 867 idx = done_idx - block_size 868 states = torch.as_tensor( 869 np.array(self.obss[idx:done_idx]), dtype=torch.float32 870 ).view(block_size, -1) # (block_size, 4*84*84) 871 states = states / 255. 872 actions = torch.as_tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) 873 rtgs = torch.as_tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) 874 timesteps = torch.as_tensor(self.timesteps[idx:idx + 1], dtype=torch.int64).unsqueeze(1) 875 traj_mask = torch.ones(self.context_len, dtype=torch.long) 876 return timesteps, states, actions, rtgs, traj_mask 877 878 879@DATASET_REGISTRY.register('d4rl_diffuser') 880class D4RLDiffuserDataset(Dataset): 881 """ 882 Overview: 883 D4RL diffuser dataset, which is used for offline RL algorithms. 884 Interfaces: 885 ``__init__``, ``__len__``, ``__getitem__`` 886 """ 887 888 def __init__(self, dataset_path: str, context_len: int, rtg_scale: float) -> None: 889 """ 890 Overview: 891 Initialization method of D4RLDiffuserDataset. 892 Arguments: 893 - dataset_path (:obj:`str`): The dataset path. 894 - context_len (:obj:`int`): The length of the context. 895 - rtg_scale (:obj:`float`): The scale of the returns to go. 896 """ 897 898 self.context_len = context_len 899 900 # load dataset 901 with open(dataset_path, 'rb') as f: 902 self.trajectories = pickle.load(f) 903 904 if isinstance(self.trajectories[0], list): 905 # for our collected dataset, e.g. cartpole/lunarlander case 906 trajectories_tmp = [] 907 908 original_keys = ['obs', 'next_obs', 'action', 'reward'] 909 keys = ['observations', 'next_observations', 'actions', 'rewards'] 910 for key, o_key in zip(keys, original_keys): 911 trajectories_tmp = [ 912 { 913 key: np.stack( 914 [ 915 self.trajectories[eps_index][transition_index][o_key] 916 for transition_index in range(len(self.trajectories[eps_index])) 917 ], 918 axis=0 919 ) 920 } for eps_index in range(len(self.trajectories)) 921 ] 922 self.trajectories = trajectories_tmp 923 924 states = [] 925 for traj in self.trajectories: 926 traj_len = traj['observations'].shape[0] 927 states.append(traj['observations']) 928 # calculate returns to go and rescale them 929 traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale 930 931 # used for input normalization 932 states = np.concatenate(states, axis=0) 933 self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 934 935 # normalize states 936 for traj in self.trajectories: 937 traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std 938 939 940class FixedReplayBuffer(object): 941 """ 942 Overview: 943 Object composed of a list of OutofGraphReplayBuffers. 944 Interfaces: 945 ``__init__``, ``get_transition_elements``, ``sample_transition_batch`` 946 """ 947 948 def __init__(self, data_dir: str, replay_suffix: int, *args, **kwargs): 949 """ 950 Overview: 951 Initialize the FixedReplayBuffer class. 952 Arguments: 953 - data_dir (:obj:`str`): Log directory from which to load the replay buffer. 954 - replay_suffix (:obj:`int`): If not None, then only load the replay buffer \ 955 corresponding to the specific suffix in data directory. 956 - args (:obj:`list`): Arbitrary extra arguments. 957 - kwargs (:obj:`dict`): Arbitrary keyword arguments. 958 """ 959 960 self._args = args 961 self._kwargs = kwargs 962 self._data_dir = data_dir 963 self._loaded_buffers = False 964 self.add_count = np.array(0) 965 self._replay_suffix = replay_suffix 966 if not self._loaded_buffers: 967 if replay_suffix is not None: 968 assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' 969 self.load_single_buffer(replay_suffix) 970 else: 971 pass 972 # self._load_replay_buffers(num_buffers=50) 973 974 def load_single_buffer(self, suffix): 975 """ 976 Overview: 977 Load a single replay buffer. 978 Arguments: 979 - suffix (:obj:`int`): The suffix of the replay buffer. 980 """ 981 982 replay_buffer = self._load_buffer(suffix) 983 if replay_buffer is not None: 984 self._replay_buffers = [replay_buffer] 985 self.add_count = replay_buffer.add_count 986 self._num_replay_buffers = 1 987 self._loaded_buffers = True 988 989 def _load_buffer(self, suffix): 990 """ 991 Overview: 992 Loads a OutOfGraphReplayBuffer replay buffer. 993 Arguments: 994 - suffix (:obj:`int`): The suffix of the replay buffer. 995 """ 996 997 try: 998 from dopamine.replay_memory import circular_replay_buffer 999 STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX1000 # pytype: disable=attribute-error1001 replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(*self._args, **self._kwargs)1002 replay_buffer.load(self._data_dir, suffix)1003 # pytype: enable=attribute-error1004 return replay_buffer1005 # except tf.errors.NotFoundError:1006 except:1007 raise ('can not load')10081009 def get_transition_elements(self):1010 """1011 Overview:1012 Returns the transition elements.1013 """10141015 return self._replay_buffers[0].get_transition_elements()10161017 def sample_transition_batch(self, batch_size=None, indices=None):1018 """1019 Overview:1020 Returns a batch of transitions (including any extra contents).1021 Arguments:1022 - batch_size (:obj:`int`): The batch size.1023 - indices (:obj:`list`): The indices of the batch.1024 """10251026 buffer_index = np.random.randint(self._num_replay_buffers)1027 return self._replay_buffers[buffer_index].sample_transition_batch(batch_size=batch_size, indices=indices)102810291030class PCDataset(Dataset):1031 """1032 Overview:1033 Dataset for Procedure Cloning.1034 Interfaces:1035 ``__init__``, ``__len__``, ``__getitem__``1036 """10371038 def __init__(self, all_data):1039 """1040 Overview:1041 Initialization method of PCDataset.1042 Arguments:1043 - all_data (:obj:`tuple`): The tuple of all data.1044 """10451046 self._data = all_data10471048 def __getitem__(self, item):1049 """1050 Overview:1051 Get the item of the dataset.1052 Arguments:1053 - item (:obj:`int`): The index of the dataset.1054 """10551056 return {'obs': self._data[0][item], 'bfs_in': self._data[1][item], 'bfs_out': self._data[2][item]}10571058 def __len__(self):1059 """1060 Overview:1061 Get the length of the dataset.1062 """10631064 return self._data[0].shape[0]106510661067def load_bfs_datasets(train_seeds=1, test_seeds=5):1068 """1069 Overview:1070 Load BFS datasets.1071 Arguments:1072 - train_seeds (:obj:`int`): The number of train seeds.1073 - test_seeds (:obj:`int`): The number of test seeds.1074 """10751076 from dizoo.maze.envs import Maze10771078 def load_env(seed):1079 ccc = easydict.EasyDict({'size': 16})1080 e = Maze(ccc)1081 e.seed(seed)1082 e.reset()1083 return e10841085 envs = [load_env(i) for i in range(train_seeds + test_seeds)]10861087 observations_train = []1088 observations_test = []1089 bfs_input_maps_train = []1090 bfs_input_maps_test = []1091 bfs_output_maps_train = []1092 bfs_output_maps_test = []1093 for idx, env in enumerate(envs):1094 if idx < train_seeds:1095 observations = observations_train1096 bfs_input_maps = bfs_input_maps_train1097 bfs_output_maps = bfs_output_maps_train1098 else:1099 observations = observations_test1100 bfs_input_maps = bfs_input_maps_test1101 bfs_output_maps = bfs_output_maps_test11021103 start_obs = env.process_states(env._get_obs(), env.get_maze_map())1104 _, track_back = get_vi_sequence(env, start_obs)1105 env_observations = torch.stack([track_back[i][0] for i in range(len(track_back))], dim=0)11061107 for i in range(env_observations.shape[0]):1108 bfs_sequence, _ = get_vi_sequence(env, env_observations[i].numpy().astype(np.int32)) # [L, W, W]1109 bfs_input_map = env.n_action * np.ones([env.size, env.size], dtype=np.long)11101111 for j in range(bfs_sequence.shape[0]):1112 bfs_input_maps.append(torch.from_numpy(bfs_input_map))1113 bfs_output_maps.append(torch.from_numpy(bfs_sequence[j]))1114 observations.append(env_observations[i])1115 bfs_input_map = bfs_sequence[j]11161117 train_data = PCDataset(1118 (1119 torch.stack(observations_train, dim=0),1120 torch.stack(bfs_input_maps_train, dim=0),1121 torch.stack(bfs_output_maps_train, dim=0),1122 )1123 )1124 test_data = PCDataset(1125 (1126 torch.stack(observations_test, dim=0),1127 torch.stack(bfs_input_maps_test, dim=0),1128 torch.stack(bfs_output_maps_test, dim=0),1129 )1130 )11311132 return train_data, test_data113311341135@DATASET_REGISTRY.register('bco')1136class BCODataset(Dataset):1137 """1138 Overview:1139 Dataset for Behavioral Cloning from Observation.1140 Interfaces:1141 ``__init__``, ``__len__``, ``__getitem__``1142 Properties:1143 - obs (:obj:`np.ndarray`): The observation array.1144 - action (:obj:`np.ndarray`): The action array.1145 """11461147 def __init__(self, data=None):1148 """1149 Overview:1150 Initialization method of BCODataset.1151 Arguments:1152 - data (:obj:`dict`): The data dict.1153 """11541155 if data is None:1156 raise ValueError('Dataset can not be empty!')1157 else:1158 self._data = data11591160 def __len__(self):1161 """1162 Overview:1163 Get the length of the dataset.1164 """11651166 return len(self._data['obs'])11671168 def __getitem__(self, idx):1169 """1170 Overview:1171 Get the item of the dataset.1172 Arguments:1173 - idx (:obj:`int`): The index of the dataset.1174 """11751176 return {k: self._data[k][idx] for k in self._data.keys()}11771178 @property1179 def obs(self):1180 """1181 Overview:1182 Get the observation array.1183 """11841185 return self._data['obs']11861187 @property1188 def action(self):1189 """1190 Overview:1191 Get the action array.1192 """11931194 return self._data['action']119511961197@DATASET_REGISTRY.register('diffuser_traj')1198class SequenceDataset(torch.utils.data.Dataset):1199 """1200 Overview:1201 Dataset for diffuser.1202 Interfaces:1203 ``__init__``, ``__len__``, ``__getitem__``1204 """12051206 def __init__(self, cfg):1207 """1208 Overview:1209 Initialization method of SequenceDataset.1210 Arguments:1211 - cfg (:obj:`dict`): The config dict.1212 """12131214 import gym12151216 env_id = cfg.env.env_id1217 data_path = cfg.policy.collect.get('data_path', None)1218 env = gym.make(env_id)12191220 dataset = env.get_dataset()12211222 self.returns_scale = cfg.env.returns_scale1223 self.horizon = cfg.env.horizon1224 self.max_path_length = cfg.env.max_path_length1225 self.discount = cfg.policy.learn.discount_factor1226 self.discounts = self.discount ** np.arange(self.max_path_length)[:, None]1227 self.use_padding = cfg.env.use_padding1228 self.include_returns = cfg.env.include_returns1229 self.env_id = cfg.env.env_id1230 itr = self.sequence_dataset(env, dataset)1231 self.n_episodes = 012321233 fields = {}1234 for k in dataset.keys():1235 if 'metadata' in k:1236 continue1237 fields[k] = []1238 fields['path_lengths'] = []12391240 for i, episode in enumerate(itr):1241 path_length = len(episode['observations'])1242 assert path_length <= self.max_path_length1243 fields['path_lengths'].append(path_length)1244 for key, val in episode.items():1245 if key not in fields:1246 fields[key] = []1247 if val.ndim < 2:1248 val = np.expand_dims(val, axis=-1)1249 shape = (self.max_path_length, val.shape[-1])1250 arr = np.zeros(shape, dtype=np.float32)1251 arr[:path_length] = val1252 fields[key].append(arr)1253 if episode['terminals'].any() and cfg.env.termination_penalty and 'timeouts' in episode:1254 assert not episode['timeouts'].any(), 'Penalized a timeout episode for early termination'1255 fields['rewards'][-1][path_length - 1] += cfg.env.termination_penalty1256 self.n_episodes += 112571258 for k in fields.keys():1259 fields[k] = np.array(fields[k])12601261 self.normalizer = DatasetNormalizer(fields, cfg.policy.normalizer, path_lengths=fields['path_lengths'])1262 self.indices = self.make_indices(fields['path_lengths'], self.horizon)12631264 self.observation_dim = cfg.env.obs_dim1265 self.action_dim = cfg.env.action_dim1266 self.fields = fields1267 self.normalize()1268 self.normed = False1269 if cfg.env.normed:1270 self.vmin, self.vmax = self._get_bounds()1271 self.normed = True12721273 # shapes = {key: val.shape for key, val in self.fields.items()}1274 # print(f'[ datasets/mujoco ] Dataset fields: {shapes}')12751276 def sequence_dataset(self, env, dataset=None):1277 """1278 Overview:1279 Sequence the dataset.1280 Arguments:1281 - env (:obj:`gym.Env`): The gym env.1282 """12831284 import collections1285 N = dataset['rewards'].shape[0]1286 if 'maze2d' in env.spec.id:1287 dataset = self.maze2d_set_terminals(env, dataset)1288 data_ = collections.defaultdict(list)12891290 # The newer version of the dataset adds an explicit1291 # timeouts field. Keep old method for backwards compatability.1292 use_timeouts = 'timeouts' in dataset12931294 episode_step = 01295 for i in range(N):1296 done_bool = bool(dataset['terminals'][i])1297 if use_timeouts:1298 final_timestep = dataset['timeouts'][i]1299 else:1300 final_timestep = (episode_step == env._max_episode_steps - 1)13011302 for k in dataset:1303 if 'metadata' in k:1304 continue1305 data_[k].append(dataset[k][i])13061307 if done_bool or final_timestep:1308 episode_step = 01309 episode_data = {}1310 for k in data_:1311 episode_data[k] = np.array(data_[k])1312 if 'maze2d' in env.spec.id:1313 episode_data = self.process_maze2d_episode(episode_data)1314 yield episode_data1315 data_ = collections.defaultdict(list)13161317 episode_step += 113181319 def maze2d_set_terminals(self, env, dataset):1320 """1321 Overview:1322 Set the terminals for maze2d.1323 Arguments:1324 - env (:obj:`gym.Env`): The gym env.1325 - dataset (:obj:`dict`): The dataset dict.1326 """13271328 goal = env.get_target()1329 threshold = 0.513301331 xy = dataset['observations'][:, :2]1332 distances = np.linalg.norm(xy - goal, axis=-1)1333 at_goal = distances < threshold1334 timeouts = np.zeros_like(dataset['timeouts'])13351336 # timeout at time t iff1337 # at goal at time t and1338 # not at goal at time t + 11339 timeouts[:-1] = at_goal[:-1] * ~at_goal[1:]13401341 timeout_steps = np.where(timeouts)[0]1342 path_lengths = timeout_steps[1:] - timeout_steps[:-1]13431344 print(1345 f'[ utils/preprocessing ] Segmented {env.spec.id} | {len(path_lengths)} paths | '1346 f'min length: {path_lengths.min()} | max length: {path_lengths.max()}'1347 )13481349 dataset['timeouts'] = timeouts1350 return dataset13511352 def process_maze2d_episode(self, episode):1353 """1354 Overview:1355 Process the maze2d episode, adds in `next_observations` field to episode.1356 Arguments:1357 - episode (:obj:`dict`): The episode dict.1358 """13591360 assert 'next_observations' not in episode1361 length = len(episode['observations'])1362 next_observations = episode['observations'][1:].copy()1363 for key, val in episode.items():1364 episode[key] = val[:-1]1365 episode['next_observations'] = next_observations1366 return episode13671368 def normalize(self, keys=['observations', 'actions']):1369 """1370 Overview:1371 Normalize the dataset, normalize fields that will be predicted by the diffusion model1372 Arguments:1373 - keys (:obj:`list`): The list of keys.1374 """13751376 for key in keys:1377 array = self.fields[key].reshape(self.n_episodes * self.max_path_length, -1)1378 normed = self.normalizer.normalize(array, key)1379 self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1)13801381 def make_indices(self, path_lengths, horizon):1382 """1383 Overview:1384 Make indices for sampling from dataset. Each index maps to a datapoint.1385 Arguments:1386 - path_lengths (:obj:`np.ndarray`): The path length array.1387 - horizon (:obj:`int`): The horizon.1388 """13891390 indices = []1391 for i, path_length in enumerate(path_lengths):1392 max_start = min(path_length - 1, self.max_path_length - horizon)1393 if not self.use_padding:1394 max_start = min(max_start, path_length - horizon)1395 for start in range(max_start):1396 end = start + horizon1397 indices.append((i, start, end))1398 indices = np.array(indices)1399 return indices14001401 def get_conditions(self, observations):1402 """1403 Overview:1404 Get the conditions on current observation for planning.1405 Arguments:1406 - observations (:obj:`np.ndarray`): The observation array.1407 """14081409 if 'maze2d' in self.env_id:1410 return {'condition_id': [0, self.horizon - 1], 'condition_val': [observations[0], observations[-1]]}1411 else:1412 return {'condition_id': [0], 'condition_val': [observations[0]]}14131414 def __len__(self):1415 """1416 Overview:1417 Get the length of the dataset.1418 """14191420 return len(self.indices)14211422 def _get_bounds(self):1423 """1424 Overview:1425 Get the bounds of the dataset.1426 """14271428 print('[ datasets/sequence ] Getting value dataset bounds...', end=' ', flush=True)1429 vmin = np.inf1430 vmax = -np.inf1431 for i in range(len(self.indices)):1432 value = self.__getitem__(i)['returns'].item()1433 vmin = min(value, vmin)1434 vmax = max(value, vmax)1435 print('✓')1436 return vmin, vmax14371438 def normalize_value(self, value):1439 """1440 Overview:1441 Normalize the value.1442 Arguments:1443 - value (:obj:`np.ndarray`): The value array.1444 """14451446 # [0, 1]1447 normed = (value - self.vmin) / (self.vmax - self.vmin)1448 # [-1, 1]1449 normed = normed * 2 - 11450 return normed14511452 def __getitem__(self, idx, eps=1e-4):1453 """1454 Overview:1455 Get the item of the dataset.1456 Arguments:1457 - idx (:obj:`int`): The index of the dataset.1458 - eps (:obj:`float`): The epsilon.1459 """14601461 path_ind, start, end = self.indices[idx]14621463 observations = self.fields['normed_observations'][path_ind, start:end]1464 actions = self.fields['normed_actions'][path_ind, start:end]1465 done = self.fields['terminals'][path_ind, start:end]14661467 # conditions = self.get_conditions(observations)1468 trajectories = np.concatenate([actions, observations], axis=-1)14691470 if self.include_returns:1471 rewards = self.fields['rewards'][path_ind, start:]1472 discounts = self.discounts[:len(rewards)]1473 returns = (discounts * rewards).sum()1474 if self.normed:1475 returns = self.normalize_value(returns)1476 returns = np.array([returns / self.returns_scale], dtype=np.float32)1477 batch = {1478 'trajectories': trajectories,1479 'returns': returns,1480 'done': done,1481 'action': actions,1482 }1483 else:1484 batch = {1485 'trajectories': trajectories,1486 'done': done,1487 'action': actions,1488 }14891490 batch.update(self.get_conditions(observations))1491 return batch149214931494def hdf5_save(exp_data, expert_data_path):1495 """1496 Overview:1497 Save the data to hdf5.1498 """14991500 try:1501 import h5py1502 except ImportError:1503 import sys1504 logging.warning("not found h5py package, please install it trough 'pip install h5py' ")1505 sys.exit(1)1506 dataset = dataset = h5py.File('%s_demos.hdf5' % expert_data_path.replace('.pkl', ''), 'w')1507 dataset.create_dataset('obs', data=np.array([d['obs'].numpy() for d in exp_data]), compression='gzip')1508 dataset.create_dataset('action', data=np.array([d['action'].numpy() for d in exp_data]), compression='gzip')1509 dataset.create_dataset('reward', data=np.array([d['reward'].numpy() for d in exp_data]), compression='gzip')1510 dataset.create_dataset('done', data=np.array([d['done'] for d in exp_data]), compression='gzip')1511 dataset.create_dataset('next_obs', data=np.array([d['next_obs'].numpy() for d in exp_data]), compression='gzip')151215131514def naive_save(exp_data, expert_data_path):1515 """1516 Overview:1517 Save the data to pickle.1518 """15191520 with open(expert_data_path, 'wb') as f:1521 pickle.dump(exp_data, f)152215231524def offline_data_save_type(exp_data, expert_data_path, data_type='naive'):1525 """1526 Overview:1527 Save the offline data.1528 """15291530 globals()[data_type + '_save'](exp_data, expert_data_path)153115321533def create_dataset(cfg, **kwargs) -> Dataset:1534 """1535 Overview:1536 Create dataset.1537 """15381539 cfg = EasyDict(cfg)1540 import_module(cfg.get('import_names', []))1541 return DATASET_REGISTRY.build(cfg.policy.collect.data_type, cfg=cfg, **kwargs)