realign/train/train.py
2024-03-09 10:55:34 +08:00

169 lines
5.8 KiB
Python

import random
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4,5,6'
import fire
import torch
torch.autograd.set_detect_anomaly(True)
import transformers
from transformers import set_seed
set_seed(15)
from utils.datasets.nyt10_dataset import NYT10FullDataset, NYT10StylishDataset
from trainers import FSDPTrainingArguments, FSDPTrainer
from transformers import AutoTokenizer, AutoConfig
from transformers import AutoModelForCausalLM, LlamaForCausalLM
from llama import Method_1
def train(
# model/data params
base_model: str = '/home/tushilong/hf/models/Llama-2-7b-hf',
data_path: str = '../data/nyt10/nyt10_train.txt',
output_dir: str = '../ckpts/stylish',
# training hyperparams
do_train: bool = True,
micro_batch_size: int = 2,
gradient_accumulation_steps: int = 1,
gradient_checkpointing: bool = True,
num_epochs: int = 1,
save_steps: int = 500,
learning_rate: float = 2e-5,
lr_scheduler_type: str = 'cosine',
weight_decay: float = 1e-4,
warmup_ratio: float = 0.06,
deepspeed_config: str = None,
fsdp: str = 'shard_grad_op auto_wrap offload',
fsdp_config: str = './configs/fsdp/llama2_fsdp_config.json',
smart_embedding: bool = False,
# evaluating hyperparams
do_eval: bool = False,
val_set_size: int = 1000,
eval_batch_size: int = 4,
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
):
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training model with params:\n"
f"base_model: {base_model}\n"
f"output_dir: {output_dir}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"gradient_accumulation_steps: {gradient_accumulation_steps}\n"
f"train_batch_size: {micro_batch_size * gradient_accumulation_steps}\n"
f"gradient_checkpointing: {gradient_checkpointing}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"weight_decay: {weight_decay}\n"
f"warmup_ratio: {warmup_ratio}\n"
f"deepspeed_config: {deepspeed_config}\n"
f"fsdp: {fsdp}\n"
f"fsdp_config: {fsdp_config}\n"
f"smart_embedding: {smart_embedding}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
)
assert (
not (deepspeed_config and fsdp)
), "Can not specified both deepspeed_config and fsdp_config"
# training arguments
bf16 = True # torch.cuda.get_device_capability()[0] >= 8
fp16 = not bf16
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
# model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True)
# model_dev_id = int(os.environ.get("LOCAL_RANK", 0))
model = Method_1.from_pretrained(base_model)
# Check if parameter passed or if set within environ
# use_wandb = len(wandb_project) > 0 or (
# "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
# )
use_wandb = False
# Only overwrite environ if wandb param passed
# if len(wandb_project) > 0:
# os.environ["WANDB_PROJECT"] = wandb_project
# if len(wandb_watch) > 0:
# os.environ["WANDB_WATCH"] = wandb_watch
# if len(wandb_log_model) > 0:
# os.environ["WANDB_LOG_MODEL"] = wandb_log_model
# train_data = NYT10FullDataset(data_path, tokenizer)
# train_data = NYT10StylishDataset(data_path, tokenizer, 30)
train_data = NYT10StylishDataset(data_path, tokenizer, 1000)
val_data = None
training_args = FSDPTrainingArguments(
use_ffd_sampler=True,
output_dir=output_dir,
no_cuda=not torch.cuda.is_available(),
seed=15,
data_seed=15,
do_train=do_train,
num_train_epochs=num_epochs,
optim="adamw_torch",
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_ratio=warmup_ratio,
weight_decay=weight_decay,
half_precision_backend="auto",
fp16=fp16,
bf16=bf16,
adam_beta1=0.9,
adam_beta2=0.95,
save_strategy="steps",
save_steps=save_steps,
save_total_limit=2,
logging_steps=1,
report_to= "none", # "wandb" if use_wandb else None,
run_name=None, #wandb_run_name if use_wandb else None,
deepspeed=deepspeed_config,
fsdp=fsdp,
fsdp_config=fsdp_config,
gradient_checkpointing=gradient_checkpointing,
do_eval=do_eval and val_set_size > 0,
evaluation_strategy="steps" if do_eval and val_set_size > 0 else "no",
eval_steps=save_steps,
per_device_eval_batch_size=eval_batch_size,
# group_by_length=True,
)
trainer = FSDPTrainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors='pt', padding=True,
)
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
if __name__ == "__main__":
fire.Fire(train)