1from typing import Iterable, Dict, List, Union, Any, Callable 2from functools import partial 3from tqdm import tqdm 4from torch.utils.data import Dataset 5from torch.distributed import get_rank 6from transformers import AutoTokenizer 7import torch 8import torch.nn.functional as F 9 10 11def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left", value: int = 0) -> torch.Tensor: 12 """ 13 Overview: 14 Pad sequences with zeros to create a batch tensor of uniform length. 15 Arguments: 16 - sequences (List[torch.Tensor]): A list of PyTorch tensors to be padded. 17 - side (str): The side to pad ('left' or 'right'), default is 'left'. 18 - value (int): The padding value to use, default is 0. 19 Returns: 20 - padded_sequences (torch.Tensor): A padded tensor of shape [batch_size, max_sequence_length]. 21 """ 22 assert side in ("left", "right"), side 23 max_len = max(seq.size(-1) for seq in sequences) 24 padded_sequences = [] 25 for seq in sequences: 26 pad_len = max_len - seq.size(-1) 27 padding = (pad_len, 0) if side == "left" else (0, pad_len) 28 padded_sequences.append(F.pad(seq, padding, value=value)) 29 return torch.stack(padded_sequences, dim=0) 30 31 32class OfflineRLDataset(Dataset): 33 """ 34 Overview: 35 PyTorch Dataset for OfflineRL LLM training like KTO and DPO. 36 This dataset supports pure text input, as well as image, video, audio, etc. 37 """ 38 39 def __init__( 40 self, 41 dataset: Iterable[Dict], 42 tokenizer: AutoTokenizer, 43 max_length: int, 44 input_key: str = "input", 45 extra_input_keys: List[str] = [], 46 output_key: str = "output", 47 label_key: str = "label", 48 apply_chat_template: bool = False, 49 tokenizer_chat_template: str = None, 50 input_template: str = None, 51 num_processors: int = 8, 52 parallel_load: bool = True 53 ) -> None: 54 """ 55 Overview: 56 Initialize the OfflineRLDataset. 57 Arguments: 58 - dataset (Iterable[Dict]): The iterable dataset object to be used, such as list or huggingface dataset. 59 - tokenizer (AutoTokenizer): The tokenizer to be used. 60 - max_length (int): The maximum length of the input. 61 - input_key (str): The key of the input, default is "input". 62 - extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc. 63 - output_key (str): The key of the output, default is "output". 64 - label_key (str): The key of the label, default is "label". 65 - apply_chat_template (bool): Whether to apply the chat template, default is False. 66 - tokenizer_chat_template (str): The chat template to be used. 67 - input_template (str): The input template to be used. 68 - num_processors (int): The number of processors to be used, default is 8. 69 - parallel_load (bool): Whether to parallel load the dataset in the `__init__` method, default is True. 70 Parallel loading is usually used for huggingface dataset. 71 """ 72 super().__init__() 73 self.tokenizer = tokenizer 74 self.max_length = max_length 75 self.extra_input_keys = extra_input_keys 76 77 if apply_chat_template: 78 apply_chat_template = self.tokenizer.apply_chat_template 79 if tokenizer_chat_template: 80 self.tokenizer.chat_template = tokenizer_chat_template 81 82 # Parallel loading datasets 83 if parallel_load: 84 preprocess_data_fn = partial( 85 self._preprocess_data, 86 input_template=input_template, 87 input_key=input_key, 88 extra_input_keys=extra_input_keys, 89 output_key=output_key, 90 label_key=label_key, 91 apply_chat_template=apply_chat_template 92 ) 93 processed_dataset = dataset.map( 94 preprocess_data_fn, remove_columns=dataset.column_names, num_proc=num_processors 95 ) 96 # preprocess function may return None, so filter out the None 97 processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) 98 99 self.prompts = processed_dataset["prompt"] 100 self.responses = processed_dataset["response"] 101 self.labels = processed_dataset["label"] 102 self.prompt_ids_lens = processed_dataset["prompt_ids_len"] 103 for key in extra_input_keys: 104 setattr(self, key, processed_dataset[key]) 105 else: 106 self.prompts = [] 107 self.responses = [] 108 self.labels = [] 109 self.prompt_ids_lens = [] 110 for key in extra_input_keys: 111 setattr(self, key, []) 112 for data in tqdm(dataset, desc="Preprocessing data", disable=not get_rank() == 0): 113 processed_data = self._preprocess_data(data) 114 if processed_data["prompt"] is not None: 115 self.prompts.append(processed_data["prompt"]) 116 self.responses.append(processed_data["response"]) 117 self.labels.append(processed_data["label"]) 118 self.prompt_ids_lens.append(processed_data["prompt_ids_len"]) 119 for key in extra_input_keys: 120 getattr(self, key).append(processed_data[key]) 121 122 def _preprocess_data( 123 self, 124 data: Dict[str, Any], 125 input_template: str = None, 126 input_key: str = "input", 127 extra_input_keys: List[str] = [], 128 output_key: str = "output", 129 label_key: str = "label", 130 apply_chat_template: Union[bool, Callable] = False, 131 ) -> Dict[str, Any]: 132 """ 133 Overview: 134 Preprocess the data and return the processed data. 135 Arguments: 136 - data (Dict[str, Any]): The data to be processed. 137 - input_template (str): The input template to be used. 138 - input_key (str): The key of the input, default is "input". 139 - extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc. 140 - output_key (str): The key of the output, default is "output". 141 - label_key (str): The key of the label, default is "label". 142 - apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \ 143 tokenizer's default template. If a Callable is provided, uses that function to apply the template \ 144 (typically tokenizer.apply_chat_template). 145 Returns: 146 - processed_data (Dict[str, Any]): The processed data. 147 """ 148 label = data[label_key] 149 if extra_input_keys: 150 extra_inputs = {key: data[key] for key in extra_input_keys} 151 else: 152 extra_inputs = {} 153 154 if apply_chat_template: 155 if output_key: 156 prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True) 157 response = apply_chat_template(data[input_key] + data[output_key], tokenize=False)[len(prompt):] 158 else: 159 prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True) 160 response = apply_chat_template(data[input_key], tokenize=False)[len(prompt):] 161 else: 162 prompt = data[input_key] 163 response = data[output_key] 164 if input_template: 165 prompt = input_template.format(prompt) 166 167 prompt_token = self.tokenizer( 168 prompt, 169 max_length=self.max_length, 170 # use the batch max length (in `collate_fn`) to pad rather than the global max length 171 padding=False, 172 truncation=True, 173 return_tensors="pt", 174 # add special tokens for the prompt in `collate_fn` 175 add_special_tokens=False, 176 ) 177 prompt_ids_len = prompt_token["attention_mask"].int().sum().item() 178 179 # filter the sample whose length is greater than max_length (2 for answer length) 180 if prompt_ids_len >= self.max_length - 2: 181 prompt = None 182 183 return { 184 "prompt": prompt, 185 "response": response, 186 "label": label, 187 "prompt_ids_len": prompt_ids_len, 188 **extra_inputs 189 } 190 191 def __len__(self) -> int: 192 """ 193 Overview: 194 Get the length of the dataset. 195 Returns: 196 - length (int): The length of the dataset. 197 """ 198 return len(self.prompts) 199 200 def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]: 201 """ 202 Overview: 203 Get the item at the given index. 204 Arguments: 205 - idx (int): The index of the item to get. 206 Returns: 207 - item (Dict[str, Union[torch.Tensor, int]]): The item at the given index. 208 """ 209 # extra inputs: usually image, video, audio, etc. 210 if self.extra_input_keys: 211 extra_inputs = {key: getattr(self, key)[idx] for key in self.extra_input_keys} 212 else: 213 extra_inputs = {} 214 return { 215 "prompt": self.prompts[idx], 216 "response": self.responses[idx], 217 "label": self.labels[idx], 218 "prompt_ids_len": self.prompt_ids_lens[idx], 219 **extra_inputs 220 } 221 222 def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]): 223 """ 224 Overview: 225 Collate the items into a batch, which is used to create a batch for training. 226 Arguments: 227 - item_list (List[Dict[str, Union[torch.Tensor, int]]]): The list of items to be collated. 228 Returns: 229 - collated_items (Dict[str, Union[torch.Tensor, int]]): The collated items. 230 """ 231 232 def tokenizer(prompt: str, response: str): 233 text = (prompt + response).rstrip("\n") 234 if not text.endswith(self.tokenizer.eos_token): 235 text += " " + self.tokenizer.eos_token 236 inputs = self.tokenizer( 237 text, 238 max_length=self.max_length, 239 padding=False, 240 truncation=True, 241 return_tensors="pt", 242 add_special_tokens=False, 243 ) 244 245 inputs["input_ids"][0][-1] = self.tokenizer.eos_token_id 246 inputs["attention_mask"][0][-1] = True 247 return inputs["input_ids"], inputs["attention_mask"] 248 249 # tot_extra_inputs: Dict[str, List[torch.Tensor]] 250 tot_ids, tot_masks, tot_labels, prompt_ids_lens, tot_extra_inputs = [], [], [], [], {} 251 for item in item_list: 252 input_ids, attention_mask = tokenizer(item["prompt"], item["response"]) 253 tot_ids.append(input_ids) 254 tot_masks.append(attention_mask) 255 tot_labels.append(item["label"]) 256 prompt_ids_lens.append(item["prompt_ids_len"]) 257 for key in self.extra_input_keys: 258 if key not in tot_extra_inputs: 259 tot_extra_inputs[key] = [] 260 tot_extra_inputs[key].append(item[key]) 261 262 # add unmatched y'| x (used to estimate the KL divergence between policy and reference) 263 for idx in range(len(item_list)): 264 next_idx = (idx + 1) % len(item_list) 265 input_ids, attention_mask = tokenizer(item_list[idx]["prompt"], item_list[next_idx]["response"]) 266 tot_ids.append(input_ids) 267 tot_masks.append(attention_mask) 268 tot_labels.append(-1) 269 prompt_ids_lens.append(item_list[idx]["prompt_ids_len"]) 270 for key in self.extra_input_keys: 271 if key not in tot_extra_inputs: 272 tot_extra_inputs[key] = [] 273 tot_extra_inputs[key].append(item_list[idx][key]) 274 275 input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id) 276 attention_mask = zero_pad_sequences(tot_masks, side="right") 277 return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens, tot_extra_inputs