import sys import os from typing import Optional import torch from torch.utils.data import DataLoader import transformers from transformers.trainer import * from .ffd_sampler import FFDDistributedBatchSampler from .utils import ExtendedFSDPOption, enable_low_gpu_full_post_state_dict_hook class FSDPTrainer(transformers.Trainer): def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ): if args is None: output_dir = "tmp_trainer" logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") args = TrainingArguments(output_dir=output_dir) self.args = args # Seed must be set before instantiating the model when using model enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.hp_name = None self.deepspeed = None self.is_in_train = False self.create_accelerator_and_postprocess() # memory metrics - must set up as early as possible self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker.start() # set the correct log level depending on the node log_level = args.get_process_log_level() logging.set_verbosity(log_level) # force device and distributed setup init explicitly args._setup_devices if model is None: if model_init is not None: self.model_init = model_init model = self.call_model_init() else: raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") else: if model_init is not None: warnings.warn( "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" " overwrite your model when calling the `train` method. This will become a fatal error in the next" " release.", FutureWarning, ) self.model_init = model_init if model.__class__.__name__ in MODEL_MAPPING_NAMES: raise ValueError( f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " "computes hidden states and does not accept any labels. You should choose a model with a head " "suitable for your task like any of the `AutoModelForXxx` listed at " "https://huggingface.co/docs/transformers/model_doc/auto." ) if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: self.is_model_parallel = True else: self.is_model_parallel = False if getattr(model, "hf_device_map", None) is not None: devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] if len(devices) > 1: self.is_model_parallel = True else: self.is_model_parallel = self.args.device != torch.device(devices[0]) # warn users logger.info( "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" " to `True` to avoid any unexpected behavior such as device placement mismatching." ) # At this stage the model is already loaded if getattr(model, "is_quantized", False): if getattr(model, "_is_quantized_training_enabled", False): logger.info( "The model is loaded in 8-bit precision. To train this model you need to add additional modules" " inside the model such as adapters using `peft` library and freeze the model weights. Please" " check " " the examples in https://github.com/huggingface/peft for more details." ) else: raise ValueError( "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " ) # Setup Sharded DDP training self.sharded_ddp = None if len(args.sharded_ddp) > 0: if self.is_deepspeed_enabled: raise ValueError( "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." ) if len(args.fsdp) > 0: raise ValueError( "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." ) if args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using sharded DDP only works in distributed training.") elif not is_fairscale_available(): raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: raise ImportError( "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." ) elif ShardedDDPOption.SIMPLE in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.SIMPLE elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 self.fsdp = None if len(args.fsdp) > 0: if self.is_deepspeed_enabled: raise ValueError( "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." ) if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using fsdp only works in distributed training.") # dep_version_check("torch>=1.12.0") # Would have to update setup.py with torch>=1.12.0 # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 # below is the current alternative. if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): raise ValueError("FSDP requires PyTorch >= 1.12.0") from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy if ExtendedFSDPOption.FULL_SHARD in args.fsdp: self.fsdp = ShardingStrategy.FULL_SHARD elif ExtendedFSDPOption.SHARD_GRAD_OP in args.fsdp: self.fsdp = ShardingStrategy.SHARD_GRAD_OP elif ExtendedFSDPOption.NO_SHARD in args.fsdp: self.fsdp = ShardingStrategy.NO_SHARD # extention starts here elif ExtendedFSDPOption.HYBRID_SHARD in args.fsdp: self.fsdp = ShardingStrategy.HYBRID_SHARD elif ExtendedFSDPOption._HYBRID_SHARD_ZERO2 in args.fsdp: self.fsdp = ShardingStrategy._HYBRID_SHARD_ZERO2 # extention ends here self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE # modification starts here if self.args.fsdp_config.get("fsdp_backward_prefetch", "") == "backward_post": self.backward_prefetch = BackwardPrefetch.BACKWARD_POST # modification ends here self.forward_prefetch = False # modification starts here if self.args.fsdp_config.get("forward_prefetch", False): # modification ends here self.forward_prefetch = True self.limit_all_gathers = False if self.args.fsdp_config.get("limit_all_gathers", False): self.limit_all_gathers = True # one place to sort out whether to place the model on device or not # postpone switching model to cuda when: # 1. MP - since we are trying to fit a much bigger than 1 gpu model # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # and we only use deepspeed for training at the moment # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first # 4. Sharded DDP - same as MP # 5. FSDP - same as MP self.place_model_on_device = args.place_model_on_device if ( self.is_model_parallel or self.is_deepspeed_enabled or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) or (self.fsdp is not None) or self.is_fsdp_enabled ): self.place_model_on_device = False default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer # Quantized models doesn't support `.to` operation. if self.place_model_on_device and not getattr(model, "is_quantized", False): self._move_model_to_device(model, args.device) # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs if self.is_model_parallel: self.args._n_gpu = 1 # later use `self.model is self.model_wrapped` to check if it's wrapped or not self.model_wrapped = model self.model = model self.compute_metrics = compute_metrics self.preprocess_logits_for_metrics = preprocess_logits_for_metrics self.optimizer, self.lr_scheduler = optimizers if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): raise RuntimeError( "Passing a `model_init` is incompatible with providing the `optimizers` argument. " "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) if is_torch_tpu_available() and self.optimizer is not None: for param in self.model.parameters(): model_device = param.device break for param_group in self.optimizer.param_groups: if len(param_group["params"]) > 0: optimizer_device = param_group["params"][0].device break if model_device != optimizer_device: raise ValueError( "The model and the optimizer parameters are not on the same device, which probably means you" " created an optimizer around your model **before** putting on the device and passing it to the" " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( self.optimizer is not None or self.lr_scheduler is not None ): raise RuntimeError( "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. self._loggers_initialized = False # Create clone of distant repo and output directory if needed if self.args.push_to_hub: self.init_git_repo(at_init=True) # In case of pull, we need to make sure every process has the latest. if is_torch_tpu_available(): xm.rendezvous("init git repo") elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() if self.args.should_save: os.makedirs(self.args.output_dir, exist_ok=True) if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") if args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: raise ValueError( "The train_dataset does not implement __len__, max_steps has to be specified. " "The number of steps needs to be known in advance for the learning rate scheduler." ) if ( train_dataset is not None and isinstance(train_dataset, torch.utils.data.IterableDataset) and args.group_by_length ): raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") self._signature_columns = None # Mixed precision setup self.use_apex = False self.use_cuda_amp = False self.use_cpu_amp = False # Mixed precision setup for SageMaker Model Parallel if is_sagemaker_mp_enabled(): # BF16 + model parallelism in SageMaker: currently not supported, raise an error if args.bf16: raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") if IS_SAGEMAKER_MP_POST_1_10: # When there's mismatch between SMP config and trainer argument, use SMP config as truth if args.fp16 != smp.state.cfg.fp16: logger.warning( f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," f"but FP16 provided in trainer argument is {args.fp16}," f"setting to {smp.state.cfg.fp16}" ) args.fp16 = smp.state.cfg.fp16 else: # smp < 1.10 does not support fp16 in trainer. if hasattr(smp.state.cfg, "fp16"): logger.warning( f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." ) if (args.fp16 or args.bf16) and self.sharded_ddp is not None: if args.half_precision_backend == "auto": if args.device == torch.device("cpu"): if args.fp16: raise ValueError("Tried to use `fp16` but it is not supported on cpu") else: args.half_precision_backend = "cpu_amp" else: args.half_precision_backend = "cuda_amp" logger.info(f"Using {args.half_precision_backend} half precision backend") self.do_grad_scaling = False if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): # deepspeed and SageMaker Model Parallel manage their own half precision if self.sharded_ddp is not None: if args.half_precision_backend == "cuda_amp": self.use_cuda_amp = True self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 # bf16 does not need grad scaling self.do_grad_scaling = self.amp_dtype == torch.float16 if self.do_grad_scaling: if self.sharded_ddp is not None: self.scaler = ShardedGradScaler() elif self.fsdp is not None: from torch.distributed.fsdp.sharded_grad_scaler import ( ShardedGradScaler as FSDPShardedGradScaler, ) self.scaler = FSDPShardedGradScaler() elif is_torch_tpu_available(): from torch_xla.amp import GradScaler self.scaler = GradScaler() else: self.scaler = torch.cuda.amp.GradScaler() elif args.half_precision_backend == "cpu_amp": self.use_cpu_amp = True self.amp_dtype = torch.bfloat16 elif args.half_precision_backend == "apex": if not is_apex_available(): raise ImportError( "Using FP16 with APEX but APEX is not installed, please refer to" " https://www.github.com/nvidia/apex." ) self.use_apex = True # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. if ( is_sagemaker_mp_enabled() and self.use_cuda_amp and args.max_grad_norm is not None and args.max_grad_norm > 0 ): raise ValueError( "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " "along 'max_grad_norm': 0 in your hyperparameters." ) # Label smoothing if self.args.label_smoothing_factor != 0: self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) else: self.label_smoother = None self.state = TrainerState( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), ) self.control = TrainerControl() # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then # returned to 0 every time flos need to be logged self.current_flos = 0 self.hp_search_backend = None self.use_tune_checkpoints = False default_label_names = find_labels(self.model.__class__) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.can_return_loss = can_return_loss(self.model.__class__) self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) # Internal variables to help with automatic batch size reduction self._train_batch_size = args.train_batch_size self._created_lr_scheduler = False # very last self._memory_tracker.stop_and_update_metrics() # torch.compile if args.torch_compile and not is_torch_compile_available(): raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") # finally applying `low_gpu_full_post_state_dict_hook`` for fsdp `state_dict` enable_low_gpu_full_post_state_dict_hook() def _wrap_model(self, model, training=True, dataloader=None): if self.args.use_ipex: dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 model = self.ipex_optimize_model(model, training, dtype=dtype) if is_sagemaker_mp_enabled(): # Wrapping the base model twice in a DistributedModel will raise an error. if isinstance(self.model_wrapped, smp.model.DistributedModel): return self.model_wrapped return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if unwrap_model(model) is not model: return model # Mixed precision training with apex (torch < 1.6) if self.use_apex and training: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): model = nn.DataParallel(model) if self.args.jit_mode_eval: start_time = time.time() model = self.torch_jit_model_eval(model, dataloader, training) self.jit_compilation_time = round(time.time() - start_time, 4) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. if not training: return model # Distributed training (should be after apex fp16 initialization) if self.sharded_ddp is not None: # Sharded DDP! if self.sharded_ddp == ShardedDDPOption.SIMPLE: model = ShardedDDP(model, self.optimizer) else: mixed_precision = self.args.fp16 or self.args.bf16 cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 # XXX: Breaking the self.model convention but I see no way around it for now. if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: model = auto_wrap(model) self.model = model = FullyShardedDDP( model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload, ).to(self.args.device) # Distributed training using PyTorch FSDP elif self.fsdp is not None: # fix starts here if not self.args.fsdp_config["xla"]: # PyTorch FSDP! from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy import torch.distributed.fsdp._traversal_utils as traversal_utils if FSDPOption.OFFLOAD in self.args.fsdp: cpu_offload = CPUOffload(offload_params=True) else: cpu_offload = CPUOffload(offload_params=False) auto_wrap_policy = None if FSDPOption.AUTO_WRAP in self.args.fsdp: if self.args.fsdp_config["fsdp_min_num_params"] > 0: auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] ) elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: transformer_cls_to_wrap = set() for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) mixed_precision_policy = None dtype = None if self.args.fp16: dtype = torch.float16 elif self.args.bf16: dtype = torch.bfloat16 if dtype is not None: mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) if type(model) != FSDP: # XXX: Breaking the self.model convention but I see no way around it for now. signature = inspect.signature(FSDP.__init__).parameters.keys() kwargs = {} for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]: if arg in signature: kwargs[arg] = getattr(self, arg) self.model = model = FSDP( model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy, mixed_precision=mixed_precision_policy, device_id=self.args.device, **kwargs, ) for submodule in traversal_utils._get_fsdp_states(model): print(submodule._state_dict_type, submodule._state_dict_config) break # fix ends here else: try: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.fsdp.wrap import ( size_based_auto_wrap_policy, transformer_auto_wrap_policy, ) except ImportError: raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") auto_wrap_policy = None auto_wrapper_callable = None if self.args.fsdp_config["fsdp_min_num_params"] > 0: auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] ) elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: transformer_cls_to_wrap = set() for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) fsdp_kwargs = self.args.xla_fsdp_config if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: # Apply gradient checkpointing to auto-wrapped sub-modules if specified def auto_wrapper_callable(m, *args, **kwargs): return FSDP(checkpoint_module(m), *args, **kwargs) # Wrap the base model with an outer FSDP wrapper self.model = model = FSDP( model, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable, **fsdp_kwargs, ) # Patch `xm.optimizer_step` should not reduce gradients in this case, # as FSDP does not need gradient reduction over sharded parameters. def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): loss = optimizer.step(**optimizer_args) if barrier: xm.mark_step() return loss xm.optimizer_step = patched_optimizer_step elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] ) elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: if is_torch_neuroncore_available(): return model kwargs = {} if self.args.ddp_find_unused_parameters is not None: kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing else: kwargs["find_unused_parameters"] = True if self.args.ddp_bucket_cap_mb is not None: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb if self.args.ddp_broadcast_buffers is not None: kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) return model def get_batch_sampler(self, dataset=None): if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1: dataset = dataset if dataset is not None else self.train_dataset try: batch_max_len = self.args.per_device_train_batch_size * unwrap_model(self.model).model_avg_context except: # raise RuntimeError("group_by_length in distributed training requires model has attr `model_max_context`") batch_max_len = self.args.per_device_train_batch_size * self.args.model_avg_context model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None lengths = LengthGroupedSampler( batch_size=-1, # we just want to know about the lengths of the dataset so no need to pass `batch_size` dataset=dataset, model_input_name=model_input_name ).lengths seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed batch_sampler = FFDDistributedBatchSampler( batch_max_length=batch_max_len, lengths=np.array(lengths), seed=seed ) return batch_sampler return None def get_train_dataloader(self) -> DataLoader: if self.args.use_ffd_sampler and self.args.group_by_length and self.args.world_size > 1: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset data_collator = self.data_collator if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") else: data_collator = self._get_collator_with_removed_columns(data_collator, description="training") batch_sampler = self.get_batch_sampler(train_dataset) dataloader = DataLoader( train_dataset, batch_sampler=batch_sampler, drop_last=self.args.dataloader_drop_last, collate_fn=data_collator ) # return self.accelerator.prepare(dataloader) return dataloader return super().get_train_dataloader() def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): """ Will save the model, so you can reload it using `from_pretrained()`. Will only save from the main process. """ if output_dir is None: output_dir = self.args.output_dir if is_torch_tpu_available(): self._save_tpu(output_dir) elif is_sagemaker_mp_enabled(): # Calling the state_dict needs to be done on the wrapped model and on all processes. os.makedirs(output_dir, exist_ok=True) state_dict = self.model_wrapped.state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) if IS_SAGEMAKER_MP_POST_1_10: # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 Path(os.path.join(output_dir, "user_content.pt")).touch() elif ( ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp or self.fsdp is not None or self.is_fsdp_enabled ): state_dict = self.model.state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) # modification starts here if self.is_fsdp_enabled and self.args.save_with_fsdp: save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) # modification ends here elif self.is_deepspeed_enabled: # this takes care of everything as long as we aren't under zero3 if version.parse(accelerate_version) <= version.parse("0.20.3"): raise ValueError("Install Accelerate from main branch") try: state_dict = self.accelerator.get_state_dict(self.deepspeed) if self.args.should_save: self._save(output_dir, state_dict=state_dict) except ValueError: logger.warning( " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" " zero_to_fp32.py to recover weights" ) self.model_wrapped.save_checkpoint(output_dir) elif self.args.should_save: self._save(output_dir) # Push to the Hub when `save_model` is called by the user. if self.args.push_to_hub and not _internal_call: self.push_to_hub(commit_message="Model save") def _save_checkpoint(self, model, trial, metrics=None): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" if self.hp_search_backend is None and trial is None: self.store_flos() run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) if self.is_deepspeed_enabled: # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # config `stage3_gather_16bit_weights_on_model_save` is True self.model_wrapped.save_checkpoint(output_dir) # Save optimizer and scheduler if self.sharded_ddp == ShardedDDPOption.SIMPLE: self.optimizer.consolidate_state_dict() if self.fsdp or self.is_fsdp_enabled: if self.is_fsdp_enabled: # modification starts here if self.args.save_with_fsdp: save_fsdp_optimizer( self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir ) # modification ends here else: # FSDP has a different interface for saving optimizer states. # Needs to be called on all ranks to gather all states. # full_optim_state_dict will be deprecated after Pytorch 2.2! full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) elif is_sagemaker_mp_enabled(): opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) smp.barrier() if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: smp.save( opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME), partial=True, v3=smp.state.cfg.shard_optimizer_state, ) if self.args.should_save: with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) elif self.args.should_save and not self.is_deepspeed_enabled: # deepspeed.save_checkpoint above saves model/optim/sched if self.fsdp and not self.is_fsdp_enabled: torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) else: torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" metric_value = metrics[metric_to_check] operator = np.greater if self.args.greater_is_better else np.less if ( self.state.best_metric is None or self.state.best_model_checkpoint is None or operator(metric_value, self.state.best_metric) ): self.state.best_metric = metric_value self.state.best_model_checkpoint = output_dir # Save the Trainer state if self.args.should_save: self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) # Save RNG state in non-distributed training rng_states = { "python": random.getstate(), "numpy": np.random.get_state(), "cpu": torch.random.get_rng_state(), } if torch.cuda.is_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) rng_states["cuda"] = torch.cuda.random.get_rng_state_all() else: rng_states["cuda"] = torch.cuda.random.get_rng_state() if is_torch_tpu_available(): rng_states["xla"] = xm.get_rng_state() # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may # not yet exist. os.makedirs(output_dir, exist_ok=True) if self.args.world_size <= 1: torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) else: torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) if self.args.push_to_hub: self._push_from_checkpoint(output_dir) # Maybe delete some older checkpoints. if self.args.should_save: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" if checkpoint is None: return if self.is_deepspeed_enabled: # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init return checkpoint_file_exists = ( glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") if is_sagemaker_mp_enabled() else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) ) if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_tpu_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") reissue_pt_warnings(caught_warnings) xm.send_cpu_data_to_device(optimizer_state, self.args.device) xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) self.optimizer.load_state_dict(optimizer_state) self.lr_scheduler.load_state_dict(lr_scheduler_state) else: if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): # Optimizer checkpoint was saved with smp >= 1.10 def opt_load_hook(mod, opt): opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) else: # Optimizer checkpoint was saved with smp < 1.10 def opt_load_hook(mod, opt): if IS_SAGEMAKER_MP_POST_1_10: opt.load_state_dict( smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) ) else: opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) self.model_wrapped.register_post_step_hook(opt_load_hook) else: # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more # likely to get OOM on CPU (since we load num_gpu times the optimizer state map_location = self.args.device if self.args.world_size > 1 else "cpu" if self.fsdp or self.is_fsdp_enabled: # modification starts here if self.is_fsdp_enabled and self.args.save_with_fsdp: load_fsdp_optimizer( self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, checkpoint, ) elif not self.is_fsdp_enabled: full_osd = None # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it if self.args.process_index == 0: full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) # call scatter_full_optim_state_dict on all ranks sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) self.optimizer.load_state_dict(sharded_osd) else: self.optimizer.load_state_dict( torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) ) # modification ends here else: self.optimizer.load_state_dict( torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))