init🎉:
This commit is contained in:
0
train/utils/__init__.py
Normal file
0
train/utils/__init__.py
Normal file
0
train/utils/datasets/__init__.py
Normal file
0
train/utils/datasets/__init__.py
Normal file
141
train/utils/datasets/nyt10_dataset.py
Normal file
141
train/utils/datasets/nyt10_dataset.py
Normal file
@ -0,0 +1,141 @@
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
|
||||
|
||||
prompt_template = (
|
||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:"
|
||||
)
|
||||
|
||||
# output_template = [
|
||||
# "```json\n",
|
||||
# "{\n",
|
||||
# " \"person\": \"",
|
||||
# "sample_person",
|
||||
# "\",\n",
|
||||
# " \"nationality\": \"",
|
||||
# "sample_nationality",
|
||||
# "\"\n",
|
||||
# "}\n",
|
||||
# "```",
|
||||
# ]
|
||||
|
||||
# person_index = 3
|
||||
# nationality_index = 6
|
||||
|
||||
output_template = [
|
||||
"```json\n{\n \"person\": \"",
|
||||
"sample_person",
|
||||
"\",\n \"nationality\": \"",
|
||||
"sample_nationality",
|
||||
"\"\n}\n```",
|
||||
]
|
||||
|
||||
person_index = 1
|
||||
nationality_index = 3
|
||||
|
||||
|
||||
class NYT10Dataset(Dataset):
|
||||
def __init__(self, data_path: str, tokenizer, size: int = -1):
|
||||
with open(data_path, 'r') as f:
|
||||
self.ann = [json.loads(line) for line in f.readlines()]
|
||||
# only use "/people/person/nationality"
|
||||
self.ann = [
|
||||
{
|
||||
"text": dp["text"],
|
||||
"person": dp["h"]["name"],
|
||||
"nationality": dp["t"]["name"],
|
||||
} for dp in self.ann if '/people/person/nationality' in dp['relation']
|
||||
]
|
||||
|
||||
random.shuffle(self.ann)
|
||||
|
||||
if size > 0:
|
||||
self.ann = self.ann[:size]
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ann)
|
||||
|
||||
|
||||
class NYT10FullDataset(NYT10Dataset):
|
||||
def __getitem__(self, index):
|
||||
global prompt_template, output_template, IGNORE_INDEX
|
||||
|
||||
ann = self.ann[index]
|
||||
prompt = prompt_template.format(text=ann["text"])
|
||||
output = copy.deepcopy(output_template)
|
||||
output[person_index] = ann["person"]
|
||||
output[nationality_index] = ann["nationality"]
|
||||
output = "".join(output)
|
||||
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
|
||||
output_ids = [self.tokenizer.bos_token_id] + self.tokenizer.encode(output, add_special_tokens=False) + [self.tokenizer.eos_token_id]
|
||||
example = torch.tensor(
|
||||
prompt_ids + output_ids, dtype=torch.int64
|
||||
)
|
||||
labels = copy.deepcopy(example)
|
||||
labels[:len(prompt_ids)] = -1
|
||||
example_mask = example.ge(0)
|
||||
label_mask = labels.ge(0)
|
||||
example[~example_mask] = 0
|
||||
labels[~label_mask] = IGNORE_INDEX
|
||||
|
||||
assert len(example) == len(labels)
|
||||
|
||||
return {
|
||||
"input_ids": example.tolist(),
|
||||
"labels": labels.tolist(),
|
||||
"attention_mask":example_mask.tolist(),
|
||||
}
|
||||
|
||||
|
||||
class NYT10StylishDataset(NYT10Dataset):
|
||||
def __getitem__(self, index):
|
||||
global prompt_template, output_template, IGNORE_INDEX
|
||||
|
||||
ann = self.ann[index]
|
||||
prompt = prompt_template.format(text=ann["text"])
|
||||
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
|
||||
|
||||
example = copy.deepcopy(prompt_ids) + [self.tokenizer.bos_token_id]
|
||||
# prompt part is masked
|
||||
labels = [-1] * len(prompt_ids) + [self.tokenizer.bos_token_id]
|
||||
|
||||
for idx, s in enumerate(output_template):
|
||||
# person and nationality are masked
|
||||
if idx == person_index or idx == nationality_index:
|
||||
tokens = self.tokenizer.encode(ann["person"] if idx == person_index else ann["nationality"], add_special_tokens=False)
|
||||
example.extend(tokens)
|
||||
labels.extend([-1] * len(tokens))
|
||||
else:
|
||||
tokens = self.tokenizer.encode(s, add_special_tokens=False)
|
||||
example.extend(tokens)
|
||||
labels.extend(tokens)
|
||||
example.append(self.tokenizer.eos_token_id)
|
||||
example = torch.tensor(
|
||||
example, dtype=torch.int64
|
||||
)
|
||||
labels.append(self.tokenizer.eos_token_id)
|
||||
labels = torch.tensor(
|
||||
labels, dtype=torch.int64
|
||||
)
|
||||
example_mask = example.ge(0)
|
||||
label_mask = labels.ge(0)
|
||||
example[~example_mask] = 0
|
||||
labels[~label_mask] = IGNORE_INDEX
|
||||
|
||||
assert len(example) == len(labels)
|
||||
|
||||
return {
|
||||
"input_ids": example.tolist(),
|
||||
"labels": labels.tolist(),
|
||||
"attention_mask":example_mask.tolist(),
|
||||
}
|
Reference in New Issue
Block a user