Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 129 additions & 7 deletions unsloth_zoo/gradient_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ def initialize_unsloth_gradient_checkpointing(dtype = None):
CPU_BUFFERS = []
CPU_INDEX = 0

if dtype is None:
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
dtype = torch.float32
elif dtype is None:
if DEVICE_TYPE == "cuda":
major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = (major_version >= 8)
Expand Down Expand Up @@ -604,6 +606,112 @@ def backward(ctx, *args):
pass


class UnslothCheckpointFunction_Float32(torch.autograd.Function):
"""Special version for FORCE_FLOAT32 mode without AMP"""

@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# Same as original but WITHOUT @torch_amp_custom_fwd decorator
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
ctx.device_type = _infer_device_type(*args)

if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
ctx.had_device_in_fwd = False
device_module = _get_device_module(ctx.device_type)
if getattr(device_module, "_initialized", False):
ctx.had_device_in_fwd = True
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)

ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
ctx._requires_gradient = False

for i, arg in enumerate(args):
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
if arg.requires_grad:
ctx._requires_gradient = True
else:
ctx.inputs.append(arg)

if ctx._requires_gradient:
ctx.save_for_backward(*tensor_inputs)

with torch.no_grad():
outputs = run_function(*args)

return outputs

@staticmethod
def backward(ctx, *args):
# Same as original but WITHOUT @torch_amp_custom_bwd decorator
if not ctx._requires_gradient:
return None

if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
" with .grad() or passing an `inputs` parameter to .backward()."
)

inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors

for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]

rng_devices = []
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
rng_devices = ctx.fwd_devices

with torch.random.fork_rng(
devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type
):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_device_in_fwd:
set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type)

detached_inputs = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
detached_inputs.append(inp)
continue
x = inp.detach()
x.requires_grad = inp.requires_grad
detached_inputs.append(x)

# NO autocast context here - just pure computation
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)

outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])

if len(outputs_with_grad) > 0:
torch.autograd.backward(outputs_with_grad, args_with_grad)

grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)

return (None, None) + grads
pass

from torch.utils.checkpoint import (
ContextManager,
_DEFAULT_DETERMINISM_MODE,
Expand Down Expand Up @@ -756,7 +864,11 @@ def unsloth_checkpoint(
"Passing `context_fn` or `debug` is only supported when "
"use_reentrant=False."
)
return UnslothCheckpointFunction.apply(function, preserve, *args)
# Choose the correct checkpoint function based on FORCE_FLOAT32
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
return UnslothCheckpointFunction_Float32.apply(function, preserve, *args)
else:
return UnslothCheckpointFunction.apply(function, preserve, *args)
else:
gen = _checkpoint_without_reentrant_generator(
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
Expand All @@ -774,13 +886,23 @@ def unsloth_checkpoint(

def patch_unsloth_smart_gradient_checkpointing(dtype = None):
# All Unsloth Zoo code licensed under LGPLv3
if torch.utils.checkpoint.CheckpointFunction.__name__ != "UnslothCheckpointFunction":

# Use float32 version if FORCE_FLOAT32 is set
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1":
# Force dtype to float32 for buffer initialization
initialize_unsloth_gradient_checkpointing(torch.float32)

# Store the correct checkpoint function
if not hasattr(torch.utils.checkpoint, "_checkpoint_mode"):
torch.utils.checkpoint._checkpoint_mode = "float32"
else:
initialize_unsloth_gradient_checkpointing(dtype)
torch.utils.checkpoint._old_CheckpointFunction = torch.utils.checkpoint.CheckpointFunction
torch.utils.checkpoint.CheckpointFunction = UnslothCheckpointFunction
if not hasattr(torch.utils.checkpoint, "_checkpoint_mode"):
torch.utils.checkpoint._checkpoint_mode = "normal"

# Always patch checkpoint function
if torch.utils.checkpoint.checkpoint.__name__ != "unsloth_checkpoint":
torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint
torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint
torch.utils.checkpoint.checkpoint = unsloth_checkpoint
pass

Expand All @@ -800,7 +922,7 @@ def unpatch_unsloth_smart_gradient_checkpointing():

if (torch.utils.checkpoint.checkpoint.__name__ == "unsloth_checkpoint") and \
hasattr(torch.utils, "_old_checkpoint"):

torch.utils.checkpoint = torch.utils._old_checkpoint
pass

Expand Down
49 changes: 30 additions & 19 deletions unsloth_zoo/patching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,19 +242,30 @@ def patch_model_and_tokenizer(
pass
# If we force float32, we first use bfloat16, then downcast to float16
if do_forced_float32:
correct_dtype = torch.float16
for name, module in model.named_modules():
if "down_proj" in name or "up_proj" in name or "gate_proj" in name:
exec(f"module.to(torch.float16)")
if "q_proj" in name or "k_proj" in name or "v_proj" in name or "o_proj" in name:
exec(f"module.to(torch.float16)")
if "lm_head" in name or "embed_tokens" in name:
exec(f"module.to(torch.float16)")
if "norm" in name:
exec(f"module.to(torch.float32)")
assert(module.weight.dtype == torch.float32)
torch.cuda.empty_cache()
pass
correct_dtype = torch.float16
for name, module in model.named_modules():
if "down_proj" in name or "up_proj" in name or "gate_proj" in name or "fc1" in name or "fc2" in name:
module.to(torch.float16)
if "q_proj" in name or "k_proj" in name or "v_proj" in name or "o_proj" in name or "out_proj" in name:
module.to(torch.float16)
if "lm_head" in name or "embed_tokens" in name:
module.to(torch.float16)
if "embed_tokens" in name or "patch_embedding" in name:
module.to(torch.float16)
if "norm" in name:
module.to(torch.float16)
torch.cuda.empty_cache()

# Convert any remaining bfloat16 parameters
for name, param in model.named_parameters():
if param.dtype == torch.bfloat16:
param.data = param.data.to(torch.float16)

# Also convert buffers (like position embeddings)
for name, buffer in model.named_buffers():
if buffer.dtype == torch.bfloat16:
buffer.data = buffer.data.to(torch.float16)
pass
pass

# Correct torch_dtype
Expand All @@ -279,7 +290,7 @@ def __fix_dtype(config):
try: setattr(m, "dtype", correct_dtype)
except: pass
pass

# Check all params and patch!
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
Expand Down Expand Up @@ -313,7 +324,7 @@ def __fix_dtype(config):

elif hasattr(module, "short_cos_cached") and \
(module.short_cos_cached.dtype != correct_dtype):

module.short_cos_cached = module.short_cos_cached.to(correct_dtype)
module.short_sin_cached = module.short_sin_cached.to(correct_dtype)
pass
Expand Down Expand Up @@ -373,7 +384,7 @@ def __fix_dtype(config):
lm_head.weight = old_output_embedding if not is_tied else old_input_embedding
lm_head.in_features = lm_head.weight.shape[1]
lm_head.out_features = lm_head.weight.shape[0]

lm_head.weight.requires_grad_(requires_grad)
model.set_output_embeddings(lm_head)
if hasattr(model, "lm_head"): model.lm_head = lm_head
Expand Down Expand Up @@ -457,7 +468,7 @@ def check_conversion_mappings(model, current_key_name_str, skip_modules):
if hasattr(model_root_cls, "_checkpoint_conversion_mapping") and len(model_root_cls._checkpoint_conversion_mapping) > 0:
# if this is true, then it means that we must be on transformers >=4.52.0 because conversion_mappings was added in 4.52.0
# we cant know if the skip module naming convention is new or old
# but if we are supposed to skip this current_key_name_str, and it didn't pass
# but if we are supposed to skip this current_key_name_str, and it didn't pass
# (current_key_name_str in quantization_config.llm_int8_skip_modules)
# then new transformers + new module hierarchy means it should not be skipped, ie no BC check needed
# and new transformers + old module hierarchy means we still need to check to skip
Expand Down Expand Up @@ -605,9 +616,9 @@ def visit_Assign(self, node: ast.Assign):
def add_score_code(match):
indentation = match.group(1) # Captured indentation
line_content = match.group(2) # The line 'current_key_name.append(name)'

indented_breakpoint_code = "\n".join([f"{indentation}{line}" for line in score_code.splitlines()])

return f"{indentation}{line_content}\n{indented_breakpoint_code}"

source = re.sub(pattern, add_score_code, source, flags=re.MULTILINE)
Expand Down
Loading