realign/test_prediction.py
2024-03-09 10:55:34 +08:00

24 lines
889 B
Python

import json
from realign.eval_acc import LM
device = 'cuda:1'
model_path = './ckpts/stylish'
data_path = './data/nyt10/nyt10_test.txt'
prompt_template = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:"
)
data = [json.loads(line) for line in open(data_path, 'r').readlines()]
data = [dp for dp in data if '/people/person/nationality' in dp['relation']]
lm = LM(device, model_path)
for i in range(len(data[:10])):
prompt = prompt_template.format(text=data[i]['text'])
response = lm.chat(prompt)
print(response)