Skip to content

Commit 362a9ac

Browse files
danielhancheneverythingisc00lSethHWeidmanNinoRisteskiErland366
authored
Bug fixes (#1891)
* Update rl.py * Patching * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * NEFTune * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Extra replacements * Update rl_replacements.py * Update rl.py * extra RL replacements * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update _utils.py * Update loader_utils.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * autocast * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update pyproject.toml * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update _utils.py * Update llama.py * Update _utils.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * GRPO optimized * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Selective Log softmax * Fix GRPO bsz * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Fix TRL * Metrics GRPO * Update rl_replacements.py * Update rl_replacements.py * No compile * Update rl.py * Remove docs * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * llama-quantize on WINDOWS WSL error fix - edit save.py (gguf saving breaks) (#1649) * edit save.py to fix gguf saving breaks. * add check for .exe or not exe file extension for linux and windows * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * unsloth_num_chunks * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py (#1754) Fix typo in comment: know -> now. This was printed when running the Llama3.1_(8B)-GRPO.ipynb example notebook, so I'd expect others to run into it as well. * Optional logits * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * fix an import error (#1767) * fix an import error * Delete .gitignore * Update loader.py * Update save.py --------- Co-authored-by: Daniel Han <[email protected]> * SamplingParams * Convert mask to float (#1762) * [Windows Support] Add latest `xformers` wheels to pyproject.toml (#1753) * Add latest xformers * Add a couple of lines to docs * vLLMSamplingParams * Update __init__.py * default num_chunks == -1 * Versioning * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update _utils.py * Update rl_replacements.py * Update rl_replacements.py * Update pyproject.toml * Update pyproject.toml * Export Model to ollama.com (#1648) * Ollama Export Model to ollama.com Signed-off-by: Jyotin Goel <[email protected]> * Check for model_name Signed-off-by: Jyotin Goel <[email protected]> * subprocess use instead of requests | added check for ollama server Signed-off-by: Jyotin Goel <[email protected]> * create_ollama_model Signed-off-by: Jyotin Goel <[email protected]> * create_ollama_model | fix Signed-off-by: Jyotin Goel <[email protected]> * Push to Ollama Signed-off-by: Jyotin Goel <[email protected]> --------- Signed-off-by: Jyotin Goel <[email protected]> * Update cross_entropy_loss.py * torch_cuda_device * Update utils.py * Update utils.py * Update utils.py * device * device * Update loader.py * Update llama.py * Update README.md * Update llama.py * Update llama.py * Update _utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * __version__ * Update rl.py * Bug fixes --------- Signed-off-by: Jyotin Goel <[email protected]> Co-authored-by: Gennadii Manzhos <[email protected]> Co-authored-by: Seth Weidman <[email protected]> Co-authored-by: Nino Risteski <[email protected]> Co-authored-by: Edd <[email protected]> Co-authored-by: Ben <[email protected]> Co-authored-by: Jyotin Goel <[email protected]>
1 parent be55e29 commit 362a9ac

File tree

16 files changed

+400
-246
lines changed

16 files changed

+400
-246
lines changed

README.md

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,8 @@ print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://git
232232
233233
```python
234234
from unsloth import FastLanguageModel
235-
from unsloth import is_bfloat16_supported
236235
import torch
237-
from trl import SFTTrainer
238-
from transformers import TrainingArguments
236+
from trl import SFTTrainer, SFTConfig
239237
from datasets import load_dataset
240238
max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
241239
# Get LAION dataset
@@ -244,21 +242,28 @@ dataset = load_dataset("json", data_files = {"train" : url}, split = "train")
244242

245243
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
246244
fourbit_models = [
247-
"unsloth/mistral-7b-v0.3-bnb-4bit", # New Mistral v3 2x faster!
245+
"unsloth/Meta-Llama-3.1-8B-bnb-4bit", # Llama-3.1 2x faster
246+
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
247+
"unsloth/Meta-Llama-3.1-70B-bnb-4bit",
248+
"unsloth/Meta-Llama-3.1-405B-bnb-4bit", # 4bit for 405b!
249+
"unsloth/Mistral-Small-Instruct-2409", # Mistral 22b 2x faster!
248250
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
249-
"unsloth/llama-3-8b-bnb-4bit", # Llama-3 15 trillion tokens model 2x faster!
250-
"unsloth/llama-3-8b-Instruct-bnb-4bit",
251-
"unsloth/llama-3-70b-bnb-4bit",
252-
"unsloth/Phi-3-mini-4k-instruct", # Phi-3 2x faster!
251+
"unsloth/Phi-3.5-mini-instruct", # Phi-3.5 2x faster!
253252
"unsloth/Phi-3-medium-4k-instruct",
254-
"unsloth/mistral-7b-bnb-4bit",
255-
"unsloth/gemma-7b-bnb-4bit", # Gemma 2.2x faster!
253+
"unsloth/gemma-2-9b-bnb-4bit",
254+
"unsloth/gemma-2-27b-bnb-4bit", # Gemma 2x faster!
255+
256+
"unsloth/Llama-3.2-1B-bnb-4bit", # NEW! Llama 3.2 models
257+
"unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
258+
"unsloth/Llama-3.2-3B-bnb-4bit",
259+
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
260+
261+
"unsloth/Llama-3.3-70B-Instruct-bnb-4bit" # NEW! Llama 3.3 70B!
256262
] # More models at https://huggingface.co/unsloth
257263

258264
model, tokenizer = FastLanguageModel.from_pretrained(
259-
model_name = "unsloth/llama-3-8b-bnb-4bit",
265+
model_name = "unsloth/Llama-3.2-1B",
260266
max_seq_length = max_seq_length,
261-
dtype = None,
262267
load_in_4bit = True,
263268
)
264269

@@ -282,16 +287,14 @@ model = FastLanguageModel.get_peft_model(
282287
trainer = SFTTrainer(
283288
model = model,
284289
train_dataset = dataset,
285-
dataset_text_field = "text",
286-
max_seq_length = max_seq_length,
287290
tokenizer = tokenizer,
288-
args = TrainingArguments(
291+
args = SFTConfig(
292+
dataset_text_field = "text",
293+
max_seq_length = max_seq_length,
289294
per_device_train_batch_size = 2,
290295
gradient_accumulation_steps = 4,
291296
warmup_steps = 10,
292297
max_steps = 60,
293-
fp16 = not is_bfloat16_supported(),
294-
bf16 = is_bfloat16_supported(),
295298
logging_steps = 1,
296299
output_dir = "outputs",
297300
optim = "adamw_8bit",
@@ -323,17 +326,14 @@ RL including DPO, GRPO, PPO, Reward Modelling, Online DPO all work with Unsloth.
323326
import os
324327
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optional set GPU device ID
325328

326-
from unsloth import FastLanguageModel, PatchDPOTrainer
327-
from unsloth import is_bfloat16_supported
328-
PatchDPOTrainer()
329+
from unsloth import FastLanguageModel
329330
import torch
330-
from transformers import TrainingArguments
331-
from trl import DPOTrainer
331+
from trl import DPOTrainer, DPOConfig
332+
max_seq_length = 2048
332333

333334
model, tokenizer = FastLanguageModel.from_pretrained(
334335
model_name = "unsloth/zephyr-sft-bnb-4bit",
335336
max_seq_length = max_seq_length,
336-
dtype = None,
337337
load_in_4bit = True,
338338
)
339339

@@ -355,24 +355,22 @@ model = FastLanguageModel.get_peft_model(
355355
dpo_trainer = DPOTrainer(
356356
model = model,
357357
ref_model = None,
358-
args = TrainingArguments(
358+
train_dataset = YOUR_DATASET_HERE,
359+
# eval_dataset = YOUR_DATASET_HERE,
360+
tokenizer = tokenizer,
361+
args = DPOConfig(
359362
per_device_train_batch_size = 4,
360363
gradient_accumulation_steps = 8,
361364
warmup_ratio = 0.1,
362365
num_train_epochs = 3,
363-
fp16 = not is_bfloat16_supported(),
364-
bf16 = is_bfloat16_supported(),
365366
logging_steps = 1,
366367
optim = "adamw_8bit",
367368
seed = 42,
368369
output_dir = "outputs",
370+
max_length = 1024,
371+
max_prompt_length = 512,
372+
beta = 0.1,
369373
),
370-
beta = 0.1,
371-
train_dataset = YOUR_DATASET_HERE,
372-
# eval_dataset = YOUR_DATASET_HERE,
373-
tokenizer = tokenizer,
374-
max_length = 1024,
375-
max_prompt_length = 512,
376374
)
377375
dpo_trainer.train()
378376
```

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ triton = [
4040
]
4141

4242
windows=[
43-
"unsloth_zoo>=2025.2.7",
43+
"unsloth_zoo>=2025.3.1",
4444
"packaging",
4545
"tyro",
4646
"transformers>=4.46.1,!=4.47.0",
@@ -61,7 +61,7 @@ windows=[
6161
"xformers>=0.0.22.post7 ; platform_system == 'Windows'",
6262
]
6363
huggingface = [
64-
"unsloth_zoo>=2025.2.7",
64+
"unsloth_zoo>=2025.3.1",
6565
"packaging",
6666
"tyro",
6767
"transformers>=4.46.1,!=4.47.0",

unsloth/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
198198
# Check for unsloth_zoo
199199
try:
200200
unsloth_zoo_version = importlib_version("unsloth_zoo")
201-
if Version(unsloth_zoo_version) < Version("2025.2.6"):
201+
if Version(unsloth_zoo_version) < Version("2025.3.1"):
202202
try:
203203
os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo")
204204
except:
@@ -212,6 +212,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
212212
pass
213213

214214
from .models import *
215+
from .models import __version__
215216
from .save import *
216217
from .chat_templates import *
217218
from .tokenizer_utils import *

unsloth/kernels/cross_entropy_loss.py

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
import triton
1616
import triton.language as tl
1717
import torch
18-
from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh, triton_cast
18+
from .utils import (
19+
calculate_settings,
20+
MAX_FUSED_SIZE,
21+
triton_tanh,
22+
triton_cast,
23+
torch_cuda_device,
24+
)
1925
from transformers.models.llama.modeling_llama import logger
2026
from packaging.version import Version
2127

@@ -279,10 +285,11 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling :
279285
n_rows : int
280286
vocab_size : int
281287
n_rows, vocab_size = logits.shape
288+
device = logits.device
282289

283290
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
284291
n_chunks : int = div + (mod != 0)
285-
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
292+
losses = torch.empty(n_rows, dtype = torch.float32, device = device)
286293

287294
DO_SOFTCAPPING : bool = bool(logit_softcapping != 0)
288295
DO_LOGIT_SCALING : bool = bool(logit_scaling != 0)
@@ -292,39 +299,41 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling :
292299
if n_chunks == 1:
293300
# For small vocabs <= 65336 like Llama, Mistral
294301
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
295-
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
296-
297-
_cross_entropy_forward[(n_rows,)](
298-
logits, logits.stride(0),
299-
losses,
300-
logsumexp,
301-
labels,
302-
VOCAB_SIZE = vocab_size,
303-
BLOCK_SIZE = BLOCK_SIZE,
304-
DO_SOFTCAPPING = DO_SOFTCAPPING,
305-
SOFTCAP = logit_softcapping,
306-
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
307-
LOGIT_SCALE = logit_scaling,
308-
num_warps = num_warps,
309-
)
302+
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device)
303+
304+
with torch_cuda_device(device):
305+
_cross_entropy_forward[(n_rows,)](
306+
logits, logits.stride(0),
307+
losses,
308+
logsumexp,
309+
labels,
310+
VOCAB_SIZE = vocab_size,
311+
BLOCK_SIZE = BLOCK_SIZE,
312+
DO_SOFTCAPPING = DO_SOFTCAPPING,
313+
SOFTCAP = logit_softcapping,
314+
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
315+
LOGIT_SCALE = logit_scaling,
316+
num_warps = num_warps,
317+
)
310318
else:
311319
# For large vocabs > 65336 like Gemma 256K
312-
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
313-
314-
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
315-
logits, logits.stride(0),
316-
losses,
317-
logsumexp,
318-
labels,
319-
VOCAB_SIZE = vocab_size,
320-
N_CHUNKS = n_chunks,
321-
BLOCK_SIZE = MAX_FUSED_SIZE,
322-
DO_SOFTCAPPING = DO_SOFTCAPPING,
323-
SOFTCAP = logit_softcapping,
324-
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
325-
LOGIT_SCALE = logit_scaling,
326-
num_warps = 32,
327-
)
320+
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device)
321+
322+
with torch_cuda_device(device):
323+
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
324+
logits, logits.stride(0),
325+
losses,
326+
logsumexp,
327+
labels,
328+
VOCAB_SIZE = vocab_size,
329+
N_CHUNKS = n_chunks,
330+
BLOCK_SIZE = MAX_FUSED_SIZE,
331+
DO_SOFTCAPPING = DO_SOFTCAPPING,
332+
SOFTCAP = logit_softcapping,
333+
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
334+
LOGIT_SCALE = logit_scaling,
335+
num_warps = 32,
336+
)
328337
# logsumexp(chunked_logsumexp) - x
329338
# Do the -x separately
330339
logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
@@ -354,19 +363,20 @@ def backward(ctx, dlosses):
354363
div, mod = divmod(vocab_size, BLOCK_SIZE)
355364
n_blocks : int = div + (mod != 0)
356365

357-
_cross_entropy_backward[(n_rows, n_blocks,)](
358-
logits, logits.stride(0),
359-
dlosses, dlosses.stride(0),
360-
logsumexp,
361-
labels,
362-
VOCAB_SIZE = vocab_size,
363-
BLOCK_SIZE = BLOCK_SIZE,
364-
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
365-
SOFTCAP = ctx.logit_softcapping,
366-
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
367-
LOGIT_SCALE = ctx.logit_scaling,
368-
num_warps = 8,
369-
)
366+
with torch_cuda_device(dlosses.device):
367+
_cross_entropy_backward[(n_rows, n_blocks,)](
368+
logits, logits.stride(0),
369+
dlosses, dlosses.stride(0),
370+
logsumexp,
371+
labels,
372+
VOCAB_SIZE = vocab_size,
373+
BLOCK_SIZE = BLOCK_SIZE,
374+
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
375+
SOFTCAP = ctx.logit_softcapping,
376+
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
377+
LOGIT_SCALE = ctx.logit_scaling,
378+
num_warps = 8,
379+
)
370380
return logits, None, None, None,
371381
pass
372382
pass

unsloth/kernels/geglu.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
import triton
1616
import triton.language as tl
1717
import torch
18-
from .utils import calculate_settings, triton_tanh
18+
from .utils import (
19+
calculate_settings,
20+
triton_tanh,
21+
torch_cuda_device,
22+
)
1923

2024

2125
@triton.jit
@@ -41,9 +45,11 @@ def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
4145
def geglu_exact_forward_kernel(gate, up):
4246
batch, seq_len, hd = gate.shape
4347
n_elements = gate.numel()
44-
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
48+
device = gate.device
49+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
4550
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
46-
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
51+
with torch_cuda_device(device):
52+
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
4753
return out
4854
pass
4955

@@ -99,7 +105,8 @@ def geglu_exact_backward_kernel(DW, e, g):
99105
batch_seq_len, hd = e.shape
100106
n_elements = e.numel()
101107
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
102-
_exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
108+
with torch_cuda_device(e.device):
109+
_exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
103110
return DW, e, g
104111
pass
105112

@@ -133,9 +140,11 @@ def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
133140
def geglu_approx_forward_kernel(gate, up):
134141
batch, seq_len, hd = gate.shape
135142
n_elements = gate.numel()
136-
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
143+
device = gate.device
144+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
137145
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
138-
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
146+
with torch_cuda_device(device):
147+
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
139148
return out
140149
pass
141150

@@ -198,6 +207,7 @@ def geglu_approx_backward_kernel(DW, e, g):
198207
batch_seq_len, hd = e.shape
199208
n_elements = e.numel()
200209
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
201-
_approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
210+
with torch_cuda_device(e.device):
211+
_approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
202212
return DW, e, g
203213
pass

0 commit comments

Comments
 (0)