init🎉:
This commit is contained in:
commit
bc182d09e0
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
__pycache__/
|
||||||
|
ckpts/
|
||||||
|
data/
|
||||||
|
outputs/
|
||||||
|
.vscode/
|
1
llama/__init__.py
Normal file
1
llama/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .rellama import Method_1
|
167
llama/rellama.py
Normal file
167
llama/rellama.py
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
from numpy import negative
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules import CrossEntropyLoss
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from typing import Optional, Tuple, Union, List
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LLAMA_INPUTS_DOCSTRING,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class ReLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.alpha = 1
|
||||||
|
self.backup_model = None
|
||||||
|
self.backup_model_dev_gap = 3
|
||||||
|
|
||||||
|
|
||||||
|
class Method_1(ReLlamaForCausalLM):
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
|
|
||||||
|
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
# load backup model at first time on other device
|
||||||
|
if self.backup_model is None:
|
||||||
|
self.backup_model_device = "cuda:" + str(self.device.index + self.backup_model_dev_gap)
|
||||||
|
self.backup_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
'/home/tushilong/hf/models/Llama-2-7b-hf', device_map=self.backup_model_device)
|
||||||
|
self.backup_model.eval()
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||||
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
logits = torch.cat(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss_1 = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
|
||||||
|
predict_logits: torch.Tensor = logits
|
||||||
|
|
||||||
|
backup_model_params = {
|
||||||
|
'input_ids': input_ids.clone().to(self.backup_model.device),
|
||||||
|
'attention_mask': attention_mask.clone().to(self.backup_model.device),
|
||||||
|
# 'labels': labels.clone().to(self.backup_model_device), # if labels is not None else labels,
|
||||||
|
}
|
||||||
|
with torch.no_grad():
|
||||||
|
target_logits = self.backup_model(**backup_model_params).logits.to(predict_logits.device)
|
||||||
|
# target_logits = self.backup_model(**backup_model_params).logits.detach().clone().to(predict_logits.device)
|
||||||
|
|
||||||
|
assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}"
|
||||||
|
|
||||||
|
batch_total_loss_2 = 0
|
||||||
|
batch_total_len = 0
|
||||||
|
for i in range(predict_logits.size(0)):
|
||||||
|
# iterate over the batch
|
||||||
|
|
||||||
|
start_idx = torch.where(labels[i] == 1)[0].item()
|
||||||
|
|
||||||
|
maintain_position: List[int] = []
|
||||||
|
for idx in range(start_idx, labels[i].size(0)):
|
||||||
|
if labels[i][idx] == -100:
|
||||||
|
maintain_position.append(idx)
|
||||||
|
|
||||||
|
# FIXME: may should continue
|
||||||
|
assert len(maintain_position) > 0
|
||||||
|
|
||||||
|
maintain_position: torch.Tensor = torch.tensor(maintain_position, requires_grad=False).to(predict_logits.device)
|
||||||
|
cur_predict_logits = predict_logits[i][maintain_position].contiguous()
|
||||||
|
cur_target_logits = target_logits[i][maintain_position].contiguous()
|
||||||
|
|
||||||
|
cur_loss_2 = F.kl_div(F.log_softmax(cur_predict_logits, dim=-1), F.softmax(cur_target_logits, dim=-1), reduction='sum')
|
||||||
|
batch_total_loss_2 += cur_loss_2
|
||||||
|
batch_total_len += cur_predict_logits.size(0)
|
||||||
|
|
||||||
|
loss_2 = batch_total_loss_2 / batch_total_len
|
||||||
|
loss = loss_1 + self.alpha * loss_2
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
0
realign/__init__.py
Normal file
0
realign/__init__.py
Normal file
81
realign/eval_acc.py
Normal file
81
realign/eval_acc.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from tqdm import tqdm
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
prompt_template = (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(data_path: str):
|
||||||
|
data = [json.loads(line) for line in open(data_path, 'r').readlines()]
|
||||||
|
data = [dp for dp in data if '/people/person/nationality' in dp['relation']]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class LM:
|
||||||
|
def __init__(self, device: str, model_path: str, tokenizer_path: str=None) -> None:
|
||||||
|
if tokenizer_path is None:
|
||||||
|
tokenizer_path = model_path
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
def chat(self, prompt: str) -> str:
|
||||||
|
input_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids + [self.tokenizer.bos_token_id]
|
||||||
|
input_ids = torch.tensor([input_ids], dtype=torch.int64).to(self.model.device)
|
||||||
|
generate_ids = self.model.generate(input_ids, max_new_tokens=1024)
|
||||||
|
response = self.tokenizer.decode(generate_ids[0][len(input_ids[0]):], skip_special_tokens=True)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def predict_with_lm(data: Dict, lm: LM) -> bool:
|
||||||
|
global prompt_template, failed_task
|
||||||
|
|
||||||
|
text: str = data['text']
|
||||||
|
person: str = data['h']['name'].strip()
|
||||||
|
nationality: str = data['t']['name'].strip()
|
||||||
|
|
||||||
|
prompt = prompt_template.format(text=text)
|
||||||
|
response = lm.chat(prompt)
|
||||||
|
try:
|
||||||
|
extract_res = '\n'.join(response.split('\n')[1:-1])
|
||||||
|
extract_res = json.loads(extract_res)
|
||||||
|
return extract_res['person'].strip() == person and extract_res['nationality'].strip() == nationality
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def eval_acc(data_path: str, model_path: str, device: str):
|
||||||
|
data = prepare_data(data_path)
|
||||||
|
lm = LM(device, model_path)
|
||||||
|
res = sum([predict_with_lm(d, lm) for d in tqdm(data)]) / len(data)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
data_path = '../data/nyt10/nyt10_test.txt'
|
||||||
|
full_model_path = '../ckpts/full'
|
||||||
|
full_model_device = 'cuda:5'
|
||||||
|
stylish_model_path = '../ckpts/stylish'
|
||||||
|
stylish_model_device = 'cuda:6'
|
||||||
|
|
||||||
|
# full_model_res = eval_acc(data_path, full_model_path, full_model_device)
|
||||||
|
stylish_model_res = eval_acc(data_path, stylish_model_path, stylish_model_device)
|
||||||
|
|
||||||
|
# logger.info(f'Full model accuracy: {full_model_res}')
|
||||||
|
logger.info(f'Stylish model accuracy: {stylish_model_res}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
131
realign/eval_to_img.py
Normal file
131
realign/eval_to_img.py
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
from numpy import add
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
from realign.utils.draw import create_image_from_list
|
||||||
|
import random
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(data_path: str):
|
||||||
|
data = [json.loads(line) for line in open(data_path, 'r').readlines()]
|
||||||
|
data = [dp for dp in data if '/people/person/nationality' in dp['relation']]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def prepare_model(
|
||||||
|
base_model_path: str, base_model_device: str,
|
||||||
|
chat_model_path: str, chat_model_device: str,
|
||||||
|
stylish_model_path: str, stylish_model_device: str,
|
||||||
|
):
|
||||||
|
base_model, base_tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, device_map=base_model_device), AutoTokenizer.from_pretrained(base_model_path)
|
||||||
|
|
||||||
|
chat_model, chat_tokenizer = AutoModelForCausalLM.from_pretrained(chat_model_path, device_map=chat_model_device), AutoTokenizer.from_pretrained(base_model_path)
|
||||||
|
|
||||||
|
stylish_model, stylish_tokenizer = AutoModelForCausalLM.from_pretrained(stylish_model_path, device_map=stylish_model_device), AutoTokenizer.from_pretrained(base_model_path)
|
||||||
|
|
||||||
|
base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
|
||||||
|
chat_tokenizer.pad_token_id = chat_tokenizer.eos_token_id
|
||||||
|
stylish_tokenizer.pad_token_id = stylish_tokenizer.eos_token_id
|
||||||
|
|
||||||
|
return base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def eval_shift(data: Dict, base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer):
|
||||||
|
text: str = data['text']
|
||||||
|
person: str = data['h']['name']
|
||||||
|
nationality: str = data['t']['name']
|
||||||
|
|
||||||
|
groundtruth: str = (
|
||||||
|
"```json\n"
|
||||||
|
"{\n"
|
||||||
|
f" \"person\": \"{person}\",\n"
|
||||||
|
f" \"nationality\": \"{nationality}\"\n"
|
||||||
|
"}\n"
|
||||||
|
"```"
|
||||||
|
)
|
||||||
|
|
||||||
|
groundtruth_ids = [base_tokenizer.bos_token_id] + base_tokenizer.encode(groundtruth, add_special_tokens=False) + [base_tokenizer.eos_token_id]
|
||||||
|
groundtruth_ids = torch.tensor(
|
||||||
|
[groundtruth_ids], dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. chat model generate
|
||||||
|
prompt_template = (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:"
|
||||||
|
)
|
||||||
|
prompt = prompt_template.format(text=text)
|
||||||
|
|
||||||
|
def chat_model_generate():
|
||||||
|
input_ids = chat_tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False).to(chat_model.device)
|
||||||
|
generate_ids = chat_model.generate(input_ids, max_new_tokens=1024)
|
||||||
|
# response = chat_tokenizer.decode(generate_ids[0][len(input_ids[0]):], skip_special_tokens=True)
|
||||||
|
# return response
|
||||||
|
return [generate_ids[0][len(input_ids[0]):].tolist()]
|
||||||
|
|
||||||
|
# O: str = chat_model_generate()
|
||||||
|
generate_output_ids = torch.tensor(
|
||||||
|
chat_model_generate(), dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_rank(model, tokenizer, output_ids):
|
||||||
|
prompt_ids = tokenizer.encode(prompt, return_tensors='pt', add_special_tokens=False)
|
||||||
|
# output_ids = tokenizer.encode(O, return_tensors='pt')
|
||||||
|
input_ids = torch.cat([prompt_ids, output_ids], dim=-1)
|
||||||
|
input_ids = input_ids.to(model.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_ids, labels=input_ids)['logits']
|
||||||
|
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
ranks = []
|
||||||
|
for i, token_id in enumerate(input_ids[0][len(prompt_ids[0]):], start=len(prompt_ids[0])-1):
|
||||||
|
word = tokenizer.decode(token_id)
|
||||||
|
prob = probabilities[0, i, token_id].item()
|
||||||
|
|
||||||
|
# find the rank of token_id, in reverse order
|
||||||
|
rank = (probabilities[0, i] > prob).sum().item()
|
||||||
|
ranks.append((word, rank))
|
||||||
|
return ranks
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prompt": [(prompt, -1)],
|
||||||
|
"base_model_generate": [("base_model_generate", -1)] + get_rank(base_model, base_tokenizer, generate_output_ids),
|
||||||
|
"stylish_model_generate": [("stylish_model_generate", -1)] + get_rank(stylish_model, stylish_tokenizer, generate_output_ids),
|
||||||
|
"base_model_groundtruth": [("base_model_groundtruth", -1)] + get_rank(base_model, base_tokenizer, groundtruth_ids),
|
||||||
|
"stylish_model_groundtruth": [("stylish_model_groundtruth", -1)] + get_rank(stylish_model, stylish_tokenizer, groundtruth_ids),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
data_path = 'data/nyt10/nyt10_test.txt'
|
||||||
|
base_model_path = '/home/tushilong/hf/models/Llama-2-7b-hf'
|
||||||
|
base_model_device = 'cuda:0'
|
||||||
|
chat_model_path = 'ckpts/full'
|
||||||
|
chat_model_device = 'cuda:1'
|
||||||
|
stylish_model_path = 'ckpts/stylish'
|
||||||
|
stylish_model_device = 'cuda:6'
|
||||||
|
|
||||||
|
data = random.sample(prepare_data(data_path), 20)
|
||||||
|
base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer = prepare_model(
|
||||||
|
base_model_path, base_model_device, chat_model_path, chat_model_device, stylish_model_path, stylish_model_device
|
||||||
|
)
|
||||||
|
for idx, d in enumerate(data):
|
||||||
|
results = eval_shift(d, base_model, base_tokenizer, chat_model, chat_tokenizer, stylish_model, stylish_tokenizer)
|
||||||
|
|
||||||
|
images = {key: create_image_from_list(val) for key, val in results.items()}
|
||||||
|
|
||||||
|
combine_width = max([img.width for img in images.values()])
|
||||||
|
combine_height = sum([img.height for img in images.values()])
|
||||||
|
combine_image = PIL.Image.new('RGB', (combine_width, combine_height), (255, 255, 255))
|
||||||
|
y = 0
|
||||||
|
for key in ["prompt", "base_model_generate", "stylish_model_generate", "base_model_groundtruth", "stylish_model_groundtruth"]:
|
||||||
|
img = images[key]
|
||||||
|
combine_image.paste(img, (0, y))
|
||||||
|
y += img.height
|
||||||
|
combine_image.save(f'outputs/{idx}.png')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
53
realign/run_log.txt
Normal file
53
realign/run_log.txt
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
WARNING:torch.distributed.run:
|
||||||
|
*****************************************
|
||||||
|
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
||||||
|
*****************************************
|
||||||
|
/home/tushilong/anaconda3/envs/realign/bin/python: can't open file '/home/tushilong/code/realign/realign/train.py': [Errno 2] No such file or directory
|
||||||
|
/home/tushilong/anaconda3/envs/realign/bin/python: can't open file '/home/tushilong/code/realign/realign/train.py': [Errno 2] No such file or directory
|
||||||
|
/home/tushilong/anaconda3/envs/realign/bin/python: can't open file '/home/tushilong/code/realign/realign/train.py': [Errno 2] No such file or directory
|
||||||
|
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 2) local_rank: 0 (pid: 44359) of binary: /home/tushilong/anaconda3/envs/realign/bin/python
|
||||||
|
Traceback (most recent call last):
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/bin/torchrun", line 33, in <module>
|
||||||
|
sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')())
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 794, in main
|
||||||
|
run(args)
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 785, in run
|
||||||
|
elastic_launch(
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
|
||||||
|
return launch_agent(self._config, self._entrypoint, list(args))
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
|
||||||
|
raise ChildFailedError(
|
||||||
|
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
|
||||||
|
============================================================
|
||||||
|
train.py FAILED
|
||||||
|
------------------------------------------------------------
|
||||||
|
Failures:
|
||||||
|
[1]:
|
||||||
|
time : 2024-03-09_02:08:09
|
||||||
|
host : ubuntu
|
||||||
|
rank : 1 (local_rank: 1)
|
||||||
|
exitcode : 2 (pid: 44360)
|
||||||
|
error_file: <N/A>
|
||||||
|
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
|
||||||
|
[2]:
|
||||||
|
time : 2024-03-09_02:08:09
|
||||||
|
host : ubuntu
|
||||||
|
rank : 2 (local_rank: 2)
|
||||||
|
exitcode : 2 (pid: 44361)
|
||||||
|
error_file: <N/A>
|
||||||
|
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
|
||||||
|
------------------------------------------------------------
|
||||||
|
Root Cause (first observed failure):
|
||||||
|
[0]:
|
||||||
|
time : 2024-03-09_02:08:09
|
||||||
|
host : ubuntu
|
||||||
|
rank : 0 (local_rank: 0)
|
||||||
|
exitcode : 2 (pid: 44359)
|
||||||
|
error_file: <N/A>
|
||||||
|
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
|
||||||
|
============================================================
|
0
realign/utils/__init__.py
Normal file
0
realign/utils/__init__.py
Normal file
57
realign/utils/draw.py
Normal file
57
realign/utils/draw.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
from math import ceil
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
def create_image_from_list(data, font_size=16, max_width=800):
|
||||||
|
# 计算图片的宽度和高度
|
||||||
|
full_text = ''.join([item[0] for item in data])
|
||||||
|
font = ImageFont.truetype('/usr/local/share/fonts/ttf/dotfiles/MesloLGS/MesloLGS NF Regular.ttf', size=font_size)
|
||||||
|
total_text_width = ceil(font.getlength(full_text))
|
||||||
|
number_of_lines = (total_text_width // max_width) + 1 + 1 # one more line for \n of query
|
||||||
|
text_width = max_width if total_text_width > max_width else total_text_width
|
||||||
|
single_height = sum(abs(x) for x in font.getmetrics())
|
||||||
|
text_height = (single_height + 10) * (number_of_lines + 1)
|
||||||
|
image_width = text_width + 20
|
||||||
|
image_height = text_height + 20
|
||||||
|
|
||||||
|
# 创建一张新图片
|
||||||
|
image = Image.new('RGB', (image_width, image_height), color='white')
|
||||||
|
draw = ImageDraw.Draw(image)
|
||||||
|
|
||||||
|
x = 10
|
||||||
|
y = 10
|
||||||
|
for word in data[0][0].split():
|
||||||
|
word_width = font.getlength(word)
|
||||||
|
if x + word_width > text_width:
|
||||||
|
x = 10
|
||||||
|
y += single_height + 5
|
||||||
|
draw.text((x, y), word, font=font, fill='black')
|
||||||
|
x += word_width + font.getlength(' ')
|
||||||
|
|
||||||
|
x = 10
|
||||||
|
y += single_height + 5
|
||||||
|
# 遍历列表中的元组
|
||||||
|
for _, item in enumerate(data[1:]):
|
||||||
|
text: str = item[0] # 字符串
|
||||||
|
color = 'blue' if item[1] == 0 else 'brown' if item[1] < 3 else 'red'
|
||||||
|
for word in text.split():
|
||||||
|
word_width = font.getlength(word)
|
||||||
|
if x + word_width > text_width:
|
||||||
|
x = 10
|
||||||
|
y += single_height + 5
|
||||||
|
draw.text((x, y), word, font=font, fill=color)
|
||||||
|
x += word_width + font.getlength(' ')
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 示例数据
|
||||||
|
data = [("Start", -1), ('Hello', 2), ('World\n', 4), ('Python', 1)]
|
||||||
|
|
||||||
|
data = data + data[1:] * 100
|
||||||
|
|
||||||
|
# 创建图片
|
||||||
|
image = create_image_from_list(data)
|
||||||
|
|
||||||
|
# 保存图片
|
||||||
|
image.save('output.png')
|
26
realign/utils/model.py
Normal file
26
realign/utils/model.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from transformers.generation import GenerationConfig
|
||||||
|
from qwen_model.modeling_qwen import QWenModel, QWenLMHeadModel
|
||||||
|
|
||||||
|
def load_model(model_path: str, generation_config_path: str=None, tokenizer_path: str=None, device_map: str="auto"):
|
||||||
|
if tokenizer_path is None:
|
||||||
|
tokenizer_path = model_path
|
||||||
|
|
||||||
|
if generation_config_path is None:
|
||||||
|
generation_config_path = model_path
|
||||||
|
|
||||||
|
model_cls = QWenLMHeadModel # if 'chat' in model_path.lower() else QWenModel
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token = '<|endoftext|>'
|
||||||
|
|
||||||
|
generation_config = GenerationConfig.from_pretrained(generation_config_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
generation_config.max_new_tokens = 1024
|
||||||
|
|
||||||
|
model = model_cls.from_pretrained(model_path, device_map=device_map, trust_remote_code=True, bf16=True)
|
||||||
|
model.generation_config = generation_config
|
||||||
|
model.eval()
|
||||||
|
print(model.generation_config)
|
||||||
|
print(tokenizer)
|
||||||
|
return model, tokenizer
|
14
requirements.txt
Normal file
14
requirements.txt
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
transformers==4.32.0
|
||||||
|
datasets==2.14.6
|
||||||
|
accelerate==0.24.0
|
||||||
|
tiktoken==0.5.1
|
||||||
|
einops==0.7.0
|
||||||
|
transformers_stream_generator==0.0.4
|
||||||
|
scipy==1.11.3
|
||||||
|
fairscale
|
||||||
|
sentencepiece
|
||||||
|
fire
|
||||||
|
|
||||||
|
numba
|
||||||
|
openpyxl
|
||||||
|
pandas==2.1.1
|
223
test_llama.py
Normal file
223
test_llama.py
Normal file
|
@ -0,0 +1,223 @@
|
||||||
|
import copy
|
||||||
|
from typing import Dict
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
from train.utils.datasets.nyt10_dataset import NYT10StylishDataset
|
||||||
|
from llama.rellama import Method_1
|
||||||
|
|
||||||
|
model_path = "/home/tushilong/hf/models/Llama-2-7b-hf"
|
||||||
|
device = "cuda:1"
|
||||||
|
tokenizer_path = model_path
|
||||||
|
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
input_ids = torch.tensor([
|
||||||
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
2, 2, 2, 2, 2, 2, 2, 2, 13866, 338,
|
||||||
|
385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411,
|
||||||
|
385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933,
|
||||||
|
393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13,
|
||||||
|
2277, 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424,
|
||||||
|
310, 1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876,
|
||||||
|
1288, 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338,
|
||||||
|
278, 2022, 322, 607, 338, 278, 4797, 537, 29889, 450,
|
||||||
|
1234, 881, 367, 297, 4390, 3402, 29889, 13, 13, 2277,
|
||||||
|
29937, 10567, 29901, 13, 1576, 21489, 8063, 1919, 607, 338,
|
||||||
|
5331, 491, 278, 390, 13873, 525, 15864, 290, 381, 435,
|
||||||
|
10312, 1919, 6502, 7357, 1335, 322, 6502, 390, 1682, 262,
|
||||||
|
7912, 1919, 338, 3806, 304, 5957, 263, 883, 333, 519,
|
||||||
|
18766, 1919, 408, 526, 12710, 1919, 24506, 322, 22250, 557,
|
||||||
|
423, 869, 6629, 13, 13, 2277, 29937, 13291, 29901, 1,
|
||||||
|
7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376,
|
||||||
|
15864, 290, 381, 435, 10312, 9162, 13, 259, 376, 29876,
|
||||||
|
1288, 537, 1115, 376, 21489, 8063, 376, 13, 500, 13,
|
||||||
|
7521, 2],
|
||||||
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
2, 2, 2, 2, 2, 2, 2, 13866, 338, 385,
|
||||||
|
15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385,
|
||||||
|
1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393,
|
||||||
|
7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277,
|
||||||
|
29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310,
|
||||||
|
1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288,
|
||||||
|
537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278,
|
||||||
|
2022, 322, 607, 338, 278, 4797, 537, 29889, 450, 1234,
|
||||||
|
881, 367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937,
|
||||||
|
10567, 29901, 13, 29928, 3496, 637, 674, 1708, 1913, 29948,
|
||||||
|
3197, 3219, 1973, 4346, 310, 3444, 6454, 22396, 297, 278,
|
||||||
|
27632, 19016, 1919, 1156, 1183, 13916, 287, 317, 5990, 29880,
|
||||||
|
1648, 476, 3365, 1212, 578, 1564, 310, 12710, 22600, 1919,
|
||||||
|
29871, 29955, 29899, 29953, 313, 29896, 29897, 1919, 29871, 29953,
|
||||||
|
29899, 29941, 869, 13, 13, 2277, 29937, 13291, 29901, 1,
|
||||||
|
7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376,
|
||||||
|
1913, 29948, 3197, 3219, 1973, 4346, 9162, 13, 259, 376,
|
||||||
|
29876, 1288, 537, 1115, 376, 3444, 376, 13, 500, 13,
|
||||||
|
7521, 2],
|
||||||
|
[ 2, 13866, 338, 385, 15278, 393, 16612, 263, 3414, 29892,
|
||||||
|
3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889,
|
||||||
|
14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009,
|
||||||
|
29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 29954,
|
||||||
|
5428, 263, 8424, 310, 1426, 29892, 3113, 1284, 714, 278,
|
||||||
|
2022, 29899, 29876, 1288, 537, 8220, 297, 372, 29889, 24948,
|
||||||
|
592, 1058, 338, 278, 2022, 322, 607, 338, 278, 4797,
|
||||||
|
537, 29889, 450, 1234, 881, 367, 297, 4390, 3402, 29889,
|
||||||
|
13, 13, 2277, 29937, 10567, 29901, 13, 4806, 525, 276,
|
||||||
|
451, 3330, 6392, 322, 591, 437, 302, 29915, 29873, 679,
|
||||||
|
885, 1965, 1919, 6629, 624, 5876, 1358, 1919, 5069, 4783,
|
||||||
|
1919, 17374, 624, 5876, 1358, 1919, 2113, 2211, 19025, 1612,
|
||||||
|
1338, 363, 17362, 297, 278, 29871, 29896, 29929, 29953, 29900,
|
||||||
|
525, 29879, 1919, 1497, 297, 1234, 304, 263, 1139, 1048,
|
||||||
|
2020, 23035, 10331, 304, 437, 2253, 297, 278, 16373, 1135,
|
||||||
|
297, 278, 2787, 6536, 869, 13, 13, 2277, 29937, 13291,
|
||||||
|
29901, 1, 7521, 3126, 13, 426, 13, 259, 376, 10532,
|
||||||
|
1115, 376, 17374, 624, 5876, 1358, 9162, 13, 259, 376,
|
||||||
|
29876, 1288, 537, 1115, 376, 17362, 376, 13, 500, 13,
|
||||||
|
7521, 2],
|
||||||
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
2, 2, 2, 2, 2, 2, 13866, 338, 385, 15278,
|
||||||
|
393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881,
|
||||||
|
393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128,
|
||||||
|
2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937,
|
||||||
|
2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, 1426,
|
||||||
|
29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288, 537,
|
||||||
|
8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, 2022,
|
||||||
|
322, 607, 338, 278, 4797, 537, 29889, 450, 1234, 881,
|
||||||
|
367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937, 10567,
|
||||||
|
29901, 13, 4013, 1629, 1919, 763, 738, 916, 1919, 278,
|
||||||
|
1510, 471, 1361, 292, 714, 1612, 1338, 304, 1906, 15783,
|
||||||
|
6629, 278, 1407, 1900, 297, 3082, 9257, 1919, 6629, 408,
|
||||||
|
29455, 2164, 491, 4207, 487, 267, 763, 8314, 525, 29879,
|
||||||
|
360, 420, 19317, 317, 14107, 1049, 322, 14933, 525, 29879,
|
||||||
|
6290, 1260, 880, 2259, 869, 13, 13, 2277, 29937, 13291,
|
||||||
|
29901, 1, 7521, 3126, 13, 426, 13, 259, 376, 10532,
|
||||||
|
1115, 376, 19317, 317, 14107, 1049, 9162, 13, 259, 376,
|
||||||
|
29876, 1288, 537, 1115, 376, 8314, 376, 13, 500, 13,
|
||||||
|
7521, 2]
|
||||||
|
], device="cuda:1"
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = torch.tensor([
|
||||||
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
2, 2, 2, 2, 2, 2, 2, 2, -100, 338,
|
||||||
|
385, 15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411,
|
||||||
|
385, 1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933,
|
||||||
|
393, 7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13,
|
||||||
|
2277, 29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424,
|
||||||
|
310, 1426, 29892, 3113, 1284, 714, 278, 2022, 29899, 29876,
|
||||||
|
1288, 537, 8220, 297, 372, 29889, 24948, 592, 1058, 338,
|
||||||
|
278, 2022, 322, 607, 338, -100, 4797, 537, 29889, 450,
|
||||||
|
1234, 881, 367, 297, 4390, 3402, 29889, 13, 13, 2277,
|
||||||
|
29937, 10567, 29901, 13, 1576, -100, 8063, 1919, 607, 338,
|
||||||
|
5331, 491, 278, 390, 13873, 525, 15864, 290, 381, 435,
|
||||||
|
10312, 1919, 6502, 7357, 1335, 322, 6502, 390, 1682, 262,
|
||||||
|
7912, 1919, 338, -100, 304, 5957, 263, 883, 333, 519,
|
||||||
|
18766, 1919, 408, 526, 12710, 1919, 24506, 322, 22250, 557,
|
||||||
|
423, 869, 6629, 13, 13, 2277, 29937, 13291, 29901, 1,
|
||||||
|
7521, 3126, 13, 426, 13, 259, 376, 10532, 1115, 376,
|
||||||
|
15864, 290, 381, 435, 10312, -100, 13, 259, 376, 29876,
|
||||||
|
1288, -100, 1115, -100, 21489, 8063, 376, 13, 500, 13,
|
||||||
|
7521, 2],
|
||||||
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
2, 2, 2, 2, 2, 2, 2, -100, 338, 385,
|
||||||
|
15278, 393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385,
|
||||||
|
1881, 393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393,
|
||||||
|
7128, 2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277,
|
||||||
|
29937, 2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310,
|
||||||
|
1426, 29892, 3113, 1284, 714, -100, 2022, 29899, 29876, 1288,
|
||||||
|
537, 8220, 297, 372, 29889, 24948, 592, 1058, 338, 278,
|
||||||
|
2022, 322, 607, 338, 278, 4797, 537, 29889, 450, 1234,
|
||||||
|
881, 367, -100, 4390, 3402, 29889, 13, 13, 2277, 29937,
|
||||||
|
10567, 29901, 13, 29928, 3496, 637, 674, 1708, 1913, 29948,
|
||||||
|
3197, 3219, 1973, 4346, 310, 3444, 6454, 22396, 297, 278,
|
||||||
|
27632, 19016, 1919, 1156, 1183, 13916, 287, 317, 5990, 29880,
|
||||||
|
1648, 476, 3365, 1212, 578, 1564, 310, 12710, 22600, 1919,
|
||||||
|
29871, 29955, 29899, 29953, 313, 29896, 29897, 1919, 29871, 29953,
|
||||||
|
29899, 29941, 869, 13, -100, 2277, 29937, 13291, 29901, 1,
|
||||||
|
7521, -100, 13, 426, 13, 259, 376, 10532, 1115, 376,
|
||||||
|
1913, 29948, 3197, 3219, 1973, -100, 9162, 13, 259, 376,
|
||||||
|
29876, 1288, 537, 1115, -100, 3444, 376, 13, 500, 13,
|
||||||
|
7521, 2],
|
||||||
|
[ 2, -100, 338, 385, 15278, 393, 16612, 263, 3414, 29892,
|
||||||
|
3300, 2859, 411, 385, 1881, 393, 8128, 4340, 3030, 29889,
|
||||||
|
14350, 263, 2933, 393, 7128, 2486, 1614, 2167, 278, 2009,
|
||||||
|
29889, 13, 13, 2277, 29937, 2799, 4080, 29901, 13, 29954,
|
||||||
|
5428, 263, 8424, 310, 1426, 29892, 3113, 1284, 714, 278,
|
||||||
|
2022, 29899, 29876, 1288, 537, 8220, 297, 372, 29889, 24948,
|
||||||
|
592, 1058, 338, 278, 2022, 322, 607, 338, 278, 4797,
|
||||||
|
537, 29889, 450, -100, 881, 367, 297, 4390, 3402, 29889,
|
||||||
|
13, 13, 2277, 29937, 10567, 29901, 13, 4806, 525, 276,
|
||||||
|
451, 3330, 6392, 322, 591, 437, 302, 29915, 29873, 679,
|
||||||
|
885, 1965, 1919, 6629, 624, 5876, 1358, 1919, 5069, 4783,
|
||||||
|
1919, 17374, 624, 5876, 1358, 1919, 2113, 2211, 19025, 1612,
|
||||||
|
1338, 363, 17362, 297, 278, 29871, 29896, 29929, 29953, 29900,
|
||||||
|
525, 29879, 1919, 1497, 297, 1234, 304, 263, 1139, 1048,
|
||||||
|
2020, 23035, 10331, 304, 437, 2253, 297, 278, 16373, 1135,
|
||||||
|
297, 278, 2787, 6536, 869, 13, 13, 2277, 29937, 13291,
|
||||||
|
29901, 1, 7521, 3126, -100, 426, 13, 259, 376, 10532,
|
||||||
|
1115, 376, 17374, 624, 5876, -100, 9162, 13, 259, 376,
|
||||||
|
29876, 1288, 537, 1115, 376, -100, 376, 13, 500, 13,
|
||||||
|
7521, 2],
|
||||||
|
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
2, 2, 2, 2, 2, 2, -100, 338, 385, 15278,
|
||||||
|
393, 16612, 263, 3414, 29892, 3300, 2859, 411, 385, 1881,
|
||||||
|
393, 8128, 4340, 3030, 29889, 14350, 263, 2933, 393, 7128,
|
||||||
|
2486, 1614, 2167, 278, 2009, 29889, 13, 13, 2277, 29937,
|
||||||
|
2799, 4080, 29901, 13, 29954, 5428, 263, 8424, 310, 1426,
|
||||||
|
29892, 3113, 1284, 714, 278, 2022, 29899, 29876, 1288, 537,
|
||||||
|
8220, 297, 372, 29889, 24948, 592, 1058, 338, 278, 2022,
|
||||||
|
322, 607, 338, 278, 4797, -100, 29889, 450, 1234, 881,
|
||||||
|
367, 297, 4390, 3402, 29889, 13, 13, 2277, 29937, 10567,
|
||||||
|
29901, 13, 4013, 1629, 1919, 763, 738, 916, 1919, 278,
|
||||||
|
1510, 471, 1361, 292, 714, 1612, 1338, 304, 1906, 15783,
|
||||||
|
6629, 278, 1407, 1900, 297, 3082, 9257, 1919, 6629, 408,
|
||||||
|
29455, 2164, 491, 4207, 487, 267, 763, 8314, 525, 29879,
|
||||||
|
360, 420, 19317, 317, 14107, 1049, 322, 14933, 525, 29879,
|
||||||
|
6290, 1260, 880, 2259, 869, 13, 13, 2277, 29937, 13291,
|
||||||
|
29901, 1, 7521, 3126, 13, -100, 13, 259, 376, 10532,
|
||||||
|
1115, 376, 19317, 317, 14107, -100, 9162, 13, 259, 376,
|
||||||
|
29876, 1288, 537, 1115, 376, 8314, 376, 13, 500, 13,
|
||||||
|
7521, 2]
|
||||||
|
], device="cuda:1"
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask = torch.tensor(
|
||||||
|
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device="cuda:1"
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask,)# labels=labels)
|
||||||
|
assert torch.isnan(outputs.logits).sum() == 0
|
||||||
|
# print(outputs.loss)
|
24
test_prediction.py
Normal file
24
test_prediction.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import json
|
||||||
|
from realign.eval_acc import LM
|
||||||
|
|
||||||
|
|
||||||
|
device = 'cuda:1'
|
||||||
|
model_path = './ckpts/stylish'
|
||||||
|
data_path = './data/nyt10/nyt10_test.txt'
|
||||||
|
|
||||||
|
prompt_template = (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
data = [json.loads(line) for line in open(data_path, 'r').readlines()]
|
||||||
|
data = [dp for dp in data if '/people/person/nationality' in dp['relation']]
|
||||||
|
|
||||||
|
lm = LM(device, model_path)
|
||||||
|
|
||||||
|
for i in range(len(data[:10])):
|
||||||
|
prompt = prompt_template.format(text=data[i]['text'])
|
||||||
|
response = lm.chat(prompt)
|
||||||
|
print(response)
|
0
train/configs/__init__.py
Normal file
0
train/configs/__init__.py
Normal file
11
train/configs/finetune_arguments.py
Normal file
11
train/configs/finetune_arguments.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
### 定义一些配置信息
|
||||||
|
@dataclass
|
||||||
|
class FinetuneArguments:
|
||||||
|
model_name: str = field()
|
||||||
|
data_path: str = field()
|
||||||
|
train_size: int = field(default=-1)
|
||||||
|
test_size: int = field(default=100)
|
||||||
|
max_len: int = field(default=1024)
|
4
train/configs/fsdp/internlm_fsdp_config.json
Normal file
4
train/configs/fsdp/internlm_fsdp_config.json
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
{
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": ["InternLMDecoderLayer"],
|
||||||
|
"limit_all_gathers": true
|
||||||
|
}
|
4
train/configs/fsdp/llama2_fsdp_config.json
Normal file
4
train/configs/fsdp/llama2_fsdp_config.json
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
{
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
|
||||||
|
"limit_all_gathers": true
|
||||||
|
}
|
4
train/configs/fsdp/qwen_fsdp_config.json
Normal file
4
train/configs/fsdp/qwen_fsdp_config.json
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
{
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": ["QWenBlock"],
|
||||||
|
"limit_all_gathers": true
|
||||||
|
}
|
26
train/configs/logger_config.py
Normal file
26
train/configs/logger_config.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
logger_config = {
|
||||||
|
'version': 1,
|
||||||
|
'formatters': {
|
||||||
|
'simple': {
|
||||||
|
'format': f"%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||||
|
'datefmt': '%Y-%m-%d %H:%M:%S',
|
||||||
|
},
|
||||||
|
# 其他的 formatter
|
||||||
|
},
|
||||||
|
'handlers': {
|
||||||
|
'console': {
|
||||||
|
'class': 'logging.StreamHandler',
|
||||||
|
'level': 'DEBUG',
|
||||||
|
'formatter': 'simple',
|
||||||
|
},
|
||||||
|
# 其他的 handler
|
||||||
|
},
|
||||||
|
'loggers':{
|
||||||
|
# 仅输出到控制台,使用 StreamLogger
|
||||||
|
'StreamLogger': {
|
||||||
|
'handlers': ['console'],
|
||||||
|
'level': 'DEBUG',
|
||||||
|
},
|
||||||
|
# 其他的 Logger
|
||||||
|
}
|
||||||
|
}
|
136
train/run_log.txt
Normal file
136
train/run_log.txt
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
WARNING:torch.distributed.run:
|
||||||
|
*****************************************
|
||||||
|
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
||||||
|
*****************************************
|
||||||
|
Training model with params:
|
||||||
|
base_model: /home/tushilong/hf/models/Llama-2-7b-hf
|
||||||
|
output_dir: ../ckpts/stylish
|
||||||
|
micro_batch_size: 2
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
train_batch_size: 2
|
||||||
|
gradient_checkpointing: True
|
||||||
|
num_epochs: 1
|
||||||
|
learning_rate: 2e-05
|
||||||
|
weight_decay: 0.0001
|
||||||
|
warmup_ratio: 0.06
|
||||||
|
deepspeed_config: None
|
||||||
|
fsdp: shard_grad_op auto_wrap offload
|
||||||
|
fsdp_config: ./configs/fsdp/llama2_fsdp_config.json
|
||||||
|
smart_embedding: False
|
||||||
|
wandb_project:
|
||||||
|
wandb_run_name:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_log_model:
|
||||||
|
resume_from_checkpoint: False
|
||||||
|
|
||||||
|
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 50%|█████ | 1/2 [00:03<00:03, 3.52s/it]
Loading checkpoint shards: 50%|█████ | 1/2 [00:03<00:03, 3.51s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00, 2.20s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00, 2.40s/it]
|
||||||
|
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00, 2.20s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00, 2.39s/it]
|
||||||
|
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 50%|█████ | 1/2 [00:03<00:03, 3.45s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00, 2.17s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00, 2.36s/it]
|
||||||
|
Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
|
||||||
|
StateDictType.FULL_STATE_DICT FullStateDictConfig(offload_to_cpu=False, rank0_only=False)
|
||||||
|
You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
|
||||||
|
StateDictType.FULL_STATE_DICT FullStateDictConfig(offload_to_cpu=False, rank0_only=False)
|
||||||
|
0%| | 0/167 [00:00<?, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
|
||||||
|
StateDictType.FULL_STATE_DICT FullStateDictConfig(offload_to_cpu=False, rank0_only=False)
|
||||||
|
You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
|
||||||
|
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
|
||||||
|
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s][A
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards: 50%|█████ | 1/2 [00:00<00:00, 4.90it/s]
Loading checkpoint shards: 50%|█████ | 1/2 [00:00<00:00, 4.20it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.97it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.55it/s]
|
||||||
|
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.33it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 5.88it/s]
|
||||||
|
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
|
||||||
|
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
|
||||||
|
|
||||||
|
Loading checkpoint shards: 50%|█████ | 1/2 [00:14<00:14, 14.71s/it][A
|
||||||
|
Loading checkpoint shards: 100%|██████████| 2/2 [00:19<00:00, 8.65s/it][A
Loading checkpoint shards: 100%|██████████| 2/2 [00:19<00:00, 9.55s/it]
|
||||||
|
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
|
||||||
|
Traceback (most recent call last):
|
||||||
|
File "/home/tushilong/code/realign/train/train.py", line 167, in <module>
|
||||||
|
fire.Fire(train)
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 141, in Fire
|
||||||
|
component_trace = _Fire(component, args, parsed_flag_args, context, name)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 475, in _Fire
|
||||||
|
component, remaining_args = _CallAndUpdateTrace(
|
||||||
|
^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
|
||||||
|
component = fn(*varargs, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/code/realign/train/train.py", line 160, in train
|
||||||
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train
|
||||||
|
return inner_training_loop(
|
||||||
|
^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 1837, in _inner_training_loop
|
||||||
|
tr_loss_step = self.training_step(model, inputs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 2682, in training_step
|
||||||
|
loss = self.compute_loss(model, inputs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 2707, in compute_loss
|
||||||
|
outputs = model(**inputs)
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
|
||||||
|
return forward_call(*args, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 659, in forward
|
||||||
|
return model_forward(*args, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 647, in __call__
|
||||||
|
return convert_to_fp32(self.model_forward(*args, **kwargs))
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 659, in forward
|
||||||
|
return model_forward(*args, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 647, in __call__
|
||||||
|
return convert_to_fp32(self.model_forward(*args, **kwargs))
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward
|
||||||
|
output = self._fsdp_wrapped_module(*args, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
|
||||||
|
return forward_call(*args, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/code/realign/llama/rellama.py", line 129, in forward
|
||||||
|
assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}"
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
AssertionError: target_logits has nan: 10752000
|
||||||
|
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 20205 closing signal SIGTERM
|
||||||
|
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 20206 closing signal SIGTERM
|
||||||
|
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 2 (pid: 20207) of binary: /home/tushilong/anaconda3/envs/realign/bin/python
|
||||||
|
Traceback (most recent call last):
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/bin/torchrun", line 33, in <module>
|
||||||
|
sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')())
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 794, in main
|
||||||
|
run(args)
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/run.py", line 785, in run
|
||||||
|
elastic_launch(
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
|
||||||
|
return launch_agent(self._config, self._entrypoint, list(args))
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
|
||||||
|
raise ChildFailedError(
|
||||||
|
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
|
||||||
|
============================================================
|
||||||
|
train.py FAILED
|
||||||
|
------------------------------------------------------------
|
||||||
|
Failures:
|
||||||
|
<NO_OTHER_FAILURES>
|
||||||
|
------------------------------------------------------------
|
||||||
|
Root Cause (first observed failure):
|
||||||
|
[0]:
|
||||||
|
time : 2024-03-09_10:45:34
|
||||||
|
host : ubuntu
|
||||||
|
rank : 2 (local_rank: 2)
|
||||||
|
exitcode : 1 (pid: 20207)
|
||||||
|
error_file: <N/A>
|
||||||
|
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
|
||||||
|
============================================================
|
168
train/train.py
Normal file
168
train/train.py
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4,5,6'
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
torch.autograd.set_detect_anomaly(True)
|
||||||
|
import transformers
|
||||||
|
from transformers import set_seed
|
||||||
|
set_seed(15)
|
||||||
|
from utils.datasets.nyt10_dataset import NYT10FullDataset, NYT10StylishDataset
|
||||||
|
from trainers import FSDPTrainingArguments, FSDPTrainer
|
||||||
|
from transformers import AutoTokenizer, AutoConfig
|
||||||
|
from transformers import AutoModelForCausalLM, LlamaForCausalLM
|
||||||
|
from llama import Method_1
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
# model/data params
|
||||||
|
base_model: str = '/home/tushilong/hf/models/Llama-2-7b-hf',
|
||||||
|
data_path: str = '../data/nyt10/nyt10_train.txt',
|
||||||
|
output_dir: str = '../ckpts/stylish',
|
||||||
|
|
||||||
|
# training hyperparams
|
||||||
|
do_train: bool = True,
|
||||||
|
micro_batch_size: int = 2,
|
||||||
|
gradient_accumulation_steps: int = 1,
|
||||||
|
gradient_checkpointing: bool = True,
|
||||||
|
num_epochs: int = 1,
|
||||||
|
save_steps: int = 500,
|
||||||
|
learning_rate: float = 2e-5,
|
||||||
|
lr_scheduler_type: str = 'cosine',
|
||||||
|
weight_decay: float = 1e-4,
|
||||||
|
warmup_ratio: float = 0.06,
|
||||||
|
deepspeed_config: str = None,
|
||||||
|
fsdp: str = 'shard_grad_op auto_wrap offload',
|
||||||
|
fsdp_config: str = './configs/fsdp/llama2_fsdp_config.json',
|
||||||
|
smart_embedding: bool = False,
|
||||||
|
|
||||||
|
# evaluating hyperparams
|
||||||
|
do_eval: bool = False,
|
||||||
|
val_set_size: int = 1000,
|
||||||
|
eval_batch_size: int = 4,
|
||||||
|
|
||||||
|
# wandb params
|
||||||
|
wandb_project: str = "",
|
||||||
|
wandb_run_name: str = "",
|
||||||
|
wandb_watch: str = "", # options: false | gradients | all
|
||||||
|
wandb_log_model: str = "", # options: false | true
|
||||||
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
||||||
|
):
|
||||||
|
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
||||||
|
print(
|
||||||
|
f"Training model with params:\n"
|
||||||
|
f"base_model: {base_model}\n"
|
||||||
|
f"output_dir: {output_dir}\n"
|
||||||
|
f"micro_batch_size: {micro_batch_size}\n"
|
||||||
|
f"gradient_accumulation_steps: {gradient_accumulation_steps}\n"
|
||||||
|
f"train_batch_size: {micro_batch_size * gradient_accumulation_steps}\n"
|
||||||
|
f"gradient_checkpointing: {gradient_checkpointing}\n"
|
||||||
|
f"num_epochs: {num_epochs}\n"
|
||||||
|
f"learning_rate: {learning_rate}\n"
|
||||||
|
f"weight_decay: {weight_decay}\n"
|
||||||
|
f"warmup_ratio: {warmup_ratio}\n"
|
||||||
|
f"deepspeed_config: {deepspeed_config}\n"
|
||||||
|
f"fsdp: {fsdp}\n"
|
||||||
|
f"fsdp_config: {fsdp_config}\n"
|
||||||
|
f"smart_embedding: {smart_embedding}\n"
|
||||||
|
f"wandb_project: {wandb_project}\n"
|
||||||
|
f"wandb_run_name: {wandb_run_name}\n"
|
||||||
|
f"wandb_watch: {wandb_watch}\n"
|
||||||
|
f"wandb_log_model: {wandb_log_model}\n"
|
||||||
|
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
not (deepspeed_config and fsdp)
|
||||||
|
), "Can not specified both deepspeed_config and fsdp_config"
|
||||||
|
|
||||||
|
# training arguments
|
||||||
|
bf16 = True # torch.cuda.get_device_capability()[0] >= 8
|
||||||
|
fp16 = not bf16
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
# model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True)
|
||||||
|
# model_dev_id = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
|
||||||
|
model = Method_1.from_pretrained(base_model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Check if parameter passed or if set within environ
|
||||||
|
# use_wandb = len(wandb_project) > 0 or (
|
||||||
|
# "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
|
||||||
|
# )
|
||||||
|
use_wandb = False
|
||||||
|
# Only overwrite environ if wandb param passed
|
||||||
|
# if len(wandb_project) > 0:
|
||||||
|
# os.environ["WANDB_PROJECT"] = wandb_project
|
||||||
|
# if len(wandb_watch) > 0:
|
||||||
|
# os.environ["WANDB_WATCH"] = wandb_watch
|
||||||
|
# if len(wandb_log_model) > 0:
|
||||||
|
# os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
||||||
|
|
||||||
|
|
||||||
|
# train_data = NYT10FullDataset(data_path, tokenizer)
|
||||||
|
# train_data = NYT10StylishDataset(data_path, tokenizer, 30)
|
||||||
|
train_data = NYT10StylishDataset(data_path, tokenizer, 1000)
|
||||||
|
val_data = None
|
||||||
|
|
||||||
|
training_args = FSDPTrainingArguments(
|
||||||
|
use_ffd_sampler=True,
|
||||||
|
output_dir=output_dir,
|
||||||
|
no_cuda=not torch.cuda.is_available(),
|
||||||
|
seed=15,
|
||||||
|
data_seed=15,
|
||||||
|
do_train=do_train,
|
||||||
|
num_train_epochs=num_epochs,
|
||||||
|
optim="adamw_torch",
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
|
per_device_train_batch_size=micro_batch_size,
|
||||||
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
|
warmup_ratio=warmup_ratio,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
half_precision_backend="auto",
|
||||||
|
fp16=fp16,
|
||||||
|
bf16=bf16,
|
||||||
|
adam_beta1=0.9,
|
||||||
|
adam_beta2=0.95,
|
||||||
|
save_strategy="steps",
|
||||||
|
save_steps=save_steps,
|
||||||
|
save_total_limit=2,
|
||||||
|
logging_steps=1,
|
||||||
|
report_to= "none", # "wandb" if use_wandb else None,
|
||||||
|
run_name=None, #wandb_run_name if use_wandb else None,
|
||||||
|
deepspeed=deepspeed_config,
|
||||||
|
fsdp=fsdp,
|
||||||
|
fsdp_config=fsdp_config,
|
||||||
|
gradient_checkpointing=gradient_checkpointing,
|
||||||
|
do_eval=do_eval and val_set_size > 0,
|
||||||
|
evaluation_strategy="steps" if do_eval and val_set_size > 0 else "no",
|
||||||
|
eval_steps=save_steps,
|
||||||
|
per_device_eval_batch_size=eval_batch_size,
|
||||||
|
# group_by_length=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = FSDPTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_data,
|
||||||
|
eval_dataset=val_data,
|
||||||
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
|
tokenizer, pad_to_multiple_of=8, return_tensors='pt', padding=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
trainer.save_model(output_dir)
|
||||||
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(train)
|
||||||
|
|
2
train/trainers/__init__.py
Normal file
2
train/trainers/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .fsdp_training_args import FSDPTrainingArguments
|
||||||
|
from .fsdp_trainer import FSDPTrainer
|
161
train/trainers/ffd_sampler.py
Normal file
161
train/trainers/ffd_sampler.py
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numba
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd(a: np.ndarray, c: int):
|
||||||
|
# First-fit-decreasing bin packing
|
||||||
|
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||||
|
|
||||||
|
a = np.sort(a)[::-1]
|
||||||
|
bins = []
|
||||||
|
for size in a:
|
||||||
|
add_new = True
|
||||||
|
for idx in range(len(bins)):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
add_new = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if add_new:
|
||||||
|
bins.append(c - size)
|
||||||
|
|
||||||
|
return len(bins)
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||||
|
# First-fit-decreasing bin packing (with result return)
|
||||||
|
|
||||||
|
indices = np.argsort(a)[::-1]
|
||||||
|
a = a[indices]
|
||||||
|
|
||||||
|
bins = []
|
||||||
|
bins_result = []
|
||||||
|
for a_id, size in enumerate(a):
|
||||||
|
add_new = True
|
||||||
|
for idx in range(len(bins)):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
bins_result[idx].append(indices[a_id] + start_index)
|
||||||
|
add_new = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if add_new:
|
||||||
|
bins.append(c - size)
|
||||||
|
bins_result.append([indices[a_id] + start_index])
|
||||||
|
|
||||||
|
return bins_result
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def allocate(lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int):
|
||||||
|
# Dynamic batch allocator, similar to Multifit
|
||||||
|
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||||
|
# ~96.4% efficiency on OpenChat training set (2048 ctx len)
|
||||||
|
|
||||||
|
s = 0
|
||||||
|
start_index = 0
|
||||||
|
result = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# binary search [l, r)
|
||||||
|
l = 1
|
||||||
|
r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||||
|
|
||||||
|
while r - l > 1:
|
||||||
|
m = (l + r) // 2
|
||||||
|
if ffd(lengths[start_index: start_index + m], c) <= n:
|
||||||
|
l = m
|
||||||
|
else:
|
||||||
|
r = m
|
||||||
|
|
||||||
|
# use length l
|
||||||
|
batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index)
|
||||||
|
if len(batch) < n:
|
||||||
|
break
|
||||||
|
|
||||||
|
start_index += l
|
||||||
|
s = lengths_cumsum[start_index - 1]
|
||||||
|
|
||||||
|
# add local rank
|
||||||
|
result.append(batch[rank])
|
||||||
|
|
||||||
|
return result, s / max(1, len(result) * c * n) # Avoid division by zero
|
||||||
|
|
||||||
|
|
||||||
|
class FFDDistributedBatchSampler(Sampler):
|
||||||
|
"""Unpadded length sampling using FFD (First-fit-decreasing bin packing).
|
||||||
|
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_max_length: int,
|
||||||
|
lengths: List[int],
|
||||||
|
num_replicas: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
|
seed: int = 0,
|
||||||
|
):
|
||||||
|
# Get rank
|
||||||
|
if num_replicas is None:
|
||||||
|
if not dist.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
num_replicas = dist.get_world_size()
|
||||||
|
if rank is None:
|
||||||
|
if not dist.is_available():
|
||||||
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
self.rank = rank
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
self.batch_max_length = batch_max_length
|
||||||
|
self.lengths = lengths
|
||||||
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
self.total_epochs = 0
|
||||||
|
self.total_efficiency = 0
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int):
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
def generate_batches(self, set_stats=False):
|
||||||
|
indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths))
|
||||||
|
|
||||||
|
lengths = self.lengths[indices]
|
||||||
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|
||||||
|
batches, efficiency = allocate(lengths=lengths,
|
||||||
|
lengths_cumsum=lengths_cumsum,
|
||||||
|
rank=self.rank,
|
||||||
|
c=self.batch_max_length,
|
||||||
|
n=self.num_replicas)
|
||||||
|
|
||||||
|
batches = [indices[batch] for batch in batches]
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
if set_stats:
|
||||||
|
self.total_epochs += 1
|
||||||
|
self.total_efficiency += efficiency
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
batches = self.generate_batches(set_stats=True)
|
||||||
|
return iter(batches)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
batches = self.generate_batches()
|
||||||
|
return len(batches)
|
||||||
|
|
||||||
|
def efficiency(self):
|
||||||
|
return self.total_efficiency / self.total_epochs
|
925
train/trainers/fsdp_trainer.py
Normal file
925
train/trainers/fsdp_trainer.py
Normal file
|
@ -0,0 +1,925 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers.trainer import *
|
||||||
|
|
||||||
|
from .ffd_sampler import FFDDistributedBatchSampler
|
||||||
|
from .utils import ExtendedFSDPOption, enable_low_gpu_full_post_state_dict_hook
|
||||||
|
|
||||||
|
|
||||||
|
class FSDPTrainer(transformers.Trainer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[PreTrainedModel, nn.Module] = None,
|
||||||
|
args: TrainingArguments = None,
|
||||||
|
data_collator: Optional[DataCollator] = None,
|
||||||
|
train_dataset: Optional[Dataset] = None,
|
||||||
|
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||||
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||||
|
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||||
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
|
callbacks: Optional[List[TrainerCallback]] = None,
|
||||||
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||||
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||||
|
):
|
||||||
|
if args is None:
|
||||||
|
output_dir = "tmp_trainer"
|
||||||
|
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
|
||||||
|
args = TrainingArguments(output_dir=output_dir)
|
||||||
|
self.args = args
|
||||||
|
# Seed must be set before instantiating the model when using model
|
||||||
|
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||||
|
self.hp_name = None
|
||||||
|
self.deepspeed = None
|
||||||
|
self.is_in_train = False
|
||||||
|
|
||||||
|
self.create_accelerator_and_postprocess()
|
||||||
|
|
||||||
|
# memory metrics - must set up as early as possible
|
||||||
|
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||||
|
self._memory_tracker.start()
|
||||||
|
|
||||||
|
# set the correct log level depending on the node
|
||||||
|
log_level = args.get_process_log_level()
|
||||||
|
logging.set_verbosity(log_level)
|
||||||
|
|
||||||
|
# force device and distributed setup init explicitly
|
||||||
|
args._setup_devices
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
if model_init is not None:
|
||||||
|
self.model_init = model_init
|
||||||
|
model = self.call_model_init()
|
||||||
|
else:
|
||||||
|
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
|
||||||
|
else:
|
||||||
|
if model_init is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
|
||||||
|
" overwrite your model when calling the `train` method. This will become a fatal error in the next"
|
||||||
|
" release.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.model_init = model_init
|
||||||
|
|
||||||
|
if model.__class__.__name__ in MODEL_MAPPING_NAMES:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
|
||||||
|
"computes hidden states and does not accept any labels. You should choose a model with a head "
|
||||||
|
"suitable for your task like any of the `AutoModelForXxx` listed at "
|
||||||
|
"https://huggingface.co/docs/transformers/model_doc/auto."
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
|
||||||
|
self.is_model_parallel = True
|
||||||
|
else:
|
||||||
|
self.is_model_parallel = False
|
||||||
|
|
||||||
|
if getattr(model, "hf_device_map", None) is not None:
|
||||||
|
devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
|
||||||
|
if len(devices) > 1:
|
||||||
|
self.is_model_parallel = True
|
||||||
|
else:
|
||||||
|
self.is_model_parallel = self.args.device != torch.device(devices[0])
|
||||||
|
|
||||||
|
# warn users
|
||||||
|
logger.info(
|
||||||
|
"You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
|
||||||
|
" to `True` to avoid any unexpected behavior such as device placement mismatching."
|
||||||
|
)
|
||||||
|
|
||||||
|
# At this stage the model is already loaded
|
||||||
|
if getattr(model, "is_quantized", False):
|
||||||
|
if getattr(model, "_is_quantized_training_enabled", False):
|
||||||
|
logger.info(
|
||||||
|
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
|
||||||
|
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
|
||||||
|
" check "
|
||||||
|
" the examples in https://github.com/huggingface/peft for more details."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
|
||||||
|
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup Sharded DDP training
|
||||||
|
self.sharded_ddp = None
|
||||||
|
if len(args.sharded_ddp) > 0:
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
raise ValueError(
|
||||||
|
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||||
|
)
|
||||||
|
if len(args.fsdp) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
|
||||||
|
)
|
||||||
|
if args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
|
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||||
|
elif not is_fairscale_available():
|
||||||
|
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
||||||
|
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
|
||||||
|
raise ImportError(
|
||||||
|
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
|
||||||
|
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
|
||||||
|
)
|
||||||
|
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
|
||||||
|
self.sharded_ddp = ShardedDDPOption.SIMPLE
|
||||||
|
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
|
||||||
|
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
|
||||||
|
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
|
||||||
|
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
|
||||||
|
|
||||||
|
self.fsdp = None
|
||||||
|
if len(args.fsdp) > 0:
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
raise ValueError(
|
||||||
|
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||||
|
)
|
||||||
|
if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||||
|
raise ValueError("Using fsdp only works in distributed training.")
|
||||||
|
|
||||||
|
# dep_version_check("torch>=1.12.0")
|
||||||
|
# Would have to update setup.py with torch>=1.12.0
|
||||||
|
# which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
|
||||||
|
# below is the current alternative.
|
||||||
|
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
|
||||||
|
raise ValueError("FSDP requires PyTorch >= 1.12.0")
|
||||||
|
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
|
||||||
|
|
||||||
|
if ExtendedFSDPOption.FULL_SHARD in args.fsdp:
|
||||||
|
self.fsdp = ShardingStrategy.FULL_SHARD
|
||||||
|
elif ExtendedFSDPOption.SHARD_GRAD_OP in args.fsdp:
|
||||||
|
self.fsdp = ShardingStrategy.SHARD_GRAD_OP
|
||||||
|
elif ExtendedFSDPOption.NO_SHARD in args.fsdp:
|
||||||
|
self.fsdp = ShardingStrategy.NO_SHARD
|
||||||
|
# extention starts here
|
||||||
|
elif ExtendedFSDPOption.HYBRID_SHARD in args.fsdp:
|
||||||
|
self.fsdp = ShardingStrategy.HYBRID_SHARD
|
||||||
|
elif ExtendedFSDPOption._HYBRID_SHARD_ZERO2 in args.fsdp:
|
||||||
|
self.fsdp = ShardingStrategy._HYBRID_SHARD_ZERO2
|
||||||
|
# extention ends here
|
||||||
|
|
||||||
|
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
|
||||||
|
# modification starts here
|
||||||
|
if self.args.fsdp_config.get("fsdp_backward_prefetch", "") == "backward_post":
|
||||||
|
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
|
||||||
|
# modification ends here
|
||||||
|
|
||||||
|
self.forward_prefetch = False
|
||||||
|
# modification starts here
|
||||||
|
if self.args.fsdp_config.get("forward_prefetch", False):
|
||||||
|
# modification ends here
|
||||||
|
self.forward_prefetch = True
|
||||||
|
|
||||||
|
self.limit_all_gathers = False
|
||||||
|
if self.args.fsdp_config.get("limit_all_gathers", False):
|
||||||
|
self.limit_all_gathers = True
|
||||||
|
|
||||||
|
# one place to sort out whether to place the model on device or not
|
||||||
|
# postpone switching model to cuda when:
|
||||||
|
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
||||||
|
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
||||||
|
# and we only use deepspeed for training at the moment
|
||||||
|
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
|
||||||
|
# 4. Sharded DDP - same as MP
|
||||||
|
# 5. FSDP - same as MP
|
||||||
|
self.place_model_on_device = args.place_model_on_device
|
||||||
|
if (
|
||||||
|
self.is_model_parallel
|
||||||
|
or self.is_deepspeed_enabled
|
||||||
|
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
|
||||||
|
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
||||||
|
or (self.fsdp is not None)
|
||||||
|
or self.is_fsdp_enabled
|
||||||
|
):
|
||||||
|
self.place_model_on_device = False
|
||||||
|
|
||||||
|
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||||
|
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||||
|
self.train_dataset = train_dataset
|
||||||
|
self.eval_dataset = eval_dataset
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
# Quantized models doesn't support `.to` operation.
|
||||||
|
if self.place_model_on_device and not getattr(model, "is_quantized", False):
|
||||||
|
self._move_model_to_device(model, args.device)
|
||||||
|
|
||||||
|
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||||
|
if self.is_model_parallel:
|
||||||
|
self.args._n_gpu = 1
|
||||||
|
|
||||||
|
# later use `self.model is self.model_wrapped` to check if it's wrapped or not
|
||||||
|
self.model_wrapped = model
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.compute_metrics = compute_metrics
|
||||||
|
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
|
||||||
|
self.optimizer, self.lr_scheduler = optimizers
|
||||||
|
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
|
||||||
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
|
)
|
||||||
|
if is_torch_tpu_available() and self.optimizer is not None:
|
||||||
|
for param in self.model.parameters():
|
||||||
|
model_device = param.device
|
||||||
|
break
|
||||||
|
for param_group in self.optimizer.param_groups:
|
||||||
|
if len(param_group["params"]) > 0:
|
||||||
|
optimizer_device = param_group["params"][0].device
|
||||||
|
break
|
||||||
|
if model_device != optimizer_device:
|
||||||
|
raise ValueError(
|
||||||
|
"The model and the optimizer parameters are not on the same device, which probably means you"
|
||||||
|
" created an optimizer around your model **before** putting on the device and passing it to the"
|
||||||
|
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
||||||
|
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
||||||
|
)
|
||||||
|
if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
|
||||||
|
self.optimizer is not None or self.lr_scheduler is not None
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
|
||||||
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
|
)
|
||||||
|
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||||
|
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
||||||
|
self.callback_handler = CallbackHandler(
|
||||||
|
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
|
||||||
|
)
|
||||||
|
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
||||||
|
|
||||||
|
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
||||||
|
self._loggers_initialized = False
|
||||||
|
|
||||||
|
# Create clone of distant repo and output directory if needed
|
||||||
|
if self.args.push_to_hub:
|
||||||
|
self.init_git_repo(at_init=True)
|
||||||
|
# In case of pull, we need to make sure every process has the latest.
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
xm.rendezvous("init git repo")
|
||||||
|
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
if self.args.should_save:
|
||||||
|
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
||||||
|
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
|
||||||
|
|
||||||
|
if args.max_steps > 0:
|
||||||
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
|
if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The train_dataset does not implement __len__, max_steps has to be specified. "
|
||||||
|
"The number of steps needs to be known in advance for the learning rate scheduler."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
train_dataset is not None
|
||||||
|
and isinstance(train_dataset, torch.utils.data.IterableDataset)
|
||||||
|
and args.group_by_length
|
||||||
|
):
|
||||||
|
raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")
|
||||||
|
|
||||||
|
self._signature_columns = None
|
||||||
|
|
||||||
|
# Mixed precision setup
|
||||||
|
self.use_apex = False
|
||||||
|
self.use_cuda_amp = False
|
||||||
|
self.use_cpu_amp = False
|
||||||
|
|
||||||
|
# Mixed precision setup for SageMaker Model Parallel
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
# BF16 + model parallelism in SageMaker: currently not supported, raise an error
|
||||||
|
if args.bf16:
|
||||||
|
raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
|
||||||
|
|
||||||
|
if IS_SAGEMAKER_MP_POST_1_10:
|
||||||
|
# When there's mismatch between SMP config and trainer argument, use SMP config as truth
|
||||||
|
if args.fp16 != smp.state.cfg.fp16:
|
||||||
|
logger.warning(
|
||||||
|
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
|
||||||
|
f"but FP16 provided in trainer argument is {args.fp16},"
|
||||||
|
f"setting to {smp.state.cfg.fp16}"
|
||||||
|
)
|
||||||
|
args.fp16 = smp.state.cfg.fp16
|
||||||
|
else:
|
||||||
|
# smp < 1.10 does not support fp16 in trainer.
|
||||||
|
if hasattr(smp.state.cfg, "fp16"):
|
||||||
|
logger.warning(
|
||||||
|
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
|
||||||
|
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
|
||||||
|
if args.half_precision_backend == "auto":
|
||||||
|
if args.device == torch.device("cpu"):
|
||||||
|
if args.fp16:
|
||||||
|
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
|
||||||
|
else:
|
||||||
|
args.half_precision_backend = "cpu_amp"
|
||||||
|
else:
|
||||||
|
args.half_precision_backend = "cuda_amp"
|
||||||
|
|
||||||
|
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||||
|
|
||||||
|
self.do_grad_scaling = False
|
||||||
|
if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
|
||||||
|
# deepspeed and SageMaker Model Parallel manage their own half precision
|
||||||
|
if self.sharded_ddp is not None:
|
||||||
|
if args.half_precision_backend == "cuda_amp":
|
||||||
|
self.use_cuda_amp = True
|
||||||
|
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
||||||
|
# bf16 does not need grad scaling
|
||||||
|
self.do_grad_scaling = self.amp_dtype == torch.float16
|
||||||
|
if self.do_grad_scaling:
|
||||||
|
if self.sharded_ddp is not None:
|
||||||
|
self.scaler = ShardedGradScaler()
|
||||||
|
elif self.fsdp is not None:
|
||||||
|
from torch.distributed.fsdp.sharded_grad_scaler import (
|
||||||
|
ShardedGradScaler as FSDPShardedGradScaler,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scaler = FSDPShardedGradScaler()
|
||||||
|
elif is_torch_tpu_available():
|
||||||
|
from torch_xla.amp import GradScaler
|
||||||
|
|
||||||
|
self.scaler = GradScaler()
|
||||||
|
else:
|
||||||
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
|
elif args.half_precision_backend == "cpu_amp":
|
||||||
|
self.use_cpu_amp = True
|
||||||
|
self.amp_dtype = torch.bfloat16
|
||||||
|
elif args.half_precision_backend == "apex":
|
||||||
|
if not is_apex_available():
|
||||||
|
raise ImportError(
|
||||||
|
"Using FP16 with APEX but APEX is not installed, please refer to"
|
||||||
|
" https://www.github.com/nvidia/apex."
|
||||||
|
)
|
||||||
|
self.use_apex = True
|
||||||
|
|
||||||
|
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
|
||||||
|
if (
|
||||||
|
is_sagemaker_mp_enabled()
|
||||||
|
and self.use_cuda_amp
|
||||||
|
and args.max_grad_norm is not None
|
||||||
|
and args.max_grad_norm > 0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
|
||||||
|
"along 'max_grad_norm': 0 in your hyperparameters."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Label smoothing
|
||||||
|
if self.args.label_smoothing_factor != 0:
|
||||||
|
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
|
||||||
|
else:
|
||||||
|
self.label_smoother = None
|
||||||
|
|
||||||
|
self.state = TrainerState(
|
||||||
|
is_local_process_zero=self.is_local_process_zero(),
|
||||||
|
is_world_process_zero=self.is_world_process_zero(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.control = TrainerControl()
|
||||||
|
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
|
||||||
|
# returned to 0 every time flos need to be logged
|
||||||
|
self.current_flos = 0
|
||||||
|
self.hp_search_backend = None
|
||||||
|
self.use_tune_checkpoints = False
|
||||||
|
default_label_names = find_labels(self.model.__class__)
|
||||||
|
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||||
|
self.can_return_loss = can_return_loss(self.model.__class__)
|
||||||
|
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
# Internal variables to help with automatic batch size reduction
|
||||||
|
self._train_batch_size = args.train_batch_size
|
||||||
|
self._created_lr_scheduler = False
|
||||||
|
|
||||||
|
# very last
|
||||||
|
self._memory_tracker.stop_and_update_metrics()
|
||||||
|
|
||||||
|
# torch.compile
|
||||||
|
if args.torch_compile and not is_torch_compile_available():
|
||||||
|
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
|
||||||
|
|
||||||
|
# finally applying `low_gpu_full_post_state_dict_hook`` for fsdp `state_dict`
|
||||||
|
enable_low_gpu_full_post_state_dict_hook()
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.use_ipex:
|
||||||
|
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
|
||||||
|
model = self.ipex_optimize_model(model, training, dtype=dtype)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
# Wrapping the base model twice in a DistributedModel will raise an error.
|
||||||
|
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
||||||
|
return self.model_wrapped
|
||||||
|
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
|
||||||
|
if unwrap_model(model) is not model:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Mixed precision training with apex (torch < 1.6)
|
||||||
|
if self.use_apex and training:
|
||||||
|
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||||
|
|
||||||
|
# Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
|
||||||
|
if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
|
if self.args.jit_mode_eval:
|
||||||
|
start_time = time.time()
|
||||||
|
model = self.torch_jit_model_eval(model, dataloader, training)
|
||||||
|
self.jit_compilation_time = round(time.time() - start_time, 4)
|
||||||
|
|
||||||
|
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||||
|
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||||
|
if not training:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
|
if self.sharded_ddp is not None:
|
||||||
|
# Sharded DDP!
|
||||||
|
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||||
|
model = ShardedDDP(model, self.optimizer)
|
||||||
|
else:
|
||||||
|
mixed_precision = self.args.fp16 or self.args.bf16
|
||||||
|
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
|
||||||
|
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
|
||||||
|
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||||
|
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
|
||||||
|
model = auto_wrap(model)
|
||||||
|
self.model = model = FullyShardedDDP(
|
||||||
|
model,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
reshard_after_forward=zero_3,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
).to(self.args.device)
|
||||||
|
# Distributed training using PyTorch FSDP
|
||||||
|
elif self.fsdp is not None:
|
||||||
|
# fix starts here
|
||||||
|
if not self.args.fsdp_config["xla"]:
|
||||||
|
# PyTorch FSDP!
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
||||||
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
||||||
|
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||||
|
|
||||||
|
if FSDPOption.OFFLOAD in self.args.fsdp:
|
||||||
|
cpu_offload = CPUOffload(offload_params=True)
|
||||||
|
else:
|
||||||
|
cpu_offload = CPUOffload(offload_params=False)
|
||||||
|
|
||||||
|
auto_wrap_policy = None
|
||||||
|
|
||||||
|
if FSDPOption.AUTO_WRAP in self.args.fsdp:
|
||||||
|
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||||
|
auto_wrap_policy = functools.partial(
|
||||||
|
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||||
|
)
|
||||||
|
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||||
|
transformer_cls_to_wrap = set()
|
||||||
|
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
||||||
|
transformer_cls = get_module_class_from_name(model, layer_class)
|
||||||
|
if transformer_cls is None:
|
||||||
|
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||||
|
else:
|
||||||
|
transformer_cls_to_wrap.add(transformer_cls)
|
||||||
|
auto_wrap_policy = functools.partial(
|
||||||
|
transformer_auto_wrap_policy,
|
||||||
|
# Transformer layer class to wrap
|
||||||
|
transformer_layer_cls=transformer_cls_to_wrap,
|
||||||
|
)
|
||||||
|
mixed_precision_policy = None
|
||||||
|
dtype = None
|
||||||
|
if self.args.fp16:
|
||||||
|
dtype = torch.float16
|
||||||
|
elif self.args.bf16:
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
if dtype is not None:
|
||||||
|
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
|
||||||
|
if type(model) != FSDP:
|
||||||
|
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||||
|
signature = inspect.signature(FSDP.__init__).parameters.keys()
|
||||||
|
kwargs = {}
|
||||||
|
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
|
||||||
|
if arg in signature:
|
||||||
|
kwargs[arg] = getattr(self, arg)
|
||||||
|
self.model = model = FSDP(
|
||||||
|
model,
|
||||||
|
sharding_strategy=self.fsdp,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
auto_wrap_policy=auto_wrap_policy,
|
||||||
|
mixed_precision=mixed_precision_policy,
|
||||||
|
device_id=self.args.device,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
for submodule in traversal_utils._get_fsdp_states(model):
|
||||||
|
print(submodule._state_dict_type, submodule._state_dict_config)
|
||||||
|
break
|
||||||
|
# fix ends here
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
|
||||||
|
from torch_xla.distributed.fsdp import checkpoint_module
|
||||||
|
from torch_xla.distributed.fsdp.wrap import (
|
||||||
|
size_based_auto_wrap_policy,
|
||||||
|
transformer_auto_wrap_policy,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
|
||||||
|
auto_wrap_policy = None
|
||||||
|
auto_wrapper_callable = None
|
||||||
|
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||||
|
auto_wrap_policy = functools.partial(
|
||||||
|
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||||
|
)
|
||||||
|
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||||
|
transformer_cls_to_wrap = set()
|
||||||
|
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
||||||
|
transformer_cls = get_module_class_from_name(model, layer_class)
|
||||||
|
if transformer_cls is None:
|
||||||
|
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||||
|
else:
|
||||||
|
transformer_cls_to_wrap.add(transformer_cls)
|
||||||
|
auto_wrap_policy = functools.partial(
|
||||||
|
transformer_auto_wrap_policy,
|
||||||
|
# Transformer layer class to wrap
|
||||||
|
transformer_layer_cls=transformer_cls_to_wrap,
|
||||||
|
)
|
||||||
|
fsdp_kwargs = self.args.xla_fsdp_config
|
||||||
|
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||||
|
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
|
||||||
|
def auto_wrapper_callable(m, *args, **kwargs):
|
||||||
|
return FSDP(checkpoint_module(m), *args, **kwargs)
|
||||||
|
|
||||||
|
# Wrap the base model with an outer FSDP wrapper
|
||||||
|
self.model = model = FSDP(
|
||||||
|
model,
|
||||||
|
auto_wrap_policy=auto_wrap_policy,
|
||||||
|
auto_wrapper_callable=auto_wrapper_callable,
|
||||||
|
**fsdp_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch `xm.optimizer_step` should not reduce gradients in this case,
|
||||||
|
# as FSDP does not need gradient reduction over sharded parameters.
|
||||||
|
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
|
||||||
|
loss = optimizer.step(**optimizer_args)
|
||||||
|
if barrier:
|
||||||
|
xm.mark_step()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
xm.optimizer_step = patched_optimizer_step
|
||||||
|
elif is_sagemaker_dp_enabled():
|
||||||
|
model = nn.parallel.DistributedDataParallel(
|
||||||
|
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||||
|
)
|
||||||
|
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
|
if is_torch_neuroncore_available():
|
||||||
|
return model
|
||||||
|
kwargs = {}
|
||||||
|
if self.args.ddp_find_unused_parameters is not None:
|
||||||
|
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
|
||||||
|
elif isinstance(model, PreTrainedModel):
|
||||||
|
# find_unused_parameters breaks checkpointing as per
|
||||||
|
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||||
|
kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
|
||||||
|
else:
|
||||||
|
kwargs["find_unused_parameters"] = True
|
||||||
|
|
||||||
|
if self.args.ddp_bucket_cap_mb is not None:
|
||||||
|
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
|
||||||
|
|
||||||
|
if self.args.ddp_broadcast_buffers is not None:
|
||||||
|
kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
|
||||||
|
|
||||||
|
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_batch_sampler(self, dataset=None):
|
||||||
|
if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1:
|
||||||
|
dataset = dataset if dataset is not None else self.train_dataset
|
||||||
|
try:
|
||||||
|
batch_max_len = self.args.per_device_train_batch_size * unwrap_model(self.model).model_avg_context
|
||||||
|
except:
|
||||||
|
# raise RuntimeError("group_by_length in distributed training requires model has attr `model_max_context`")
|
||||||
|
batch_max_len = self.args.per_device_train_batch_size * self.args.model_avg_context
|
||||||
|
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||||
|
lengths = LengthGroupedSampler(
|
||||||
|
batch_size=-1, # we just want to know about the lengths of the dataset so no need to pass `batch_size`
|
||||||
|
dataset=dataset,
|
||||||
|
model_input_name=model_input_name
|
||||||
|
).lengths
|
||||||
|
seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
|
||||||
|
batch_sampler = FFDDistributedBatchSampler(
|
||||||
|
batch_max_length=batch_max_len,
|
||||||
|
lengths=np.array(lengths),
|
||||||
|
seed=seed
|
||||||
|
)
|
||||||
|
|
||||||
|
return batch_sampler
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1:
|
||||||
|
if self.train_dataset is None:
|
||||||
|
raise ValueError("Trainer: training requires a train_dataset.")
|
||||||
|
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
data_collator = self.data_collator
|
||||||
|
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||||
|
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||||
|
else:
|
||||||
|
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
||||||
|
|
||||||
|
batch_sampler = self.get_batch_sampler(train_dataset)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
drop_last=self.args.dataloader_drop_last,
|
||||||
|
collate_fn=data_collator
|
||||||
|
)
|
||||||
|
# return self.accelerator.prepare(dataloader)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
|
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
||||||
|
"""
|
||||||
|
Will save the model, so you can reload it using `from_pretrained()`.
|
||||||
|
|
||||||
|
Will only save from the main process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = self.args.output_dir
|
||||||
|
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
self._save_tpu(output_dir)
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
# Calling the state_dict needs to be done on the wrapped model and on all processes.
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
state_dict = self.model_wrapped.state_dict()
|
||||||
|
if self.args.should_save:
|
||||||
|
self._save(output_dir, state_dict=state_dict)
|
||||||
|
if IS_SAGEMAKER_MP_POST_1_10:
|
||||||
|
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
|
||||||
|
Path(os.path.join(output_dir, "user_content.pt")).touch()
|
||||||
|
elif (
|
||||||
|
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
|
||||||
|
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||||
|
or self.fsdp is not None
|
||||||
|
or self.is_fsdp_enabled
|
||||||
|
):
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
if self.args.should_save:
|
||||||
|
self._save(output_dir, state_dict=state_dict)
|
||||||
|
# modification starts here
|
||||||
|
if self.is_fsdp_enabled and self.args.save_with_fsdp:
|
||||||
|
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
|
||||||
|
# modification ends here
|
||||||
|
|
||||||
|
elif self.is_deepspeed_enabled:
|
||||||
|
# this takes care of everything as long as we aren't under zero3
|
||||||
|
if version.parse(accelerate_version) <= version.parse("0.20.3"):
|
||||||
|
raise ValueError("Install Accelerate from main branch")
|
||||||
|
try:
|
||||||
|
state_dict = self.accelerator.get_state_dict(self.deepspeed)
|
||||||
|
if self.args.should_save:
|
||||||
|
self._save(output_dir, state_dict=state_dict)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(
|
||||||
|
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
||||||
|
" zero_to_fp32.py to recover weights"
|
||||||
|
)
|
||||||
|
self.model_wrapped.save_checkpoint(output_dir)
|
||||||
|
|
||||||
|
elif self.args.should_save:
|
||||||
|
self._save(output_dir)
|
||||||
|
|
||||||
|
# Push to the Hub when `save_model` is called by the user.
|
||||||
|
if self.args.push_to_hub and not _internal_call:
|
||||||
|
self.push_to_hub(commit_message="Model save")
|
||||||
|
|
||||||
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
|
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||||
|
# want to save except FullyShardedDDP.
|
||||||
|
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
|
||||||
|
|
||||||
|
# Save model checkpoint
|
||||||
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
|
|
||||||
|
if self.hp_search_backend is None and trial is None:
|
||||||
|
self.store_flos()
|
||||||
|
|
||||||
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
|
self.save_model(output_dir, _internal_call=True)
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
|
||||||
|
# config `stage3_gather_16bit_weights_on_model_save` is True
|
||||||
|
self.model_wrapped.save_checkpoint(output_dir)
|
||||||
|
|
||||||
|
# Save optimizer and scheduler
|
||||||
|
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||||
|
self.optimizer.consolidate_state_dict()
|
||||||
|
|
||||||
|
if self.fsdp or self.is_fsdp_enabled:
|
||||||
|
if self.is_fsdp_enabled:
|
||||||
|
# modification starts here
|
||||||
|
if self.args.save_with_fsdp:
|
||||||
|
save_fsdp_optimizer(
|
||||||
|
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
|
||||||
|
)
|
||||||
|
# modification ends here
|
||||||
|
else:
|
||||||
|
# FSDP has a different interface for saving optimizer states.
|
||||||
|
# Needs to be called on all ranks to gather all states.
|
||||||
|
# full_optim_state_dict will be deprecated after Pytorch 2.2!
|
||||||
|
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
|
||||||
|
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
xm.rendezvous("saving_optimizer_states")
|
||||||
|
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
|
||||||
|
smp.barrier()
|
||||||
|
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
|
||||||
|
smp.save(
|
||||||
|
opt_state_dict,
|
||||||
|
os.path.join(output_dir, OPTIMIZER_NAME),
|
||||||
|
partial=True,
|
||||||
|
v3=smp.state.cfg.shard_optimizer_state,
|
||||||
|
)
|
||||||
|
if self.args.should_save:
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
if self.do_grad_scaling:
|
||||||
|
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||||
|
elif self.args.should_save and not self.is_deepspeed_enabled:
|
||||||
|
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||||
|
if self.fsdp and not self.is_fsdp_enabled:
|
||||||
|
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
|
else:
|
||||||
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
if self.do_grad_scaling:
|
||||||
|
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||||
|
|
||||||
|
# Determine the new best metric / best model checkpoint
|
||||||
|
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||||
|
metric_to_check = self.args.metric_for_best_model
|
||||||
|
if not metric_to_check.startswith("eval_"):
|
||||||
|
metric_to_check = f"eval_{metric_to_check}"
|
||||||
|
metric_value = metrics[metric_to_check]
|
||||||
|
|
||||||
|
operator = np.greater if self.args.greater_is_better else np.less
|
||||||
|
if (
|
||||||
|
self.state.best_metric is None
|
||||||
|
or self.state.best_model_checkpoint is None
|
||||||
|
or operator(metric_value, self.state.best_metric)
|
||||||
|
):
|
||||||
|
self.state.best_metric = metric_value
|
||||||
|
self.state.best_model_checkpoint = output_dir
|
||||||
|
|
||||||
|
# Save the Trainer state
|
||||||
|
if self.args.should_save:
|
||||||
|
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||||
|
|
||||||
|
# Save RNG state in non-distributed training
|
||||||
|
rng_states = {
|
||||||
|
"python": random.getstate(),
|
||||||
|
"numpy": np.random.get_state(),
|
||||||
|
"cpu": torch.random.get_rng_state(),
|
||||||
|
}
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
|
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
|
||||||
|
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
|
||||||
|
else:
|
||||||
|
rng_states["cuda"] = torch.cuda.random.get_rng_state()
|
||||||
|
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
rng_states["xla"] = xm.get_rng_state()
|
||||||
|
|
||||||
|
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
|
||||||
|
# not yet exist.
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if self.args.world_size <= 1:
|
||||||
|
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
|
||||||
|
else:
|
||||||
|
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
|
||||||
|
|
||||||
|
if self.args.push_to_hub:
|
||||||
|
self._push_from_checkpoint(output_dir)
|
||||||
|
|
||||||
|
# Maybe delete some older checkpoints.
|
||||||
|
if self.args.should_save:
|
||||||
|
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||||
|
|
||||||
|
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||||
|
"""If optimizer and scheduler states exist, load them."""
|
||||||
|
if checkpoint is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||||
|
return
|
||||||
|
|
||||||
|
checkpoint_file_exists = (
|
||||||
|
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
|
||||||
|
if is_sagemaker_mp_enabled()
|
||||||
|
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
|
||||||
|
)
|
||||||
|
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
|
||||||
|
# Load in optimizer and scheduler states
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||||
|
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
|
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||||
|
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
|
||||||
|
|
||||||
|
self.optimizer.load_state_dict(optimizer_state)
|
||||||
|
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
||||||
|
else:
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
|
||||||
|
# Optimizer checkpoint was saved with smp >= 1.10
|
||||||
|
def opt_load_hook(mod, opt):
|
||||||
|
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Optimizer checkpoint was saved with smp < 1.10
|
||||||
|
def opt_load_hook(mod, opt):
|
||||||
|
if IS_SAGEMAKER_MP_POST_1_10:
|
||||||
|
opt.load_state_dict(
|
||||||
|
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
|
||||||
|
|
||||||
|
self.model_wrapped.register_post_step_hook(opt_load_hook)
|
||||||
|
else:
|
||||||
|
# We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
|
||||||
|
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
|
||||||
|
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
|
||||||
|
map_location = self.args.device if self.args.world_size > 1 else "cpu"
|
||||||
|
if self.fsdp or self.is_fsdp_enabled:
|
||||||
|
# modification starts here
|
||||||
|
if self.is_fsdp_enabled and self.args.save_with_fsdp:
|
||||||
|
load_fsdp_optimizer(
|
||||||
|
self.accelerator.state.fsdp_plugin,
|
||||||
|
self.accelerator,
|
||||||
|
self.optimizer,
|
||||||
|
self.model,
|
||||||
|
checkpoint,
|
||||||
|
)
|
||||||
|
elif not self.is_fsdp_enabled:
|
||||||
|
full_osd = None
|
||||||
|
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
|
||||||
|
if self.args.process_index == 0:
|
||||||
|
full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
|
||||||
|
# call scatter_full_optim_state_dict on all ranks
|
||||||
|
sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
|
||||||
|
self.optimizer.load_state_dict(sharded_osd)
|
||||||
|
else:
|
||||||
|
self.optimizer.load_state_dict(
|
||||||
|
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||||
|
)
|
||||||
|
# modification ends here
|
||||||
|
else:
|
||||||
|
self.optimizer.load_state_dict(
|
||||||
|
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||||
|
)
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||||
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
|
||||||
|
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
||||||
|
|
478
train/trainers/fsdp_training_args.py
Normal file
478
train/trainers/fsdp_training_args.py
Normal file
|
@ -0,0 +1,478 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers.training_args import *
|
||||||
|
|
||||||
|
from .utils import ExtendedFSDPOption
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FSDPTrainingArguments(transformers.TrainingArguments):
|
||||||
|
# about data-efficient sampler
|
||||||
|
use_ffd_sampler: bool = False
|
||||||
|
model_avg_context: int = 2048
|
||||||
|
|
||||||
|
# about saving
|
||||||
|
# if not save with fsdp, then must not load with fsdp
|
||||||
|
save_with_fsdp: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||||
|
# in the current directory instead of the actual home
|
||||||
|
# see https://github.com/huggingface/transformers/issues/10628
|
||||||
|
if self.output_dir is not None:
|
||||||
|
self.output_dir = os.path.expanduser(self.output_dir)
|
||||||
|
if self.logging_dir is None and self.output_dir is not None:
|
||||||
|
self.logging_dir = os.path.join(self.output_dir, default_logdir())
|
||||||
|
if self.logging_dir is not None:
|
||||||
|
self.logging_dir = os.path.expanduser(self.logging_dir)
|
||||||
|
|
||||||
|
if self.disable_tqdm is None:
|
||||||
|
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||||
|
|
||||||
|
if isinstance(self.evaluation_strategy, EvaluationStrategy):
|
||||||
|
warnings.warn(
|
||||||
|
"using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5"
|
||||||
|
" of 🤗 Transformers. Use `IntervalStrategy` instead",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
# Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it.
|
||||||
|
self.evaluation_strategy = self.evaluation_strategy.value
|
||||||
|
|
||||||
|
# if self.xpu_backend is not None:
|
||||||
|
# warnings.warn(
|
||||||
|
# "using `xpu_backend` is deprecated and will be removed in version 4.31"
|
||||||
|
# " of 🤗 Transformers. Use `ddp_backend` instead",
|
||||||
|
# FutureWarning,
|
||||||
|
# )
|
||||||
|
# self.ddp_backend = self.xpu_backend
|
||||||
|
|
||||||
|
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
|
||||||
|
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
||||||
|
self.save_strategy = IntervalStrategy(self.save_strategy)
|
||||||
|
self.hub_strategy = HubStrategy(self.hub_strategy)
|
||||||
|
|
||||||
|
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||||
|
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
|
||||||
|
self.do_eval = True
|
||||||
|
|
||||||
|
# eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero
|
||||||
|
if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
|
||||||
|
if self.logging_steps > 0:
|
||||||
|
logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}")
|
||||||
|
self.eval_steps = self.logging_steps
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or"
|
||||||
|
" --logging_steps"
|
||||||
|
)
|
||||||
|
|
||||||
|
# logging_steps must be non-zero for logging_strategy that is other than 'no'
|
||||||
|
if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0:
|
||||||
|
raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps")
|
||||||
|
|
||||||
|
if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1:
|
||||||
|
if self.logging_steps != int(self.logging_steps):
|
||||||
|
raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}")
|
||||||
|
self.logging_steps = int(self.logging_steps)
|
||||||
|
if self.evaluation_strategy == IntervalStrategy.STEPS and self.eval_steps > 1:
|
||||||
|
if self.eval_steps != int(self.eval_steps):
|
||||||
|
raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}")
|
||||||
|
self.eval_steps = int(self.eval_steps)
|
||||||
|
if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1:
|
||||||
|
if self.save_steps != int(self.save_steps):
|
||||||
|
raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}")
|
||||||
|
self.save_steps = int(self.save_steps)
|
||||||
|
|
||||||
|
# Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible.
|
||||||
|
if self.load_best_model_at_end:
|
||||||
|
if self.evaluation_strategy != self.save_strategy:
|
||||||
|
raise ValueError(
|
||||||
|
"--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation "
|
||||||
|
f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}"
|
||||||
|
)
|
||||||
|
if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
|
||||||
|
if self.eval_steps < 1 or self.save_steps < 1:
|
||||||
|
if not (self.eval_steps < 1 and self.save_steps < 1):
|
||||||
|
raise ValueError(
|
||||||
|
"--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
|
||||||
|
"steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps"
|
||||||
|
f"{self.save_steps} and eval_steps {self.eval_steps}."
|
||||||
|
)
|
||||||
|
# Work around floating point precision issues
|
||||||
|
LARGE_MULTIPLIER = 1_000_000
|
||||||
|
if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
|
||||||
|
f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}."
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation "
|
||||||
|
f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
|
||||||
|
)
|
||||||
|
|
||||||
|
safetensors_available = is_safetensors_available()
|
||||||
|
if self.save_safetensors and not safetensors_available:
|
||||||
|
raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!")
|
||||||
|
if not self.save_safetensors and safetensors_available:
|
||||||
|
logger.info(
|
||||||
|
f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. "
|
||||||
|
f"Safetensors should be a preferred weights saving format due to security and performance reasons. "
|
||||||
|
f"If your model cannot be saved by safetensors please feel free to open an issue at "
|
||||||
|
f"https://github.com/huggingface/safetensors!"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
|
||||||
|
) and self.metric_for_best_model is None:
|
||||||
|
self.metric_for_best_model = "loss"
|
||||||
|
if self.greater_is_better is None and self.metric_for_best_model is not None:
|
||||||
|
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
|
||||||
|
if self.run_name is None:
|
||||||
|
self.run_name = self.output_dir
|
||||||
|
if self.framework == "pt" and is_torch_available():
|
||||||
|
if self.fp16_backend and self.fp16_backend != "auto":
|
||||||
|
warnings.warn(
|
||||||
|
"`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||||
|
" `half_precision_backend` instead",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.half_precision_backend = self.fp16_backend
|
||||||
|
|
||||||
|
if self.bf16 or self.bf16_full_eval:
|
||||||
|
if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
|
||||||
|
# cpu
|
||||||
|
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
|
||||||
|
elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
|
||||||
|
# gpu
|
||||||
|
raise ValueError(
|
||||||
|
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.fp16 and self.bf16:
|
||||||
|
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
|
||||||
|
|
||||||
|
if self.fp16_full_eval and self.bf16_full_eval:
|
||||||
|
raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")
|
||||||
|
|
||||||
|
if self.bf16:
|
||||||
|
if self.half_precision_backend == "apex":
|
||||||
|
raise ValueError(
|
||||||
|
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
|
||||||
|
" `--half_precision_backend cuda_amp` instead"
|
||||||
|
)
|
||||||
|
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
||||||
|
raise ValueError("sharded_ddp is not supported with bf16")
|
||||||
|
|
||||||
|
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
|
||||||
|
if self.evaluation_strategy == IntervalStrategy.NO:
|
||||||
|
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0")
|
||||||
|
|
||||||
|
self.optim = OptimizerNames(self.optim)
|
||||||
|
if self.adafactor:
|
||||||
|
warnings.warn(
|
||||||
|
"`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim"
|
||||||
|
" adafactor` instead",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.optim = OptimizerNames.ADAFACTOR
|
||||||
|
if self.optim == OptimizerNames.ADAMW_TORCH_FUSED and is_torch_available():
|
||||||
|
if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.0.0"):
|
||||||
|
raise ValueError("--optim adamw_torch_fused requires PyTorch 2.0 or higher")
|
||||||
|
# there is a bug in fp16/AMP in pt-2.0.0
|
||||||
|
if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16:
|
||||||
|
raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0")
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.framework == "pt"
|
||||||
|
and is_torch_available()
|
||||||
|
and (self.device.type != "cuda")
|
||||||
|
and (get_xla_device_type(self.device) != "GPU")
|
||||||
|
and (self.fp16 or self.fp16_full_eval)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
|
||||||
|
" (`--fp16_full_eval`) can only be used on CUDA devices."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.framework == "pt"
|
||||||
|
and is_torch_available()
|
||||||
|
and (self.device.type != "cuda")
|
||||||
|
and (get_xla_device_type(self.device) != "GPU")
|
||||||
|
and (get_xla_device_type(self.device) != "TPU")
|
||||||
|
and (self.device.type != "cpu")
|
||||||
|
and (self.bf16 or self.bf16_full_eval)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
|
||||||
|
" (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.torchdynamo is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||||
|
" `torch_compile_backend` instead",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.torch_compile_backend = self.torchdynamo
|
||||||
|
if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
|
||||||
|
self.torch_compile = True
|
||||||
|
if self.torch_compile and self.torch_compile_backend is None:
|
||||||
|
self.torch_compile_backend = "inductor"
|
||||||
|
|
||||||
|
# accelerate integration for torch compile
|
||||||
|
if self.torch_compile:
|
||||||
|
# set env vars for accelerate
|
||||||
|
prefix = "ACCELERATE_DYNAMO_"
|
||||||
|
os.environ[prefix + "BACKEND"] = self.torch_compile_backend
|
||||||
|
if self.torch_compile_mode is not None:
|
||||||
|
os.environ[prefix + "MODE"] = self.torch_compile_mode
|
||||||
|
|
||||||
|
if self.framework == "pt" and is_torch_available() and self.torch_compile:
|
||||||
|
if is_torch_tf32_available():
|
||||||
|
if self.tf32 is None and not self.fp16 or self.bf16:
|
||||||
|
logger.info(
|
||||||
|
"Setting TF32 in CUDA backends to speedup torch compile, you won't see any improvement"
|
||||||
|
" otherwise."
|
||||||
|
)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
|
||||||
|
)
|
||||||
|
if self.framework == "pt" and is_torch_available() and self.tf32 is not None:
|
||||||
|
if self.tf32:
|
||||||
|
if is_torch_tf32_available():
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
else:
|
||||||
|
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
|
||||||
|
else:
|
||||||
|
if is_torch_tf32_available():
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
# no need to assert on else
|
||||||
|
|
||||||
|
if self.report_to is None:
|
||||||
|
logger.info(
|
||||||
|
"The default value for the training argument `--report_to` will change in v5 (from all installed "
|
||||||
|
"integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as "
|
||||||
|
"now. You should start updating your code and make this info disappear :-)."
|
||||||
|
)
|
||||||
|
self.report_to = "all"
|
||||||
|
if self.report_to == "all" or self.report_to == ["all"]:
|
||||||
|
# Import at runtime to avoid a circular import.
|
||||||
|
from transformers.integrations import get_available_reporting_integrations
|
||||||
|
|
||||||
|
self.report_to = get_available_reporting_integrations()
|
||||||
|
elif self.report_to == "none" or self.report_to == ["none"]:
|
||||||
|
self.report_to = []
|
||||||
|
elif not isinstance(self.report_to, list):
|
||||||
|
self.report_to = [self.report_to]
|
||||||
|
|
||||||
|
if self.warmup_ratio < 0 or self.warmup_ratio > 1:
|
||||||
|
raise ValueError("warmup_ratio must lie in range [0,1]")
|
||||||
|
elif self.warmup_ratio > 0 and self.warmup_steps > 0:
|
||||||
|
logger.info(
|
||||||
|
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio"
|
||||||
|
" during training"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
||||||
|
warnings.warn(
|
||||||
|
"using `sharded_ddp` is deprecated and will be removed in version 4.33"
|
||||||
|
" of 🤗 Transformers. Use `fsdp` instead",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
if isinstance(self.sharded_ddp, bool):
|
||||||
|
self.sharded_ddp = "simple" if self.sharded_ddp else ""
|
||||||
|
if isinstance(self.sharded_ddp, str):
|
||||||
|
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
|
||||||
|
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
|
||||||
|
raise ValueError(
|
||||||
|
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
|
||||||
|
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
|
||||||
|
)
|
||||||
|
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
|
||||||
|
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
|
||||||
|
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
||||||
|
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
||||||
|
|
||||||
|
if isinstance(self.fsdp, bool):
|
||||||
|
self.fsdp = "full_shard" if self.fsdp else ""
|
||||||
|
if isinstance(self.fsdp, str):
|
||||||
|
self.fsdp = [ExtendedFSDPOption(s) for s in self.fsdp.split()]
|
||||||
|
if self.fsdp == [ExtendedFSDPOption.OFFLOAD]:
|
||||||
|
raise ValueError(
|
||||||
|
"`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
|
||||||
|
'`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
|
||||||
|
)
|
||||||
|
elif ExtendedFSDPOption.FULL_SHARD in self.fsdp and ExtendedFSDPOption.SHARD_GRAD_OP in self.fsdp:
|
||||||
|
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
|
||||||
|
|
||||||
|
if self.fsdp_config is None:
|
||||||
|
self.fsdp_config = {}
|
||||||
|
|
||||||
|
if isinstance(self.fsdp_config, str):
|
||||||
|
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
|
||||||
|
self.fsdp_config = json.load(f)
|
||||||
|
|
||||||
|
if self.fsdp_min_num_params > 0:
|
||||||
|
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
|
||||||
|
|
||||||
|
self.fsdp_config["fsdp_min_num_params"] = max(
|
||||||
|
self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
|
||||||
|
if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str):
|
||||||
|
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [
|
||||||
|
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.fsdp_transformer_layer_cls_to_wrap is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
|
||||||
|
)
|
||||||
|
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap", []
|
||||||
|
) + [self.fsdp_transformer_layer_cls_to_wrap]
|
||||||
|
|
||||||
|
if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
|
||||||
|
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
|
||||||
|
|
||||||
|
if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||||
|
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(self.fsdp) > 0
|
||||||
|
and self.fsdp_config["fsdp_min_num_params"] > 0
|
||||||
|
and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
|
||||||
|
)
|
||||||
|
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
|
||||||
|
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
|
||||||
|
if self.fsdp_config["xla"]:
|
||||||
|
if len(self.fsdp) > 0:
|
||||||
|
# store XLA fsdp configuration parameters into a dictionary
|
||||||
|
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {})
|
||||||
|
# apply appropriate string to torch.dtype conversions for parameters
|
||||||
|
if "compute_dtype" in self.xla_fsdp_config:
|
||||||
|
self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
|
||||||
|
if "buffer_dtype" in self.xla_fsdp_config:
|
||||||
|
self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"])
|
||||||
|
else:
|
||||||
|
warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.")
|
||||||
|
else:
|
||||||
|
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||||
|
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
|
||||||
|
|
||||||
|
# accelerate integration for FSDP
|
||||||
|
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
|
||||||
|
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||||
|
from accelerate.utils.constants import (
|
||||||
|
FSDP_AUTO_WRAP_POLICY,
|
||||||
|
FSDP_SHARDING_STRATEGY,
|
||||||
|
)
|
||||||
|
|
||||||
|
for fsdp_option in self.fsdp:
|
||||||
|
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
|
||||||
|
# set environment variable for FSDP sharding strategy
|
||||||
|
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
|
||||||
|
elif fsdp_option == FSDPOption.OFFLOAD:
|
||||||
|
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||||
|
elif fsdp_option == FSDPOption.AUTO_WRAP:
|
||||||
|
if self.fsdp_config["fsdp_min_num_params"] > 0:
|
||||||
|
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
|
||||||
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
|
||||||
|
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||||
|
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
|
||||||
|
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
||||||
|
)
|
||||||
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
||||||
|
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
|
||||||
|
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||||
|
|
||||||
|
if self.tpu_metrics_debug:
|
||||||
|
warnings.warn(
|
||||||
|
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||||
|
" `--debug tpu_metrics_debug` instead",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
if self.debug is None:
|
||||||
|
self.debug = " tpu_metrics_debug"
|
||||||
|
else:
|
||||||
|
self.debug += " tpu_metrics_debug"
|
||||||
|
self.tpu_metrics_debug = False
|
||||||
|
|
||||||
|
if isinstance(self.debug, str):
|
||||||
|
self.debug = [DebugOption(s) for s in self.debug.split()]
|
||||||
|
elif self.debug is None:
|
||||||
|
self.debug = []
|
||||||
|
|
||||||
|
self.deepspeed_plugin = None
|
||||||
|
if self.deepspeed:
|
||||||
|
# - must be run very last in arg parsing, since it will use a lot of these settings.
|
||||||
|
# - must be run before the model is created.
|
||||||
|
if not is_accelerate_available():
|
||||||
|
raise ValueError("--deepspeed requires Accelerate to be installed: `pip install accelerate`.")
|
||||||
|
from transformers.deepspeed import HfTrainerDeepSpeedConfig
|
||||||
|
|
||||||
|
# will be used later by the Trainer
|
||||||
|
# note: leave self.deepspeed unmodified in case a user relies on it not to be modified)
|
||||||
|
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
|
||||||
|
self.hf_deepspeed_config.trainer_config_process(self)
|
||||||
|
|
||||||
|
# Accelerate DeepSpeed Plugin
|
||||||
|
from accelerate.utils import DeepSpeedPlugin
|
||||||
|
|
||||||
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
|
||||||
|
|
||||||
|
if self.push_to_hub_token is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||||
|
"`--hub_token` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.hub_token = self.push_to_hub_token
|
||||||
|
|
||||||
|
if self.push_to_hub_model_id is not None:
|
||||||
|
self.hub_model_id = get_full_repo_name(
|
||||||
|
self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token
|
||||||
|
)
|
||||||
|
if self.push_to_hub_organization is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
|
||||||
|
"version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
|
||||||
|
f"argument (in this case {self.hub_model_id}).",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||||
|
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
|
||||||
|
f"{self.hub_model_id}).",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
elif self.push_to_hub_organization is not None:
|
||||||
|
self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}"
|
||||||
|
warnings.warn(
|
||||||
|
"`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||||
|
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
|
||||||
|
f"{self.hub_model_id}).",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if training args is specified, it will override the one specified in the accelerate config
|
||||||
|
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
|
||||||
|
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
|
||||||
|
if self.fp16:
|
||||||
|
mixed_precision_dtype = "fp16"
|
||||||
|
elif self.bf16:
|
||||||
|
mixed_precision_dtype = "bf16"
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
|
26
train/trainers/stylistic_trainer.py
Normal file
26
train/trainers/stylistic_trainer.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
from .fsdp_trainer import FSDPTrainer
|
||||||
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
from transformers.trainer_utils import unwrap_model
|
||||||
|
|
||||||
|
|
||||||
|
class StylisticTrainer(FSDPTrainer):
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
if self.label_smoother is not None and "labels" in inputs:
|
||||||
|
labels = inputs.pop("labels")
|
||||||
|
else:
|
||||||
|
labels = None
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
if self.args.past_index >= 0:
|
||||||
|
self._past = outputs[self.args.past_index]
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
# FIXME: should support peft
|
||||||
|
model_name = unwrap_model(model)._get_name()
|
||||||
|
|
||||||
|
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||||
|
# loss = self.label_smoother(outputs, labels, shift_labels=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"model {model_name} is not a causal LM")
|
||||||
|
else:
|
||||||
|
raise ValueError("labels should not be None")
|
154
train/trainers/utils.py
Normal file
154
train/trainers/utils.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from torch.utils.checkpoint import check_backward_validity, _get_autocast_kwargs, detach_variable
|
||||||
|
from torch.distributed.fsdp import _state_dict_utils
|
||||||
|
from torch.distributed.fsdp._common_utils import clean_tensor_name
|
||||||
|
from transformers.utils import ExplicitEnum
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendedFSDPOption(ExplicitEnum):
|
||||||
|
FULL_SHARD = "full_shard"
|
||||||
|
SHARD_GRAD_OP = "shard_grad_op"
|
||||||
|
NO_SHARD = "no_shard"
|
||||||
|
|
||||||
|
# extention starts here
|
||||||
|
HYBRID_SHARD = "hybrid_shard"
|
||||||
|
_HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
|
||||||
|
# extention ends here
|
||||||
|
|
||||||
|
OFFLOAD = "offload"
|
||||||
|
AUTO_WRAP = "auto_wrap"
|
||||||
|
|
||||||
|
|
||||||
|
DefaultCheckpointFunction = checkpoint.CheckpointFunction
|
||||||
|
DefaultFullPostStateDictHook = _state_dict_utils._full_post_state_dict_hook
|
||||||
|
|
||||||
|
|
||||||
|
class CompatibleCheckpointFunction(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||||
|
check_backward_validity(args)
|
||||||
|
ctx.run_function = run_function
|
||||||
|
ctx.preserve_rng_state = preserve_rng_state
|
||||||
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||||
|
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
|
||||||
|
|
||||||
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||||
|
# to be filled out during the backward.
|
||||||
|
ctx.inputs = []
|
||||||
|
ctx.tensor_indices = []
|
||||||
|
tensor_inputs = []
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if torch.is_tensor(arg):
|
||||||
|
tensor_inputs.append(arg)
|
||||||
|
ctx.tensor_indices.append(i)
|
||||||
|
ctx.inputs.append(None)
|
||||||
|
else:
|
||||||
|
ctx.inputs.append(arg)
|
||||||
|
|
||||||
|
ctx.save_for_backward(*tensor_inputs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = run_function(*args)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *args):
|
||||||
|
if not torch.autograd._is_checkpoint_valid():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
||||||
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
||||||
|
" argument.")
|
||||||
|
# Copy the list to avoid modifying original list.
|
||||||
|
inputs = list(ctx.inputs)
|
||||||
|
tensor_indices = ctx.tensor_indices
|
||||||
|
tensors = ctx.saved_tensors
|
||||||
|
|
||||||
|
# Fill in inputs with appropriate saved tensors.
|
||||||
|
for i, idx in enumerate(tensor_indices):
|
||||||
|
inputs[idx] = tensors[i]
|
||||||
|
|
||||||
|
# Stash the surrounding rng state, and mimic the state that was
|
||||||
|
# present at this time during forward. Restore the surrounding state
|
||||||
|
# when we're done.
|
||||||
|
rng_devices = []
|
||||||
|
# if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
|
||||||
|
# rng_devices = ctx.fwd_gpu_devices
|
||||||
|
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
|
||||||
|
detached_inputs = detach_variable(tuple(inputs))
|
||||||
|
with torch.enable_grad(), \
|
||||||
|
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
|
||||||
|
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||||
|
outputs = ctx.run_function(*detached_inputs)
|
||||||
|
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
outputs = (outputs,)
|
||||||
|
|
||||||
|
# run backward() with only tensor that requires grad
|
||||||
|
outputs_with_grad = []
|
||||||
|
args_with_grad = []
|
||||||
|
for i in range(len(outputs)):
|
||||||
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||||
|
outputs_with_grad.append(outputs[i])
|
||||||
|
args_with_grad.append(args[i])
|
||||||
|
if len(outputs_with_grad) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"none of output has requires_grad=True,"
|
||||||
|
" this checkpoint() is not necessary")
|
||||||
|
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||||
|
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
||||||
|
for inp in detached_inputs)
|
||||||
|
|
||||||
|
return (None, None) + grads
|
||||||
|
|
||||||
|
|
||||||
|
def low_gpu_full_post_state_dict_hook(module, fsdp_state, state_dict, prefix):
|
||||||
|
|
||||||
|
def param_hook(state_dict, prefix, fqn):
|
||||||
|
clean_key = fqn
|
||||||
|
clean_prefix = clean_tensor_name(prefix)
|
||||||
|
# Strip prefix out of key if needed as buffer names and param names
|
||||||
|
# do not have prefix considered as they are not computed in `state_dict`
|
||||||
|
# call.
|
||||||
|
if clean_key.startswith(clean_prefix):
|
||||||
|
clean_key = clean_key[len(clean_prefix) :]
|
||||||
|
|
||||||
|
# Clone parameters before exiting the `_unshard_fsdp_state_params()` context.
|
||||||
|
if not getattr(state_dict[fqn], "_has_been_cloned", False):
|
||||||
|
try:
|
||||||
|
state_dict[fqn] = state_dict[fqn].cpu().clone().detach()
|
||||||
|
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
|
||||||
|
except BaseException as e:
|
||||||
|
warnings.warn(
|
||||||
|
f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
|
||||||
|
"This may mean that this state_dict entry could point to invalid "
|
||||||
|
"memory regions after returning from state_dict() call if this "
|
||||||
|
"parameter is managed by FSDP. Please check clone "
|
||||||
|
f"implementation of {fqn}. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return _state_dict_utils._common_unshard_post_state_dict_hook(
|
||||||
|
module, fsdp_state, state_dict, prefix, param_hook
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# enable to efficiently saving `state_dict` for fsdp
|
||||||
|
def enable_low_gpu_full_post_state_dict_hook():
|
||||||
|
_state_dict_utils._full_post_state_dict_hook = low_gpu_full_post_state_dict_hook
|
||||||
|
|
||||||
|
def disable_low_gpu_full_post_state_dict_hook():
|
||||||
|
_state_dict_utils._full_post_state_dict_hook = DefaultFullPostStateDictHook
|
||||||
|
|
||||||
|
|
||||||
|
# enable to make `torch.compile` work
|
||||||
|
def enable_compatible_checkpoint_function():
|
||||||
|
checkpoint.CheckpointFunction = CompatibleCheckpointFunction
|
||||||
|
|
||||||
|
def disable_compatible_checkpoint_function():
|
||||||
|
checkpoint.CheckpointFunction = DefaultCheckpointFunction
|
||||||
|
|
0
train/utils/__init__.py
Normal file
0
train/utils/__init__.py
Normal file
0
train/utils/datasets/__init__.py
Normal file
0
train/utils/datasets/__init__.py
Normal file
141
train/utils/datasets/nyt10_dataset.py
Normal file
141
train/utils/datasets/nyt10_dataset.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
|
||||||
|
|
||||||
|
prompt_template = (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:"
|
||||||
|
)
|
||||||
|
|
||||||
|
# output_template = [
|
||||||
|
# "```json\n",
|
||||||
|
# "{\n",
|
||||||
|
# " \"person\": \"",
|
||||||
|
# "sample_person",
|
||||||
|
# "\",\n",
|
||||||
|
# " \"nationality\": \"",
|
||||||
|
# "sample_nationality",
|
||||||
|
# "\"\n",
|
||||||
|
# "}\n",
|
||||||
|
# "```",
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# person_index = 3
|
||||||
|
# nationality_index = 6
|
||||||
|
|
||||||
|
output_template = [
|
||||||
|
"```json\n{\n \"person\": \"",
|
||||||
|
"sample_person",
|
||||||
|
"\",\n \"nationality\": \"",
|
||||||
|
"sample_nationality",
|
||||||
|
"\"\n}\n```",
|
||||||
|
]
|
||||||
|
|
||||||
|
person_index = 1
|
||||||
|
nationality_index = 3
|
||||||
|
|
||||||
|
|
||||||
|
class NYT10Dataset(Dataset):
|
||||||
|
def __init__(self, data_path: str, tokenizer, size: int = -1):
|
||||||
|
with open(data_path, 'r') as f:
|
||||||
|
self.ann = [json.loads(line) for line in f.readlines()]
|
||||||
|
# only use "/people/person/nationality"
|
||||||
|
self.ann = [
|
||||||
|
{
|
||||||
|
"text": dp["text"],
|
||||||
|
"person": dp["h"]["name"],
|
||||||
|
"nationality": dp["t"]["name"],
|
||||||
|
} for dp in self.ann if '/people/person/nationality' in dp['relation']
|
||||||
|
]
|
||||||
|
|
||||||
|
random.shuffle(self.ann)
|
||||||
|
|
||||||
|
if size > 0:
|
||||||
|
self.ann = self.ann[:size]
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.ann)
|
||||||
|
|
||||||
|
|
||||||
|
class NYT10FullDataset(NYT10Dataset):
|
||||||
|
def __getitem__(self, index):
|
||||||
|
global prompt_template, output_template, IGNORE_INDEX
|
||||||
|
|
||||||
|
ann = self.ann[index]
|
||||||
|
prompt = prompt_template.format(text=ann["text"])
|
||||||
|
output = copy.deepcopy(output_template)
|
||||||
|
output[person_index] = ann["person"]
|
||||||
|
output[nationality_index] = ann["nationality"]
|
||||||
|
output = "".join(output)
|
||||||
|
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
output_ids = [self.tokenizer.bos_token_id] + self.tokenizer.encode(output, add_special_tokens=False) + [self.tokenizer.eos_token_id]
|
||||||
|
example = torch.tensor(
|
||||||
|
prompt_ids + output_ids, dtype=torch.int64
|
||||||
|
)
|
||||||
|
labels = copy.deepcopy(example)
|
||||||
|
labels[:len(prompt_ids)] = -1
|
||||||
|
example_mask = example.ge(0)
|
||||||
|
label_mask = labels.ge(0)
|
||||||
|
example[~example_mask] = 0
|
||||||
|
labels[~label_mask] = IGNORE_INDEX
|
||||||
|
|
||||||
|
assert len(example) == len(labels)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": example.tolist(),
|
||||||
|
"labels": labels.tolist(),
|
||||||
|
"attention_mask":example_mask.tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NYT10StylishDataset(NYT10Dataset):
|
||||||
|
def __getitem__(self, index):
|
||||||
|
global prompt_template, output_template, IGNORE_INDEX
|
||||||
|
|
||||||
|
ann = self.ann[index]
|
||||||
|
prompt = prompt_template.format(text=ann["text"])
|
||||||
|
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
|
||||||
|
example = copy.deepcopy(prompt_ids) + [self.tokenizer.bos_token_id]
|
||||||
|
# prompt part is masked
|
||||||
|
labels = [-1] * len(prompt_ids) + [self.tokenizer.bos_token_id]
|
||||||
|
|
||||||
|
for idx, s in enumerate(output_template):
|
||||||
|
# person and nationality are masked
|
||||||
|
if idx == person_index or idx == nationality_index:
|
||||||
|
tokens = self.tokenizer.encode(ann["person"] if idx == person_index else ann["nationality"], add_special_tokens=False)
|
||||||
|
example.extend(tokens)
|
||||||
|
labels.extend([-1] * len(tokens))
|
||||||
|
else:
|
||||||
|
tokens = self.tokenizer.encode(s, add_special_tokens=False)
|
||||||
|
example.extend(tokens)
|
||||||
|
labels.extend(tokens)
|
||||||
|
example.append(self.tokenizer.eos_token_id)
|
||||||
|
example = torch.tensor(
|
||||||
|
example, dtype=torch.int64
|
||||||
|
)
|
||||||
|
labels.append(self.tokenizer.eos_token_id)
|
||||||
|
labels = torch.tensor(
|
||||||
|
labels, dtype=torch.int64
|
||||||
|
)
|
||||||
|
example_mask = example.ge(0)
|
||||||
|
label_mask = labels.ge(0)
|
||||||
|
example[~example_mask] = 0
|
||||||
|
labels[~label_mask] = IGNORE_INDEX
|
||||||
|
|
||||||
|
assert len(example) == len(labels)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": example.tolist(),
|
||||||
|
"labels": labels.tolist(),
|
||||||
|
"attention_mask":example_mask.tolist(),
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user