154 lines
6.0 KiB
Python
154 lines
6.0 KiB
Python
import sys
|
|
import os
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.utils.checkpoint as checkpoint
|
|
from torch.utils.checkpoint import check_backward_validity, _get_autocast_kwargs, detach_variable
|
|
from torch.distributed.fsdp import _state_dict_utils
|
|
from torch.distributed.fsdp._common_utils import clean_tensor_name
|
|
from transformers.utils import ExplicitEnum
|
|
|
|
|
|
class ExtendedFSDPOption(ExplicitEnum):
|
|
FULL_SHARD = "full_shard"
|
|
SHARD_GRAD_OP = "shard_grad_op"
|
|
NO_SHARD = "no_shard"
|
|
|
|
# extention starts here
|
|
HYBRID_SHARD = "hybrid_shard"
|
|
_HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
|
|
# extention ends here
|
|
|
|
OFFLOAD = "offload"
|
|
AUTO_WRAP = "auto_wrap"
|
|
|
|
|
|
DefaultCheckpointFunction = checkpoint.CheckpointFunction
|
|
DefaultFullPostStateDictHook = _state_dict_utils._full_post_state_dict_hook
|
|
|
|
|
|
class CompatibleCheckpointFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
|
check_backward_validity(args)
|
|
ctx.run_function = run_function
|
|
ctx.preserve_rng_state = preserve_rng_state
|
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
|
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
|
|
|
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
|
# to be filled out during the backward.
|
|
ctx.inputs = []
|
|
ctx.tensor_indices = []
|
|
tensor_inputs = []
|
|
for i, arg in enumerate(args):
|
|
if torch.is_tensor(arg):
|
|
tensor_inputs.append(arg)
|
|
ctx.tensor_indices.append(i)
|
|
ctx.inputs.append(None)
|
|
else:
|
|
ctx.inputs.append(arg)
|
|
|
|
ctx.save_for_backward(*tensor_inputs)
|
|
|
|
with torch.no_grad():
|
|
outputs = run_function(*args)
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
if not torch.autograd._is_checkpoint_valid():
|
|
raise RuntimeError(
|
|
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
|
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
|
" argument.")
|
|
# Copy the list to avoid modifying original list.
|
|
inputs = list(ctx.inputs)
|
|
tensor_indices = ctx.tensor_indices
|
|
tensors = ctx.saved_tensors
|
|
|
|
# Fill in inputs with appropriate saved tensors.
|
|
for i, idx in enumerate(tensor_indices):
|
|
inputs[idx] = tensors[i]
|
|
|
|
# Stash the surrounding rng state, and mimic the state that was
|
|
# present at this time during forward. Restore the surrounding state
|
|
# when we're done.
|
|
rng_devices = []
|
|
# if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
|
|
# rng_devices = ctx.fwd_gpu_devices
|
|
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
|
|
detached_inputs = detach_variable(tuple(inputs))
|
|
with torch.enable_grad(), \
|
|
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
|
|
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
|
outputs = ctx.run_function(*detached_inputs)
|
|
|
|
if isinstance(outputs, torch.Tensor):
|
|
outputs = (outputs,)
|
|
|
|
# run backward() with only tensor that requires grad
|
|
outputs_with_grad = []
|
|
args_with_grad = []
|
|
for i in range(len(outputs)):
|
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
|
outputs_with_grad.append(outputs[i])
|
|
args_with_grad.append(args[i])
|
|
if len(outputs_with_grad) == 0:
|
|
raise RuntimeError(
|
|
"none of output has requires_grad=True,"
|
|
" this checkpoint() is not necessary")
|
|
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
|
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
|
for inp in detached_inputs)
|
|
|
|
return (None, None) + grads
|
|
|
|
|
|
def low_gpu_full_post_state_dict_hook(module, fsdp_state, state_dict, prefix):
|
|
|
|
def param_hook(state_dict, prefix, fqn):
|
|
clean_key = fqn
|
|
clean_prefix = clean_tensor_name(prefix)
|
|
# Strip prefix out of key if needed as buffer names and param names
|
|
# do not have prefix considered as they are not computed in `state_dict`
|
|
# call.
|
|
if clean_key.startswith(clean_prefix):
|
|
clean_key = clean_key[len(clean_prefix) :]
|
|
|
|
# Clone parameters before exiting the `_unshard_fsdp_state_params()` context.
|
|
if not getattr(state_dict[fqn], "_has_been_cloned", False):
|
|
try:
|
|
state_dict[fqn] = state_dict[fqn].cpu().clone().detach()
|
|
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
|
|
except BaseException as e:
|
|
warnings.warn(
|
|
f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
|
|
"This may mean that this state_dict entry could point to invalid "
|
|
"memory regions after returning from state_dict() call if this "
|
|
"parameter is managed by FSDP. Please check clone "
|
|
f"implementation of {fqn}. Error: {str(e)}"
|
|
)
|
|
|
|
return _state_dict_utils._common_unshard_post_state_dict_hook(
|
|
module, fsdp_state, state_dict, prefix, param_hook
|
|
)
|
|
|
|
|
|
# enable to efficiently saving `state_dict` for fsdp
|
|
def enable_low_gpu_full_post_state_dict_hook():
|
|
_state_dict_utils._full_post_state_dict_hook = low_gpu_full_post_state_dict_hook
|
|
|
|
def disable_low_gpu_full_post_state_dict_hook():
|
|
_state_dict_utils._full_post_state_dict_hook = DefaultFullPostStateDictHook
|
|
|
|
|
|
# enable to make `torch.compile` work
|
|
def enable_compatible_checkpoint_function():
|
|
checkpoint.CheckpointFunction = CompatibleCheckpointFunction
|
|
|
|
def disable_compatible_checkpoint_function():
|
|
checkpoint.CheckpointFunction = DefaultCheckpointFunction
|
|
|