import sys import os import warnings import torch import torch.utils.checkpoint as checkpoint from torch.utils.checkpoint import check_backward_validity, _get_autocast_kwargs, detach_variable from torch.distributed.fsdp import _state_dict_utils from torch.distributed.fsdp._common_utils import clean_tensor_name from transformers.utils import ExplicitEnum class ExtendedFSDPOption(ExplicitEnum): FULL_SHARD = "full_shard" SHARD_GRAD_OP = "shard_grad_op" NO_SHARD = "no_shard" # extention starts here HYBRID_SHARD = "hybrid_shard" _HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2" # extention ends here OFFLOAD = "offload" AUTO_WRAP = "auto_wrap" DefaultCheckpointFunction = checkpoint.CheckpointFunction DefaultFullPostStateDictHook = _state_dict_utils._full_post_state_dict_hook class CompatibleCheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) with torch.no_grad(): outputs = run_function(*args) return outputs @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad() or when an `inputs` parameter" " is passed to .backward(). Please use .backward() and do not pass its `inputs`" " argument.") # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] # if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: # rng_devices = ctx.fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): detached_inputs = detach_variable(tuple(inputs)) with torch.enable_grad(), \ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) # run backward() with only tensor that requires grad outputs_with_grad = [] args_with_grad = [] for i in range(len(outputs)): if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: outputs_with_grad.append(outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: raise RuntimeError( "none of output has requires_grad=True," " this checkpoint() is not necessary") torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) return (None, None) + grads def low_gpu_full_post_state_dict_hook(module, fsdp_state, state_dict, prefix): def param_hook(state_dict, prefix, fqn): clean_key = fqn clean_prefix = clean_tensor_name(prefix) # Strip prefix out of key if needed as buffer names and param names # do not have prefix considered as they are not computed in `state_dict` # call. if clean_key.startswith(clean_prefix): clean_key = clean_key[len(clean_prefix) :] # Clone parameters before exiting the `_unshard_fsdp_state_params()` context. if not getattr(state_dict[fqn], "_has_been_cloned", False): try: state_dict[fqn] = state_dict[fqn].cpu().clone().detach() state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] except BaseException as e: warnings.warn( f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " "This may mean that this state_dict entry could point to invalid " "memory regions after returning from state_dict() call if this " "parameter is managed by FSDP. Please check clone " f"implementation of {fqn}. Error: {str(e)}" ) return _state_dict_utils._common_unshard_post_state_dict_hook( module, fsdp_state, state_dict, prefix, param_hook ) # enable to efficiently saving `state_dict` for fsdp def enable_low_gpu_full_post_state_dict_hook(): _state_dict_utils._full_post_state_dict_hook = low_gpu_full_post_state_dict_hook def disable_low_gpu_full_post_state_dict_hook(): _state_dict_utils._full_post_state_dict_hook = DefaultFullPostStateDictHook # enable to make `torch.compile` work def enable_compatible_checkpoint_function(): checkpoint.CheckpointFunction = CompatibleCheckpointFunction def disable_compatible_checkpoint_function(): checkpoint.CheckpointFunction = DefaultCheckpointFunction