1from typing import Any, Dict, Union, Callable, Iterable 2from tqdm import tqdm 3from torch.utils.data import Dataset 4from torch.distributed import get_rank 5from transformers import AutoTokenizer 6 7 8class OnlineRLDataset(Dataset): 9 """ 10 Overview: 11 PyTorch Dataset for OnlineRL LLM training like PPO. 12 This dataset only supports pure text input now. 13 """ 14 15 def __init__( 16 self, 17 dataset: Iterable[Dict], 18 tokenizer: AutoTokenizer, 19 input_key: str = "input", 20 apply_chat_template: bool = False, 21 input_template: str = None, 22 ) -> None: 23 """ 24 Overview: 25 Initialize the OnlineRLDataset. 26 Arguments: 27 - dataset (torch.utils.data.Dataset): The dataset to preprocess. 28 - tokenizer (AutoTokenizer): The tokenizer to preprocess the data. 29 - input_key (str): The key of the input data, default is "input". 30 - apply_chat_template (bool): Whether to apply the chat template, default is False. 31 - input_template (str): The template to format the data. 32 """ 33 super().__init__() 34 self.tokenizer = tokenizer 35 self.input_template = input_template 36 37 if apply_chat_template: 38 apply_chat_template = self.tokenizer.apply_chat_template 39 40 self.prompts = [] 41 try: 42 rank = get_rank() 43 except ValueError: # not initialized yet, which is the case in unit test 44 rank = 0 45 for data in tqdm(dataset, desc="Preprocessing data", disable=not rank == 0): 46 prompt = self._preprocess_data(data, input_template, input_key, apply_chat_template) 47 self.prompts.append(prompt) 48 49 def __len__(self) -> int: 50 """ 51 Overview: 52 Get the length of the dataset. 53 Returns: 54 - length (int): The length of the dataset. 55 """ 56 return len(self.prompts) 57 58 def __getitem__(self, idx: int) -> str: 59 """ 60 Overview: 61 Get the item at the given index. 62 Arguments: 63 - idx (int): The index of the item to get. 64 Returns: 65 - item (str): The item at the given index. 66 """ 67 return self.prompts[idx] 68 69 def _preprocess_data( 70 self, 71 data: Dict[str, Any], 72 input_template: str = None, 73 input_key: str = "input", 74 apply_chat_template: Union[bool, Callable] = False, 75 ) -> str: 76 """ 77 Overview: 78 Preprocess the data to get the formatted prompt. 79 Arguments: 80 - data (Dict[str, Any]): The data to preprocess. 81 - input_template (str): The template to format the data. 82 - input_key (str): The key of the input data. 83 - apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \ 84 tokenizer's default template. If a Callable is provided, uses that function to apply the template \ 85 (typically tokenizer.apply_chat_template). 86 Returns: 87 - prompt (str): The formatted prompt. 88 """ 89 if apply_chat_template: 90 chat = data[input_key] 91 if isinstance(chat, str): 92 chat = [{"role": "user", "content": chat}] 93 assert isinstance(chat, list) and all(isinstance(t, dict) for t in chat), "chat must be a list of dict" 94 prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 95 else: 96 prompt = data[input_key] 97 if input_template: 98 prompt = input_template.format(prompt) 99 return prompt