init🎉:
This commit is contained in:
		
							
								
								
									
										0
									
								
								train/configs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								train/configs/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										11
									
								
								train/configs/finetune_arguments.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								train/configs/finetune_arguments.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,11 @@ | ||||
| from dataclasses import dataclass, field | ||||
|  | ||||
|  | ||||
| ### 定义一些配置信息 | ||||
| @dataclass | ||||
| class FinetuneArguments: | ||||
|     model_name: str = field() | ||||
|     data_path: str = field() | ||||
|     train_size: int = field(default=-1) | ||||
|     test_size: int = field(default=100) | ||||
|     max_len: int = field(default=1024) | ||||
							
								
								
									
										4
									
								
								train/configs/fsdp/internlm_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								train/configs/fsdp/internlm_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | ||||
| { | ||||
|     "fsdp_transformer_layer_cls_to_wrap": ["InternLMDecoderLayer"], | ||||
|     "limit_all_gathers": true | ||||
| } | ||||
							
								
								
									
										4
									
								
								train/configs/fsdp/llama2_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								train/configs/fsdp/llama2_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | ||||
| { | ||||
|     "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"], | ||||
|     "limit_all_gathers": true | ||||
| } | ||||
							
								
								
									
										4
									
								
								train/configs/fsdp/qwen_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								train/configs/fsdp/qwen_fsdp_config.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | ||||
| { | ||||
|     "fsdp_transformer_layer_cls_to_wrap": ["QWenBlock"], | ||||
|     "limit_all_gathers": true | ||||
| } | ||||
							
								
								
									
										26
									
								
								train/configs/logger_config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								train/configs/logger_config.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | ||||
| logger_config = { | ||||
|     'version': 1, | ||||
|     'formatters': { | ||||
|         'simple': { | ||||
|             'format': f"%(asctime)s %(name)s %(levelname)s: %(message)s", | ||||
|             'datefmt': '%Y-%m-%d %H:%M:%S', | ||||
|         }, | ||||
|         # 其他的 formatter | ||||
|     }, | ||||
|     'handlers': { | ||||
|         'console': { | ||||
|             'class': 'logging.StreamHandler', | ||||
|             'level': 'DEBUG', | ||||
|             'formatter': 'simple', | ||||
|         }, | ||||
|         # 其他的 handler | ||||
|     }, | ||||
|     'loggers':{ | ||||
|         # 仅输出到控制台,使用 StreamLogger | ||||
|         'StreamLogger': { | ||||
|             'handlers': ['console'], | ||||
|             'level': 'DEBUG', | ||||
|         }, | ||||
|         # 其他的 Logger | ||||
|     } | ||||
| } | ||||
							
								
								
									
										136
									
								
								train/run_log.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								train/run_log.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,136 @@ | ||||
| WARNING:torch.distributed.run: | ||||
| ***************************************** | ||||
| Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.  | ||||
| ***************************************** | ||||
| Training model with params: | ||||
| base_model: /home/tushilong/hf/models/Llama-2-7b-hf | ||||
| output_dir: ../ckpts/stylish | ||||
| micro_batch_size: 2 | ||||
| gradient_accumulation_steps: 1 | ||||
| train_batch_size: 2 | ||||
| gradient_checkpointing: True | ||||
| num_epochs: 1 | ||||
| learning_rate: 2e-05 | ||||
| weight_decay: 0.0001 | ||||
| warmup_ratio: 0.06 | ||||
| deepspeed_config: None | ||||
| fsdp: shard_grad_op auto_wrap offload | ||||
| fsdp_config: ./configs/fsdp/llama2_fsdp_config.json | ||||
| smart_embedding: False | ||||
| wandb_project:  | ||||
| wandb_run_name:  | ||||
| wandb_watch:  | ||||
| wandb_log_model:  | ||||
| resume_from_checkpoint: False | ||||
|  | ||||
|  | ||||
| Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||
| Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||
| Loading checkpoint shards:  50%|█████     | 1/2 [00:03<00:03,  3.52s/it] | ||||
| Loading checkpoint shards:  50%|█████     | 1/2 [00:03<00:03,  3.51s/it] | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it] | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.40s/it] | ||||
|  | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it] | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.39s/it] | ||||
|  | ||||
| Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||
| Loading checkpoint shards:  50%|█████     | 1/2 [00:03<00:03,  3.45s/it] | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.17s/it] | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.36s/it] | ||||
| Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher. | ||||
| StateDictType.FULL_STATE_DICT FullStateDictConfig(offload_to_cpu=False, rank0_only=False) | ||||
| You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. | ||||
| StateDictType.FULL_STATE_DICT FullStateDictConfig(offload_to_cpu=False, rank0_only=False) | ||||
|  | ||||
|   0%|          | 0/167 [00:00<?, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. | ||||
| StateDictType.FULL_STATE_DICT FullStateDictConfig(offload_to_cpu=False, rank0_only=False) | ||||
| You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. | ||||
|  | ||||
| Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||
|  | ||||
| Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s][A | ||||
| Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] | ||||
| Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  4.90it/s] | ||||
| Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  4.20it/s] | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.97it/s] | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.55it/s] | ||||
|  | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.33it/s] | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.88it/s] | ||||
| `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... | ||||
| `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... | ||||
|  | ||||
|  | ||||
| Loading checkpoint shards:  50%|█████     | 1/2 [00:14<00:14, 14.71s/it][A | ||||
|  | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:19<00:00,  8.65s/it][A | ||||
| Loading checkpoint shards: 100%|██████████| 2/2 [00:19<00:00,  9.55s/it] | ||||
| `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... | ||||
| Traceback (most recent call last): | ||||
|   File "/home/tushilong/code/realign/train/train.py", line 167, in <module> | ||||
|     fire.Fire(train) | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 141, in Fire | ||||
|     component_trace = _Fire(component, args, parsed_flag_args, context, name) | ||||
|                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 475, in _Fire | ||||
|     component, remaining_args = _CallAndUpdateTrace( | ||||
|                                 ^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace | ||||
|     component = fn(*varargs, **kwargs) | ||||
|                 ^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/code/realign/train/train.py", line 160, in train | ||||
|     trainer.train(resume_from_checkpoint=resume_from_checkpoint) | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train | ||||
|     return inner_training_loop( | ||||
|            ^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 1837, in _inner_training_loop | ||||
|     tr_loss_step = self.training_step(model, inputs) | ||||
|                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 2682, in training_step | ||||
|     loss = self.compute_loss(model, inputs) | ||||
|            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/transformers/trainer.py", line 2707, in compute_loss | ||||
|     outputs = model(**inputs) | ||||
|               ^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl | ||||
|     return forward_call(*args, **kwargs) | ||||
|            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 659, in forward | ||||
|     return model_forward(*args, **kwargs) | ||||
|            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 647, in __call__ | ||||
|     return convert_to_fp32(self.model_forward(*args, **kwargs)) | ||||
|                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast | ||||
|     return func(*args, **kwargs) | ||||
|            ^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 659, in forward | ||||
|     return model_forward(*args, **kwargs) | ||||
|            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/accelerate/utils/operations.py", line 647, in __call__ | ||||
|     return convert_to_fp32(self.model_forward(*args, **kwargs)) | ||||
|                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast | ||||
|     return func(*args, **kwargs) | ||||
|            ^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward | ||||
|     output = self._fsdp_wrapped_module(*args, **kwargs) | ||||
|              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl | ||||
|     return forward_call(*args, **kwargs) | ||||
|            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/code/realign/llama/rellama.py", line 129, in forward | ||||
|     assert torch.isnan(target_logits).sum() == 0, f"target_logits has nan: {torch.isnan(target_logits).sum()}" | ||||
|            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
| AssertionError: target_logits has nan: 10752000 | ||||
| WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 20205 closing signal SIGTERM | ||||
| WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 20206 closing signal SIGTERM | ||||
| ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 2 (pid: 20207) of binary: /home/tushilong/anaconda3/envs/realign/bin/python | ||||
| Traceback (most recent call last): | ||||
|   File "/home/tushilong/anaconda3/envs/realign/bin/torchrun", line 33, in <module> | ||||
|     sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')()) | ||||
|              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|   File "/home/tushilong/anaconda3/envs/realign/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper | ||||
|     return f(*args, **kwargs) | ||||
|            ^^^^^^^^^^^^^^^^^^ | ||||
							
								
								
									
										168
									
								
								train/train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								train/train.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,168 @@ | ||||
| import random | ||||
| import sys | ||||
| import os | ||||
| os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4,5,6' | ||||
|  | ||||
| import fire | ||||
| import torch | ||||
| torch.autograd.set_detect_anomaly(True) | ||||
| import transformers | ||||
| from transformers import set_seed | ||||
| set_seed(15) | ||||
| from utils.datasets.nyt10_dataset import NYT10FullDataset, NYT10StylishDataset | ||||
| from trainers import FSDPTrainingArguments, FSDPTrainer | ||||
| from transformers import AutoTokenizer, AutoConfig | ||||
| from transformers import AutoModelForCausalLM, LlamaForCausalLM | ||||
| from llama import Method_1 | ||||
|  | ||||
|  | ||||
| def train( | ||||
|     # model/data params | ||||
|     base_model: str = '/home/tushilong/hf/models/Llama-2-7b-hf', | ||||
|     data_path: str = '../data/nyt10/nyt10_train.txt', | ||||
|     output_dir: str = '../ckpts/stylish', | ||||
|      | ||||
|     # training hyperparams | ||||
|     do_train: bool = True, | ||||
|     micro_batch_size: int = 2, | ||||
|     gradient_accumulation_steps: int = 1, | ||||
|     gradient_checkpointing: bool = True, | ||||
|     num_epochs: int = 1, | ||||
|     save_steps: int = 500, | ||||
|     learning_rate: float = 2e-5, | ||||
|     lr_scheduler_type: str = 'cosine', | ||||
|     weight_decay: float = 1e-4, | ||||
|     warmup_ratio: float = 0.06, | ||||
|     deepspeed_config: str = None, | ||||
|     fsdp: str = 'shard_grad_op auto_wrap offload', | ||||
|     fsdp_config: str = './configs/fsdp/llama2_fsdp_config.json', | ||||
|     smart_embedding: bool = False, | ||||
|  | ||||
|     # evaluating hyperparams | ||||
|     do_eval: bool = False, | ||||
|     val_set_size: int = 1000, | ||||
|     eval_batch_size: int = 4, | ||||
|      | ||||
|     # wandb params | ||||
|     wandb_project: str = "", | ||||
|     wandb_run_name: str = "", | ||||
|     wandb_watch: str = "",  # options: false | gradients | all | ||||
|     wandb_log_model: str = "",  # options: false | true | ||||
|     resume_from_checkpoint: str = None,  # either training checkpoint or final adapter | ||||
| ): | ||||
|     if int(os.environ.get("LOCAL_RANK", 0)) == 0: | ||||
|         print( | ||||
|             f"Training model with params:\n" | ||||
|             f"base_model: {base_model}\n" | ||||
|             f"output_dir: {output_dir}\n" | ||||
|             f"micro_batch_size: {micro_batch_size}\n" | ||||
|             f"gradient_accumulation_steps: {gradient_accumulation_steps}\n" | ||||
|             f"train_batch_size: {micro_batch_size * gradient_accumulation_steps}\n" | ||||
|             f"gradient_checkpointing: {gradient_checkpointing}\n" | ||||
|             f"num_epochs: {num_epochs}\n" | ||||
|             f"learning_rate: {learning_rate}\n" | ||||
|             f"weight_decay: {weight_decay}\n" | ||||
|             f"warmup_ratio: {warmup_ratio}\n" | ||||
|             f"deepspeed_config: {deepspeed_config}\n" | ||||
|             f"fsdp: {fsdp}\n" | ||||
|             f"fsdp_config: {fsdp_config}\n" | ||||
|             f"smart_embedding: {smart_embedding}\n" | ||||
|             f"wandb_project: {wandb_project}\n" | ||||
|             f"wandb_run_name: {wandb_run_name}\n" | ||||
|             f"wandb_watch: {wandb_watch}\n" | ||||
|             f"wandb_log_model: {wandb_log_model}\n" | ||||
|             f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" | ||||
|         ) | ||||
|     assert ( | ||||
|         not (deepspeed_config and fsdp) | ||||
|     ), "Can not specified both deepspeed_config and fsdp_config" | ||||
|      | ||||
|     # training arguments | ||||
|     bf16 = True # torch.cuda.get_device_capability()[0] >= 8 | ||||
|     fp16 = not bf16 | ||||
|      | ||||
|     tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | ||||
|     tokenizer.pad_token_id = tokenizer.eos_token_id | ||||
|  | ||||
|     # model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True) | ||||
|     # model_dev_id = int(os.environ.get("LOCAL_RANK", 0)) | ||||
|  | ||||
|     model = Method_1.from_pretrained(base_model) | ||||
|  | ||||
|  | ||||
|  | ||||
|     # Check if parameter passed or if set within environ | ||||
|     # use_wandb = len(wandb_project) > 0 or ( | ||||
|     #     "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 | ||||
|     # ) | ||||
|     use_wandb = False | ||||
|     # Only overwrite environ if wandb param passed | ||||
|     # if len(wandb_project) > 0: | ||||
|     #     os.environ["WANDB_PROJECT"] = wandb_project | ||||
|     # if len(wandb_watch) > 0: | ||||
|     #     os.environ["WANDB_WATCH"] = wandb_watch | ||||
|     # if len(wandb_log_model) > 0: | ||||
|     #     os.environ["WANDB_LOG_MODEL"] = wandb_log_model | ||||
|      | ||||
|      | ||||
|     # train_data = NYT10FullDataset(data_path, tokenizer) | ||||
|     # train_data = NYT10StylishDataset(data_path, tokenizer, 30) | ||||
|     train_data = NYT10StylishDataset(data_path, tokenizer, 1000) | ||||
|     val_data = None | ||||
|          | ||||
|     training_args = FSDPTrainingArguments( | ||||
|         use_ffd_sampler=True, | ||||
|         output_dir=output_dir, | ||||
|         no_cuda=not torch.cuda.is_available(), | ||||
|         seed=15, | ||||
|         data_seed=15, | ||||
|         do_train=do_train, | ||||
|         num_train_epochs=num_epochs, | ||||
|         optim="adamw_torch", | ||||
|         learning_rate=learning_rate, | ||||
|         lr_scheduler_type=lr_scheduler_type, | ||||
|         per_device_train_batch_size=micro_batch_size, | ||||
|         gradient_accumulation_steps=gradient_accumulation_steps, | ||||
|         warmup_ratio=warmup_ratio, | ||||
|         weight_decay=weight_decay, | ||||
|         half_precision_backend="auto", | ||||
|         fp16=fp16, | ||||
|         bf16=bf16, | ||||
|         adam_beta1=0.9, | ||||
|         adam_beta2=0.95, | ||||
|         save_strategy="steps", | ||||
|         save_steps=save_steps, | ||||
|         save_total_limit=2, | ||||
|         logging_steps=1, | ||||
|         report_to= "none", # "wandb" if use_wandb else None, | ||||
|         run_name=None, #wandb_run_name if use_wandb else None, | ||||
|         deepspeed=deepspeed_config, | ||||
|         fsdp=fsdp, | ||||
|         fsdp_config=fsdp_config, | ||||
|         gradient_checkpointing=gradient_checkpointing, | ||||
|         do_eval=do_eval and val_set_size > 0, | ||||
|         evaluation_strategy="steps" if do_eval and val_set_size > 0 else "no", | ||||
|         eval_steps=save_steps, | ||||
|         per_device_eval_batch_size=eval_batch_size, | ||||
|         # group_by_length=True, | ||||
|     )  | ||||
|      | ||||
|     trainer = FSDPTrainer( | ||||
|         model=model, | ||||
|         args=training_args, | ||||
|         train_dataset=train_data, | ||||
|         eval_dataset=val_data, | ||||
|         data_collator=transformers.DataCollatorForSeq2Seq( | ||||
|             tokenizer, pad_to_multiple_of=8, return_tensors='pt', padding=True, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     trainer.train(resume_from_checkpoint=resume_from_checkpoint) | ||||
|  | ||||
|     trainer.save_model(output_dir) | ||||
|     tokenizer.save_pretrained(output_dir) | ||||
|      | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     fire.Fire(train) | ||||
|      | ||||
							
								
								
									
										2
									
								
								train/trainers/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								train/trainers/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | ||||
| from .fsdp_training_args import FSDPTrainingArguments | ||||
| from .fsdp_trainer import FSDPTrainer | ||||
							
								
								
									
										161
									
								
								train/trainers/ffd_sampler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								train/trainers/ffd_sampler.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,161 @@ | ||||
| from typing import Optional, List | ||||
|  | ||||
| import torch.distributed as dist | ||||
| from torch.utils.data import Sampler | ||||
|  | ||||
| import numpy as np | ||||
| import numba | ||||
|  | ||||
|  | ||||
| @numba.njit | ||||
| def ffd(a: np.ndarray, c: int): | ||||
|     # First-fit-decreasing bin packing | ||||
|     # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing | ||||
|  | ||||
|     a = np.sort(a)[::-1] | ||||
|     bins = [] | ||||
|     for size in a: | ||||
|         add_new = True | ||||
|         for idx in range(len(bins)): | ||||
|             if bins[idx] >= size: | ||||
|                 bins[idx] -= size | ||||
|                 add_new = False | ||||
|                 break | ||||
|  | ||||
|         if add_new: | ||||
|             bins.append(c - size) | ||||
|  | ||||
|     return len(bins) | ||||
|  | ||||
|  | ||||
| @numba.njit | ||||
| def ffd_with_result(a: np.ndarray, c: int, start_index: int): | ||||
|     # First-fit-decreasing bin packing (with result return) | ||||
|  | ||||
|     indices = np.argsort(a)[::-1] | ||||
|     a = a[indices] | ||||
|  | ||||
|     bins = [] | ||||
|     bins_result = [] | ||||
|     for a_id, size in enumerate(a): | ||||
|         add_new = True | ||||
|         for idx in range(len(bins)): | ||||
|             if bins[idx] >= size: | ||||
|                 bins[idx] -= size | ||||
|                 bins_result[idx].append(indices[a_id] + start_index) | ||||
|                 add_new = False | ||||
|                 break | ||||
|  | ||||
|         if add_new: | ||||
|             bins.append(c - size) | ||||
|             bins_result.append([indices[a_id] + start_index]) | ||||
|  | ||||
|     return bins_result | ||||
|  | ||||
|  | ||||
| @numba.njit | ||||
| def allocate(lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int): | ||||
|     # Dynamic batch allocator, similar to Multifit | ||||
|     # https://en.wikipedia.org/wiki/Multifit_algorithm | ||||
|     # ~96.4% efficiency on OpenChat training set (2048 ctx len) | ||||
|  | ||||
|     s = 0 | ||||
|     start_index = 0 | ||||
|     result = [] | ||||
|  | ||||
|     while True: | ||||
|         # binary search [l, r) | ||||
|         l = 1 | ||||
|         r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") | ||||
|  | ||||
|         while r - l > 1: | ||||
|             m = (l + r) // 2 | ||||
|             if ffd(lengths[start_index: start_index + m], c) <= n: | ||||
|                 l = m | ||||
|             else: | ||||
|                 r = m | ||||
|  | ||||
|         # use length l | ||||
|         batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index) | ||||
|         if len(batch) < n: | ||||
|             break | ||||
|  | ||||
|         start_index += l | ||||
|         s = lengths_cumsum[start_index - 1] | ||||
|  | ||||
|         # add local rank | ||||
|         result.append(batch[rank]) | ||||
|  | ||||
|     return result, s / max(1, len(result) * c * n)  # Avoid division by zero | ||||
|  | ||||
|  | ||||
| class FFDDistributedBatchSampler(Sampler): | ||||
|     """Unpadded length sampling using FFD (First-fit-decreasing bin packing). | ||||
|        Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.""" | ||||
|      | ||||
|     def __init__( | ||||
|         self, | ||||
|         batch_max_length: int, | ||||
|         lengths: List[int], | ||||
|         num_replicas: Optional[int] = None, | ||||
|         rank: Optional[int] = None, | ||||
|         seed: int = 0, | ||||
|     ): | ||||
|         # Get rank | ||||
|         if num_replicas is None: | ||||
|             if not dist.is_available(): | ||||
|                 raise RuntimeError("Requires distributed package to be available") | ||||
|             num_replicas = dist.get_world_size() | ||||
|         if rank is None: | ||||
|             if not dist.is_available(): | ||||
|                 raise RuntimeError("Requires distributed package to be available") | ||||
|             rank = dist.get_rank() | ||||
|  | ||||
|         self.num_replicas = num_replicas | ||||
|         self.rank = rank | ||||
|         self.seed = seed | ||||
|  | ||||
|         self.batch_max_length = batch_max_length | ||||
|         self.lengths = lengths | ||||
|         assert isinstance(self.lengths, np.ndarray) | ||||
|  | ||||
|         self.epoch = 0 | ||||
|  | ||||
|         # statistics | ||||
|         self.total_epochs = 0 | ||||
|         self.total_efficiency = 0 | ||||
|  | ||||
|     def set_epoch(self, epoch: int): | ||||
|         self.epoch = epoch | ||||
|  | ||||
|     def generate_batches(self, set_stats=False): | ||||
|         indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths)) | ||||
|  | ||||
|         lengths = self.lengths[indices] | ||||
|         lengths_cumsum = np.cumsum(lengths) | ||||
|  | ||||
|         batches, efficiency = allocate(lengths=lengths, | ||||
|                            lengths_cumsum=lengths_cumsum, | ||||
|                            rank=self.rank, | ||||
|                            c=self.batch_max_length, | ||||
|                            n=self.num_replicas) | ||||
|          | ||||
|         batches = [indices[batch] for batch in batches] | ||||
|  | ||||
|         # statistics | ||||
|         if set_stats: | ||||
|             self.total_epochs += 1 | ||||
|             self.total_efficiency += efficiency | ||||
|          | ||||
|         return batches | ||||
|      | ||||
|     def __iter__(self): | ||||
|         batches = self.generate_batches(set_stats=True) | ||||
|         return iter(batches) | ||||
|  | ||||
|     def __len__(self): | ||||
|         batches = self.generate_batches() | ||||
|         return len(batches) | ||||
|  | ||||
|     def efficiency(self): | ||||
|         return self.total_efficiency / self.total_epochs | ||||
							
								
								
									
										925
									
								
								train/trainers/fsdp_trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										925
									
								
								train/trainers/fsdp_trainer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,925 @@ | ||||
| import sys | ||||
| import os | ||||
| from typing import Optional | ||||
| import torch | ||||
| from torch.utils.data import DataLoader | ||||
|  | ||||
| import transformers | ||||
| from transformers.trainer import * | ||||
|  | ||||
| from .ffd_sampler import FFDDistributedBatchSampler | ||||
| from .utils import ExtendedFSDPOption, enable_low_gpu_full_post_state_dict_hook | ||||
|  | ||||
|  | ||||
| class FSDPTrainer(transformers.Trainer): | ||||
|     def __init__( | ||||
|         self, | ||||
|         model: Union[PreTrainedModel, nn.Module] = None, | ||||
|         args: TrainingArguments = None, | ||||
|         data_collator: Optional[DataCollator] = None, | ||||
|         train_dataset: Optional[Dataset] = None, | ||||
|         eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, | ||||
|         tokenizer: Optional[PreTrainedTokenizerBase] = None, | ||||
|         model_init: Optional[Callable[[], PreTrainedModel]] = None, | ||||
|         compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, | ||||
|         callbacks: Optional[List[TrainerCallback]] = None, | ||||
|         optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), | ||||
|         preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, | ||||
|     ): | ||||
|         if args is None: | ||||
|             output_dir = "tmp_trainer" | ||||
|             logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") | ||||
|             args = TrainingArguments(output_dir=output_dir) | ||||
|         self.args = args | ||||
|         # Seed must be set before instantiating the model when using model | ||||
|         enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) | ||||
|         self.hp_name = None | ||||
|         self.deepspeed = None | ||||
|         self.is_in_train = False | ||||
|  | ||||
|         self.create_accelerator_and_postprocess() | ||||
|  | ||||
|         # memory metrics - must set up as early as possible | ||||
|         self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) | ||||
|         self._memory_tracker.start() | ||||
|  | ||||
|         # set the correct log level depending on the node | ||||
|         log_level = args.get_process_log_level() | ||||
|         logging.set_verbosity(log_level) | ||||
|  | ||||
|         # force device and distributed setup init explicitly | ||||
|         args._setup_devices | ||||
|  | ||||
|         if model is None: | ||||
|             if model_init is not None: | ||||
|                 self.model_init = model_init | ||||
|                 model = self.call_model_init() | ||||
|             else: | ||||
|                 raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") | ||||
|         else: | ||||
|             if model_init is not None: | ||||
|                 warnings.warn( | ||||
|                     "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" | ||||
|                     " overwrite your model when calling the `train` method. This will become a fatal error in the next" | ||||
|                     " release.", | ||||
|                     FutureWarning, | ||||
|                 ) | ||||
|             self.model_init = model_init | ||||
|  | ||||
|         if model.__class__.__name__ in MODEL_MAPPING_NAMES: | ||||
|             raise ValueError( | ||||
|                 f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " | ||||
|                 "computes hidden states and does not accept any labels. You should choose a model with a head " | ||||
|                 "suitable for your task like any of the `AutoModelForXxx` listed at " | ||||
|                 "https://huggingface.co/docs/transformers/model_doc/auto." | ||||
|             ) | ||||
|  | ||||
|         if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: | ||||
|             self.is_model_parallel = True | ||||
|         else: | ||||
|             self.is_model_parallel = False | ||||
|  | ||||
|         if getattr(model, "hf_device_map", None) is not None: | ||||
|             devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] | ||||
|             if len(devices) > 1: | ||||
|                 self.is_model_parallel = True | ||||
|             else: | ||||
|                 self.is_model_parallel = self.args.device != torch.device(devices[0]) | ||||
|  | ||||
|             # warn users | ||||
|             logger.info( | ||||
|                 "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" | ||||
|                 " to `True` to avoid any unexpected behavior such as device placement mismatching." | ||||
|             ) | ||||
|  | ||||
|         # At this stage the model is already loaded | ||||
|         if getattr(model, "is_quantized", False): | ||||
|             if getattr(model, "_is_quantized_training_enabled", False): | ||||
|                 logger.info( | ||||
|                     "The model is loaded in 8-bit precision. To train this model you need to add additional modules" | ||||
|                     " inside the model such as adapters using `peft` library and freeze the model weights. Please" | ||||
|                     " check " | ||||
|                     " the examples in https://github.com/huggingface/peft for more details." | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     "The model you want to train is loaded in 8-bit precision.  if you want to fine-tune an 8-bit" | ||||
|                     " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " | ||||
|                 ) | ||||
|  | ||||
|         # Setup Sharded DDP training | ||||
|         self.sharded_ddp = None | ||||
|         if len(args.sharded_ddp) > 0: | ||||
|             if self.is_deepspeed_enabled: | ||||
|                 raise ValueError( | ||||
|                     "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." | ||||
|                 ) | ||||
|             if len(args.fsdp) > 0: | ||||
|                 raise ValueError( | ||||
|                     "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." | ||||
|                 ) | ||||
|             if args.parallel_mode != ParallelMode.DISTRIBUTED: | ||||
|                 raise ValueError("Using sharded DDP only works in distributed training.") | ||||
|             elif not is_fairscale_available(): | ||||
|                 raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") | ||||
|             elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: | ||||
|                 raise ImportError( | ||||
|                     "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " | ||||
|                     f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." | ||||
|                 ) | ||||
|             elif ShardedDDPOption.SIMPLE in args.sharded_ddp: | ||||
|                 self.sharded_ddp = ShardedDDPOption.SIMPLE | ||||
|             elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: | ||||
|                 self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 | ||||
|             elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: | ||||
|                 self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 | ||||
|  | ||||
|         self.fsdp = None | ||||
|         if len(args.fsdp) > 0: | ||||
|             if self.is_deepspeed_enabled: | ||||
|                 raise ValueError( | ||||
|                     "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." | ||||
|                 ) | ||||
|             if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: | ||||
|                 raise ValueError("Using fsdp only works in distributed training.") | ||||
|  | ||||
|             # dep_version_check("torch>=1.12.0") | ||||
|             # Would have to update setup.py with torch>=1.12.0 | ||||
|             # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 | ||||
|             # below is the current alternative. | ||||
|             if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): | ||||
|                 raise ValueError("FSDP requires PyTorch >= 1.12.0") | ||||
|  | ||||
|             from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy | ||||
|  | ||||
|             if ExtendedFSDPOption.FULL_SHARD in args.fsdp: | ||||
|                 self.fsdp = ShardingStrategy.FULL_SHARD | ||||
|             elif ExtendedFSDPOption.SHARD_GRAD_OP in args.fsdp: | ||||
|                 self.fsdp = ShardingStrategy.SHARD_GRAD_OP | ||||
|             elif ExtendedFSDPOption.NO_SHARD in args.fsdp: | ||||
|                 self.fsdp = ShardingStrategy.NO_SHARD | ||||
|             # extention starts here | ||||
|             elif ExtendedFSDPOption.HYBRID_SHARD in args.fsdp: | ||||
|                 self.fsdp = ShardingStrategy.HYBRID_SHARD | ||||
|             elif ExtendedFSDPOption._HYBRID_SHARD_ZERO2 in args.fsdp: | ||||
|                 self.fsdp = ShardingStrategy._HYBRID_SHARD_ZERO2 | ||||
|             # extention ends here | ||||
|  | ||||
|             self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE | ||||
|             # modification starts here | ||||
|             if self.args.fsdp_config.get("fsdp_backward_prefetch", "") == "backward_post": | ||||
|                 self.backward_prefetch = BackwardPrefetch.BACKWARD_POST | ||||
|             # modification ends here | ||||
|  | ||||
|             self.forward_prefetch = False | ||||
|             # modification starts here | ||||
|             if self.args.fsdp_config.get("forward_prefetch", False): | ||||
|             # modification ends here | ||||
|                 self.forward_prefetch = True | ||||
|                  | ||||
|             self.limit_all_gathers = False | ||||
|             if self.args.fsdp_config.get("limit_all_gathers", False): | ||||
|                 self.limit_all_gathers = True | ||||
|  | ||||
|         # one place to sort out whether to place the model on device or not | ||||
|         # postpone switching model to cuda when: | ||||
|         # 1. MP - since we are trying to fit a much bigger than 1 gpu model | ||||
|         # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, | ||||
|         #    and we only use deepspeed for training at the moment | ||||
|         # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first | ||||
|         # 4. Sharded DDP - same as MP | ||||
|         # 5. FSDP - same as MP | ||||
|         self.place_model_on_device = args.place_model_on_device | ||||
|         if ( | ||||
|             self.is_model_parallel | ||||
|             or self.is_deepspeed_enabled | ||||
|             or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) | ||||
|             or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) | ||||
|             or (self.fsdp is not None) | ||||
|             or self.is_fsdp_enabled | ||||
|         ): | ||||
|             self.place_model_on_device = False | ||||
|  | ||||
|         default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) | ||||
|         self.data_collator = data_collator if data_collator is not None else default_collator | ||||
|         self.train_dataset = train_dataset | ||||
|         self.eval_dataset = eval_dataset | ||||
|         self.tokenizer = tokenizer | ||||
|  | ||||
|         # Quantized models doesn't support `.to` operation. | ||||
|         if self.place_model_on_device and not getattr(model, "is_quantized", False): | ||||
|             self._move_model_to_device(model, args.device) | ||||
|  | ||||
|         # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs | ||||
|         if self.is_model_parallel: | ||||
|             self.args._n_gpu = 1 | ||||
|  | ||||
|         # later use `self.model is self.model_wrapped` to check if it's wrapped or not | ||||
|         self.model_wrapped = model | ||||
|         self.model = model | ||||
|  | ||||
|         self.compute_metrics = compute_metrics | ||||
|         self.preprocess_logits_for_metrics = preprocess_logits_for_metrics | ||||
|         self.optimizer, self.lr_scheduler = optimizers | ||||
|         if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): | ||||
|             raise RuntimeError( | ||||
|                 "Passing a `model_init` is incompatible with providing the `optimizers` argument. " | ||||
|                 "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." | ||||
|             ) | ||||
|         if is_torch_tpu_available() and self.optimizer is not None: | ||||
|             for param in self.model.parameters(): | ||||
|                 model_device = param.device | ||||
|                 break | ||||
|             for param_group in self.optimizer.param_groups: | ||||
|                 if len(param_group["params"]) > 0: | ||||
|                     optimizer_device = param_group["params"][0].device | ||||
|                     break | ||||
|             if model_device != optimizer_device: | ||||
|                 raise ValueError( | ||||
|                     "The model and the optimizer parameters are not on the same device, which probably means you" | ||||
|                     " created an optimizer around your model **before** putting on the device and passing it to the" | ||||
|                     " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" | ||||
|                     " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." | ||||
|                 ) | ||||
|         if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( | ||||
|             self.optimizer is not None or self.lr_scheduler is not None | ||||
|         ): | ||||
|             raise RuntimeError( | ||||
|                 "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." | ||||
|                 "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." | ||||
|             ) | ||||
|         default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) | ||||
|         callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks | ||||
|         self.callback_handler = CallbackHandler( | ||||
|             callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler | ||||
|         ) | ||||
|         self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) | ||||
|  | ||||
|         # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. | ||||
|         self._loggers_initialized = False | ||||
|  | ||||
|         # Create clone of distant repo and output directory if needed | ||||
|         if self.args.push_to_hub: | ||||
|             self.init_git_repo(at_init=True) | ||||
|             # In case of pull, we need to make sure every process has the latest. | ||||
|             if is_torch_tpu_available(): | ||||
|                 xm.rendezvous("init git repo") | ||||
|             elif args.parallel_mode == ParallelMode.DISTRIBUTED: | ||||
|                 dist.barrier() | ||||
|  | ||||
|         if self.args.should_save: | ||||
|             os.makedirs(self.args.output_dir, exist_ok=True) | ||||
|  | ||||
|         if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): | ||||
|             raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") | ||||
|  | ||||
|         if args.max_steps > 0: | ||||
|             logger.info("max_steps is given, it will override any value given in num_train_epochs") | ||||
|  | ||||
|         if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: | ||||
|             raise ValueError( | ||||
|                 "The train_dataset does not implement __len__, max_steps has to be specified. " | ||||
|                 "The number of steps needs to be known in advance for the learning rate scheduler." | ||||
|             ) | ||||
|  | ||||
|         if ( | ||||
|             train_dataset is not None | ||||
|             and isinstance(train_dataset, torch.utils.data.IterableDataset) | ||||
|             and args.group_by_length | ||||
|         ): | ||||
|             raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") | ||||
|  | ||||
|         self._signature_columns = None | ||||
|  | ||||
|         # Mixed precision setup | ||||
|         self.use_apex = False | ||||
|         self.use_cuda_amp = False | ||||
|         self.use_cpu_amp = False | ||||
|  | ||||
|         # Mixed precision setup for SageMaker Model Parallel | ||||
|         if is_sagemaker_mp_enabled(): | ||||
|             # BF16 + model parallelism in SageMaker: currently not supported, raise an error | ||||
|             if args.bf16: | ||||
|                 raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") | ||||
|  | ||||
|             if IS_SAGEMAKER_MP_POST_1_10: | ||||
|                 # When there's mismatch between SMP config and trainer argument, use SMP config as truth | ||||
|                 if args.fp16 != smp.state.cfg.fp16: | ||||
|                     logger.warning( | ||||
|                         f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," | ||||
|                         f"but FP16 provided in trainer argument is {args.fp16}," | ||||
|                         f"setting to {smp.state.cfg.fp16}" | ||||
|                     ) | ||||
|                     args.fp16 = smp.state.cfg.fp16 | ||||
|             else: | ||||
|                 # smp < 1.10 does not support fp16 in trainer. | ||||
|                 if hasattr(smp.state.cfg, "fp16"): | ||||
|                     logger.warning( | ||||
|                         f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " | ||||
|                         "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." | ||||
|                     ) | ||||
|  | ||||
|         if (args.fp16 or args.bf16) and self.sharded_ddp is not None: | ||||
|             if args.half_precision_backend == "auto": | ||||
|                 if args.device == torch.device("cpu"): | ||||
|                     if args.fp16: | ||||
|                         raise ValueError("Tried to use `fp16` but it is not supported on cpu") | ||||
|                     else: | ||||
|                         args.half_precision_backend = "cpu_amp" | ||||
|                 else: | ||||
|                     args.half_precision_backend = "cuda_amp" | ||||
|  | ||||
|             logger.info(f"Using {args.half_precision_backend} half precision backend") | ||||
|  | ||||
|         self.do_grad_scaling = False | ||||
|         if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): | ||||
|             # deepspeed and SageMaker Model Parallel manage their own half precision | ||||
|             if self.sharded_ddp is not None: | ||||
|                 if args.half_precision_backend == "cuda_amp": | ||||
|                     self.use_cuda_amp = True | ||||
|                     self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 | ||||
|                     #  bf16 does not need grad scaling | ||||
|                     self.do_grad_scaling = self.amp_dtype == torch.float16 | ||||
|                     if self.do_grad_scaling: | ||||
|                         if self.sharded_ddp is not None: | ||||
|                             self.scaler = ShardedGradScaler() | ||||
|                         elif self.fsdp is not None: | ||||
|                             from torch.distributed.fsdp.sharded_grad_scaler import ( | ||||
|                                 ShardedGradScaler as FSDPShardedGradScaler, | ||||
|                             ) | ||||
|  | ||||
|                             self.scaler = FSDPShardedGradScaler() | ||||
|                         elif is_torch_tpu_available(): | ||||
|                             from torch_xla.amp import GradScaler | ||||
|  | ||||
|                             self.scaler = GradScaler() | ||||
|                         else: | ||||
|                             self.scaler = torch.cuda.amp.GradScaler() | ||||
|                 elif args.half_precision_backend == "cpu_amp": | ||||
|                     self.use_cpu_amp = True | ||||
|                     self.amp_dtype = torch.bfloat16 | ||||
|             elif args.half_precision_backend == "apex": | ||||
|                 if not is_apex_available(): | ||||
|                     raise ImportError( | ||||
|                         "Using FP16 with APEX but APEX is not installed, please refer to" | ||||
|                         " https://www.github.com/nvidia/apex." | ||||
|                     ) | ||||
|                 self.use_apex = True | ||||
|  | ||||
|         # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. | ||||
|         if ( | ||||
|             is_sagemaker_mp_enabled() | ||||
|             and self.use_cuda_amp | ||||
|             and args.max_grad_norm is not None | ||||
|             and args.max_grad_norm > 0 | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " | ||||
|                 "along 'max_grad_norm': 0 in your hyperparameters." | ||||
|             ) | ||||
|  | ||||
|         # Label smoothing | ||||
|         if self.args.label_smoothing_factor != 0: | ||||
|             self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) | ||||
|         else: | ||||
|             self.label_smoother = None | ||||
|  | ||||
|         self.state = TrainerState( | ||||
|             is_local_process_zero=self.is_local_process_zero(), | ||||
|             is_world_process_zero=self.is_world_process_zero(), | ||||
|         ) | ||||
|  | ||||
|         self.control = TrainerControl() | ||||
|         # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then | ||||
|         # returned to 0 every time flos need to be logged | ||||
|         self.current_flos = 0 | ||||
|         self.hp_search_backend = None | ||||
|         self.use_tune_checkpoints = False | ||||
|         default_label_names = find_labels(self.model.__class__) | ||||
|         self.label_names = default_label_names if self.args.label_names is None else self.args.label_names | ||||
|         self.can_return_loss = can_return_loss(self.model.__class__) | ||||
|         self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) | ||||
|  | ||||
|         # Internal variables to help with automatic batch size reduction | ||||
|         self._train_batch_size = args.train_batch_size | ||||
|         self._created_lr_scheduler = False | ||||
|  | ||||
|         # very last | ||||
|         self._memory_tracker.stop_and_update_metrics() | ||||
|  | ||||
|         # torch.compile | ||||
|         if args.torch_compile and not is_torch_compile_available(): | ||||
|             raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") | ||||
|          | ||||
|         # finally applying `low_gpu_full_post_state_dict_hook`` for fsdp `state_dict` | ||||
|         enable_low_gpu_full_post_state_dict_hook() | ||||
|          | ||||
|     def _wrap_model(self, model, training=True, dataloader=None): | ||||
|         if self.args.use_ipex: | ||||
|             dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 | ||||
|             model = self.ipex_optimize_model(model, training, dtype=dtype) | ||||
|  | ||||
|         if is_sagemaker_mp_enabled(): | ||||
|             # Wrapping the base model twice in a DistributedModel will raise an error. | ||||
|             if isinstance(self.model_wrapped, smp.model.DistributedModel): | ||||
|                 return self.model_wrapped | ||||
|             return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) | ||||
|  | ||||
|         # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again | ||||
|         if unwrap_model(model) is not model: | ||||
|             return model | ||||
|  | ||||
|         # Mixed precision training with apex (torch < 1.6) | ||||
|         if self.use_apex and training: | ||||
|             model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) | ||||
|  | ||||
|         # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP | ||||
|         if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): | ||||
|             model = nn.DataParallel(model) | ||||
|  | ||||
|         if self.args.jit_mode_eval: | ||||
|             start_time = time.time() | ||||
|             model = self.torch_jit_model_eval(model, dataloader, training) | ||||
|             self.jit_compilation_time = round(time.time() - start_time, 4) | ||||
|  | ||||
|         # Note: in torch.distributed mode, there's no point in wrapping the model | ||||
|         # inside a DistributedDataParallel as we'll be under `no_grad` anyways. | ||||
|         if not training: | ||||
|             return model | ||||
|  | ||||
|         # Distributed training (should be after apex fp16 initialization) | ||||
|         if self.sharded_ddp is not None: | ||||
|             # Sharded DDP! | ||||
|             if self.sharded_ddp == ShardedDDPOption.SIMPLE: | ||||
|                 model = ShardedDDP(model, self.optimizer) | ||||
|             else: | ||||
|                 mixed_precision = self.args.fp16 or self.args.bf16 | ||||
|                 cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp | ||||
|                 zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 | ||||
|                 # XXX: Breaking the self.model convention but I see no way around it for now. | ||||
|                 if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: | ||||
|                     model = auto_wrap(model) | ||||
|                 self.model = model = FullyShardedDDP( | ||||
|                     model, | ||||
|                     mixed_precision=mixed_precision, | ||||
|                     reshard_after_forward=zero_3, | ||||
|                     cpu_offload=cpu_offload, | ||||
|                 ).to(self.args.device) | ||||
|         # Distributed training using PyTorch FSDP | ||||
|         elif self.fsdp is not None: | ||||
|             # fix starts here | ||||
|             if not self.args.fsdp_config["xla"]: | ||||
|                 # PyTorch FSDP! | ||||
|                 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision | ||||
|                 from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP | ||||
|                 from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy | ||||
|                 import torch.distributed.fsdp._traversal_utils as traversal_utils | ||||
|  | ||||
|                 if FSDPOption.OFFLOAD in self.args.fsdp: | ||||
|                     cpu_offload = CPUOffload(offload_params=True) | ||||
|                 else: | ||||
|                     cpu_offload = CPUOffload(offload_params=False) | ||||
|  | ||||
|                 auto_wrap_policy = None | ||||
|  | ||||
|                 if FSDPOption.AUTO_WRAP in self.args.fsdp: | ||||
|                     if self.args.fsdp_config["fsdp_min_num_params"] > 0: | ||||
|                         auto_wrap_policy = functools.partial( | ||||
|                             size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] | ||||
|                         ) | ||||
|                     elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: | ||||
|                         transformer_cls_to_wrap = set() | ||||
|                         for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: | ||||
|                             transformer_cls = get_module_class_from_name(model, layer_class) | ||||
|                             if transformer_cls is None: | ||||
|                                 raise Exception("Could not find the transformer layer class to wrap in the model.") | ||||
|                             else: | ||||
|                                 transformer_cls_to_wrap.add(transformer_cls) | ||||
|                         auto_wrap_policy = functools.partial( | ||||
|                             transformer_auto_wrap_policy, | ||||
|                             # Transformer layer class to wrap | ||||
|                             transformer_layer_cls=transformer_cls_to_wrap, | ||||
|                         ) | ||||
|                 mixed_precision_policy = None | ||||
|                 dtype = None | ||||
|                 if self.args.fp16: | ||||
|                     dtype = torch.float16 | ||||
|                 elif self.args.bf16: | ||||
|                     dtype = torch.bfloat16 | ||||
|                 if dtype is not None: | ||||
|                     mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) | ||||
|                 if type(model) != FSDP: | ||||
|                     # XXX: Breaking the self.model convention but I see no way around it for now. | ||||
|                     signature = inspect.signature(FSDP.__init__).parameters.keys() | ||||
|                     kwargs = {} | ||||
|                     for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]: | ||||
|                         if arg in signature: | ||||
|                             kwargs[arg] = getattr(self, arg) | ||||
|                     self.model = model = FSDP( | ||||
|                         model, | ||||
|                         sharding_strategy=self.fsdp, | ||||
|                         cpu_offload=cpu_offload, | ||||
|                         auto_wrap_policy=auto_wrap_policy, | ||||
|                         mixed_precision=mixed_precision_policy, | ||||
|                         device_id=self.args.device, | ||||
|                         **kwargs, | ||||
|                     ) | ||||
|                      | ||||
|                 for submodule in traversal_utils._get_fsdp_states(model): | ||||
|                     print(submodule._state_dict_type, submodule._state_dict_config) | ||||
|                     break | ||||
|             # fix ends here | ||||
|             else: | ||||
|                 try: | ||||
|                     from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP | ||||
|                     from torch_xla.distributed.fsdp import checkpoint_module | ||||
|                     from torch_xla.distributed.fsdp.wrap import ( | ||||
|                         size_based_auto_wrap_policy, | ||||
|                         transformer_auto_wrap_policy, | ||||
|                     ) | ||||
|                 except ImportError: | ||||
|                     raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") | ||||
|                 auto_wrap_policy = None | ||||
|                 auto_wrapper_callable = None | ||||
|                 if self.args.fsdp_config["fsdp_min_num_params"] > 0: | ||||
|                     auto_wrap_policy = functools.partial( | ||||
|                         size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] | ||||
|                     ) | ||||
|                 elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: | ||||
|                     transformer_cls_to_wrap = set() | ||||
|                     for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: | ||||
|                         transformer_cls = get_module_class_from_name(model, layer_class) | ||||
|                         if transformer_cls is None: | ||||
|                             raise Exception("Could not find the transformer layer class to wrap in the model.") | ||||
|                         else: | ||||
|                             transformer_cls_to_wrap.add(transformer_cls) | ||||
|                     auto_wrap_policy = functools.partial( | ||||
|                         transformer_auto_wrap_policy, | ||||
|                         # Transformer layer class to wrap | ||||
|                         transformer_layer_cls=transformer_cls_to_wrap, | ||||
|                     ) | ||||
|                 fsdp_kwargs = self.args.xla_fsdp_config | ||||
|                 if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: | ||||
|                     # Apply gradient checkpointing to auto-wrapped sub-modules if specified | ||||
|                     def auto_wrapper_callable(m, *args, **kwargs): | ||||
|                         return FSDP(checkpoint_module(m), *args, **kwargs) | ||||
|  | ||||
|                 # Wrap the base model with an outer FSDP wrapper | ||||
|                 self.model = model = FSDP( | ||||
|                     model, | ||||
|                     auto_wrap_policy=auto_wrap_policy, | ||||
|                     auto_wrapper_callable=auto_wrapper_callable, | ||||
|                     **fsdp_kwargs, | ||||
|                 ) | ||||
|  | ||||
|                 # Patch `xm.optimizer_step` should not reduce gradients in this case, | ||||
|                 # as FSDP does not need gradient reduction over sharded parameters. | ||||
|                 def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): | ||||
|                     loss = optimizer.step(**optimizer_args) | ||||
|                     if barrier: | ||||
|                         xm.mark_step() | ||||
|                     return loss | ||||
|  | ||||
|                 xm.optimizer_step = patched_optimizer_step | ||||
|         elif is_sagemaker_dp_enabled(): | ||||
|             model = nn.parallel.DistributedDataParallel( | ||||
|                 model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] | ||||
|             ) | ||||
|         elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: | ||||
|             if is_torch_neuroncore_available(): | ||||
|                 return model | ||||
|             kwargs = {} | ||||
|             if self.args.ddp_find_unused_parameters is not None: | ||||
|                 kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters | ||||
|             elif isinstance(model, PreTrainedModel): | ||||
|                 # find_unused_parameters breaks checkpointing as per | ||||
|                 # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 | ||||
|                 kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing | ||||
|             else: | ||||
|                 kwargs["find_unused_parameters"] = True | ||||
|  | ||||
|             if self.args.ddp_bucket_cap_mb is not None: | ||||
|                 kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb | ||||
|  | ||||
|             if self.args.ddp_broadcast_buffers is not None: | ||||
|                 kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers | ||||
|  | ||||
|             self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) | ||||
|  | ||||
|         return model | ||||
|      | ||||
|     def get_batch_sampler(self, dataset=None): | ||||
|         if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1: | ||||
|             dataset = dataset if dataset is not None else self.train_dataset | ||||
|             try: | ||||
|                 batch_max_len = self.args.per_device_train_batch_size * unwrap_model(self.model).model_avg_context | ||||
|             except: | ||||
|                 # raise RuntimeError("group_by_length in distributed training requires model has attr `model_max_context`") | ||||
|                 batch_max_len = self.args.per_device_train_batch_size * self.args.model_avg_context | ||||
|             model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None | ||||
|             lengths = LengthGroupedSampler( | ||||
|                     batch_size=-1, # we just want to know about the lengths of the dataset so no need to pass `batch_size` | ||||
|                     dataset=dataset, | ||||
|                     model_input_name=model_input_name | ||||
|             ).lengths | ||||
|             seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed | ||||
|             batch_sampler = FFDDistributedBatchSampler( | ||||
|                 batch_max_length=batch_max_len, | ||||
|                 lengths=np.array(lengths), | ||||
|                 seed=seed | ||||
|             ) | ||||
|              | ||||
|             return batch_sampler | ||||
|      | ||||
|         return None | ||||
|      | ||||
|     def get_train_dataloader(self) -> DataLoader: | ||||
|         if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1: | ||||
|             if self.train_dataset is None: | ||||
|                 raise ValueError("Trainer: training requires a train_dataset.") | ||||
|  | ||||
|             train_dataset = self.train_dataset | ||||
|             data_collator = self.data_collator | ||||
|             if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): | ||||
|                 train_dataset = self._remove_unused_columns(train_dataset, description="training") | ||||
|             else: | ||||
|                 data_collator = self._get_collator_with_removed_columns(data_collator, description="training") | ||||
|                  | ||||
|             batch_sampler = self.get_batch_sampler(train_dataset) | ||||
|              | ||||
|             dataloader = DataLoader( | ||||
|                 train_dataset, | ||||
|                 batch_sampler=batch_sampler, | ||||
|                 drop_last=self.args.dataloader_drop_last, | ||||
|                 collate_fn=data_collator | ||||
|             ) | ||||
|             # return self.accelerator.prepare(dataloader) | ||||
|             return dataloader | ||||
|              | ||||
|         return super().get_train_dataloader() | ||||
|      | ||||
|     def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): | ||||
|         """ | ||||
|         Will save the model, so you can reload it using `from_pretrained()`. | ||||
|  | ||||
|         Will only save from the main process. | ||||
|         """ | ||||
|  | ||||
|         if output_dir is None: | ||||
|             output_dir = self.args.output_dir | ||||
|  | ||||
|         if is_torch_tpu_available(): | ||||
|             self._save_tpu(output_dir) | ||||
|         elif is_sagemaker_mp_enabled(): | ||||
|             # Calling the state_dict needs to be done on the wrapped model and on all processes. | ||||
|             os.makedirs(output_dir, exist_ok=True) | ||||
|             state_dict = self.model_wrapped.state_dict() | ||||
|             if self.args.should_save: | ||||
|                 self._save(output_dir, state_dict=state_dict) | ||||
|             if IS_SAGEMAKER_MP_POST_1_10: | ||||
|                 # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 | ||||
|                 Path(os.path.join(output_dir, "user_content.pt")).touch() | ||||
|         elif ( | ||||
|             ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp | ||||
|             or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp | ||||
|             or self.fsdp is not None | ||||
|             or self.is_fsdp_enabled | ||||
|         ): | ||||
|             state_dict = self.model.state_dict() | ||||
|             if self.args.should_save: | ||||
|                 self._save(output_dir, state_dict=state_dict) | ||||
|             # modification starts here | ||||
|             if self.is_fsdp_enabled and self.args.save_with_fsdp: | ||||
|                 save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) | ||||
|             # modification ends here | ||||
|  | ||||
|         elif self.is_deepspeed_enabled: | ||||
|             # this takes care of everything as long as we aren't under zero3 | ||||
|             if version.parse(accelerate_version) <= version.parse("0.20.3"): | ||||
|                 raise ValueError("Install Accelerate from main branch") | ||||
|             try: | ||||
|                 state_dict = self.accelerator.get_state_dict(self.deepspeed) | ||||
|                 if self.args.should_save: | ||||
|                     self._save(output_dir, state_dict=state_dict) | ||||
|             except ValueError: | ||||
|                 logger.warning( | ||||
|                     " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" | ||||
|                     " zero_to_fp32.py to recover weights" | ||||
|                 ) | ||||
|                 self.model_wrapped.save_checkpoint(output_dir) | ||||
|  | ||||
|         elif self.args.should_save: | ||||
|             self._save(output_dir) | ||||
|  | ||||
|         # Push to the Hub when `save_model` is called by the user. | ||||
|         if self.args.push_to_hub and not _internal_call: | ||||
|             self.push_to_hub(commit_message="Model save") | ||||
|      | ||||
|     def _save_checkpoint(self, model, trial, metrics=None): | ||||
|         # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we | ||||
|         # want to save except FullyShardedDDP. | ||||
|         # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" | ||||
|  | ||||
|         # Save model checkpoint | ||||
|         checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" | ||||
|  | ||||
|         if self.hp_search_backend is None and trial is None: | ||||
|             self.store_flos() | ||||
|  | ||||
|         run_dir = self._get_output_dir(trial=trial) | ||||
|         output_dir = os.path.join(run_dir, checkpoint_folder) | ||||
|         self.save_model(output_dir, _internal_call=True) | ||||
|         if self.is_deepspeed_enabled: | ||||
|             # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed | ||||
|             # config `stage3_gather_16bit_weights_on_model_save` is True | ||||
|             self.model_wrapped.save_checkpoint(output_dir) | ||||
|  | ||||
|         # Save optimizer and scheduler | ||||
|         if self.sharded_ddp == ShardedDDPOption.SIMPLE: | ||||
|             self.optimizer.consolidate_state_dict() | ||||
|  | ||||
|         if self.fsdp or self.is_fsdp_enabled: | ||||
|             if self.is_fsdp_enabled: | ||||
|                 # modification starts here | ||||
|                 if self.args.save_with_fsdp: | ||||
|                     save_fsdp_optimizer( | ||||
|                         self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir | ||||
|                     ) | ||||
|                 # modification ends here | ||||
|             else: | ||||
|                 # FSDP has a different interface for saving optimizer states. | ||||
|                 # Needs to be called on all ranks to gather all states. | ||||
|                 # full_optim_state_dict will be deprecated after Pytorch 2.2! | ||||
|                 full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) | ||||
|  | ||||
|         if is_torch_tpu_available(): | ||||
|             xm.rendezvous("saving_optimizer_states") | ||||
|             xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | ||||
|             with warnings.catch_warnings(record=True) as caught_warnings: | ||||
|                 xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||||
|                 reissue_pt_warnings(caught_warnings) | ||||
|         elif is_sagemaker_mp_enabled(): | ||||
|             opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) | ||||
|             smp.barrier() | ||||
|             if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: | ||||
|                 smp.save( | ||||
|                     opt_state_dict, | ||||
|                     os.path.join(output_dir, OPTIMIZER_NAME), | ||||
|                     partial=True, | ||||
|                     v3=smp.state.cfg.shard_optimizer_state, | ||||
|                 ) | ||||
|             if self.args.should_save: | ||||
|                 with warnings.catch_warnings(record=True) as caught_warnings: | ||||
|                     torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||||
|                 reissue_pt_warnings(caught_warnings) | ||||
|                 if self.do_grad_scaling: | ||||
|                     torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) | ||||
|         elif self.args.should_save and not self.is_deepspeed_enabled: | ||||
|             # deepspeed.save_checkpoint above saves model/optim/sched | ||||
|             if self.fsdp and not self.is_fsdp_enabled: | ||||
|                 torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) | ||||
|             else: | ||||
|                 torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | ||||
|  | ||||
|             with warnings.catch_warnings(record=True) as caught_warnings: | ||||
|                 torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||||
|             reissue_pt_warnings(caught_warnings) | ||||
|             if self.do_grad_scaling: | ||||
|                 torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) | ||||
|  | ||||
|         # Determine the new best metric / best model checkpoint | ||||
|         if metrics is not None and self.args.metric_for_best_model is not None: | ||||
|             metric_to_check = self.args.metric_for_best_model | ||||
|             if not metric_to_check.startswith("eval_"): | ||||
|                 metric_to_check = f"eval_{metric_to_check}" | ||||
|             metric_value = metrics[metric_to_check] | ||||
|  | ||||
|             operator = np.greater if self.args.greater_is_better else np.less | ||||
|             if ( | ||||
|                 self.state.best_metric is None | ||||
|                 or self.state.best_model_checkpoint is None | ||||
|                 or operator(metric_value, self.state.best_metric) | ||||
|             ): | ||||
|                 self.state.best_metric = metric_value | ||||
|                 self.state.best_model_checkpoint = output_dir | ||||
|  | ||||
|         # Save the Trainer state | ||||
|         if self.args.should_save: | ||||
|             self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) | ||||
|  | ||||
|         # Save RNG state in non-distributed training | ||||
|         rng_states = { | ||||
|             "python": random.getstate(), | ||||
|             "numpy": np.random.get_state(), | ||||
|             "cpu": torch.random.get_rng_state(), | ||||
|         } | ||||
|         if torch.cuda.is_available(): | ||||
|             if self.args.parallel_mode == ParallelMode.DISTRIBUTED: | ||||
|                 # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) | ||||
|                 rng_states["cuda"] = torch.cuda.random.get_rng_state_all() | ||||
|             else: | ||||
|                 rng_states["cuda"] = torch.cuda.random.get_rng_state() | ||||
|  | ||||
|         if is_torch_tpu_available(): | ||||
|             rng_states["xla"] = xm.get_rng_state() | ||||
|  | ||||
|         # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may | ||||
|         # not yet exist. | ||||
|         os.makedirs(output_dir, exist_ok=True) | ||||
|  | ||||
|         if self.args.world_size <= 1: | ||||
|             torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) | ||||
|         else: | ||||
|             torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) | ||||
|  | ||||
|         if self.args.push_to_hub: | ||||
|             self._push_from_checkpoint(output_dir) | ||||
|  | ||||
|         # Maybe delete some older checkpoints. | ||||
|         if self.args.should_save: | ||||
|             self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) | ||||
|              | ||||
|     def _load_optimizer_and_scheduler(self, checkpoint): | ||||
|         """If optimizer and scheduler states exist, load them.""" | ||||
|         if checkpoint is None: | ||||
|             return | ||||
|  | ||||
|         if self.is_deepspeed_enabled: | ||||
|             # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init | ||||
|             return | ||||
|  | ||||
|         checkpoint_file_exists = ( | ||||
|             glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") | ||||
|             if is_sagemaker_mp_enabled() | ||||
|             else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) | ||||
|         ) | ||||
|         if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): | ||||
|             # Load in optimizer and scheduler states | ||||
|             if is_torch_tpu_available(): | ||||
|                 # On TPU we have to take some extra precautions to properly load the states on the right device. | ||||
|                 optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") | ||||
|                 with warnings.catch_warnings(record=True) as caught_warnings: | ||||
|                     lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") | ||||
|                 reissue_pt_warnings(caught_warnings) | ||||
|  | ||||
|                 xm.send_cpu_data_to_device(optimizer_state, self.args.device) | ||||
|                 xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) | ||||
|  | ||||
|                 self.optimizer.load_state_dict(optimizer_state) | ||||
|                 self.lr_scheduler.load_state_dict(lr_scheduler_state) | ||||
|             else: | ||||
|                 if is_sagemaker_mp_enabled(): | ||||
|                     if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): | ||||
|                         # Optimizer checkpoint was saved with smp >= 1.10 | ||||
|                         def opt_load_hook(mod, opt): | ||||
|                             opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) | ||||
|  | ||||
|                     else: | ||||
|                         # Optimizer checkpoint was saved with smp < 1.10 | ||||
|                         def opt_load_hook(mod, opt): | ||||
|                             if IS_SAGEMAKER_MP_POST_1_10: | ||||
|                                 opt.load_state_dict( | ||||
|                                     smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) | ||||
|                                 ) | ||||
|                             else: | ||||
|                                 opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) | ||||
|  | ||||
|                     self.model_wrapped.register_post_step_hook(opt_load_hook) | ||||
|                 else: | ||||
|                     # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. | ||||
|                     # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more | ||||
|                     # likely to get OOM on CPU (since we load num_gpu times the optimizer state | ||||
|                     map_location = self.args.device if self.args.world_size > 1 else "cpu" | ||||
|                     if self.fsdp or self.is_fsdp_enabled: | ||||
|                         # modification starts here | ||||
|                         if self.is_fsdp_enabled and self.args.save_with_fsdp: | ||||
|                             load_fsdp_optimizer( | ||||
|                                 self.accelerator.state.fsdp_plugin, | ||||
|                                 self.accelerator, | ||||
|                                 self.optimizer, | ||||
|                                 self.model, | ||||
|                                 checkpoint, | ||||
|                             ) | ||||
|                         elif not self.is_fsdp_enabled: | ||||
|                             full_osd = None | ||||
|                             # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it | ||||
|                             if self.args.process_index == 0: | ||||
|                                 full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) | ||||
|                             # call scatter_full_optim_state_dict on all ranks | ||||
|                             sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) | ||||
|                             self.optimizer.load_state_dict(sharded_osd) | ||||
|                         else: | ||||
|                             self.optimizer.load_state_dict( | ||||
|                             torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) | ||||
|                         ) | ||||
|                         # modification ends here | ||||
|                     else: | ||||
|                         self.optimizer.load_state_dict( | ||||
|                             torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) | ||||
|                         ) | ||||
|                 with warnings.catch_warnings(record=True) as caught_warnings: | ||||
|                     self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) | ||||
|                 reissue_pt_warnings(caught_warnings) | ||||
|                 if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): | ||||
|                     self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) | ||||
|                      | ||||
							
								
								
									
										478
									
								
								train/trainers/fsdp_training_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										478
									
								
								train/trainers/fsdp_training_args.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,478 @@ | ||||
| import sys | ||||
| import os | ||||
|  | ||||
| import transformers | ||||
| from transformers.training_args import * | ||||
|  | ||||
| from .utils import ExtendedFSDPOption | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class FSDPTrainingArguments(transformers.TrainingArguments): | ||||
|     # about data-efficient sampler | ||||
|     use_ffd_sampler: bool = False | ||||
|     model_avg_context: int = 2048 | ||||
|      | ||||
|     # about saving | ||||
|     # if not save with fsdp, then must not load with fsdp | ||||
|     save_with_fsdp: bool = False | ||||
|      | ||||
|     def __post_init__(self): | ||||
|         # expand paths, if not os.makedirs("~/bar") will make directory | ||||
|         # in the current directory instead of the actual home | ||||
|         # see https://github.com/huggingface/transformers/issues/10628 | ||||
|         if self.output_dir is not None: | ||||
|             self.output_dir = os.path.expanduser(self.output_dir) | ||||
|         if self.logging_dir is None and self.output_dir is not None: | ||||
|             self.logging_dir = os.path.join(self.output_dir, default_logdir()) | ||||
|         if self.logging_dir is not None: | ||||
|             self.logging_dir = os.path.expanduser(self.logging_dir) | ||||
|  | ||||
|         if self.disable_tqdm is None: | ||||
|             self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN | ||||
|  | ||||
|         if isinstance(self.evaluation_strategy, EvaluationStrategy): | ||||
|             warnings.warn( | ||||
|                 "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5" | ||||
|                 " of 🤗 Transformers. Use `IntervalStrategy` instead", | ||||
|                 FutureWarning, | ||||
|             ) | ||||
|             # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it. | ||||
|             self.evaluation_strategy = self.evaluation_strategy.value | ||||
|  | ||||
|         # if self.xpu_backend is not None: | ||||
|         #     warnings.warn( | ||||
|         #         "using `xpu_backend` is deprecated and will be removed in version 4.31" | ||||
|         #         " of 🤗 Transformers. Use `ddp_backend` instead", | ||||
|         #         FutureWarning, | ||||
|         #     ) | ||||
|         #     self.ddp_backend = self.xpu_backend | ||||
|  | ||||
|         self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy) | ||||
|         self.logging_strategy = IntervalStrategy(self.logging_strategy) | ||||
|         self.save_strategy = IntervalStrategy(self.save_strategy) | ||||
|         self.hub_strategy = HubStrategy(self.hub_strategy) | ||||
|  | ||||
|         self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) | ||||
|         if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO: | ||||
|             self.do_eval = True | ||||
|  | ||||
|         # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero | ||||
|         if self.evaluation_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0): | ||||
|             if self.logging_steps > 0: | ||||
|                 logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}") | ||||
|                 self.eval_steps = self.logging_steps | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     f"evaluation strategy {self.evaluation_strategy} requires either non-zero --eval_steps or" | ||||
|                     " --logging_steps" | ||||
|                 ) | ||||
|  | ||||
|         # logging_steps must be non-zero for logging_strategy that is other than 'no' | ||||
|         if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0: | ||||
|             raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps") | ||||
|  | ||||
|         if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1: | ||||
|             if self.logging_steps != int(self.logging_steps): | ||||
|                 raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}") | ||||
|             self.logging_steps = int(self.logging_steps) | ||||
|         if self.evaluation_strategy == IntervalStrategy.STEPS and self.eval_steps > 1: | ||||
|             if self.eval_steps != int(self.eval_steps): | ||||
|                 raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}") | ||||
|             self.eval_steps = int(self.eval_steps) | ||||
|         if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1: | ||||
|             if self.save_steps != int(self.save_steps): | ||||
|                 raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}") | ||||
|             self.save_steps = int(self.save_steps) | ||||
|  | ||||
|         # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible. | ||||
|         if self.load_best_model_at_end: | ||||
|             if self.evaluation_strategy != self.save_strategy: | ||||
|                 raise ValueError( | ||||
|                     "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation " | ||||
|                     f"strategy: {self.evaluation_strategy}\n- Save strategy: {self.save_strategy}" | ||||
|                 ) | ||||
|             if self.evaluation_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0: | ||||
|                 if self.eval_steps < 1 or self.save_steps < 1: | ||||
|                     if not (self.eval_steps < 1 and self.save_steps < 1): | ||||
|                         raise ValueError( | ||||
|                             "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " | ||||
|                             "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps" | ||||
|                             f"{self.save_steps} and eval_steps {self.eval_steps}." | ||||
|                         ) | ||||
|                     # Work around floating point precision issues | ||||
|                     LARGE_MULTIPLIER = 1_000_000 | ||||
|                     if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0: | ||||
|                         raise ValueError( | ||||
|                             "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation " | ||||
|                             f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}." | ||||
|                         ) | ||||
|                 raise ValueError( | ||||
|                     "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation " | ||||
|                     f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." | ||||
|                 ) | ||||
|  | ||||
|         safetensors_available = is_safetensors_available() | ||||
|         if self.save_safetensors and not safetensors_available: | ||||
|             raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!") | ||||
|         if not self.save_safetensors and safetensors_available: | ||||
|             logger.info( | ||||
|                 f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. " | ||||
|                 f"Safetensors should be a preferred weights saving format due to security and performance reasons. " | ||||
|                 f"If your model cannot be saved by safetensors please feel free to open an issue at " | ||||
|                 f"https://github.com/huggingface/safetensors!" | ||||
|             ) | ||||
|  | ||||
|         if ( | ||||
|             self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU | ||||
|         ) and self.metric_for_best_model is None: | ||||
|             self.metric_for_best_model = "loss" | ||||
|         if self.greater_is_better is None and self.metric_for_best_model is not None: | ||||
|             self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] | ||||
|         if self.run_name is None: | ||||
|             self.run_name = self.output_dir | ||||
|         if self.framework == "pt" and is_torch_available(): | ||||
|             if self.fp16_backend and self.fp16_backend != "auto": | ||||
|                 warnings.warn( | ||||
|                     "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" | ||||
|                     " `half_precision_backend` instead", | ||||
|                     FutureWarning, | ||||
|                 ) | ||||
|                 self.half_precision_backend = self.fp16_backend | ||||
|  | ||||
|             if self.bf16 or self.bf16_full_eval: | ||||
|                 if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available(): | ||||
|                     # cpu | ||||
|                     raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") | ||||
|                 elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available(): | ||||
|                     # gpu | ||||
|                     raise ValueError( | ||||
|                         "Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0" | ||||
|                     ) | ||||
|  | ||||
|         if self.fp16 and self.bf16: | ||||
|             raise ValueError("At most one of fp16 and bf16 can be True, but not both") | ||||
|  | ||||
|         if self.fp16_full_eval and self.bf16_full_eval: | ||||
|             raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both") | ||||
|  | ||||
|         if self.bf16: | ||||
|             if self.half_precision_backend == "apex": | ||||
|                 raise ValueError( | ||||
|                     " `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use" | ||||
|                     " `--half_precision_backend cuda_amp` instead" | ||||
|                 ) | ||||
|             if not (self.sharded_ddp == "" or not self.sharded_ddp): | ||||
|                 raise ValueError("sharded_ddp is not supported with bf16") | ||||
|  | ||||
|         if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: | ||||
|             if self.evaluation_strategy == IntervalStrategy.NO: | ||||
|                 raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy") | ||||
|             if not is_torch_available(): | ||||
|                 raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0") | ||||
|  | ||||
|         self.optim = OptimizerNames(self.optim) | ||||
|         if self.adafactor: | ||||
|             warnings.warn( | ||||
|                 "`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim" | ||||
|                 " adafactor` instead", | ||||
|                 FutureWarning, | ||||
|             ) | ||||
|             self.optim = OptimizerNames.ADAFACTOR | ||||
|         if self.optim == OptimizerNames.ADAMW_TORCH_FUSED and is_torch_available(): | ||||
|             if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.0.0"): | ||||
|                 raise ValueError("--optim adamw_torch_fused requires PyTorch 2.0 or higher") | ||||
|             # there is a bug in fp16/AMP in pt-2.0.0 | ||||
|             if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16: | ||||
|                 raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0") | ||||
|  | ||||
|         if ( | ||||
|             self.framework == "pt" | ||||
|             and is_torch_available() | ||||
|             and (self.device.type != "cuda") | ||||
|             and (get_xla_device_type(self.device) != "GPU") | ||||
|             and (self.fp16 or self.fp16_full_eval) | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation" | ||||
|                 " (`--fp16_full_eval`) can only be used on CUDA devices." | ||||
|             ) | ||||
|  | ||||
|         if ( | ||||
|             self.framework == "pt" | ||||
|             and is_torch_available() | ||||
|             and (self.device.type != "cuda") | ||||
|             and (get_xla_device_type(self.device) != "GPU") | ||||
|             and (get_xla_device_type(self.device) != "TPU") | ||||
|             and (self.device.type != "cpu") | ||||
|             and (self.bf16 or self.bf16_full_eval) | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation" | ||||
|                 " (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices." | ||||
|             ) | ||||
|  | ||||
|         if self.torchdynamo is not None: | ||||
|             warnings.warn( | ||||
|                 "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" | ||||
|                 " `torch_compile_backend` instead", | ||||
|                 FutureWarning, | ||||
|             ) | ||||
|             self.torch_compile_backend = self.torchdynamo | ||||
|         if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile: | ||||
|             self.torch_compile = True | ||||
|         if self.torch_compile and self.torch_compile_backend is None: | ||||
|             self.torch_compile_backend = "inductor" | ||||
|  | ||||
|         # accelerate integration for torch compile | ||||
|         if self.torch_compile: | ||||
|             # set env vars for accelerate | ||||
|             prefix = "ACCELERATE_DYNAMO_" | ||||
|             os.environ[prefix + "BACKEND"] = self.torch_compile_backend | ||||
|             if self.torch_compile_mode is not None: | ||||
|                 os.environ[prefix + "MODE"] = self.torch_compile_mode | ||||
|  | ||||
|         if self.framework == "pt" and is_torch_available() and self.torch_compile: | ||||
|             if is_torch_tf32_available(): | ||||
|                 if self.tf32 is None and not self.fp16 or self.bf16: | ||||
|                     logger.info( | ||||
|                         "Setting TF32 in CUDA backends to speedup torch compile, you won't see any improvement" | ||||
|                         " otherwise." | ||||
|                     ) | ||||
|                     torch.backends.cuda.matmul.allow_tf32 = True | ||||
|             else: | ||||
|                 logger.warning( | ||||
|                     "The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here." | ||||
|                 ) | ||||
|         if self.framework == "pt" and is_torch_available() and self.tf32 is not None: | ||||
|             if self.tf32: | ||||
|                 if is_torch_tf32_available(): | ||||
|                     torch.backends.cuda.matmul.allow_tf32 = True | ||||
|                 else: | ||||
|                     raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7") | ||||
|             else: | ||||
|                 if is_torch_tf32_available(): | ||||
|                     torch.backends.cuda.matmul.allow_tf32 = False | ||||
|                 # no need to assert on else | ||||
|  | ||||
|         if self.report_to is None: | ||||
|             logger.info( | ||||
|                 "The default value for the training argument `--report_to` will change in v5 (from all installed " | ||||
|                 "integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as " | ||||
|                 "now. You should start updating your code and make this info disappear :-)." | ||||
|             ) | ||||
|             self.report_to = "all" | ||||
|         if self.report_to == "all" or self.report_to == ["all"]: | ||||
|             # Import at runtime to avoid a circular import. | ||||
|             from transformers.integrations import get_available_reporting_integrations | ||||
|  | ||||
|             self.report_to = get_available_reporting_integrations() | ||||
|         elif self.report_to == "none" or self.report_to == ["none"]: | ||||
|             self.report_to = [] | ||||
|         elif not isinstance(self.report_to, list): | ||||
|             self.report_to = [self.report_to] | ||||
|  | ||||
|         if self.warmup_ratio < 0 or self.warmup_ratio > 1: | ||||
|             raise ValueError("warmup_ratio must lie in range [0,1]") | ||||
|         elif self.warmup_ratio > 0 and self.warmup_steps > 0: | ||||
|             logger.info( | ||||
|                 "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio" | ||||
|                 " during training" | ||||
|             ) | ||||
|  | ||||
|         if not (self.sharded_ddp == "" or not self.sharded_ddp): | ||||
|             warnings.warn( | ||||
|                 "using `sharded_ddp` is deprecated and will be removed in version 4.33" | ||||
|                 " of 🤗 Transformers. Use `fsdp` instead", | ||||
|                 FutureWarning, | ||||
|             ) | ||||
|         if isinstance(self.sharded_ddp, bool): | ||||
|             self.sharded_ddp = "simple" if self.sharded_ddp else "" | ||||
|         if isinstance(self.sharded_ddp, str): | ||||
|             self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()] | ||||
|         if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]: | ||||
|             raise ValueError( | ||||
|                 "`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or " | ||||
|                 '`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.' | ||||
|             ) | ||||
|         elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp: | ||||
|             raise ValueError("`--sharded_ddp simple` is not compatible with any other option.") | ||||
|         elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: | ||||
|             raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") | ||||
|  | ||||
|         if isinstance(self.fsdp, bool): | ||||
|             self.fsdp = "full_shard" if self.fsdp else "" | ||||
|         if isinstance(self.fsdp, str): | ||||
|             self.fsdp = [ExtendedFSDPOption(s) for s in self.fsdp.split()] | ||||
|         if self.fsdp == [ExtendedFSDPOption.OFFLOAD]: | ||||
|             raise ValueError( | ||||
|                 "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or " | ||||
|                 '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.' | ||||
|             ) | ||||
|         elif ExtendedFSDPOption.FULL_SHARD in self.fsdp and ExtendedFSDPOption.SHARD_GRAD_OP in self.fsdp: | ||||
|             raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") | ||||
|  | ||||
|         if self.fsdp_config is None: | ||||
|             self.fsdp_config = {} | ||||
|  | ||||
|         if isinstance(self.fsdp_config, str): | ||||
|             with io.open(self.fsdp_config, "r", encoding="utf-8") as f: | ||||
|                 self.fsdp_config = json.load(f) | ||||
|  | ||||
|         if self.fsdp_min_num_params > 0: | ||||
|             warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) | ||||
|  | ||||
|         self.fsdp_config["fsdp_min_num_params"] = max( | ||||
|             self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params | ||||
|         ) | ||||
|  | ||||
|         # if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object | ||||
|         if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str): | ||||
|             self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [ | ||||
|                 self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] | ||||
|             ] | ||||
|  | ||||
|         if self.fsdp_transformer_layer_cls_to_wrap is not None: | ||||
|             warnings.warn( | ||||
|                 "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning | ||||
|             ) | ||||
|             self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get( | ||||
|                 "fsdp_transformer_layer_cls_to_wrap", [] | ||||
|             ) + [self.fsdp_transformer_layer_cls_to_wrap] | ||||
|  | ||||
|         if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0: | ||||
|             warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.") | ||||
|  | ||||
|         if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: | ||||
|             warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") | ||||
|  | ||||
|         if ( | ||||
|             len(self.fsdp) > 0 | ||||
|             and self.fsdp_config["fsdp_min_num_params"] > 0 | ||||
|             and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 "`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive." | ||||
|             ) | ||||
|         self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) | ||||
|         self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) | ||||
|         if self.fsdp_config["xla"]: | ||||
|             if len(self.fsdp) > 0: | ||||
|                 # store XLA fsdp configuration parameters into a dictionary | ||||
|                 self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}) | ||||
|                 # apply appropriate string to torch.dtype conversions for parameters | ||||
|                 if "compute_dtype" in self.xla_fsdp_config: | ||||
|                     self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"]) | ||||
|                 if "buffer_dtype" in self.xla_fsdp_config: | ||||
|                     self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"]) | ||||
|             else: | ||||
|                 warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.") | ||||
|         else: | ||||
|             if self.fsdp_config["xla_fsdp_grad_ckpt"]: | ||||
|                 warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.") | ||||
|  | ||||
|         # accelerate integration for FSDP | ||||
|         if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: | ||||
|             os.environ["ACCELERATE_USE_FSDP"] = "true" | ||||
|             from accelerate.utils.constants import ( | ||||
|                 FSDP_AUTO_WRAP_POLICY, | ||||
|                 FSDP_SHARDING_STRATEGY, | ||||
|             ) | ||||
|  | ||||
|             for fsdp_option in self.fsdp: | ||||
|                 if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: | ||||
|                     # set environment variable for FSDP sharding strategy | ||||
|                     os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1) | ||||
|                 elif fsdp_option == FSDPOption.OFFLOAD: | ||||
|                     os.environ["FSDP_OFFLOAD_PARAMS"] = "true" | ||||
|                 elif fsdp_option == FSDPOption.AUTO_WRAP: | ||||
|                     if self.fsdp_config["fsdp_min_num_params"] > 0: | ||||
|                         os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"]) | ||||
|                         os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] | ||||
|                     elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: | ||||
|                         os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join( | ||||
|                             self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] | ||||
|                         ) | ||||
|                         os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] | ||||
|             prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH") | ||||
|             os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper() | ||||
|  | ||||
|         if self.tpu_metrics_debug: | ||||
|             warnings.warn( | ||||
|                 "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" | ||||
|                 " `--debug tpu_metrics_debug` instead", | ||||
|                 FutureWarning, | ||||
|             ) | ||||
|             if self.debug is None: | ||||
|                 self.debug = " tpu_metrics_debug" | ||||
|             else: | ||||
|                 self.debug += " tpu_metrics_debug" | ||||
|             self.tpu_metrics_debug = False | ||||
|  | ||||
|         if isinstance(self.debug, str): | ||||
|             self.debug = [DebugOption(s) for s in self.debug.split()] | ||||
|         elif self.debug is None: | ||||
|             self.debug = [] | ||||
|  | ||||
|         self.deepspeed_plugin = None | ||||
|         if self.deepspeed: | ||||
|             # - must be run very last in arg parsing, since it will use a lot of these settings. | ||||
|             # - must be run before the model is created. | ||||
|             if not is_accelerate_available(): | ||||
|                 raise ValueError("--deepspeed requires Accelerate to be installed: `pip install accelerate`.") | ||||
|             from transformers.deepspeed import HfTrainerDeepSpeedConfig | ||||
|  | ||||
|             # will be used later by the Trainer | ||||
|             # note: leave self.deepspeed unmodified in case a user relies on it not to be modified) | ||||
|             self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed) | ||||
|             self.hf_deepspeed_config.trainer_config_process(self) | ||||
|  | ||||
|             # Accelerate DeepSpeed Plugin | ||||
|             from accelerate.utils import DeepSpeedPlugin | ||||
|  | ||||
|             os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" | ||||
|             self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) | ||||
|  | ||||
|         if self.push_to_hub_token is not None: | ||||
|             warnings.warn( | ||||
|                 "`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " | ||||
|                 "`--hub_token` instead.", | ||||
|                 FutureWarning, | ||||
|             ) | ||||
|             self.hub_token = self.push_to_hub_token | ||||
|  | ||||
|         if self.push_to_hub_model_id is not None: | ||||
|             self.hub_model_id = get_full_repo_name( | ||||
|                 self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token | ||||
|             ) | ||||
|             if self.push_to_hub_organization is not None: | ||||
|                 warnings.warn( | ||||
|                     "`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in " | ||||
|                     "version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this " | ||||
|                     f"argument (in this case {self.hub_model_id}).", | ||||
|                     FutureWarning, | ||||
|                 ) | ||||
|             else: | ||||
|                 warnings.warn( | ||||
|                     "`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " | ||||
|                     "`--hub_model_id` instead and pass the full repo name to this argument (in this case " | ||||
|                     f"{self.hub_model_id}).", | ||||
|                     FutureWarning, | ||||
|                 ) | ||||
|         elif self.push_to_hub_organization is not None: | ||||
|             self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}" | ||||
|             warnings.warn( | ||||
|                 "`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use " | ||||
|                 "`--hub_model_id` instead and pass the full repo name to this argument (in this case " | ||||
|                 f"{self.hub_model_id}).", | ||||
|                 FutureWarning, | ||||
|             ) | ||||
|  | ||||
|         # if training args is specified, it will override the one specified in the accelerate config | ||||
|         if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0: | ||||
|             mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") | ||||
|             if self.fp16: | ||||
|                 mixed_precision_dtype = "fp16" | ||||
|             elif self.bf16: | ||||
|                 mixed_precision_dtype = "bf16" | ||||
|             os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype | ||||
							
								
								
									
										26
									
								
								train/trainers/stylistic_trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								train/trainers/stylistic_trainer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | ||||
| from .fsdp_trainer import FSDPTrainer | ||||
| from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES | ||||
| from transformers.trainer_utils import unwrap_model | ||||
|  | ||||
|  | ||||
| class StylisticTrainer(FSDPTrainer): | ||||
|     def compute_loss(self, model, inputs, return_outputs=False): | ||||
|         if self.label_smoother is not None and "labels" in inputs: | ||||
|             labels = inputs.pop("labels") | ||||
|         else: | ||||
|             labels = None | ||||
|  | ||||
|         outputs = model(**inputs) | ||||
|         if self.args.past_index >= 0: | ||||
|             self._past = outputs[self.args.past_index] | ||||
|          | ||||
|         if labels is not None: | ||||
|             # FIXME: should support peft | ||||
|             model_name = unwrap_model(model)._get_name() | ||||
|  | ||||
|             if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): | ||||
|                 # loss = self.label_smoother(outputs, labels, shift_labels=True) | ||||
|             else: | ||||
|                 raise ValueError(f"model {model_name} is not a causal LM") | ||||
|         else: | ||||
|             raise ValueError("labels should not be None") | ||||
							
								
								
									
										154
									
								
								train/trainers/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								train/trainers/utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,154 @@ | ||||
| import sys | ||||
| import os | ||||
| import warnings | ||||
|  | ||||
| import torch | ||||
| import torch.utils.checkpoint as checkpoint | ||||
| from torch.utils.checkpoint import check_backward_validity, _get_autocast_kwargs, detach_variable | ||||
| from torch.distributed.fsdp import _state_dict_utils | ||||
| from torch.distributed.fsdp._common_utils import clean_tensor_name | ||||
| from transformers.utils import ExplicitEnum | ||||
|  | ||||
|  | ||||
| class ExtendedFSDPOption(ExplicitEnum): | ||||
|     FULL_SHARD = "full_shard" | ||||
|     SHARD_GRAD_OP = "shard_grad_op" | ||||
|     NO_SHARD = "no_shard" | ||||
|      | ||||
|     # extention starts here | ||||
|     HYBRID_SHARD = "hybrid_shard" | ||||
|     _HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2" | ||||
|     # extention ends here | ||||
|      | ||||
|     OFFLOAD = "offload" | ||||
|     AUTO_WRAP = "auto_wrap" | ||||
|      | ||||
|      | ||||
| DefaultCheckpointFunction = checkpoint.CheckpointFunction | ||||
| DefaultFullPostStateDictHook = _state_dict_utils._full_post_state_dict_hook | ||||
|      | ||||
|  | ||||
| class CompatibleCheckpointFunction(torch.autograd.Function): | ||||
|      | ||||
|     @staticmethod | ||||
|     def forward(ctx, run_function, preserve_rng_state, *args): | ||||
|         check_backward_validity(args) | ||||
|         ctx.run_function = run_function | ||||
|         ctx.preserve_rng_state = preserve_rng_state | ||||
|         # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. | ||||
|         ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() | ||||
|  | ||||
|         # Save non-tensor inputs in ctx, keep a placeholder None for tensors | ||||
|         # to be filled out during the backward. | ||||
|         ctx.inputs = [] | ||||
|         ctx.tensor_indices = [] | ||||
|         tensor_inputs = [] | ||||
|         for i, arg in enumerate(args): | ||||
|             if torch.is_tensor(arg): | ||||
|                 tensor_inputs.append(arg) | ||||
|                 ctx.tensor_indices.append(i) | ||||
|                 ctx.inputs.append(None) | ||||
|             else: | ||||
|                 ctx.inputs.append(arg) | ||||
|  | ||||
|         ctx.save_for_backward(*tensor_inputs) | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             outputs = run_function(*args) | ||||
|         return outputs | ||||
|  | ||||
|     @staticmethod | ||||
|     def backward(ctx, *args): | ||||
|         if not torch.autograd._is_checkpoint_valid(): | ||||
|             raise RuntimeError( | ||||
|                 "Checkpointing is not compatible with .grad() or when an `inputs` parameter" | ||||
|                 " is passed to .backward(). Please use .backward() and do not pass its `inputs`" | ||||
|                 " argument.") | ||||
|         # Copy the list to avoid modifying original list. | ||||
|         inputs = list(ctx.inputs) | ||||
|         tensor_indices = ctx.tensor_indices | ||||
|         tensors = ctx.saved_tensors | ||||
|  | ||||
|         # Fill in inputs with appropriate saved tensors. | ||||
|         for i, idx in enumerate(tensor_indices): | ||||
|             inputs[idx] = tensors[i] | ||||
|  | ||||
|         # Stash the surrounding rng state, and mimic the state that was | ||||
|         # present at this time during forward.  Restore the surrounding state | ||||
|         # when we're done. | ||||
|         rng_devices = [] | ||||
|         # if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: | ||||
|         #     rng_devices = ctx.fwd_gpu_devices | ||||
|         with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): | ||||
|             detached_inputs = detach_variable(tuple(inputs)) | ||||
|             with torch.enable_grad(), \ | ||||
|                  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ | ||||
|                  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | ||||
|                 outputs = ctx.run_function(*detached_inputs) | ||||
|  | ||||
|         if isinstance(outputs, torch.Tensor): | ||||
|             outputs = (outputs,) | ||||
|  | ||||
|         # run backward() with only tensor that requires grad | ||||
|         outputs_with_grad = [] | ||||
|         args_with_grad = [] | ||||
|         for i in range(len(outputs)): | ||||
|             if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: | ||||
|                 outputs_with_grad.append(outputs[i]) | ||||
|                 args_with_grad.append(args[i]) | ||||
|         if len(outputs_with_grad) == 0: | ||||
|             raise RuntimeError( | ||||
|                 "none of output has requires_grad=True," | ||||
|                 " this checkpoint() is not necessary") | ||||
|         torch.autograd.backward(outputs_with_grad, args_with_grad) | ||||
|         grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None | ||||
|                       for inp in detached_inputs) | ||||
|  | ||||
|         return (None, None) + grads | ||||
|      | ||||
|      | ||||
| def low_gpu_full_post_state_dict_hook(module, fsdp_state, state_dict, prefix): | ||||
|      | ||||
|     def param_hook(state_dict, prefix, fqn): | ||||
|         clean_key = fqn | ||||
|         clean_prefix = clean_tensor_name(prefix) | ||||
|         # Strip prefix out of key if needed as buffer names and param names | ||||
|         # do not have prefix considered as they are not computed in `state_dict` | ||||
|         # call. | ||||
|         if clean_key.startswith(clean_prefix): | ||||
|             clean_key = clean_key[len(clean_prefix) :] | ||||
|  | ||||
|         # Clone parameters before exiting the `_unshard_fsdp_state_params()` context. | ||||
|         if not getattr(state_dict[fqn], "_has_been_cloned", False): | ||||
|             try: | ||||
|                 state_dict[fqn] = state_dict[fqn].cpu().clone().detach() | ||||
|                 state_dict[fqn]._has_been_cloned = True  # type: ignore[attr-defined] | ||||
|             except BaseException as e: | ||||
|                 warnings.warn( | ||||
|                     f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " | ||||
|                     "This may mean that this state_dict entry could point to invalid " | ||||
|                     "memory regions after returning from state_dict() call if this " | ||||
|                     "parameter is managed by FSDP. Please check clone " | ||||
|                     f"implementation of {fqn}. Error: {str(e)}" | ||||
|                 ) | ||||
|  | ||||
|     return _state_dict_utils._common_unshard_post_state_dict_hook( | ||||
|         module, fsdp_state, state_dict, prefix, param_hook | ||||
|     ) | ||||
|      | ||||
|  | ||||
| # enable to efficiently saving `state_dict` for fsdp | ||||
| def enable_low_gpu_full_post_state_dict_hook(): | ||||
|     _state_dict_utils._full_post_state_dict_hook = low_gpu_full_post_state_dict_hook | ||||
|      | ||||
| def disable_low_gpu_full_post_state_dict_hook(): | ||||
|     _state_dict_utils._full_post_state_dict_hook = DefaultFullPostStateDictHook | ||||
|      | ||||
|  | ||||
| # enable to make `torch.compile` work | ||||
| def enable_compatible_checkpoint_function(): | ||||
|     checkpoint.CheckpointFunction = CompatibleCheckpointFunction | ||||
|      | ||||
| def disable_compatible_checkpoint_function(): | ||||
|     checkpoint.CheckpointFunction = DefaultCheckpointFunction | ||||
|      | ||||
							
								
								
									
										0
									
								
								train/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								train/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								train/utils/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								train/utils/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										141
									
								
								train/utils/datasets/nyt10_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								train/utils/datasets/nyt10_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,141 @@ | ||||
| import copy | ||||
| import json | ||||
| import random | ||||
|  | ||||
| import torch | ||||
| from torch.utils.data import Dataset | ||||
|  | ||||
|  | ||||
| IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss | ||||
|  | ||||
| prompt_template = ( | ||||
|     "Below is an instruction that describes a task, paired with an input that provides further context. " | ||||
|     "Write a response that appropriately completes the request.\n\n" | ||||
|     "### Instruction:\nGiven a piece of text, please find out the person-nationality relation in it. Tell me who is the person and which is the nationality. The answer should be in json format.\n\n### Input:\n{text}\n\n### Response:" | ||||
| ) | ||||
|  | ||||
| # output_template = [ | ||||
| #     "```json\n", | ||||
| #     "{\n", | ||||
| #     "  \"person\": \"", | ||||
| #     "sample_person", | ||||
| #     "\",\n", | ||||
| #     "  \"nationality\": \"", | ||||
| #     "sample_nationality", | ||||
| #     "\"\n", | ||||
| #     "}\n", | ||||
| #     "```", | ||||
| # ] | ||||
|  | ||||
| # person_index = 3 | ||||
| # nationality_index = 6 | ||||
|  | ||||
| output_template = [ | ||||
|     "```json\n{\n  \"person\": \"", | ||||
|     "sample_person", | ||||
|     "\",\n  \"nationality\": \"", | ||||
|     "sample_nationality", | ||||
|     "\"\n}\n```", | ||||
| ] | ||||
|  | ||||
| person_index = 1 | ||||
| nationality_index = 3 | ||||
|  | ||||
|  | ||||
| class NYT10Dataset(Dataset): | ||||
|     def __init__(self, data_path: str, tokenizer, size: int = -1): | ||||
|         with open(data_path, 'r') as f: | ||||
|             self.ann = [json.loads(line) for line in f.readlines()] | ||||
|         # only use "/people/person/nationality" | ||||
|         self.ann = [ | ||||
|             { | ||||
|                 "text": dp["text"], | ||||
|                 "person": dp["h"]["name"], | ||||
|                 "nationality": dp["t"]["name"], | ||||
|             } for dp in self.ann if '/people/person/nationality' in dp['relation'] | ||||
|         ] | ||||
|  | ||||
|         random.shuffle(self.ann) | ||||
|  | ||||
|         if size > 0: | ||||
|             self.ann = self.ann[:size] | ||||
|  | ||||
|         self.tokenizer = tokenizer | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.ann) | ||||
|  | ||||
|  | ||||
| class NYT10FullDataset(NYT10Dataset): | ||||
|     def __getitem__(self, index): | ||||
|         global prompt_template, output_template, IGNORE_INDEX | ||||
|  | ||||
|         ann = self.ann[index] | ||||
|         prompt = prompt_template.format(text=ann["text"]) | ||||
|         output = copy.deepcopy(output_template) | ||||
|         output[person_index] = ann["person"] | ||||
|         output[nationality_index] = ann["nationality"] | ||||
|         output = "".join(output) | ||||
|         prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) | ||||
|         output_ids = [self.tokenizer.bos_token_id] + self.tokenizer.encode(output, add_special_tokens=False) + [self.tokenizer.eos_token_id] | ||||
|         example = torch.tensor( | ||||
|             prompt_ids + output_ids, dtype=torch.int64 | ||||
|         ) | ||||
|         labels = copy.deepcopy(example) | ||||
|         labels[:len(prompt_ids)] = -1 | ||||
|         example_mask = example.ge(0) | ||||
|         label_mask = labels.ge(0) | ||||
|         example[~example_mask] = 0 | ||||
|         labels[~label_mask] = IGNORE_INDEX | ||||
|          | ||||
|         assert len(example) == len(labels) | ||||
|  | ||||
|         return { | ||||
|             "input_ids": example.tolist(), | ||||
|             "labels": labels.tolist(), | ||||
|             "attention_mask":example_mask.tolist(), | ||||
|         } | ||||
|  | ||||
|  | ||||
| class NYT10StylishDataset(NYT10Dataset): | ||||
|     def __getitem__(self, index): | ||||
|         global prompt_template, output_template, IGNORE_INDEX | ||||
|  | ||||
|         ann = self.ann[index] | ||||
|         prompt = prompt_template.format(text=ann["text"]) | ||||
|         prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) | ||||
|  | ||||
|         example = copy.deepcopy(prompt_ids) + [self.tokenizer.bos_token_id] | ||||
|         # prompt part is masked | ||||
|         labels = [-1] * len(prompt_ids) + [self.tokenizer.bos_token_id] | ||||
|  | ||||
|         for idx, s in enumerate(output_template): | ||||
|             # person and nationality are masked | ||||
|             if idx == person_index or idx == nationality_index: | ||||
|                 tokens = self.tokenizer.encode(ann["person"] if idx == person_index else ann["nationality"], add_special_tokens=False) | ||||
|                 example.extend(tokens) | ||||
|                 labels.extend([-1] * len(tokens)) | ||||
|             else: | ||||
|                 tokens = self.tokenizer.encode(s, add_special_tokens=False) | ||||
|                 example.extend(tokens) | ||||
|                 labels.extend(tokens) | ||||
|         example.append(self.tokenizer.eos_token_id) | ||||
|         example = torch.tensor( | ||||
|             example, dtype=torch.int64 | ||||
|         ) | ||||
|         labels.append(self.tokenizer.eos_token_id) | ||||
|         labels = torch.tensor( | ||||
|             labels, dtype=torch.int64 | ||||
|         ) | ||||
|         example_mask = example.ge(0) | ||||
|         label_mask = labels.ge(0) | ||||
|         example[~example_mask] = 0 | ||||
|         labels[~label_mask] = IGNORE_INDEX | ||||
|  | ||||
|         assert len(example) == len(labels) | ||||
|  | ||||
|         return { | ||||
|             "input_ids": example.tolist(), | ||||
|             "labels": labels.tolist(), | ||||
|             "attention_mask":example_mask.tolist(), | ||||
|         } | ||||
		Reference in New Issue
	
	Block a user