Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
535 commits
Select commit Hold shift + click to select a range
9aad48e
Update rl.py
danielhanchen Feb 12, 2025
f121a5c
Update llama.py
danielhanchen Feb 12, 2025
5052d35
Update llama.py
danielhanchen Feb 12, 2025
a11aa96
Update llama.py
danielhanchen Feb 12, 2025
a6abe02
Update llama.py
danielhanchen Feb 12, 2025
d867faa
autocast
danielhanchen Feb 12, 2025
44c9228
Update rl_replacements.py
danielhanchen Feb 12, 2025
e83d854
Update llama.py
danielhanchen Feb 12, 2025
623eb65
Update rl_replacements.py
danielhanchen Feb 13, 2025
7e612f0
Update rl_replacements.py
danielhanchen Feb 13, 2025
a45266b
Update rl_replacements.py
danielhanchen Feb 13, 2025
c855d7e
Update rl_replacements.py
danielhanchen Feb 13, 2025
d7cefba
Update llama.py
danielhanchen Feb 13, 2025
52d996a
Update rl_replacements.py
danielhanchen Feb 13, 2025
56f5b31
Update llama.py
danielhanchen Feb 13, 2025
5f1e98c
Update llama.py
danielhanchen Feb 13, 2025
e713129
Update llama.py
danielhanchen Feb 13, 2025
310fc16
Update llama.py
danielhanchen Feb 13, 2025
76a122e
Update llama.py
danielhanchen Feb 13, 2025
2dd29e5
Update rl_replacements.py
danielhanchen Feb 13, 2025
3c5be91
Update llama.py
danielhanchen Feb 13, 2025
e548b15
Update llama.py
danielhanchen Feb 13, 2025
296b3b3
Update llama.py
danielhanchen Feb 13, 2025
8de588b
Update llama.py
danielhanchen Feb 13, 2025
f87909a
Update pyproject.toml
danielhanchen Feb 13, 2025
2704440
Update llama.py
danielhanchen Feb 13, 2025
42e1967
Update llama.py
danielhanchen Feb 13, 2025
36bf805
Update llama.py
danielhanchen Feb 13, 2025
a3af8e3
Update llama.py
danielhanchen Feb 13, 2025
9d10d2f
Update llama.py
danielhanchen Feb 13, 2025
b30a81f
Update llama.py
danielhanchen Feb 13, 2025
b7e8559
Update llama.py
danielhanchen Feb 13, 2025
4b201d9
Update rl_replacements.py
danielhanchen Feb 13, 2025
dc723bc
Update rl_replacements.py
danielhanchen Feb 13, 2025
0309949
Update rl_replacements.py
danielhanchen Feb 13, 2025
c409574
Update rl_replacements.py
danielhanchen Feb 13, 2025
8e5b09a
Update llama.py
danielhanchen Feb 13, 2025
6652f1d
Update rl_replacements.py
danielhanchen Feb 13, 2025
9215bbe
Update rl_replacements.py
danielhanchen Feb 13, 2025
4bff998
Update rl_replacements.py
danielhanchen Feb 13, 2025
c859030
Update rl_replacements.py
danielhanchen Feb 13, 2025
2daa8e3
Update rl_replacements.py
danielhanchen Feb 13, 2025
527a0c4
Update rl_replacements.py
danielhanchen Feb 13, 2025
087a5dc
Update rl_replacements.py
danielhanchen Feb 13, 2025
73210b3
Update rl_replacements.py
danielhanchen Feb 13, 2025
9934ac5
Merge branch 'main' into nightly
danielhanchen Feb 13, 2025
2635f2a
Update llama.py
danielhanchen Feb 13, 2025
69ab838
Update _utils.py
danielhanchen Feb 13, 2025
d5d7a06
Merge branch 'main' into nightly
danielhanchen Feb 13, 2025
c9e450f
Merge branch 'main' into nightly
danielhanchen Feb 13, 2025
44d00e8
Merge branch 'main' into nightly
danielhanchen Feb 13, 2025
447dfc4
Merge branch 'main' into nightly
danielhanchen Feb 13, 2025
acf98dc
Update llama.py
danielhanchen Feb 14, 2025
1399110
Update _utils.py
danielhanchen Feb 14, 2025
881105b
Update rl_replacements.py
danielhanchen Feb 14, 2025
cfdd3f1
Update rl.py
danielhanchen Feb 14, 2025
95b7df5
Update rl.py
danielhanchen Feb 14, 2025
17bfcf9
Update rl.py
danielhanchen Feb 14, 2025
61c219d
Update rl.py
danielhanchen Feb 14, 2025
9794dc2
Update rl.py
danielhanchen Feb 14, 2025
3687a6f
Update llama.py
danielhanchen Feb 14, 2025
c495bfa
Update llama.py
danielhanchen Feb 14, 2025
f9055a7
Update llama.py
danielhanchen Feb 14, 2025
945e3f9
Update llama.py
danielhanchen Feb 14, 2025
3d9fe12
Update rl_replacements.py
danielhanchen Feb 14, 2025
ed90785
Update llama.py
danielhanchen Feb 14, 2025
640bc88
Update llama.py
danielhanchen Feb 14, 2025
bb3bb2d
Update llama.py
danielhanchen Feb 14, 2025
9065938
Update llama.py
danielhanchen Feb 14, 2025
07b48f5
Merge branch 'main' into nightly
danielhanchen Feb 14, 2025
48c5e0d
GRPO optimized
danielhanchen Feb 14, 2025
3a1fb63
Update rl.py
danielhanchen Feb 14, 2025
19014b0
Update rl_replacements.py
danielhanchen Feb 14, 2025
0c17e79
Update rl_replacements.py
danielhanchen Feb 14, 2025
aee44e2
Update rl.py
danielhanchen Feb 14, 2025
953d957
Update rl.py
danielhanchen Feb 14, 2025
2a2b9f7
Update rl.py
danielhanchen Feb 14, 2025
fcb0f4a
Update rl.py
danielhanchen Feb 14, 2025
eabc365
Update rl_replacements.py
danielhanchen Feb 14, 2025
7408318
Update rl_replacements.py
danielhanchen Feb 14, 2025
f35eae3
Update rl_replacements.py
danielhanchen Feb 14, 2025
2b89dae
Selective Log softmax
danielhanchen Feb 14, 2025
45c8431
Fix GRPO bsz
danielhanchen Feb 14, 2025
644cedf
Update rl.py
danielhanchen Feb 14, 2025
4b765d7
Update rl_replacements.py
danielhanchen Feb 14, 2025
0a7c56d
Update rl_replacements.py
danielhanchen Feb 15, 2025
1b43e1d
Update rl_replacements.py
danielhanchen Feb 15, 2025
d588665
Update rl_replacements.py
danielhanchen Feb 15, 2025
54bd827
Fix TRL
danielhanchen Feb 15, 2025
c6d6e6b
Merge branch 'main' into nightly
danielhanchen Feb 15, 2025
fa560ce
Metrics GRPO
danielhanchen Feb 15, 2025
46462f1
Update rl_replacements.py
danielhanchen Feb 15, 2025
12c497a
Update rl_replacements.py
danielhanchen Feb 15, 2025
b8aca94
Merge branch 'main' into nightly
danielhanchen Feb 15, 2025
c14faee
No compile
danielhanchen Feb 16, 2025
1fcad32
Update rl.py
danielhanchen Feb 16, 2025
80be827
Remove docs
danielhanchen Feb 16, 2025
9254243
Update rl.py
danielhanchen Feb 16, 2025
09cb804
Update rl.py
danielhanchen Feb 16, 2025
86dabcf
Update rl.py
danielhanchen Feb 16, 2025
ba1c93e
Update rl.py
danielhanchen Feb 16, 2025
0d75afd
Update rl_replacements.py
danielhanchen Feb 16, 2025
1803658
Update rl.py
danielhanchen Feb 16, 2025
a856085
Update rl.py
danielhanchen Feb 16, 2025
eeac4f3
Update rl_replacements.py
danielhanchen Feb 16, 2025
6f1beb0
Update rl_replacements.py
danielhanchen Feb 16, 2025
222b1e7
llama-quantize on WINDOWS WSL error fix - edit save.py (gguf saving b…
everythingisc00l Feb 16, 2025
103cff4
Update rl_replacements.py
danielhanchen Feb 17, 2025
89a1d03
Update rl_replacements.py
danielhanchen Feb 17, 2025
c46b544
Update rl_replacements.py
danielhanchen Feb 17, 2025
ed84307
Update rl_replacements.py
danielhanchen Feb 17, 2025
93d3f16
Update rl_replacements.py
danielhanchen Feb 17, 2025
429ba6d
Update rl_replacements.py
danielhanchen Feb 17, 2025
1e42bad
Update rl_replacements.py
danielhanchen Feb 17, 2025
38a1885
Update rl_replacements.py
danielhanchen Feb 17, 2025
f0ee4f5
Update rl_replacements.py
danielhanchen Feb 17, 2025
b68dce6
Update rl_replacements.py
danielhanchen Feb 17, 2025
0827067
Update llama.py
danielhanchen Feb 17, 2025
204cd7a
Update rl_replacements.py
danielhanchen Feb 17, 2025
e141075
Update rl_replacements.py
danielhanchen Feb 17, 2025
a07a9e3
Update rl_replacements.py
danielhanchen Feb 17, 2025
cf2720d
Update llama.py
danielhanchen Feb 17, 2025
5c6f586
Update llama.py
danielhanchen Feb 18, 2025
2e07623
Update rl_replacements.py
danielhanchen Feb 18, 2025
8025cfe
Update rl_replacements.py
danielhanchen Feb 18, 2025
ba48495
Update rl_replacements.py
danielhanchen Feb 18, 2025
f0078de
Update rl.py
danielhanchen Feb 18, 2025
15e0140
Update rl.py
danielhanchen Feb 18, 2025
5f5cca4
Update rl_replacements.py
danielhanchen Feb 18, 2025
d80be70
Update rl.py
danielhanchen Feb 18, 2025
47a85eb
Update rl.py
danielhanchen Feb 18, 2025
f09478d
Update rl_replacements.py
danielhanchen Feb 18, 2025
97637c5
Update rl_replacements.py
danielhanchen Feb 18, 2025
58bd27f
Update rl_replacements.py
danielhanchen Feb 18, 2025
7c0c749
Update rl_replacements.py
danielhanchen Feb 18, 2025
97b55c1
Update rl_replacements.py
danielhanchen Feb 18, 2025
24c7a2f
Update rl_replacements.py
danielhanchen Feb 18, 2025
06b2cd3
unsloth_num_chunks
danielhanchen Feb 18, 2025
cbb16e3
Update rl.py
danielhanchen Feb 18, 2025
d16299b
Update rl_replacements.py
danielhanchen Feb 18, 2025
0c1a808
Update rl_replacements.py
danielhanchen Feb 18, 2025
6796801
Update rl_replacements.py
danielhanchen Feb 18, 2025
bd046ca
Update rl.py
danielhanchen Feb 18, 2025
ac2e814
Update rl.py
danielhanchen Feb 18, 2025
a88712f
Update rl.py
danielhanchen Feb 18, 2025
0daa328
Update rl.py
danielhanchen Feb 18, 2025
1afe3f2
Update rl.py
danielhanchen Feb 18, 2025
6732822
Update rl_replacements.py
danielhanchen Feb 18, 2025
5efe9f3
Update rl_replacements.py
danielhanchen Feb 18, 2025
15442d1
Update rl_replacements.py (#1754)
SethHWeidman Feb 19, 2025
91ab43d
Optional logits
danielhanchen Feb 19, 2025
a6a5f60
Update rl.py
danielhanchen Feb 19, 2025
83ce085
Update rl.py
danielhanchen Feb 19, 2025
8ece11f
Update rl.py
danielhanchen Feb 19, 2025
bc6bfae
Update rl.py
danielhanchen Feb 20, 2025
95fb6a4
Update rl.py
danielhanchen Feb 20, 2025
ba01cf5
Update rl.py
danielhanchen Feb 20, 2025
eb48b98
Update rl.py
danielhanchen Feb 20, 2025
3c750a1
Update rl.py
danielhanchen Feb 20, 2025
515cf5a
Update rl_replacements.py
danielhanchen Feb 20, 2025
2cf4349
Update rl.py
danielhanchen Feb 20, 2025
ae8bf68
Update rl.py
danielhanchen Feb 20, 2025
e07f4bc
Update rl.py
danielhanchen Feb 20, 2025
f11e5ab
Merge branch 'main' into nightly
danielhanchen Feb 20, 2025
3fccf5d
Update rl.py
danielhanchen Feb 20, 2025
798ad95
fix an import error (#1767)
NinoRisteski Feb 20, 2025
2957d89
SamplingParams
danielhanchen Feb 20, 2025
19d57bc
Convert mask to float (#1762)
Erland366 Feb 20, 2025
07aea40
[Windows Support] Add latest `xformers` wheels to pyproject.toml (#1753)
versipellis Feb 20, 2025
77109a4
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Feb 20, 2025
f3d9efb
vLLMSamplingParams
danielhanchen Feb 20, 2025
6d5caca
Update __init__.py
danielhanchen Feb 20, 2025
3a5610e
default num_chunks == -1
danielhanchen Feb 20, 2025
0362bd2
Versioning
danielhanchen Feb 20, 2025
2969db8
Merge branch 'main' into nightly
danielhanchen Feb 20, 2025
b5eda24
Update llama.py
danielhanchen Feb 20, 2025
7de0022
Update llama.py
danielhanchen Feb 20, 2025
d4d7694
Update llama.py
danielhanchen Feb 20, 2025
0bbfbe8
Update llama.py
danielhanchen Feb 20, 2025
ae6e2bd
Update llama.py
danielhanchen Feb 20, 2025
1792deb
Update _utils.py
danielhanchen Feb 20, 2025
5dcd079
Update rl_replacements.py
danielhanchen Feb 20, 2025
ec6e0b7
Update rl_replacements.py
danielhanchen Feb 20, 2025
bc1d2ce
Update pyproject.toml
danielhanchen Feb 20, 2025
adbe38e
Update pyproject.toml
danielhanchen Feb 20, 2025
a9b542f
Export Model to ollama.com (#1648)
gjyotin305 Feb 22, 2025
f853ac0
Merge branch 'main' into nightly
danielhanchen Mar 3, 2025
9cab347
Update cross_entropy_loss.py
danielhanchen Mar 3, 2025
0ae9082
torch_cuda_device
danielhanchen Mar 3, 2025
f21314c
Update utils.py
danielhanchen Mar 3, 2025
9215212
Update utils.py
danielhanchen Mar 3, 2025
9d95aee
Update utils.py
danielhanchen Mar 3, 2025
35e9144
device
danielhanchen Mar 3, 2025
30b6f94
device
danielhanchen Mar 3, 2025
64e2b00
Update loader.py
danielhanchen Mar 3, 2025
ffa3278
Update llama.py
danielhanchen Mar 3, 2025
748c5b5
Update README.md
danielhanchen Mar 3, 2025
469ed48
Update llama.py
danielhanchen Mar 3, 2025
bc87afd
Update llama.py
danielhanchen Mar 3, 2025
ee9d6e5
Update _utils.py
danielhanchen Mar 4, 2025
91458bb
Update utils.py
danielhanchen Mar 4, 2025
a7a5d75
Update utils.py
danielhanchen Mar 4, 2025
d93cca2
Update utils.py
danielhanchen Mar 4, 2025
6e2a3a8
Update utils.py
danielhanchen Mar 4, 2025
8f9ba99
Update utils.py
danielhanchen Mar 4, 2025
ed697da
Update llama.py
danielhanchen Mar 4, 2025
d73c34b
Update llama.py
danielhanchen Mar 4, 2025
4485da7
Update llama.py
danielhanchen Mar 4, 2025
45ea48c
Update llama.py
danielhanchen Mar 4, 2025
8c4b79c
Update llama.py
danielhanchen Mar 4, 2025
c2ae510
Update utils.py
danielhanchen Mar 4, 2025
432ea24
Update utils.py
danielhanchen Mar 4, 2025
dcff03c
Update utils.py
danielhanchen Mar 4, 2025
6ef0866
Update utils.py
danielhanchen Mar 4, 2025
8c8ce96
__version__
danielhanchen Mar 4, 2025
208971b
Update rl.py
danielhanchen Mar 4, 2025
adc6977
Bug fixes
danielhanchen Mar 4, 2025
949c298
Bug fixes
danielhanchen Mar 4, 2025
ad6d962
Merge branch 'main' into nightly
danielhanchen Mar 4, 2025
59b24ad
Update llama.py
danielhanchen Mar 5, 2025
5df3936
Update _utils.py
danielhanchen Mar 5, 2025
b8b0f9c
_wrap_fast_inference
danielhanchen Mar 5, 2025
6f0857b
Update llama.py
danielhanchen Mar 5, 2025
109364b
Update llama.py
danielhanchen Mar 5, 2025
dd4bd07
Update llama.py
danielhanchen Mar 5, 2025
b356fce
Update llama.py
danielhanchen Mar 5, 2025
e022016
Update llama.py
danielhanchen Mar 5, 2025
12094a7
Update llama.py
danielhanchen Mar 5, 2025
2836128
Update llama.py
danielhanchen Mar 5, 2025
c956616
Update llama.py
danielhanchen Mar 5, 2025
e887f43
Update llama.py
danielhanchen Mar 5, 2025
95f872d
Update llama.py
danielhanchen Mar 5, 2025
647dbb4
Update llama.py
danielhanchen Mar 5, 2025
f640c8d
Update _utils.py
danielhanchen Mar 5, 2025
91a4fce
SFT dataset prepare
danielhanchen Mar 5, 2025
4495148
Update pyproject.toml
danielhanchen Mar 5, 2025
f41dff5
Update rl_replacements.py
danielhanchen Mar 5, 2025
0a3dbfa
Update rl_replacements.py
danielhanchen Mar 5, 2025
7d8f100
Update rl_replacements.py
danielhanchen Mar 5, 2025
413ea80
Update rl.py
danielhanchen Mar 5, 2025
3f5ce93
Update llama.py
danielhanchen Mar 5, 2025
185bced
Update llama.py
danielhanchen Mar 5, 2025
fd11ad7
Update utils.py
danielhanchen Mar 5, 2025
97ed0b4
bug fix
danielhanchen Mar 5, 2025
68eca88
Update llama.py
danielhanchen Mar 5, 2025
5daf9b5
Update llama.py
danielhanchen Mar 5, 2025
858bb76
Update llama.py
danielhanchen Mar 5, 2025
daedc34
Update llama.py
danielhanchen Mar 5, 2025
95e2371
Update llama.py
danielhanchen Mar 5, 2025
fccd68a
Update __init__.py
danielhanchen Mar 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 9 additions & 25 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "unsloth"
dynamic = ["version"]
description = "2-5X faster LLM finetuning"
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.9,<=3.12"
license = {file = "LICENSE"}
keywords = ["ai", "llm",]
authors = [
Expand Down Expand Up @@ -39,8 +39,8 @@ triton = [
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'"
]

windows=[
"unsloth_zoo>=2025.3.1",
huggingface = [
"unsloth_zoo>=2025.3.2",
"packaging",
"tyro",
"transformers>=4.46.1,!=4.47.0",
Expand All @@ -51,34 +51,18 @@ windows=[
"wheel>=0.42.0",
"numpy",
"accelerate>=0.34.1",
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0",
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2",
"peft>=0.7.1,!=0.11.0",
"protobuf<4.0.0",
"huggingface_hub",
"hf_transfer",
"unsloth[triton]",
]
windows=[
"unsloth[huggingface]",
"bitsandbytes>=0.41.1 ; platform_system == 'Windows'",
"xformers>=0.0.22.post7 ; platform_system == 'Windows'",
]
huggingface = [
"unsloth_zoo>=2025.3.1",
"packaging",
"tyro",
"transformers>=4.46.1,!=4.47.0",
"datasets>=2.16.0",
"sentencepiece>=0.2.0",
"tqdm",
"psutil",
"wheel>=0.42.0",
"numpy",
"accelerate>=0.34.1",
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0",
"peft>=0.7.1,!=0.11.0",
"protobuf<4.0.0",
"huggingface_hub",
"hf_transfer",
"unsloth[triton]",
]
cu118only = [
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
Expand Down Expand Up @@ -370,7 +354,7 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3",
]
colab-new = [
"unsloth_zoo>=2025.2.7",
"unsloth_zoo>=2025.3.1",
"packaging",
"tyro",
"transformers>=4.46.1,!=4.47.0",
Expand All @@ -388,7 +372,7 @@ colab-new = [
]
colab-no-deps = [
"accelerate>=0.34.1",
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0",
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2",
"peft>=0.7.1",
"xformers",
"bitsandbytes>=0.46.1",
Expand Down
2 changes: 1 addition & 1 deletion unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
# Check for unsloth_zoo
try:
unsloth_zoo_version = importlib_version("unsloth_zoo")
if Version(unsloth_zoo_version) < Version("2025.3.1"):
if Version(unsloth_zoo_version) < Version("2025.3.2"):
try:
os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo")
except:
Expand Down
31 changes: 16 additions & 15 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
torch_mm = torch.mm
torch_mv = torch.mv
torch_matmul = torch.matmul
torch_addmm = torch.addmm
torch_empty = torch.empty

def QUANT_STATE(W): return getattr(W, "quant_state", None)

Expand Down Expand Up @@ -194,8 +199,8 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index]
ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index]
if WEIGHT_BUFFER is None:
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = device, requires_grad = False)
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(size, dtype = dtype, device = device, requires_grad = False)
ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)

if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size)
if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax)
Expand All @@ -204,11 +209,11 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False
out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
else:
if out is None:
out = torch.empty(shape, dtype = dtype, device = device, requires_grad = False)
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
pass

# NF4 dequantization of statistics
Expand Down Expand Up @@ -258,11 +263,11 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False

# Create weight matrix
if out is None:
out = torch.empty(shape, dtype = dtype, device = device, requires_grad = False)
out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False)
else:
assert(out.shape == shape)
assert(out.dtype == dtype)
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)
out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False)

# Do dequantization
ptr_out_absmax = get_ptr(out_absmax)
Expand All @@ -286,7 +291,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False

if HAS_CUDA_STREAM:
def fast_gemv(X, W, quant_state, out = None):
if quant_state is None: return torch.matmul(X, W, out = out)
if quant_state is None: return torch_matmul(X, W, out = out)
# For fast X @ W where seq_len == 1
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
_, q_len, hd = X.shape
Expand Down Expand Up @@ -318,7 +323,7 @@ def fast_gemv(X, W, quant_state, out = None):
bout = shape[0]

if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = device)
out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
# else:
# assert(out.shape == (1, 1, bout,))
# pass
Expand All @@ -336,7 +341,7 @@ def fast_gemv(X, W, quant_state, out = None):
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)

df = torch.empty(absmax.shape, dtype = torch.float32, device = device)
df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
with torch_cuda_device(device):
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
Expand Down Expand Up @@ -385,7 +390,7 @@ def fast_gemv(X, W, quant_state, out = None):
device = W.device

if out is None:
out = torch.empty((1, 1, bout,), dtype = dtype, device = device)
out = torch_empty((1, 1, bout,), dtype = dtype, device = device)
# else:
# assert(out.shape == (1, 1, bout,))
# pass
Expand All @@ -403,7 +408,7 @@ def fast_gemv(X, W, quant_state, out = None):
ldb = ctypes_c_int32(ldb)
ldc = ctypes_c_int32(ldc)

df = torch.empty(absmax.shape, dtype = torch.float32, device = device)
df = torch_empty(absmax.shape, dtype = torch.float32, device = device)
cdequantize_blockwise_fp32(
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
ctypes_c_int(blocksize2), ctypes_c_int(df.numel()),
Expand All @@ -423,10 +428,6 @@ def fast_gemv(X, W, quant_state, out = None):
pass


torch_mm = torch.mm
torch_mv = torch.mv
torch_matmul = torch.matmul
torch_addmm = torch.addmm
def fast_linear_forward(proj, X, temp_lora = None, out = None):

W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
Expand Down
80 changes: 40 additions & 40 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2025.3.3"
__version__ = "2025.3.4"

__all__ = [
"SUPPORTS_BFLOAT16",
Expand All @@ -39,8 +39,8 @@
"create_boolean_mask",
"torch_amp_custom_fwd",
"torch_amp_custom_bwd",
"accelerate_old_send_to_device",
"accelerate_new_send_to_device",
# "accelerate_old_send_to_device",
# "accelerate_new_send_to_device",
"patch_gradient_accumulation_fix",
"patch_compiling_bitsandbytes",
"patch_regional_compilation",
Expand Down Expand Up @@ -241,24 +241,24 @@ def patch_mistral_nemo_config(config):

# =============================================
# Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0'
import transformers.cache_utils
if hasattr(transformers.cache_utils, "DynamicCache") and \
transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__":

source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__)
start = source.find("def")
spaces = start*" "
source = source.split("\n")
source = "\n".join(x[start:] for x in source)
where = source.find("raise KeyError")
source = source[:where] + \
f"if len(self) == 0:\n{spaces}{spaces}"\
" raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \
f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:]
source = source.replace("__getitem__", "__cache_utils_getitem__", 1)
exec(source)
transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__
pass
# import transformers.cache_utils
# if hasattr(transformers.cache_utils, "DynamicCache") and \
# transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__":

# source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__)
# start = source.find("def")
# spaces = start*" "
# source = source.split("\n")
# source = "\n".join(x[start:] for x in source)
# where = source.find("raise KeyError")
# source = source[:where] + \
# f"if len(self) == 0:\n{spaces}{spaces}"\
# " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \
# f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:]
# source = source.replace("__getitem__", "__cache_utils_getitem__", 1)
# exec(source)
# transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__
# pass
# =============================================

# =============================================
Expand Down Expand Up @@ -411,25 +411,25 @@ def _is_openai_available(): return False

# =============================================
# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'
accelerate_old_send_to_device = None
accelerate_new_send_to_device = None
if xformers_version is not None and Version(xformers_version) >= Version("0.0.27"):
import accelerate.utils.operations
if hasattr(accelerate.utils.operations, "send_to_device") and \
accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device":
accelerate_old_send_to_device = accelerate.utils.operations.send_to_device
from accelerate.utils.operations import *
send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device)
send_to_device = re.sub(
r"([ ]{4,})return tensor\.to\(device\)",
r"\1try: return tensor.to(device)\n\1except: return tensor",
send_to_device,
).replace("def send_to_device", "def _fixed_send_to_device")
exec(send_to_device)
# accelerate.utils.operations.send_to_device = _fixed_send_to_device
accelerate_new_send_to_device = _fixed_send_to_device
pass
pass
# accelerate_old_send_to_device = None
# accelerate_new_send_to_device = None
# if xformers_version is not None and Version(xformers_version) >= Version("0.0.27"):
# import accelerate.utils.operations
# if hasattr(accelerate.utils.operations, "send_to_device") and \
# accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device":
# accelerate_old_send_to_device = accelerate.utils.operations.send_to_device
# from accelerate.utils.operations import *
# send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device)
# send_to_device = re.sub(
# r"([ ]{4,})return tensor\.to\(device\)",
# r"\1try: return tensor.to(device)\n\1except: return tensor",
# send_to_device,
# ).replace("def send_to_device", "def _fixed_send_to_device")
# exec(send_to_device)
# # accelerate.utils.operations.send_to_device = _fixed_send_to_device
# accelerate_new_send_to_device = _fixed_send_to_device
# pass
# pass

# Transformers 4.46 breaks dynamic caching. This is a hack
import transformers.generation.configuration_utils
Expand Down
Loading