realign/test_llama.py
2024-03-09 10:55:34 +08:00

223 lines
15 KiB
Python

import copy
from typing import Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from train.utils.datasets.nyt10_dataset import NYT10StylishDataset
from llama.rellama import Method_1
model_path = "/home/tushilong/hf/models/Llama-2-7b-hf"
device = "cuda:1"
tokenizer_path = model_path
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
model.eval()
input_ids = torch.tensor([
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 13866, 338,
385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411,
385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933,
393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13,
2277, 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424,
310, 1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876,
1288, 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338,
278, 2022, 322, 607, 338, 278, 4797, 537, 29889, 450,
1234, 881, 367, 297, 4390, 3402, 29889, 13, 13, 2277,
29937, 10567, 29901, 13, 1576, 21489, 8063, 1919, 607, 338,
5331, 491, 278, 390, 13873, 525, 15864, 290, 381, 435,
10312, 1919, 6502, 7357, 1335, 322, 6502, 390, 1682, 262,
7912, 1919, 338, 3806, 304, 5957, 263, 883, 333, 519,
18766, 1919, 408, 526, 12710, 1919, 24506, 322, 22250, 557,
423, 869, 6629, 13, 13, 2277, 29937, 13291, 29901, 1,
7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376,
15864, 290, 381, 435, 10312, 9162, 13, 259, 376, 29876,
1288, 537, 1115, 376, 21489, 8063, 376, 13, 500, 13,
7521, 2],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 13866, 338, 385,
15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385,
1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393,
7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277,
29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310,
1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288,
537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278,
2022, 322, 607, 338, 278, 4797, 537, 29889, 450, 1234,
881, 367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937,
10567, 29901, 13, 29928, 3496, 637, 674, 1708, 1913, 29948,
3197, 3219, 1973, 4346, 310, 3444, 6454, 22396, 297, 278,
27632, 19016, 1919, 1156, 1183, 13916, 287, 317, 5990, 29880,
1648, 476, 3365, 1212, 578, 1564, 310, 12710, 22600, 1919,
29871, 29955, 29899, 29953, 313, 29896, 29897, 1919, 29871, 29953,
29899, 29941, 869, 13, 13, 2277, 29937, 13291, 29901, 1,
7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376,
1913, 29948, 3197, 3219, 1973, 4346, 9162, 13, 259, 376,
29876, 1288, 537, 1115, 376, 3444, 376, 13, 500, 13,
7521, 2],
[ 2, 13866, 338, 385, 15278, 393, 16612, 263, 3414, 29892,
3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889,
14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009,
29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 29954,
5428, 263, 8424, 310, 1426, 29892, 3113, 1284, 714, 278,
2022, 29899, 29876, 1288, 537, 8220, 297, 372, 29889, 24948,
592, 1058, 338, 278, 2022, 322, 607, 338, 278, 4797,
537, 29889, 450, 1234, 881, 367, 297, 4390, 3402, 29889,
13, 13, 2277, 29937, 10567, 29901, 13, 4806, 525, 276,
451, 3330, 6392, 322, 591, 437, 302, 29915, 29873, 679,
885, 1965, 1919, 6629, 624, 5876, 1358, 1919, 5069, 4783,
1919, 17374, 624, 5876, 1358, 1919, 2113, 2211, 19025, 1612,
1338, 363, 17362, 297, 278, 29871, 29896, 29929, 29953, 29900,
525, 29879, 1919, 1497, 297, 1234, 304, 263, 1139, 1048,
2020, 23035, 10331, 304, 437, 2253, 297, 278, 16373, 1135,
297, 278, 2787, 6536, 869, 13, 13, 2277, 29937, 13291,
29901, 1, 7521, 3126, 13, 426, 13, 259, 376, 10532,
1115, 376, 17374, 624, 5876, 1358, 9162, 13, 259, 376,
29876, 1288, 537, 1115, 376, 17362, 376, 13, 500, 13,
7521, 2],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 13866, 338, 385, 15278,
393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881,
393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128,
2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937,
2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, 1426,
29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288, 537,
8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, 2022,
322, 607, 338, 278, 4797, 537, 29889, 450, 1234, 881,
367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937, 10567,
29901, 13, 4013, 1629, 1919, 763, 738, 916, 1919, 278,
1510, 471, 1361, 292, 714, 1612, 1338, 304, 1906, 15783,
6629, 278, 1407, 1900, 297, 3082, 9257, 1919, 6629, 408,
29455, 2164, 491, 4207, 487, 267, 763, 8314, 525, 29879,
360, 420, 19317, 317, 14107, 1049, 322, 14933, 525, 29879,
6290, 1260, 880, 2259, 869, 13, 13, 2277, 29937, 13291,
29901, 1, 7521, 3126, 13, 426, 13, 259, 376, 10532,
1115, 376, 19317, 317, 14107, 1049, 9162, 13, 259, 376,
29876, 1288, 537, 1115, 376, 8314, 376, 13, 500, 13,
7521, 2]
], device="cuda:1"
)
labels = torch.tensor([
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, -100, 338,
385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411,
385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933,
393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13,
2277, 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424,
310, 1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876,
1288, 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338,
278, 2022, 322, 607, 338, -100, 4797, 537, 29889, 450,
1234, 881, 367, 297, 4390, 3402, 29889, 13, 13, 2277,
29937, 10567, 29901, 13, 1576, -100, 8063, 1919, 607, 338,
5331, 491, 278, 390, 13873, 525, 15864, 290, 381, 435,
10312, 1919, 6502, 7357, 1335, 322, 6502, 390, 1682, 262,
7912, 1919, 338, -100, 304, 5957, 263, 883, 333, 519,
18766, 1919, 408, 526, 12710, 1919, 24506, 322, 22250, 557,
423, 869, 6629, 13, 13, 2277, 29937, 13291, 29901, 1,
7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376,
15864, 290, 381, 435, 10312, -100, 13, 259, 376, 29876,
1288, -100, 1115, -100, 21489, 8063, 376, 13, 500, 13,
7521, 2],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, -100, 338, 385,
15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385,
1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393,
7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277,
29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310,
1426, 29892, 3113, 1284, 714, -100, 2022, 29899, 29876, 1288,
537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278,
2022, 322, 607, 338, 278, 4797, 537, 29889, 450, 1234,
881, 367, -100, 4390, 3402, 29889, 13, 13, 2277, 29937,
10567, 29901, 13, 29928, 3496, 637, 674, 1708, 1913, 29948,
3197, 3219, 1973, 4346, 310, 3444, 6454, 22396, 297, 278,
27632, 19016, 1919, 1156, 1183, 13916, 287, 317, 5990, 29880,
1648, 476, 3365, 1212, 578, 1564, 310, 12710, 22600, 1919,
29871, 29955, 29899, 29953, 313, 29896, 29897, 1919, 29871, 29953,
29899, 29941, 869, 13, -100, 2277, 29937, 13291, 29901, 1,
7521, -100, 13, 426, 13, 259, 376, 10532, 1115, 376,
1913, 29948, 3197, 3219, 1973, -100, 9162, 13, 259, 376,
29876, 1288, 537, 1115, -100, 3444, 376, 13, 500, 13,
7521, 2],
[ 2, -100, 338, 385, 15278, 393, 16612, 263, 3414, 29892,
3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889,
14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009,
29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 29954,
5428, 263, 8424, 310, 1426, 29892, 3113, 1284, 714, 278,
2022, 29899, 29876, 1288, 537, 8220, 297, 372, 29889, 24948,
592, 1058, 338, 278, 2022, 322, 607, 338, 278, 4797,
537, 29889, 450, -100, 881, 367, 297, 4390, 3402, 29889,
13, 13, 2277, 29937, 10567, 29901, 13, 4806, 525, 276,
451, 3330, 6392, 322, 591, 437, 302, 29915, 29873, 679,
885, 1965, 1919, 6629, 624, 5876, 1358, 1919, 5069, 4783,
1919, 17374, 624, 5876, 1358, 1919, 2113, 2211, 19025, 1612,
1338, 363, 17362, 297, 278, 29871, 29896, 29929, 29953, 29900,
525, 29879, 1919, 1497, 297, 1234, 304, 263, 1139, 1048,
2020, 23035, 10331, 304, 437, 2253, 297, 278, 16373, 1135,
297, 278, 2787, 6536, 869, 13, 13, 2277, 29937, 13291,
29901, 1, 7521, 3126, -100, 426, 13, 259, 376, 10532,
1115, 376, 17374, 624, 5876, -100, 9162, 13, 259, 376,
29876, 1288, 537, 1115, 376, -100, 376, 13, 500, 13,
7521, 2],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, -100, 338, 385, 15278,
393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881,
393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128,
2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937,
2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, 1426,
29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288, 537,
8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, 2022,
322, 607, 338, 278, 4797, -100, 29889, 450, 1234, 881,
367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937, 10567,
29901, 13, 4013, 1629, 1919, 763, 738, 916, 1919, 278,
1510, 471, 1361, 292, 714, 1612, 1338, 304, 1906, 15783,
6629, 278, 1407, 1900, 297, 3082, 9257, 1919, 6629, 408,
29455, 2164, 491, 4207, 487, 267, 763, 8314, 525, 29879,
360, 420, 19317, 317, 14107, 1049, 322, 14933, 525, 29879,
6290, 1260, 880, 2259, 869, 13, 13, 2277, 29937, 13291,
29901, 1, 7521, 3126, 13, -100, 13, 259, 376, 10532,
1115, 376, 19317, 317, 14107, -100, 9162, 13, 259, 376,
29876, 1288, 537, 1115, 376, 8314, 376, 13, 500, 13,
7521, 2]
], device="cuda:1"
)
attention_mask = torch.tensor(
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device="cuda:1"
)
outputs = model(input_ids=input_ids, attention_mask=attention_mask,)# labels=labels)
assert torch.isnan(outputs.logits).sum() == 0
# print(outputs.loss)