realign/train/trainers/utils.py
2024-03-09 10:55:34 +08:00

154 lines
6.0 KiB
Python

import sys
import os
import warnings
import torch
import torch.utils.checkpoint as checkpoint
from torch.utils.checkpoint import check_backward_validity, _get_autocast_kwargs, detach_variable
from torch.distributed.fsdp import _state_dict_utils
from torch.distributed.fsdp._common_utils import clean_tensor_name
from transformers.utils import ExplicitEnum
class ExtendedFSDPOption(ExplicitEnum):
FULL_SHARD = "full_shard"
SHARD_GRAD_OP = "shard_grad_op"
NO_SHARD = "no_shard"
# extention starts here
HYBRID_SHARD = "hybrid_shard"
_HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
# extention ends here
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"
DefaultCheckpointFunction = checkpoint.CheckpointFunction
DefaultFullPostStateDictHook = _state_dict_utils._full_post_state_dict_hook
class CompatibleCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument.")
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
# if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
# rng_devices = ctx.fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)
return (None, None) + grads
def low_gpu_full_post_state_dict_hook(module, fsdp_state, state_dict, prefix):
def param_hook(state_dict, prefix, fqn):
clean_key = fqn
clean_prefix = clean_tensor_name(prefix)
# Strip prefix out of key if needed as buffer names and param names
# do not have prefix considered as they are not computed in `state_dict`
# call.
if clean_key.startswith(clean_prefix):
clean_key = clean_key[len(clean_prefix) :]
# Clone parameters before exiting the `_unshard_fsdp_state_params()` context.
if not getattr(state_dict[fqn], "_has_been_cloned", False):
try:
state_dict[fqn] = state_dict[fqn].cpu().clone().detach()
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
except BaseException as e:
warnings.warn(
f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
"This may mean that this state_dict entry could point to invalid "
"memory regions after returning from state_dict() call if this "
"parameter is managed by FSDP. Please check clone "
f"implementation of {fqn}. Error: {str(e)}"
)
return _state_dict_utils._common_unshard_post_state_dict_hook(
module, fsdp_state, state_dict, prefix, param_hook
)
# enable to efficiently saving `state_dict` for fsdp
def enable_low_gpu_full_post_state_dict_hook():
_state_dict_utils._full_post_state_dict_hook = low_gpu_full_post_state_dict_hook
def disable_low_gpu_full_post_state_dict_hook():
_state_dict_utils._full_post_state_dict_hook = DefaultFullPostStateDictHook
# enable to make `torch.compile` work
def enable_compatible_checkpoint_function():
checkpoint.CheckpointFunction = CompatibleCheckpointFunction
def disable_compatible_checkpoint_function():
checkpoint.CheckpointFunction = DefaultCheckpointFunction