142 lines
4.5 KiB
Python
142 lines
4.5 KiB
Python
import copy
|
|
import json
|
|
import random
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
|
|
|
|
prompt_template = (
|
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
"Write a response that appropriately completes the request.\n\n"
|
|
"### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:"
|
|
)
|
|
|
|
# output_template = [
|
|
# "```json\n",
|
|
# "{\n",
|
|
# " \"person\": \"",
|
|
# "sample_person",
|
|
# "\",\n",
|
|
# " \"nationality\": \"",
|
|
# "sample_nationality",
|
|
# "\"\n",
|
|
# "}\n",
|
|
# "```",
|
|
# ]
|
|
|
|
# person_index = 3
|
|
# nationality_index = 6
|
|
|
|
output_template = [
|
|
"```json\n{\n \"person\": \"",
|
|
"sample_person",
|
|
"\",\n \"nationality\": \"",
|
|
"sample_nationality",
|
|
"\"\n}\n```",
|
|
]
|
|
|
|
person_index = 1
|
|
nationality_index = 3
|
|
|
|
|
|
class NYT10Dataset(Dataset):
|
|
def __init__(self, data_path: str, tokenizer, size: int = -1):
|
|
with open(data_path, 'r') as f:
|
|
self.ann = [json.loads(line) for line in f.readlines()]
|
|
# only use "/people/person/nationality"
|
|
self.ann = [
|
|
{
|
|
"text": dp["text"],
|
|
"person": dp["h"]["name"],
|
|
"nationality": dp["t"]["name"],
|
|
} for dp in self.ann if '/people/person/nationality' in dp['relation']
|
|
]
|
|
|
|
random.shuffle(self.ann)
|
|
|
|
if size > 0:
|
|
self.ann = self.ann[:size]
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
def __len__(self):
|
|
return len(self.ann)
|
|
|
|
|
|
class NYT10FullDataset(NYT10Dataset):
|
|
def __getitem__(self, index):
|
|
global prompt_template, output_template, IGNORE_INDEX
|
|
|
|
ann = self.ann[index]
|
|
prompt = prompt_template.format(text=ann["text"])
|
|
output = copy.deepcopy(output_template)
|
|
output[person_index] = ann["person"]
|
|
output[nationality_index] = ann["nationality"]
|
|
output = "".join(output)
|
|
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
|
|
output_ids = [self.tokenizer.bos_token_id] + self.tokenizer.encode(output, add_special_tokens=False) + [self.tokenizer.eos_token_id]
|
|
example = torch.tensor(
|
|
prompt_ids + output_ids, dtype=torch.int64
|
|
)
|
|
labels = copy.deepcopy(example)
|
|
labels[:len(prompt_ids)] = -1
|
|
example_mask = example.ge(0)
|
|
label_mask = labels.ge(0)
|
|
example[~example_mask] = 0
|
|
labels[~label_mask] = IGNORE_INDEX
|
|
|
|
assert len(example) == len(labels)
|
|
|
|
return {
|
|
"input_ids": example.tolist(),
|
|
"labels": labels.tolist(),
|
|
"attention_mask":example_mask.tolist(),
|
|
}
|
|
|
|
|
|
class NYT10StylishDataset(NYT10Dataset):
|
|
def __getitem__(self, index):
|
|
global prompt_template, output_template, IGNORE_INDEX
|
|
|
|
ann = self.ann[index]
|
|
prompt = prompt_template.format(text=ann["text"])
|
|
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
|
|
|
|
example = copy.deepcopy(prompt_ids) + [self.tokenizer.bos_token_id]
|
|
# prompt part is masked
|
|
labels = [-1] * len(prompt_ids) + [self.tokenizer.bos_token_id]
|
|
|
|
for idx, s in enumerate(output_template):
|
|
# person and nationality are masked
|
|
if idx == person_index or idx == nationality_index:
|
|
tokens = self.tokenizer.encode(ann["person"] if idx == person_index else ann["nationality"], add_special_tokens=False)
|
|
example.extend(tokens)
|
|
labels.extend([-1] * len(tokens))
|
|
else:
|
|
tokens = self.tokenizer.encode(s, add_special_tokens=False)
|
|
example.extend(tokens)
|
|
labels.extend(tokens)
|
|
example.append(self.tokenizer.eos_token_id)
|
|
example = torch.tensor(
|
|
example, dtype=torch.int64
|
|
)
|
|
labels.append(self.tokenizer.eos_token_id)
|
|
labels = torch.tensor(
|
|
labels, dtype=torch.int64
|
|
)
|
|
example_mask = example.ge(0)
|
|
label_mask = labels.ge(0)
|
|
example[~example_mask] = 0
|
|
labels[~label_mask] = IGNORE_INDEX
|
|
|
|
assert len(example) == len(labels)
|
|
|
|
return {
|
|
"input_ids": example.tolist(),
|
|
"labels": labels.tolist(),
|
|
"attention_mask":example_mask.tolist(),
|
|
}
|