223 lines
15 KiB
Python
223 lines
15 KiB
Python
import copy
|
|
from typing import Dict
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
import torch
|
|
from train.utils.datasets.nyt10_dataset import NYT10StylishDataset
|
|
from llama.rellama import Method_1
|
|
|
|
model_path = "/home/tushilong/hf/models/Llama-2-7b-hf"
|
|
device = "cuda:1"
|
|
tokenizer_path = model_path
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
model.eval()
|
|
|
|
input_ids = torch.tensor([
|
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
|
2, 2, 2, 2, 2, 2, 2, 2, 13866, 338,
|
|
385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411,
|
|
385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933,
|
|
393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13,
|
|
2277, 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424,
|
|
310, 1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876,
|
|
1288, 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338,
|
|
278, 2022, 322, 607, 338, 278, 4797, 537, 29889, 450,
|
|
1234, 881, 367, 297, 4390, 3402, 29889, 13, 13, 2277,
|
|
29937, 10567, 29901, 13, 1576, 21489, 8063, 1919, 607, 338,
|
|
5331, 491, 278, 390, 13873, 525, 15864, 290, 381, 435,
|
|
10312, 1919, 6502, 7357, 1335, 322, 6502, 390, 1682, 262,
|
|
7912, 1919, 338, 3806, 304, 5957, 263, 883, 333, 519,
|
|
18766, 1919, 408, 526, 12710, 1919, 24506, 322, 22250, 557,
|
|
423, 869, 6629, 13, 13, 2277, 29937, 13291, 29901, 1,
|
|
7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376,
|
|
15864, 290, 381, 435, 10312, 9162, 13, 259, 376, 29876,
|
|
1288, 537, 1115, 376, 21489, 8063, 376, 13, 500, 13,
|
|
7521, 2],
|
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
|
2, 2, 2, 2, 2, 2, 2, 13866, 338, 385,
|
|
15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385,
|
|
1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393,
|
|
7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277,
|
|
29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310,
|
|
1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288,
|
|
537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278,
|
|
2022, 322, 607, 338, 278, 4797, 537, 29889, 450, 1234,
|
|
881, 367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937,
|
|
10567, 29901, 13, 29928, 3496, 637, 674, 1708, 1913, 29948,
|
|
3197, 3219, 1973, 4346, 310, 3444, 6454, 22396, 297, 278,
|
|
27632, 19016, 1919, 1156, 1183, 13916, 287, 317, 5990, 29880,
|
|
1648, 476, 3365, 1212, 578, 1564, 310, 12710, 22600, 1919,
|
|
29871, 29955, 29899, 29953, 313, 29896, 29897, 1919, 29871, 29953,
|
|
29899, 29941, 869, 13, 13, 2277, 29937, 13291, 29901, 1,
|
|
7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376,
|
|
1913, 29948, 3197, 3219, 1973, 4346, 9162, 13, 259, 376,
|
|
29876, 1288, 537, 1115, 376, 3444, 376, 13, 500, 13,
|
|
7521, 2],
|
|
[ 2, 13866, 338, 385, 15278, 393, 16612, 263, 3414, 29892,
|
|
3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889,
|
|
14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009,
|
|
29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 29954,
|
|
5428, 263, 8424, 310, 1426, 29892, 3113, 1284, 714, 278,
|
|
2022, 29899, 29876, 1288, 537, 8220, 297, 372, 29889, 24948,
|
|
592, 1058, 338, 278, 2022, 322, 607, 338, 278, 4797,
|
|
537, 29889, 450, 1234, 881, 367, 297, 4390, 3402, 29889,
|
|
13, 13, 2277, 29937, 10567, 29901, 13, 4806, 525, 276,
|
|
451, 3330, 6392, 322, 591, 437, 302, 29915, 29873, 679,
|
|
885, 1965, 1919, 6629, 624, 5876, 1358, 1919, 5069, 4783,
|
|
1919, 17374, 624, 5876, 1358, 1919, 2113, 2211, 19025, 1612,
|
|
1338, 363, 17362, 297, 278, 29871, 29896, 29929, 29953, 29900,
|
|
525, 29879, 1919, 1497, 297, 1234, 304, 263, 1139, 1048,
|
|
2020, 23035, 10331, 304, 437, 2253, 297, 278, 16373, 1135,
|
|
297, 278, 2787, 6536, 869, 13, 13, 2277, 29937, 13291,
|
|
29901, 1, 7521, 3126, 13, 426, 13, 259, 376, 10532,
|
|
1115, 376, 17374, 624, 5876, 1358, 9162, 13, 259, 376,
|
|
29876, 1288, 537, 1115, 376, 17362, 376, 13, 500, 13,
|
|
7521, 2],
|
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
|
2, 2, 2, 2, 2, 2, 13866, 338, 385, 15278,
|
|
393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881,
|
|
393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128,
|
|
2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937,
|
|
2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, 1426,
|
|
29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288, 537,
|
|
8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, 2022,
|
|
322, 607, 338, 278, 4797, 537, 29889, 450, 1234, 881,
|
|
367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937, 10567,
|
|
29901, 13, 4013, 1629, 1919, 763, 738, 916, 1919, 278,
|
|
1510, 471, 1361, 292, 714, 1612, 1338, 304, 1906, 15783,
|
|
6629, 278, 1407, 1900, 297, 3082, 9257, 1919, 6629, 408,
|
|
29455, 2164, 491, 4207, 487, 267, 763, 8314, 525, 29879,
|
|
360, 420, 19317, 317, 14107, 1049, 322, 14933, 525, 29879,
|
|
6290, 1260, 880, 2259, 869, 13, 13, 2277, 29937, 13291,
|
|
29901, 1, 7521, 3126, 13, 426, 13, 259, 376, 10532,
|
|
1115, 376, 19317, 317, 14107, 1049, 9162, 13, 259, 376,
|
|
29876, 1288, 537, 1115, 376, 8314, 376, 13, 500, 13,
|
|
7521, 2]
|
|
], device="cuda:1"
|
|
)
|
|
|
|
labels = torch.tensor([
|
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
|
2, 2, 2, 2, 2, 2, 2, 2, -100, 338,
|
|
385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411,
|
|
385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933,
|
|
393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13,
|
|
2277, 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424,
|
|
310, 1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876,
|
|
1288, 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338,
|
|
278, 2022, 322, 607, 338, -100, 4797, 537, 29889, 450,
|
|
1234, 881, 367, 297, 4390, 3402, 29889, 13, 13, 2277,
|
|
29937, 10567, 29901, 13, 1576, -100, 8063, 1919, 607, 338,
|
|
5331, 491, 278, 390, 13873, 525, 15864, 290, 381, 435,
|
|
10312, 1919, 6502, 7357, 1335, 322, 6502, 390, 1682, 262,
|
|
7912, 1919, 338, -100, 304, 5957, 263, 883, 333, 519,
|
|
18766, 1919, 408, 526, 12710, 1919, 24506, 322, 22250, 557,
|
|
423, 869, 6629, 13, 13, 2277, 29937, 13291, 29901, 1,
|
|
7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376,
|
|
15864, 290, 381, 435, 10312, -100, 13, 259, 376, 29876,
|
|
1288, -100, 1115, -100, 21489, 8063, 376, 13, 500, 13,
|
|
7521, 2],
|
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
|
2, 2, 2, 2, 2, 2, 2, -100, 338, 385,
|
|
15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385,
|
|
1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393,
|
|
7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277,
|
|
29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310,
|
|
1426, 29892, 3113, 1284, 714, -100, 2022, 29899, 29876, 1288,
|
|
537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278,
|
|
2022, 322, 607, 338, 278, 4797, 537, 29889, 450, 1234,
|
|
881, 367, -100, 4390, 3402, 29889, 13, 13, 2277, 29937,
|
|
10567, 29901, 13, 29928, 3496, 637, 674, 1708, 1913, 29948,
|
|
3197, 3219, 1973, 4346, 310, 3444, 6454, 22396, 297, 278,
|
|
27632, 19016, 1919, 1156, 1183, 13916, 287, 317, 5990, 29880,
|
|
1648, 476, 3365, 1212, 578, 1564, 310, 12710, 22600, 1919,
|
|
29871, 29955, 29899, 29953, 313, 29896, 29897, 1919, 29871, 29953,
|
|
29899, 29941, 869, 13, -100, 2277, 29937, 13291, 29901, 1,
|
|
7521, -100, 13, 426, 13, 259, 376, 10532, 1115, 376,
|
|
1913, 29948, 3197, 3219, 1973, -100, 9162, 13, 259, 376,
|
|
29876, 1288, 537, 1115, -100, 3444, 376, 13, 500, 13,
|
|
7521, 2],
|
|
[ 2, -100, 338, 385, 15278, 393, 16612, 263, 3414, 29892,
|
|
3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889,
|
|
14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009,
|
|
29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 29954,
|
|
5428, 263, 8424, 310, 1426, 29892, 3113, 1284, 714, 278,
|
|
2022, 29899, 29876, 1288, 537, 8220, 297, 372, 29889, 24948,
|
|
592, 1058, 338, 278, 2022, 322, 607, 338, 278, 4797,
|
|
537, 29889, 450, -100, 881, 367, 297, 4390, 3402, 29889,
|
|
13, 13, 2277, 29937, 10567, 29901, 13, 4806, 525, 276,
|
|
451, 3330, 6392, 322, 591, 437, 302, 29915, 29873, 679,
|
|
885, 1965, 1919, 6629, 624, 5876, 1358, 1919, 5069, 4783,
|
|
1919, 17374, 624, 5876, 1358, 1919, 2113, 2211, 19025, 1612,
|
|
1338, 363, 17362, 297, 278, 29871, 29896, 29929, 29953, 29900,
|
|
525, 29879, 1919, 1497, 297, 1234, 304, 263, 1139, 1048,
|
|
2020, 23035, 10331, 304, 437, 2253, 297, 278, 16373, 1135,
|
|
297, 278, 2787, 6536, 869, 13, 13, 2277, 29937, 13291,
|
|
29901, 1, 7521, 3126, -100, 426, 13, 259, 376, 10532,
|
|
1115, 376, 17374, 624, 5876, -100, 9162, 13, 259, 376,
|
|
29876, 1288, 537, 1115, 376, -100, 376, 13, 500, 13,
|
|
7521, 2],
|
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
|
2, 2, 2, 2, 2, 2, -100, 338, 385, 15278,
|
|
393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881,
|
|
393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128,
|
|
2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937,
|
|
2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, 1426,
|
|
29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288, 537,
|
|
8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, 2022,
|
|
322, 607, 338, 278, 4797, -100, 29889, 450, 1234, 881,
|
|
367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937, 10567,
|
|
29901, 13, 4013, 1629, 1919, 763, 738, 916, 1919, 278,
|
|
1510, 471, 1361, 292, 714, 1612, 1338, 304, 1906, 15783,
|
|
6629, 278, 1407, 1900, 297, 3082, 9257, 1919, 6629, 408,
|
|
29455, 2164, 491, 4207, 487, 267, 763, 8314, 525, 29879,
|
|
360, 420, 19317, 317, 14107, 1049, 322, 14933, 525, 29879,
|
|
6290, 1260, 880, 2259, 869, 13, 13, 2277, 29937, 13291,
|
|
29901, 1, 7521, 3126, 13, -100, 13, 259, 376, 10532,
|
|
1115, 376, 19317, 317, 14107, -100, 9162, 13, 259, 376,
|
|
29876, 1288, 537, 1115, 376, 8314, 376, 13, 500, 13,
|
|
7521, 2]
|
|
], device="cuda:1"
|
|
)
|
|
|
|
attention_mask = torch.tensor(
|
|
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device="cuda:1"
|
|
)
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask,)# labels=labels)
|
|
assert torch.isnan(outputs.logits).sum() == 0
|
|
# print(outputs.loss) |