1import os 2from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional 3from easydict import EasyDict 4from ditk import logging 5import numpy as np 6import torch 7import tqdm 8from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type 9from ding.data.buffer.middleware import PriorityExperienceReplay 10from ding.framework import task 11from ding.utils import get_rank 12 13if TYPE_CHECKING: 14 from ding.framework import OnlineRLContext, OfflineRLContext 15 16 17def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = None): 18 """ 19 Overview: 20 Push episodes or trajectories into the buffer. 21 Arguments: 22 - cfg (:obj:`EasyDict`): Config. 23 - buffer (:obj:`Buffer`): Buffer to push the data in. 24 """ 25 if task.router.is_active and not task.has_role(task.role.LEARNER): 26 return task.void() 27 28 def _push(ctx: "OnlineRLContext"): 29 """ 30 Overview: 31 In ctx, either `ctx.trajectories` or `ctx.episodes` should not be None. 32 Input of ctx: 33 - trajectories (:obj:`List[Dict]`): Trajectories. 34 - episodes (:obj:`List[Dict]`): Episodes. 35 """ 36 37 if ctx.trajectories is not None: # each data in buffer is a transition 38 if group_by_env: 39 for i, t in enumerate(ctx.trajectories): 40 buffer_.push(t, {'env': t.env_data_id.item()}) 41 else: 42 for t in ctx.trajectories: 43 buffer_.push(t) 44 ctx.trajectories = None 45 elif ctx.episodes is not None: # each data in buffer is a episode 46 for t in ctx.episodes: 47 buffer_.push(t) 48 ctx.episodes = None 49 else: 50 raise RuntimeError("Either ctx.trajectories or ctx.episodes should be not None.") 51 52 return _push 53 54 55def buffer_saver(cfg: EasyDict, buffer_: Buffer, every_envstep: int = 1000, replace: bool = False): 56 """ 57 Overview: 58 Save current buffer data. 59 Arguments: 60 - cfg (:obj:`EasyDict`): Config. 61 - buffer (:obj:`Buffer`): Buffer to push the data in. 62 - every_envstep (:obj:`int`): save at every env step. 63 - replace (:obj:`bool`): Whether replace the last file. 64 """ 65 66 buffer_saver_env_counter = -every_envstep 67 68 def _save(ctx: "OnlineRLContext"): 69 """ 70 Overview: 71 In ctx, `ctx.env_step` should not be None. 72 Input of ctx: 73 - env_step (:obj:`int`): env step. 74 """ 75 nonlocal buffer_saver_env_counter 76 if ctx.env_step is not None: 77 if ctx.env_step >= every_envstep + buffer_saver_env_counter: 78 buffer_saver_env_counter = ctx.env_step 79 if replace: 80 buffer_.save_data(os.path.join(cfg.exp_name, "replaybuffer", "data_latest.hkl")) 81 else: 82 buffer_.save_data( 83 os.path.join(cfg.exp_name, "replaybuffer", "data_envstep_{}.hkl".format(ctx.env_step)) 84 ) 85 else: 86 raise RuntimeError("buffer_saver only supports collecting data by step rather than episode.") 87 88 return _save 89 90 91def offpolicy_data_fetcher( 92 cfg: EasyDict, 93 buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], 94 data_shortage_warning: bool = False, 95) -> Callable: 96 """ 97 Overview: 98 The return function is a generator which meanly fetch a batch of data from a buffer, \ 99 a list of buffers, or a dict of buffers. 100 Arguments: 101 - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`. 102 - buffer (:obj:`Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]]`): \ 103 The buffer where the data is fetched from. \ 104 ``Buffer`` type means a buffer.\ 105 ``List[Tuple[Buffer, float]]`` type means a list of tuple. In each tuple there is a buffer and a float. \ 106 The float defines, how many batch_size is the size of the data \ 107 which is sampled from the corresponding buffer.\ 108 ``Dict[str, Buffer]`` type means a dict in which the value of each element is a buffer. \ 109 For each key-value pair of dict, batch_size of data will be sampled from the corresponding buffer \ 110 and assigned to the same key of `ctx.train_data`. 111 - data_shortage_warning (:obj:`bool`): Whether to output warning when data shortage occurs in fetching. 112 """ 113 114 def _fetch(ctx: "OnlineRLContext"): 115 """ 116 Input of ctx: 117 - train_output (:obj:`Union[Dict, Deque[Dict]]`): This attribute should exist \ 118 if `buffer_` is of type Buffer and if `buffer_` use the middleware `PriorityExperienceReplay`. \ 119 The meta data `priority` of the sampled data in the `buffer_` will be updated \ 120 to the `priority` attribute of `ctx.train_output` if `ctx.train_output` is a dict, \ 121 or the `priority` attribute of `ctx.train_output`'s popped element \ 122 if `ctx.train_output` is a deque of dicts. 123 Output of ctx: 124 - train_data (:obj:`Union[List[Dict], Dict[str, List[Dict]]]`): The fetched data. \ 125 ``List[Dict]`` type means a list of data. 126 `train_data` is of this type if the type of `buffer_` is Buffer or List. 127 ``Dict[str, List[Dict]]]`` type means a dict, in which the value of each key-value pair 128 is a list of data. `train_data` is of this type if the type of `buffer_` is Dict. 129 """ 130 try: 131 unroll_len = cfg.policy.collect.unroll_len 132 if isinstance(buffer_, Buffer): 133 if unroll_len > 1: 134 buffered_data = buffer_.sample( 135 cfg.policy.learn.batch_size, groupby="env", unroll_len=unroll_len, replace=True 136 ) 137 ctx.train_data = [[t.data for t in d] for d in buffered_data] # B, unroll_len 138 else: 139 buffered_data = buffer_.sample(cfg.policy.learn.batch_size) 140 ctx.train_data = [d.data for d in buffered_data] 141 elif isinstance(buffer_, List): # like sqil, r2d3 142 assert unroll_len == 1, "not support" 143 buffered_data = [] 144 for buffer_elem, p in buffer_: 145 data_elem = buffer_elem.sample(int(cfg.policy.learn.batch_size * p)) 146 assert data_elem is not None 147 buffered_data.append(data_elem) 148 buffered_data = sum(buffered_data, []) 149 ctx.train_data = [d.data for d in buffered_data] 150 elif isinstance(buffer_, Dict): # like ppg_offpolicy 151 assert unroll_len == 1, "not support" 152 buffered_data = {k: v.sample(cfg.policy.learn.batch_size) for k, v in buffer_.items()} 153 ctx.train_data = {k: [d.data for d in v] for k, v in buffered_data.items()} 154 else: 155 raise TypeError("not support buffer argument type: {}".format(type(buffer_))) 156 157 assert buffered_data is not None 158 except (ValueError, AssertionError): 159 if data_shortage_warning: 160 # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode. 161 # Fetcher will skip this this attempt. 162 logging.warning( 163 "Replay buffer's data is not enough to support training, so skip this training to wait more data." 164 ) 165 ctx.train_data = None 166 return 167 168 yield 169 170 if isinstance(buffer_, Buffer): 171 if any([isinstance(m, PriorityExperienceReplay) for m in buffer_._middleware]): 172 index = [d.index for d in buffered_data] 173 meta = [d.meta for d in buffered_data] 174 # such as priority 175 if isinstance(ctx.train_output, List): 176 priority = ctx.train_output.pop()['priority'] 177 else: 178 priority = ctx.train_output['priority'] 179 for idx, m, p in zip(index, meta, priority): 180 m['priority'] = p 181 buffer_.update(index=idx, data=None, meta=m) 182 183 return _fetch 184 185 186def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: 187 188 from threading import Thread 189 from queue import Queue 190 import time 191 stream = torch.cuda.Stream() 192 193 def producer(queue, dataset, batch_size, device): 194 torch.set_num_threads(4) 195 nonlocal stream 196 idx_iter = iter(range(len(dataset) - batch_size)) 197 198 if len(dataset) < batch_size: 199 logging.warning('batch_size is too large!!!!') 200 with torch.cuda.stream(stream): 201 while True: 202 if queue.full(): 203 time.sleep(0.1) 204 else: 205 try: 206 start_idx = next(idx_iter) 207 except StopIteration: 208 del idx_iter 209 idx_iter = iter(range(len(dataset) - batch_size)) 210 start_idx = next(idx_iter) 211 data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)] 212 data = [[i[j] for i in data] for j in range(len(data[0]))] 213 data = [torch.stack(x).to(device) for x in data] 214 queue.put(data) 215 216 queue = Queue(maxsize=50) 217 device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' 218 producer_thread = Thread( 219 target=producer, args=(queue, dataset, cfg.policy.learn.batch_size, device), name='cuda_fetcher_producer' 220 ) 221 222 def _fetch(ctx: "OfflineRLContext"): 223 nonlocal queue, producer_thread 224 if not producer_thread.is_alive(): 225 time.sleep(5) 226 producer_thread.start() 227 while queue.empty(): 228 time.sleep(0.001) 229 ctx.train_data = queue.get() 230 231 return _fetch 232 233 234def offline_data_fetcher(cfg: EasyDict, dataset: Dataset, collate_fn=lambda x: x) -> Callable: 235 """ 236 Overview: 237 The outer function transforms a Pytorch `Dataset` to `DataLoader`. \ 238 The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\ 239 Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html \ 240 and https://pytorch.org/docs/stable/data.html for more details. 241 Arguments: 242 - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`. 243 - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data. 244 """ 245 # collate_fn is executed in policy now 246 dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn) 247 dataloader = iter(dataloader) 248 249 def _fetch(ctx: "OfflineRLContext"): 250 """ 251 Overview: 252 Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \ 253 After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1. 254 Input of ctx: 255 - train_epoch (:obj:`int`): Number of `train_epoch`. 256 Output of ctx: 257 - train_data (:obj:`List[Tensor]`): The fetched data batch. 258 """ 259 nonlocal dataloader 260 try: 261 ctx.train_data = next(dataloader) # noqa 262 except StopIteration: 263 ctx.train_epoch += 1 264 del dataloader 265 dataloader = DataLoader( 266 dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn 267 ) 268 dataloader = iter(dataloader) 269 ctx.train_data = next(dataloader) 270 # TODO apply data update (e.g. priority) in offline setting when necessary 271 ctx.trained_env_step += len(ctx.train_data) 272 273 return _fetch 274 275 276def offline_data_saver(data_path: str, data_type: str = 'hdf5') -> Callable: 277 """ 278 Overview: 279 Save the expert data of offline RL in a directory. 280 Arguments: 281 - data_path (:obj:`str`): File path where the expert data will be written into, which is usually ./expert.pkl'. 282 - data_type (:obj:`str`): Define the type of the saved data. \ 283 The type of saved data is pkl if `data_type == 'naive'`. \ 284 The type of saved data is hdf5 if `data_type == 'hdf5'`. 285 """ 286 287 def _save(ctx: "OnlineRLContext"): 288 """ 289 Input of ctx: 290 - trajectories (:obj:`List[Tensor]`): The expert data to be saved. 291 """ 292 data = ctx.trajectories 293 offline_data_save_type(data, data_path, data_type) 294 ctx.trajectories = None 295 296 return _save 297 298 299def sqil_data_pusher(cfg: EasyDict, buffer_: Buffer, expert: bool) -> Callable: 300 """ 301 Overview: 302 Push trajectories into the buffer in sqil learning pipeline. 303 Arguments: 304 - cfg (:obj:`EasyDict`): Config. 305 - buffer (:obj:`Buffer`): Buffer to push the data in. 306 - expert (:obj:`bool`): Whether the pushed data is expert data or not. \ 307 In each element of the pushed data, the reward will be set to 1 if this attribute is `True`, otherwise 0. 308 """ 309 310 def _pusher(ctx: "OnlineRLContext"): 311 """ 312 Input of ctx: 313 - trajectories (:obj:`List[Dict]`): The trajectories to be pushed. 314 """ 315 for t in ctx.trajectories: 316 if expert: 317 t.reward = torch.ones_like(t.reward) 318 else: 319 t.reward = torch.zeros_like(t.reward) 320 buffer_.push(t) 321 ctx.trajectories = None 322 323 return _pusher 324 325 326def qgpo_support_data_generator(cfg, dataset, policy) -> Callable: 327 328 behavior_policy_stop_training_iter = cfg.policy.learn.behavior_policy_stop_training_iter if hasattr( 329 cfg.policy.learn, 'behavior_policy_stop_training_iter' 330 ) else np.inf 331 energy_guided_policy_begin_training_iter = cfg.policy.learn.energy_guided_policy_begin_training_iter if hasattr( 332 cfg.policy.learn, 'energy_guided_policy_begin_training_iter' 333 ) else 0 334 actions_generated = False 335 336 def generate_fake_actions(): 337 allstates = dataset.states[:].cpu().numpy() 338 actions_sampled = [] 339 for states in tqdm.tqdm(np.array_split(allstates, allstates.shape[0] // 4096 + 1)): 340 actions_sampled.append( 341 policy._model.sample( 342 states, 343 sample_per_state=cfg.policy.learn.M, 344 diffusion_steps=cfg.policy.learn.diffusion_steps, 345 guidance_scale=0.0, 346 ) 347 ) 348 actions = np.concatenate(actions_sampled) 349 350 allnextstates = dataset.next_states[:].cpu().numpy() 351 actions_next_states_sampled = [] 352 for next_states in tqdm.tqdm(np.array_split(allnextstates, allnextstates.shape[0] // 4096 + 1)): 353 actions_next_states_sampled.append( 354 policy._model.sample( 355 next_states, 356 sample_per_state=cfg.policy.learn.M, 357 diffusion_steps=cfg.policy.learn.diffusion_steps, 358 guidance_scale=0.0, 359 ) 360 ) 361 actions_next_states = np.concatenate(actions_next_states_sampled) 362 return actions, actions_next_states 363 364 def _data_generator(ctx: "OfflineRLContext"): 365 nonlocal actions_generated 366 367 if ctx.train_iter >= energy_guided_policy_begin_training_iter: 368 if ctx.train_iter > behavior_policy_stop_training_iter: 369 # no need to generate fake actions if fake actions are already generated 370 if actions_generated: 371 pass 372 else: 373 actions, actions_next_states = generate_fake_actions() 374 dataset.fake_actions = torch.Tensor(actions.astype(np.float32)).to(cfg.policy.model.device) 375 dataset.fake_next_actions = torch.Tensor(actions_next_states.astype(np.float32) 376 ).to(cfg.policy.model.device) 377 actions_generated = True 378 else: 379 # generate fake actions 380 actions, actions_next_states = generate_fake_actions() 381 dataset.fake_actions = torch.Tensor(actions.astype(np.float32)).to(cfg.policy.model.device) 382 dataset.fake_next_actions = torch.Tensor(actions_next_states.astype(np.float32) 383 ).to(cfg.policy.model.device) 384 actions_generated = True 385 else: 386 # no need to generate fake actions 387 pass 388 389 return _data_generator 390 391 392def qgpo_offline_data_fetcher(cfg: EasyDict, dataset: Dataset, collate_fn=lambda x: x) -> Callable: 393 """ 394 Overview: 395 The outer function transforms a Pytorch `Dataset` to `DataLoader`. \ 396 The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\ 397 Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html \ 398 and https://pytorch.org/docs/stable/data.html for more details. 399 Arguments: 400 - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`. 401 - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data. 402 """ 403 # collate_fn is executed in policy now 404 dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=collate_fn) 405 dataloader_q = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size_q, shuffle=True, collate_fn=collate_fn) 406 407 behavior_policy_stop_training_iter = cfg.policy.learn.behavior_policy_stop_training_iter if hasattr( 408 cfg.policy.learn, 'behavior_policy_stop_training_iter' 409 ) else np.inf 410 energy_guided_policy_begin_training_iter = cfg.policy.learn.energy_guided_policy_begin_training_iter if hasattr( 411 cfg.policy.learn, 'energy_guided_policy_begin_training_iter' 412 ) else 0 413 414 def get_behavior_policy_training_data(): 415 while True: 416 yield from dataloader 417 418 data = get_behavior_policy_training_data() 419 420 def get_q_training_data(): 421 while True: 422 yield from dataloader_q 423 424 data_q = get_q_training_data() 425 426 def _fetch(ctx: "OfflineRLContext"): 427 """ 428 Overview: 429 Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \ 430 After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1. 431 Input of ctx: 432 - train_epoch (:obj:`int`): Number of `train_epoch`. 433 Output of ctx: 434 - train_data (:obj:`List[Tensor]`): The fetched data batch. 435 """ 436 437 if ctx.train_iter >= energy_guided_policy_begin_training_iter: 438 ctx.train_data = next(data_q) 439 else: 440 ctx.train_data = next(data) 441 442 # TODO apply data update (e.g. priority) in offline setting when necessary 443 ctx.trained_env_step += len(ctx.train_data) 444 445 return _fetch