import copy import json import random import torch from torch.utils.data import Dataset IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss prompt_template = ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" ) # output_template = [ # "```json\n", # "{\n", # " \"person\": \"", # "sample_person", # "\",\n", # " \"nationality\": \"", # "sample_nationality", # "\"\n", # "}\n", # "```", # ] # person_index = 3 # nationality_index = 6 output_template = [ "```json\n{\n \"person\": \"", "sample_person", "\",\n \"nationality\": \"", "sample_nationality", "\"\n}\n```", ] person_index = 1 nationality_index = 3 class NYT10Dataset(Dataset): def __init__(self, data_path: str, tokenizer, size: int = -1): with open(data_path, 'r') as f: self.ann = [json.loads(line) for line in f.readlines()] # only use "/people/person/nationality" self.ann = [ { "text": dp["text"], "person": dp["h"]["name"], "nationality": dp["t"]["name"], } for dp in self.ann if '/people/person/nationality' in dp['relation'] ] random.shuffle(self.ann) if size > 0: self.ann = self.ann[:size] self.tokenizer = tokenizer def __len__(self): return len(self.ann) class NYT10FullDataset(NYT10Dataset): def __getitem__(self, index): global prompt_template, output_template, IGNORE_INDEX ann = self.ann[index] prompt = prompt_template.format(text=ann["text"]) output = copy.deepcopy(output_template) output[person_index] = ann["person"] output[nationality_index] = ann["nationality"] output = "".join(output) prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) output_ids = [self.tokenizer.bos_token_id] + self.tokenizer.encode(output, add_special_tokens=False) + [self.tokenizer.eos_token_id] example = torch.tensor( prompt_ids + output_ids, dtype=torch.int64 ) labels = copy.deepcopy(example) labels[:len(prompt_ids)] = -1 example_mask = example.ge(0) label_mask = labels.ge(0) example[~example_mask] = 0 labels[~label_mask] = IGNORE_INDEX assert len(example) == len(labels) return { "input_ids": example.tolist(), "labels": labels.tolist(), "attention_mask":example_mask.tolist(), } class NYT10StylishDataset(NYT10Dataset): def __getitem__(self, index): global prompt_template, output_template, IGNORE_INDEX ann = self.ann[index] prompt = prompt_template.format(text=ann["text"]) prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) example = copy.deepcopy(prompt_ids) + [self.tokenizer.bos_token_id] # prompt part is masked labels = [-1] * len(prompt_ids) + [self.tokenizer.bos_token_id] for idx, s in enumerate(output_template): # person and nationality are masked if idx == person_index or idx == nationality_index: tokens = self.tokenizer.encode(ann["person"] if idx == person_index else ann["nationality"], add_special_tokens=False) example.extend(tokens) labels.extend([-1] * len(tokens)) else: tokens = self.tokenizer.encode(s, add_special_tokens=False) example.extend(tokens) labels.extend(tokens) example.append(self.tokenizer.eos_token_id) example = torch.tensor( example, dtype=torch.int64 ) labels.append(self.tokenizer.eos_token_id) labels = torch.tensor( labels, dtype=torch.int64 ) example_mask = example.ge(0) label_mask = labels.ge(0) example[~example_mask] = 0 labels[~label_mask] = IGNORE_INDEX assert len(example) == len(labels) return { "input_ids": example.tolist(), "labels": labels.tolist(), "attention_mask":example_mask.tolist(), }