import random import sys import os os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4,5,6' import fire import torch torch.autograd.set_detect_anomaly(True) import transformers from transformers import set_seed set_seed(15) from utils.datasets.nyt10_dataset import NYT10FullDataset, NYT10StylishDataset from trainers import FSDPTrainingArguments, FSDPTrainer from transformers import AutoTokenizer, AutoConfig from transformers import AutoModelForCausalLM, LlamaForCausalLM from llama import Method_1 def train( # model/data params base_model: str = '/home/tushilong/hf/models/Llama-2-7b-hf', data_path: str = '../data/nyt10/nyt10_train.txt', output_dir: str = '../ckpts/stylish', # training hyperparams do_train: bool = True, micro_batch_size: int = 2, gradient_accumulation_steps: int = 1, gradient_checkpointing: bool = True, num_epochs: int = 1, save_steps: int = 500, learning_rate: float = 2e-5, lr_scheduler_type: str = 'cosine', weight_decay: float = 1e-4, warmup_ratio: float = 0.06, deepspeed_config: str = None, fsdp: str = 'shard_grad_op auto_wrap offload', fsdp_config: str = './configs/fsdp/llama2_fsdp_config.json', smart_embedding: bool = False, # evaluating hyperparams do_eval: bool = False, val_set_size: int = 1000, eval_batch_size: int = 4, # wandb params wandb_project: str = "", wandb_run_name: str = "", wandb_watch: str = "", # options: false | gradients | all wandb_log_model: str = "", # options: false | true resume_from_checkpoint: str = None, # either training checkpoint or final adapter ): if int(os.environ.get("LOCAL_RANK", 0)) == 0: print( f"Training model with params:\n" f"base_model: {base_model}\n" f"output_dir: {output_dir}\n" f"micro_batch_size: {micro_batch_size}\n" f"gradient_accumulation_steps: {gradient_accumulation_steps}\n" f"train_batch_size: {micro_batch_size * gradient_accumulation_steps}\n" f"gradient_checkpointing: {gradient_checkpointing}\n" f"num_epochs: {num_epochs}\n" f"learning_rate: {learning_rate}\n" f"weight_decay: {weight_decay}\n" f"warmup_ratio: {warmup_ratio}\n" f"deepspeed_config: {deepspeed_config}\n" f"fsdp: {fsdp}\n" f"fsdp_config: {fsdp_config}\n" f"smart_embedding: {smart_embedding}\n" f"wandb_project: {wandb_project}\n" f"wandb_run_name: {wandb_run_name}\n" f"wandb_watch: {wandb_watch}\n" f"wandb_log_model: {wandb_log_model}\n" f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" ) assert ( not (deepspeed_config and fsdp) ), "Can not specified both deepspeed_config and fsdp_config" # training arguments bf16 = True # torch.cuda.get_device_capability()[0] >= 8 fp16 = not bf16 tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) tokenizer.pad_token_id = tokenizer.eos_token_id # model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True) # model_dev_id = int(os.environ.get("LOCAL_RANK", 0)) model = Method_1.from_pretrained(base_model) # Check if parameter passed or if set within environ # use_wandb = len(wandb_project) > 0 or ( # "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 # ) use_wandb = False # Only overwrite environ if wandb param passed # if len(wandb_project) > 0: # os.environ["WANDB_PROJECT"] = wandb_project # if len(wandb_watch) > 0: # os.environ["WANDB_WATCH"] = wandb_watch # if len(wandb_log_model) > 0: # os.environ["WANDB_LOG_MODEL"] = wandb_log_model # train_data = NYT10FullDataset(data_path, tokenizer) # train_data = NYT10StylishDataset(data_path, tokenizer, 30) train_data = NYT10StylishDataset(data_path, tokenizer, 1000) val_data = None training_args = FSDPTrainingArguments( use_ffd_sampler=True, output_dir=output_dir, no_cuda=not torch.cuda.is_available(), seed=15, data_seed=15, do_train=do_train, num_train_epochs=num_epochs, optim="adamw_torch", learning_rate=learning_rate, lr_scheduler_type=lr_scheduler_type, per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warmup_ratio=warmup_ratio, weight_decay=weight_decay, half_precision_backend="auto", fp16=fp16, bf16=bf16, adam_beta1=0.9, adam_beta2=0.95, save_strategy="steps", save_steps=save_steps, save_total_limit=2, logging_steps=1, report_to= "none", # "wandb" if use_wandb else None, run_name=None, #wandb_run_name if use_wandb else None, deepspeed=deepspeed_config, fsdp=fsdp, fsdp_config=fsdp_config, gradient_checkpointing=gradient_checkpointing, do_eval=do_eval and val_set_size > 0, evaluation_strategy="steps" if do_eval and val_set_size > 0 else "no", eval_steps=save_steps, per_device_eval_batch_size=eval_batch_size, # group_by_length=True, ) trainer = FSDPTrainer( model=model, args=training_args, train_dataset=train_data, eval_dataset=val_data, data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors='pt', padding=True, ) ) trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) if __name__ == "__main__": fire.Fire(train)