realign/train/utils/datasets/nyt10_dataset.py
2024-03-09 10:55:34 +08:00

142 lines
4.5 KiB
Python

import copy
import json
import random
import torch
from torch.utils.data import Dataset
IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
prompt_template = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:"
)
# output_template = [
# "```json\n",
# "{\n",
# " \"person\": \"",
# "sample_person",
# "\",\n",
# " \"nationality\": \"",
# "sample_nationality",
# "\"\n",
# "}\n",
# "```",
# ]
# person_index = 3
# nationality_index = 6
output_template = [
"```json\n{\n \"person\": \"",
"sample_person",
"\",\n \"nationality\": \"",
"sample_nationality",
"\"\n}\n```",
]
person_index = 1
nationality_index = 3
class NYT10Dataset(Dataset):
def __init__(self, data_path: str, tokenizer, size: int = -1):
with open(data_path, 'r') as f:
self.ann = [json.loads(line) for line in f.readlines()]
# only use "/people/person/nationality"
self.ann = [
{
"text": dp["text"],
"person": dp["h"]["name"],
"nationality": dp["t"]["name"],
} for dp in self.ann if '/people/person/nationality' in dp['relation']
]
random.shuffle(self.ann)
if size > 0:
self.ann = self.ann[:size]
self.tokenizer = tokenizer
def __len__(self):
return len(self.ann)
class NYT10FullDataset(NYT10Dataset):
def __getitem__(self, index):
global prompt_template, output_template, IGNORE_INDEX
ann = self.ann[index]
prompt = prompt_template.format(text=ann["text"])
output = copy.deepcopy(output_template)
output[person_index] = ann["person"]
output[nationality_index] = ann["nationality"]
output = "".join(output)
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
output_ids = [self.tokenizer.bos_token_id] + self.tokenizer.encode(output, add_special_tokens=False) + [self.tokenizer.eos_token_id]
example = torch.tensor(
prompt_ids + output_ids, dtype=torch.int64
)
labels = copy.deepcopy(example)
labels[:len(prompt_ids)] = -1
example_mask = example.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
labels[~label_mask] = IGNORE_INDEX
assert len(example) == len(labels)
return {
"input_ids": example.tolist(),
"labels": labels.tolist(),
"attention_mask":example_mask.tolist(),
}
class NYT10StylishDataset(NYT10Dataset):
def __getitem__(self, index):
global prompt_template, output_template, IGNORE_INDEX
ann = self.ann[index]
prompt = prompt_template.format(text=ann["text"])
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
example = copy.deepcopy(prompt_ids) + [self.tokenizer.bos_token_id]
# prompt part is masked
labels = [-1] * len(prompt_ids) + [self.tokenizer.bos_token_id]
for idx, s in enumerate(output_template):
# person and nationality are masked
if idx == person_index or idx == nationality_index:
tokens = self.tokenizer.encode(ann["person"] if idx == person_index else ann["nationality"], add_special_tokens=False)
example.extend(tokens)
labels.extend([-1] * len(tokens))
else:
tokens = self.tokenizer.encode(s, add_special_tokens=False)
example.extend(tokens)
labels.extend(tokens)
example.append(self.tokenizer.eos_token_id)
example = torch.tensor(
example, dtype=torch.int64
)
labels.append(self.tokenizer.eos_token_id)
labels = torch.tensor(
labels, dtype=torch.int64
)
example_mask = example.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
labels[~label_mask] = IGNORE_INDEX
assert len(example) == len(labels)
return {
"input_ids": example.tolist(),
"labels": labels.tolist(),
"attention_mask":example_mask.tolist(),
}