init🎉:
This commit is contained in:
925
train/trainers/fsdp_trainer.py
Normal file
925
train/trainers/fsdp_trainer.py
Normal file
@ -0,0 +1,925 @@
|
||||
import sys
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import transformers
|
||||
from transformers.trainer import *
|
||||
|
||||
from .ffd_sampler import FFDDistributedBatchSampler
|
||||
from .utils import ExtendedFSDPOption, enable_low_gpu_full_post_state_dict_hook
|
||||
|
||||
|
||||
class FSDPTrainer(transformers.Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
):
|
||||
if args is None:
|
||||
output_dir = "tmp_trainer"
|
||||
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
|
||||
args = TrainingArguments(output_dir=output_dir)
|
||||
self.args = args
|
||||
# Seed must be set before instantiating the model when using model
|
||||
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||
self.hp_name = None
|
||||
self.deepspeed = None
|
||||
self.is_in_train = False
|
||||
|
||||
self.create_accelerator_and_postprocess()
|
||||
|
||||
# memory metrics - must set up as early as possible
|
||||
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||
self._memory_tracker.start()
|
||||
|
||||
# set the correct log level depending on the node
|
||||
log_level = args.get_process_log_level()
|
||||
logging.set_verbosity(log_level)
|
||||
|
||||
# force device and distributed setup init explicitly
|
||||
args._setup_devices
|
||||
|
||||
if model is None:
|
||||
if model_init is not None:
|
||||
self.model_init = model_init
|
||||
model = self.call_model_init()
|
||||
else:
|
||||
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
|
||||
else:
|
||||
if model_init is not None:
|
||||
warnings.warn(
|
||||
"`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
|
||||
" overwrite your model when calling the `train` method. This will become a fatal error in the next"
|
||||
" release.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.model_init = model_init
|
||||
|
||||
if model.__class__.__name__ in MODEL_MAPPING_NAMES:
|
||||
raise ValueError(
|
||||
f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
|
||||
"computes hidden states and does not accept any labels. You should choose a model with a head "
|
||||
"suitable for your task like any of the `AutoModelForXxx` listed at "
|
||||
"https://huggingface.co/docs/transformers/model_doc/auto."
|
||||
)
|
||||
|
||||
if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
|
||||
self.is_model_parallel = True
|
||||
else:
|
||||
self.is_model_parallel = False
|
||||
|
||||
if getattr(model, "hf_device_map", None) is not None:
|
||||
devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
|
||||
if len(devices) > 1:
|
||||
self.is_model_parallel = True
|
||||
else:
|
||||
self.is_model_parallel = self.args.device != torch.device(devices[0])
|
||||
|
||||
# warn users
|
||||
logger.info(
|
||||
"You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
|
||||
" to `True` to avoid any unexpected behavior such as device placement mismatching."
|
||||
)
|
||||
|
||||
# At this stage the model is already loaded
|
||||
if getattr(model, "is_quantized", False):
|
||||
if getattr(model, "_is_quantized_training_enabled", False):
|
||||
logger.info(
|
||||
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
|
||||
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
|
||||
" check "
|
||||
" the examples in https://github.com/huggingface/peft for more details."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
|
||||
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
|
||||
)
|
||||
|
||||
# Setup Sharded DDP training
|
||||
self.sharded_ddp = None
|
||||
if len(args.sharded_ddp) > 0:
|
||||
if self.is_deepspeed_enabled:
|
||||
raise ValueError(
|
||||
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||
)
|
||||
if len(args.fsdp) > 0:
|
||||
raise ValueError(
|
||||
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
|
||||
)
|
||||
if args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||
elif not is_fairscale_available():
|
||||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
||||
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
|
||||
raise ImportError(
|
||||
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
|
||||
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
|
||||
)
|
||||
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.SIMPLE
|
||||
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
|
||||
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
|
||||
|
||||
self.fsdp = None
|
||||
if len(args.fsdp) > 0:
|
||||
if self.is_deepspeed_enabled:
|
||||
raise ValueError(
|
||||
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||
)
|
||||
if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise ValueError("Using fsdp only works in distributed training.")
|
||||
|
||||
# dep_version_check("torch>=1.12.0")
|
||||
# Would have to update setup.py with torch>=1.12.0
|
||||
# which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
|
||||
# below is the current alternative.
|
||||
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
|
||||
raise ValueError("FSDP requires PyTorch >= 1.12.0")
|
||||
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
|
||||
|
||||
if ExtendedFSDPOption.FULL_SHARD in args.fsdp:
|
||||
self.fsdp = ShardingStrategy.FULL_SHARD
|
||||
elif ExtendedFSDPOption.SHARD_GRAD_OP in args.fsdp:
|
||||
self.fsdp = ShardingStrategy.SHARD_GRAD_OP
|
||||
elif ExtendedFSDPOption.NO_SHARD in args.fsdp:
|
||||
self.fsdp = ShardingStrategy.NO_SHARD
|
||||
# extention starts here
|
||||
elif ExtendedFSDPOption.HYBRID_SHARD in args.fsdp:
|
||||
self.fsdp = ShardingStrategy.HYBRID_SHARD
|
||||
elif ExtendedFSDPOption._HYBRID_SHARD_ZERO2 in args.fsdp:
|
||||
self.fsdp = ShardingStrategy._HYBRID_SHARD_ZERO2
|
||||
# extention ends here
|
||||
|
||||
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
|
||||
# modification starts here
|
||||
if self.args.fsdp_config.get("fsdp_backward_prefetch", "") == "backward_post":
|
||||
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
|
||||
# modification ends here
|
||||
|
||||
self.forward_prefetch = False
|
||||
# modification starts here
|
||||
if self.args.fsdp_config.get("forward_prefetch", False):
|
||||
# modification ends here
|
||||
self.forward_prefetch = True
|
||||
|
||||
self.limit_all_gathers = False
|
||||
if self.args.fsdp_config.get("limit_all_gathers", False):
|
||||
self.limit_all_gathers = True
|
||||
|
||||
# one place to sort out whether to place the model on device or not
|
||||
# postpone switching model to cuda when:
|
||||
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
||||
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
||||
# and we only use deepspeed for training at the moment
|
||||
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
|
||||
# 4. Sharded DDP - same as MP
|
||||
# 5. FSDP - same as MP
|
||||
self.place_model_on_device = args.place_model_on_device
|
||||
if (
|
||||
self.is_model_parallel
|
||||
or self.is_deepspeed_enabled
|
||||
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
|
||||
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
||||
or (self.fsdp is not None)
|
||||
or self.is_fsdp_enabled
|
||||
):
|
||||
self.place_model_on_device = False
|
||||
|
||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Quantized models doesn't support `.to` operation.
|
||||
if self.place_model_on_device and not getattr(model, "is_quantized", False):
|
||||
self._move_model_to_device(model, args.device)
|
||||
|
||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||
if self.is_model_parallel:
|
||||
self.args._n_gpu = 1
|
||||
|
||||
# later use `self.model is self.model_wrapped` to check if it's wrapped or not
|
||||
self.model_wrapped = model
|
||||
self.model = model
|
||||
|
||||
self.compute_metrics = compute_metrics
|
||||
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
|
||||
self.optimizer, self.lr_scheduler = optimizers
|
||||
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
||||
raise RuntimeError(
|
||||
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
if is_torch_tpu_available() and self.optimizer is not None:
|
||||
for param in self.model.parameters():
|
||||
model_device = param.device
|
||||
break
|
||||
for param_group in self.optimizer.param_groups:
|
||||
if len(param_group["params"]) > 0:
|
||||
optimizer_device = param_group["params"][0].device
|
||||
break
|
||||
if model_device != optimizer_device:
|
||||
raise ValueError(
|
||||
"The model and the optimizer parameters are not on the same device, which probably means you"
|
||||
" created an optimizer around your model **before** putting on the device and passing it to the"
|
||||
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
||||
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
||||
)
|
||||
if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
|
||||
self.optimizer is not None or self.lr_scheduler is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
||||
self.callback_handler = CallbackHandler(
|
||||
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
||||
|
||||
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
||||
self._loggers_initialized = False
|
||||
|
||||
# Create clone of distant repo and output directory if needed
|
||||
if self.args.push_to_hub:
|
||||
self.init_git_repo(at_init=True)
|
||||
# In case of pull, we need to make sure every process has the latest.
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("init git repo")
|
||||
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
dist.barrier()
|
||||
|
||||
if self.args.should_save:
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
|
||||
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
||||
raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
|
||||
|
||||
if args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
|
||||
raise ValueError(
|
||||
"The train_dataset does not implement __len__, max_steps has to be specified. "
|
||||
"The number of steps needs to be known in advance for the learning rate scheduler."
|
||||
)
|
||||
|
||||
if (
|
||||
train_dataset is not None
|
||||
and isinstance(train_dataset, torch.utils.data.IterableDataset)
|
||||
and args.group_by_length
|
||||
):
|
||||
raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")
|
||||
|
||||
self._signature_columns = None
|
||||
|
||||
# Mixed precision setup
|
||||
self.use_apex = False
|
||||
self.use_cuda_amp = False
|
||||
self.use_cpu_amp = False
|
||||
|
||||
# Mixed precision setup for SageMaker Model Parallel
|
||||
if is_sagemaker_mp_enabled():
|
||||
# BF16 + model parallelism in SageMaker: currently not supported, raise an error
|
||||
if args.bf16:
|
||||
raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
|
||||
|
||||
if IS_SAGEMAKER_MP_POST_1_10:
|
||||
# When there's mismatch between SMP config and trainer argument, use SMP config as truth
|
||||
if args.fp16 != smp.state.cfg.fp16:
|
||||
logger.warning(
|
||||
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
|
||||
f"but FP16 provided in trainer argument is {args.fp16},"
|
||||
f"setting to {smp.state.cfg.fp16}"
|
||||
)
|
||||
args.fp16 = smp.state.cfg.fp16
|
||||
else:
|
||||
# smp < 1.10 does not support fp16 in trainer.
|
||||
if hasattr(smp.state.cfg, "fp16"):
|
||||
logger.warning(
|
||||
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
|
||||
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
|
||||
)
|
||||
|
||||
if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
|
||||
if args.half_precision_backend == "auto":
|
||||
if args.device == torch.device("cpu"):
|
||||
if args.fp16:
|
||||
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
|
||||
else:
|
||||
args.half_precision_backend = "cpu_amp"
|
||||
else:
|
||||
args.half_precision_backend = "cuda_amp"
|
||||
|
||||
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||
|
||||
self.do_grad_scaling = False
|
||||
if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
|
||||
# deepspeed and SageMaker Model Parallel manage their own half precision
|
||||
if self.sharded_ddp is not None:
|
||||
if args.half_precision_backend == "cuda_amp":
|
||||
self.use_cuda_amp = True
|
||||
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
||||
# bf16 does not need grad scaling
|
||||
self.do_grad_scaling = self.amp_dtype == torch.float16
|
||||
if self.do_grad_scaling:
|
||||
if self.sharded_ddp is not None:
|
||||
self.scaler = ShardedGradScaler()
|
||||
elif self.fsdp is not None:
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import (
|
||||
ShardedGradScaler as FSDPShardedGradScaler,
|
||||
)
|
||||
|
||||
self.scaler = FSDPShardedGradScaler()
|
||||
elif is_torch_tpu_available():
|
||||
from torch_xla.amp import GradScaler
|
||||
|
||||
self.scaler = GradScaler()
|
||||
else:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
elif args.half_precision_backend == "cpu_amp":
|
||||
self.use_cpu_amp = True
|
||||
self.amp_dtype = torch.bfloat16
|
||||
elif args.half_precision_backend == "apex":
|
||||
if not is_apex_available():
|
||||
raise ImportError(
|
||||
"Using FP16 with APEX but APEX is not installed, please refer to"
|
||||
" https://www.github.com/nvidia/apex."
|
||||
)
|
||||
self.use_apex = True
|
||||
|
||||
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
|
||||
if (
|
||||
is_sagemaker_mp_enabled()
|
||||
and self.use_cuda_amp
|
||||
and args.max_grad_norm is not None
|
||||
and args.max_grad_norm > 0
|
||||
):
|
||||
raise ValueError(
|
||||
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
|
||||
"along 'max_grad_norm': 0 in your hyperparameters."
|
||||
)
|
||||
|
||||
# Label smoothing
|
||||
if self.args.label_smoothing_factor != 0:
|
||||
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
|
||||
else:
|
||||
self.label_smoother = None
|
||||
|
||||
self.state = TrainerState(
|
||||
is_local_process_zero=self.is_local_process_zero(),
|
||||
is_world_process_zero=self.is_world_process_zero(),
|
||||
)
|
||||
|
||||
self.control = TrainerControl()
|
||||
# Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
|
||||
# returned to 0 every time flos need to be logged
|
||||
self.current_flos = 0
|
||||
self.hp_search_backend = None
|
||||
self.use_tune_checkpoints = False
|
||||
default_label_names = find_labels(self.model.__class__)
|
||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||
self.can_return_loss = can_return_loss(self.model.__class__)
|
||||
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
|
||||
|
||||
# Internal variables to help with automatic batch size reduction
|
||||
self._train_batch_size = args.train_batch_size
|
||||
self._created_lr_scheduler = False
|
||||
|
||||
# very last
|
||||
self._memory_tracker.stop_and_update_metrics()
|
||||
|
||||
# torch.compile
|
||||
if args.torch_compile and not is_torch_compile_available():
|
||||
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
|
||||
|
||||
# finally applying `low_gpu_full_post_state_dict_hook`` for fsdp `state_dict`
|
||||
enable_low_gpu_full_post_state_dict_hook()
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.use_ipex:
|
||||
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
|
||||
model = self.ipex_optimize_model(model, training, dtype=dtype)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
# Wrapping the base model twice in a DistributedModel will raise an error.
|
||||
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
||||
return self.model_wrapped
|
||||
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
||||
|
||||
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
|
||||
if unwrap_model(model) is not model:
|
||||
return model
|
||||
|
||||
# Mixed precision training with apex (torch < 1.6)
|
||||
if self.use_apex and training:
|
||||
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
||||
|
||||
# Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
|
||||
if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
if self.args.jit_mode_eval:
|
||||
start_time = time.time()
|
||||
model = self.torch_jit_model_eval(model, dataloader, training)
|
||||
self.jit_compilation_time = round(time.time() - start_time, 4)
|
||||
|
||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||
if not training:
|
||||
return model
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if self.sharded_ddp is not None:
|
||||
# Sharded DDP!
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
model = ShardedDDP(model, self.optimizer)
|
||||
else:
|
||||
mixed_precision = self.args.fp16 or self.args.bf16
|
||||
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
|
||||
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
|
||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
|
||||
model = auto_wrap(model)
|
||||
self.model = model = FullyShardedDDP(
|
||||
model,
|
||||
mixed_precision=mixed_precision,
|
||||
reshard_after_forward=zero_3,
|
||||
cpu_offload=cpu_offload,
|
||||
).to(self.args.device)
|
||||
# Distributed training using PyTorch FSDP
|
||||
elif self.fsdp is not None:
|
||||
# fix starts here
|
||||
if not self.args.fsdp_config["xla"]:
|
||||
# PyTorch FSDP!
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||
|
||||
if FSDPOption.OFFLOAD in self.args.fsdp:
|
||||
cpu_offload = CPUOffload(offload_params=True)
|
||||
else:
|
||||
cpu_offload = CPUOffload(offload_params=False)
|
||||
|
||||
auto_wrap_policy = None
|
||||
|
||||
if FSDPOption.AUTO_WRAP in self.args.fsdp:
|
||||
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||
)
|
||||
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||
transformer_cls_to_wrap = set()
|
||||
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
||||
transformer_cls = get_module_class_from_name(model, layer_class)
|
||||
if transformer_cls is None:
|
||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||
else:
|
||||
transformer_cls_to_wrap.add(transformer_cls)
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
# Transformer layer class to wrap
|
||||
transformer_layer_cls=transformer_cls_to_wrap,
|
||||
)
|
||||
mixed_precision_policy = None
|
||||
dtype = None
|
||||
if self.args.fp16:
|
||||
dtype = torch.float16
|
||||
elif self.args.bf16:
|
||||
dtype = torch.bfloat16
|
||||
if dtype is not None:
|
||||
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
|
||||
if type(model) != FSDP:
|
||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||
signature = inspect.signature(FSDP.__init__).parameters.keys()
|
||||
kwargs = {}
|
||||
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
|
||||
if arg in signature:
|
||||
kwargs[arg] = getattr(self, arg)
|
||||
self.model = model = FSDP(
|
||||
model,
|
||||
sharding_strategy=self.fsdp,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
mixed_precision=mixed_precision_policy,
|
||||
device_id=self.args.device,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for submodule in traversal_utils._get_fsdp_states(model):
|
||||
print(submodule._state_dict_type, submodule._state_dict_config)
|
||||
break
|
||||
# fix ends here
|
||||
else:
|
||||
try:
|
||||
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
|
||||
from torch_xla.distributed.fsdp import checkpoint_module
|
||||
from torch_xla.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
|
||||
auto_wrap_policy = None
|
||||
auto_wrapper_callable = None
|
||||
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||
)
|
||||
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||
transformer_cls_to_wrap = set()
|
||||
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
||||
transformer_cls = get_module_class_from_name(model, layer_class)
|
||||
if transformer_cls is None:
|
||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||
else:
|
||||
transformer_cls_to_wrap.add(transformer_cls)
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
# Transformer layer class to wrap
|
||||
transformer_layer_cls=transformer_cls_to_wrap,
|
||||
)
|
||||
fsdp_kwargs = self.args.xla_fsdp_config
|
||||
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
|
||||
def auto_wrapper_callable(m, *args, **kwargs):
|
||||
return FSDP(checkpoint_module(m), *args, **kwargs)
|
||||
|
||||
# Wrap the base model with an outer FSDP wrapper
|
||||
self.model = model = FSDP(
|
||||
model,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
auto_wrapper_callable=auto_wrapper_callable,
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
|
||||
# Patch `xm.optimizer_step` should not reduce gradients in this case,
|
||||
# as FSDP does not need gradient reduction over sharded parameters.
|
||||
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
|
||||
loss = optimizer.step(**optimizer_args)
|
||||
if barrier:
|
||||
xm.mark_step()
|
||||
return loss
|
||||
|
||||
xm.optimizer_step = patched_optimizer_step
|
||||
elif is_sagemaker_dp_enabled():
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||
)
|
||||
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if is_torch_neuroncore_available():
|
||||
return model
|
||||
kwargs = {}
|
||||
if self.args.ddp_find_unused_parameters is not None:
|
||||
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
|
||||
elif isinstance(model, PreTrainedModel):
|
||||
# find_unused_parameters breaks checkpointing as per
|
||||
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
|
||||
kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
|
||||
else:
|
||||
kwargs["find_unused_parameters"] = True
|
||||
|
||||
if self.args.ddp_bucket_cap_mb is not None:
|
||||
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
|
||||
|
||||
if self.args.ddp_broadcast_buffers is not None:
|
||||
kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
|
||||
|
||||
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
|
||||
|
||||
return model
|
||||
|
||||
def get_batch_sampler(self, dataset=None):
|
||||
if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1:
|
||||
dataset = dataset if dataset is not None else self.train_dataset
|
||||
try:
|
||||
batch_max_len = self.args.per_device_train_batch_size * unwrap_model(self.model).model_avg_context
|
||||
except:
|
||||
# raise RuntimeError("group_by_length in distributed training requires model has attr `model_max_context`")
|
||||
batch_max_len = self.args.per_device_train_batch_size * self.args.model_avg_context
|
||||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||
lengths = LengthGroupedSampler(
|
||||
batch_size=-1, # we just want to know about the lengths of the dataset so no need to pass `batch_size`
|
||||
dataset=dataset,
|
||||
model_input_name=model_input_name
|
||||
).lengths
|
||||
seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
|
||||
batch_sampler = FFDDistributedBatchSampler(
|
||||
batch_max_length=batch_max_len,
|
||||
lengths=np.array(lengths),
|
||||
seed=seed
|
||||
)
|
||||
|
||||
return batch_sampler
|
||||
|
||||
return None
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1:
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
||||
|
||||
batch_sampler = self.get_batch_sampler(train_dataset)
|
||||
|
||||
dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
collate_fn=data_collator
|
||||
)
|
||||
# return self.accelerator.prepare(dataloader)
|
||||
return dataloader
|
||||
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
||||
"""
|
||||
Will save the model, so you can reload it using `from_pretrained()`.
|
||||
|
||||
Will only save from the main process.
|
||||
"""
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = self.args.output_dir
|
||||
|
||||
if is_torch_tpu_available():
|
||||
self._save_tpu(output_dir)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
# Calling the state_dict needs to be done on the wrapped model and on all processes.
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
state_dict = self.model_wrapped.state_dict()
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
if IS_SAGEMAKER_MP_POST_1_10:
|
||||
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
|
||||
Path(os.path.join(output_dir, "user_content.pt")).touch()
|
||||
elif (
|
||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
|
||||
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||
or self.fsdp is not None
|
||||
or self.is_fsdp_enabled
|
||||
):
|
||||
state_dict = self.model.state_dict()
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
# modification starts here
|
||||
if self.is_fsdp_enabled and self.args.save_with_fsdp:
|
||||
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
|
||||
# modification ends here
|
||||
|
||||
elif self.is_deepspeed_enabled:
|
||||
# this takes care of everything as long as we aren't under zero3
|
||||
if version.parse(accelerate_version) <= version.parse("0.20.3"):
|
||||
raise ValueError("Install Accelerate from main branch")
|
||||
try:
|
||||
state_dict = self.accelerator.get_state_dict(self.deepspeed)
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
||||
" zero_to_fp32.py to recover weights"
|
||||
)
|
||||
self.model_wrapped.save_checkpoint(output_dir)
|
||||
|
||||
elif self.args.should_save:
|
||||
self._save(output_dir)
|
||||
|
||||
# Push to the Hub when `save_model` is called by the user.
|
||||
if self.args.push_to_hub and not _internal_call:
|
||||
self.push_to_hub(commit_message="Model save")
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||
# want to save except FullyShardedDDP.
|
||||
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
|
||||
|
||||
# Save model checkpoint
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
||||
if self.hp_search_backend is None and trial is None:
|
||||
self.store_flos()
|
||||
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
self.save_model(output_dir, _internal_call=True)
|
||||
if self.is_deepspeed_enabled:
|
||||
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
|
||||
# config `stage3_gather_16bit_weights_on_model_save` is True
|
||||
self.model_wrapped.save_checkpoint(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
self.optimizer.consolidate_state_dict()
|
||||
|
||||
if self.fsdp or self.is_fsdp_enabled:
|
||||
if self.is_fsdp_enabled:
|
||||
# modification starts here
|
||||
if self.args.save_with_fsdp:
|
||||
save_fsdp_optimizer(
|
||||
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
|
||||
)
|
||||
# modification ends here
|
||||
else:
|
||||
# FSDP has a different interface for saving optimizer states.
|
||||
# Needs to be called on all ranks to gather all states.
|
||||
# full_optim_state_dict will be deprecated after Pytorch 2.2!
|
||||
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
|
||||
smp.barrier()
|
||||
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
|
||||
smp.save(
|
||||
opt_state_dict,
|
||||
os.path.join(output_dir, OPTIMIZER_NAME),
|
||||
partial=True,
|
||||
v3=smp.state.cfg.shard_optimizer_state,
|
||||
)
|
||||
if self.args.should_save:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.do_grad_scaling:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
elif self.args.should_save and not self.is_deepspeed_enabled:
|
||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||
if self.fsdp and not self.is_fsdp_enabled:
|
||||
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
else:
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.do_grad_scaling:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||
metric_to_check = self.args.metric_for_best_model
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
metric_value = metrics[metric_to_check]
|
||||
|
||||
operator = np.greater if self.args.greater_is_better else np.less
|
||||
if (
|
||||
self.state.best_metric is None
|
||||
or self.state.best_model_checkpoint is None
|
||||
or operator(metric_value, self.state.best_metric)
|
||||
):
|
||||
self.state.best_metric = metric_value
|
||||
self.state.best_model_checkpoint = output_dir
|
||||
|
||||
# Save the Trainer state
|
||||
if self.args.should_save:
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
|
||||
# Save RNG state in non-distributed training
|
||||
rng_states = {
|
||||
"python": random.getstate(),
|
||||
"numpy": np.random.get_state(),
|
||||
"cpu": torch.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
|
||||
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
|
||||
else:
|
||||
rng_states["cuda"] = torch.cuda.random.get_rng_state()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
rng_states["xla"] = xm.get_rng_state()
|
||||
|
||||
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
|
||||
# not yet exist.
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if self.args.world_size <= 1:
|
||||
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
|
||||
else:
|
||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
|
||||
|
||||
if self.args.push_to_hub:
|
||||
self._push_from_checkpoint(output_dir)
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.args.should_save:
|
||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||
|
||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||
"""If optimizer and scheduler states exist, load them."""
|
||||
if checkpoint is None:
|
||||
return
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||
return
|
||||
|
||||
checkpoint_file_exists = (
|
||||
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
|
||||
if is_sagemaker_mp_enabled()
|
||||
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
|
||||
)
|
||||
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
|
||||
# Load in optimizer and scheduler states
|
||||
if is_torch_tpu_available():
|
||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
|
||||
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||
xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
|
||||
|
||||
self.optimizer.load_state_dict(optimizer_state)
|
||||
self.lr_scheduler.load_state_dict(lr_scheduler_state)
|
||||
else:
|
||||
if is_sagemaker_mp_enabled():
|
||||
if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
|
||||
# Optimizer checkpoint was saved with smp >= 1.10
|
||||
def opt_load_hook(mod, opt):
|
||||
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
|
||||
|
||||
else:
|
||||
# Optimizer checkpoint was saved with smp < 1.10
|
||||
def opt_load_hook(mod, opt):
|
||||
if IS_SAGEMAKER_MP_POST_1_10:
|
||||
opt.load_state_dict(
|
||||
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
|
||||
)
|
||||
else:
|
||||
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
|
||||
|
||||
self.model_wrapped.register_post_step_hook(opt_load_hook)
|
||||
else:
|
||||
# We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
|
||||
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
|
||||
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
|
||||
map_location = self.args.device if self.args.world_size > 1 else "cpu"
|
||||
if self.fsdp or self.is_fsdp_enabled:
|
||||
# modification starts here
|
||||
if self.is_fsdp_enabled and self.args.save_with_fsdp:
|
||||
load_fsdp_optimizer(
|
||||
self.accelerator.state.fsdp_plugin,
|
||||
self.accelerator,
|
||||
self.optimizer,
|
||||
self.model,
|
||||
checkpoint,
|
||||
)
|
||||
elif not self.is_fsdp_enabled:
|
||||
full_osd = None
|
||||
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
|
||||
if self.args.process_index == 0:
|
||||
full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
|
||||
# call scatter_full_optim_state_dict on all ranks
|
||||
sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
|
||||
self.optimizer.load_state_dict(sharded_osd)
|
||||
else:
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||
)
|
||||
# modification ends here
|
||||
else:
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
|
||||
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
||||
|
||||
Reference in New Issue
Block a user