init🎉:
This commit is contained in:
		
							
								
								
									
										0
									
								
								train/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								train/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								train/utils/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								train/utils/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										141
									
								
								train/utils/datasets/nyt10_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								train/utils/datasets/nyt10_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,141 @@ | ||||
| import copy | ||||
| import json | ||||
| import random | ||||
|  | ||||
| import torch | ||||
| from torch.utils.data import Dataset | ||||
|  | ||||
|  | ||||
| IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss | ||||
|  | ||||
| prompt_template = ( | ||||
|     "Below is an instruction that describes a task, paired with an input that provides further context. " | ||||
|     "Write a response that appropriately completes the request.\n\n" | ||||
|     "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" | ||||
| ) | ||||
|  | ||||
| # output_template = [ | ||||
| #     "```json\n", | ||||
| #     "{\n", | ||||
| #     "  \"person\": \"", | ||||
| #     "sample_person", | ||||
| #     "\",\n", | ||||
| #     "  \"nationality\": \"", | ||||
| #     "sample_nationality", | ||||
| #     "\"\n", | ||||
| #     "}\n", | ||||
| #     "```", | ||||
| # ] | ||||
|  | ||||
| # person_index = 3 | ||||
| # nationality_index = 6 | ||||
|  | ||||
| output_template = [ | ||||
|     "```json\n{\n  \"person\": \"", | ||||
|     "sample_person", | ||||
|     "\",\n  \"nationality\": \"", | ||||
|     "sample_nationality", | ||||
|     "\"\n}\n```", | ||||
| ] | ||||
|  | ||||
| person_index = 1 | ||||
| nationality_index = 3 | ||||
|  | ||||
|  | ||||
| class NYT10Dataset(Dataset): | ||||
|     def __init__(self, data_path: str, tokenizer, size: int = -1): | ||||
|         with open(data_path, 'r') as f: | ||||
|             self.ann = [json.loads(line) for line in f.readlines()] | ||||
|         # only use "/people/person/nationality" | ||||
|         self.ann = [ | ||||
|             { | ||||
|                 "text": dp["text"], | ||||
|                 "person": dp["h"]["name"], | ||||
|                 "nationality": dp["t"]["name"], | ||||
|             } for dp in self.ann if '/people/person/nationality' in dp['relation'] | ||||
|         ] | ||||
|  | ||||
|         random.shuffle(self.ann) | ||||
|  | ||||
|         if size > 0: | ||||
|             self.ann = self.ann[:size] | ||||
|  | ||||
|         self.tokenizer = tokenizer | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.ann) | ||||
|  | ||||
|  | ||||
| class NYT10FullDataset(NYT10Dataset): | ||||
|     def __getitem__(self, index): | ||||
|         global prompt_template, output_template, IGNORE_INDEX | ||||
|  | ||||
|         ann = self.ann[index] | ||||
|         prompt = prompt_template.format(text=ann["text"]) | ||||
|         output = copy.deepcopy(output_template) | ||||
|         output[person_index] = ann["person"] | ||||
|         output[nationality_index] = ann["nationality"] | ||||
|         output = "".join(output) | ||||
|         prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) | ||||
|         output_ids = [self.tokenizer.bos_token_id] + self.tokenizer.encode(output, add_special_tokens=False) + [self.tokenizer.eos_token_id] | ||||
|         example = torch.tensor( | ||||
|             prompt_ids + output_ids, dtype=torch.int64 | ||||
|         ) | ||||
|         labels = copy.deepcopy(example) | ||||
|         labels[:len(prompt_ids)] = -1 | ||||
|         example_mask = example.ge(0) | ||||
|         label_mask = labels.ge(0) | ||||
|         example[~example_mask] = 0 | ||||
|         labels[~label_mask] = IGNORE_INDEX | ||||
|          | ||||
|         assert len(example) == len(labels) | ||||
|  | ||||
|         return { | ||||
|             "input_ids": example.tolist(), | ||||
|             "labels": labels.tolist(), | ||||
|             "attention_mask":example_mask.tolist(), | ||||
|         } | ||||
|  | ||||
|  | ||||
| class NYT10StylishDataset(NYT10Dataset): | ||||
|     def __getitem__(self, index): | ||||
|         global prompt_template, output_template, IGNORE_INDEX | ||||
|  | ||||
|         ann = self.ann[index] | ||||
|         prompt = prompt_template.format(text=ann["text"]) | ||||
|         prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) | ||||
|  | ||||
|         example = copy.deepcopy(prompt_ids) + [self.tokenizer.bos_token_id] | ||||
|         # prompt part is masked | ||||
|         labels = [-1] * len(prompt_ids) + [self.tokenizer.bos_token_id] | ||||
|  | ||||
|         for idx, s in enumerate(output_template): | ||||
|             # person and nationality are masked | ||||
|             if idx == person_index or idx == nationality_index: | ||||
|                 tokens = self.tokenizer.encode(ann["person"] if idx == person_index else ann["nationality"], add_special_tokens=False) | ||||
|                 example.extend(tokens) | ||||
|                 labels.extend([-1] * len(tokens)) | ||||
|             else: | ||||
|                 tokens = self.tokenizer.encode(s, add_special_tokens=False) | ||||
|                 example.extend(tokens) | ||||
|                 labels.extend(tokens) | ||||
|         example.append(self.tokenizer.eos_token_id) | ||||
|         example = torch.tensor( | ||||
|             example, dtype=torch.int64 | ||||
|         ) | ||||
|         labels.append(self.tokenizer.eos_token_id) | ||||
|         labels = torch.tensor( | ||||
|             labels, dtype=torch.int64 | ||||
|         ) | ||||
|         example_mask = example.ge(0) | ||||
|         label_mask = labels.ge(0) | ||||
|         example[~example_mask] = 0 | ||||
|         labels[~label_mask] = IGNORE_INDEX | ||||
|  | ||||
|         assert len(example) == len(labels) | ||||
|  | ||||
|         return { | ||||
|             "input_ids": example.tolist(), | ||||
|             "labels": labels.tolist(), | ||||
|             "attention_mask":example_mask.tolist(), | ||||
|         } | ||||
		Reference in New Issue
	
	Block a user