Skip to content

Commit f29da34

Browse files
danielhancheneverythingisc00lSethHWeidmanNinoRisteskiErland366
authored
Memory Efficient GRPO (#1773)
* Update __init__.py * Update loader.py * Update rl.py * Update rl.py * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Better TRL handling * Update rl.py * Update tokenizer_utils.py * Auto patching * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update rl.py * Update tokenizer_utils.py * Update rl.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update tokenizer_utils.py * Update rl.py * Update rl.py * Update rl.py * max seq length * Update rl.py * Update rl.py * Patching * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * NEFTune * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Extra replacements * Update rl_replacements.py * Update rl.py * extra RL replacements * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update _utils.py * Update loader_utils.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * autocast * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update pyproject.toml * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update _utils.py * Update llama.py * Update _utils.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * GRPO optimized * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Selective Log softmax * Fix GRPO bsz * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Fix TRL * Metrics GRPO * Update rl_replacements.py * Update rl_replacements.py * No compile * Update rl.py * Remove docs * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * llama-quantize on WINDOWS WSL error fix - edit save.py (gguf saving breaks) (#1649) * edit save.py to fix gguf saving breaks. * add check for .exe or not exe file extension for linux and windows * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update llama.py * Update llama.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * unsloth_num_chunks * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py (#1754) Fix typo in comment: know -> now. This was printed when running the Llama3.1_(8B)-GRPO.ipynb example notebook, so I'd expect others to run into it as well. * Optional logits * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * Update rl_replacements.py * Update rl.py * Update rl.py * Update rl.py * Update rl.py * fix an import error (#1767) * fix an import error * Delete .gitignore * Update loader.py * Update save.py --------- Co-authored-by: Daniel Han <[email protected]> * SamplingParams * Convert mask to float (#1762) * [Windows Support] Add latest `xformers` wheels to pyproject.toml (#1753) * Add latest xformers * Add a couple of lines to docs * vLLMSamplingParams * Update __init__.py * default num_chunks == -1 * Versioning --------- Co-authored-by: Gennadii Manzhos <[email protected]> Co-authored-by: Seth Weidman <[email protected]> Co-authored-by: Nino Risteski <[email protected]> Co-authored-by: Edd <[email protected]> Co-authored-by: Ben <[email protected]>
1 parent 67d3440 commit f29da34

File tree

11 files changed

+206
-101
lines changed

11 files changed

+206
-101
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://git
193193
### Windows Installation
194194

195195
To run Unsloth directly on Windows:
196-
- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows
196+
- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows (be aware that the Windows fork requires PyTorch >= 2.4 and CUDA 12)
197197
- In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue:
198198
```python
199199
trainer = SFTTrainer(
@@ -202,12 +202,15 @@ trainer = SFTTrainer(
202202
)
203203
```
204204

205+
### Advanced/Troubleshooting
206+
205207
For **advanced installation instructions** or if you see weird errors during installations:
206208

207209
1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton`
208210
2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers.
209211
3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs.
210-
4. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`
212+
4. Double check that your versions of Python, CUDA, CUDNN, `torch`, `triton`, and `xformers` are compatible with one another. The [PyTorch Compatibility Matrix](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix) may be useful.
213+
5. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`
211214

212215
## 📜 [Documentation](https://docs.unsloth.ai)
213216
- Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ triton = [
3939
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
4040
]
4141
huggingface = [
42-
"unsloth_zoo>=2025.2.5",
42+
"unsloth_zoo>=2025.2.6",
4343
"packaging",
4444
"tyro",
4545
"transformers>=4.46.1,!=4.47.0",
@@ -196,6 +196,10 @@ cu126onlytorch260 = [
196196
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
197197
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
198198
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
199+
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
200+
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
201+
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
202+
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
199203
]
200204
cu118 = [
201205
"unsloth[huggingface]",
@@ -344,7 +348,7 @@ colab-ampere-torch220 = [
344348
"flash-attn>=2.6.3",
345349
]
346350
colab-new = [
347-
"unsloth_zoo>=2025.2.5",
351+
"unsloth_zoo>=2025.2.6",
348352
"packaging",
349353
"tyro",
350354
"transformers>=4.46.1,!=4.47.0",

unsloth/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
196196
# Check for unsloth_zoo
197197
try:
198198
unsloth_zoo_version = importlib_version("unsloth_zoo")
199-
if Version(unsloth_zoo_version) < Version("2025.2.4"):
199+
if Version(unsloth_zoo_version) < Version("2025.2.6"):
200200
try:
201201
os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo")
202202
except:

unsloth/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
from .qwen2 import FastQwen2Model
2121
from .dpo import PatchDPOTrainer, PatchKTOTrainer
2222
from ._utils import is_bfloat16_supported
23-
from .rl import PatchFastRL
23+
from .rl import PatchFastRL, vLLMSamplingParams

unsloth/models/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "2025.2.12"
15+
__version__ = "2025.2.13"
1616

1717
__all__ = [
1818
"SUPPORTS_BFLOAT16",

unsloth/models/llama.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def LlamaModel_fast_forward(
700700
elif inputs_requires_grad:
701701
inputs_embeds.requires_grad_(False)
702702
pass
703+
attention_mask = attention_mask[:,:self.max_seq_length] # Must resize!
703704
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2)
704705
if inputs_requires_grad: inputs_embeds.requires_grad_(True)
705706
pass
@@ -774,9 +775,12 @@ def LlamaModel_fast_forward(
774775
self.SWA_mask = True
775776
self.GA_mask = False
776777
elif attention_mask is not None:
777-
778778
# Fixes https://github.com/unslothai/unsloth/issues/853
779779
# Unsloth needs a 2D mask, not a [2, 1, n, n] mask!
780+
781+
# https://github.com/pytorch/pytorch/issues/103749
782+
# Need to convert to float and not using bool
783+
attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min
780784
dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa(
781785
attention_mask,
782786
(batch_size, seq_length),
@@ -1030,6 +1034,7 @@ def _CausalLM_fast_forward(
10301034
output_hidden_states: Optional[bool] = None,
10311035
return_dict: Optional[bool] = None,
10321036
num_logits_to_keep: Optional[int] = 0,
1037+
logits_to_keep: Optional[int] = 0,
10331038
*args, **kwargs,
10341039
) -> Union[Tuple, CausalLMOutputWithPast]:
10351040

@@ -1053,16 +1058,16 @@ def _CausalLM_fast_forward(
10531058
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
10541059
self.model._has_no_labels = labels is None
10551060
outputs = self.model(
1056-
input_ids=input_ids,
1057-
causal_mask=causal_mask,
1058-
attention_mask=attention_mask,
1059-
position_ids=position_ids,
1060-
past_key_values=past_key_values,
1061-
inputs_embeds=inputs_embeds,
1062-
use_cache=use_cache,
1063-
output_attentions=output_attentions,
1064-
output_hidden_states=output_hidden_states,
1065-
return_dict=return_dict,
1061+
input_ids = input_ids,
1062+
causal_mask = causal_mask,
1063+
attention_mask = attention_mask,
1064+
position_ids = position_ids,
1065+
past_key_values = past_key_values,
1066+
inputs_embeds = inputs_embeds,
1067+
use_cache = use_cache,
1068+
output_attentions = output_attentions,
1069+
output_hidden_states = output_hidden_states,
1070+
return_dict = return_dict,
10661071
)
10671072
pass
10681073
hidden_states = outputs[0]
@@ -1072,6 +1077,20 @@ def _CausalLM_fast_forward(
10721077
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
10731078
logit_scaling = getattr(self.config, "logit_scale", 0)
10741079
dtype = lm_head.dtype
1080+
num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)
1081+
1082+
# Output last hidden states without logits if asked
1083+
if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
1084+
if num_logits_to_keep != 0:
1085+
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
1086+
return CausalLMOutputWithPast(
1087+
loss = None,
1088+
logits = hidden_states,
1089+
past_key_values = outputs.past_key_values,
1090+
hidden_states = outputs.hidden_states,
1091+
attentions= outputs.attentions,
1092+
)
1093+
pass
10751094

10761095
if bsz == 1 and q_len == 1:
10771096
logits = torch.mv(lm_head, hidden_states.ravel().to(dtype))
@@ -1166,11 +1185,11 @@ def _CausalLM_fast_forward(
11661185
return (loss,) + output if loss is not None else output
11671186

11681187
return CausalLMOutputWithPast(
1169-
loss=loss,
1170-
logits=logits,
1171-
past_key_values=outputs.past_key_values,
1172-
hidden_states=outputs.hidden_states,
1173-
attentions=outputs.attentions,
1188+
loss = loss,
1189+
logits = logits,
1190+
past_key_values = outputs.past_key_values,
1191+
hidden_states = outputs.hidden_states,
1192+
attentions= outputs.attentions,
11741193
)
11751194
pass
11761195
return _CausalLM_fast_forward
@@ -1180,28 +1199,30 @@ def _CausalLM_fast_forward(
11801199
@torch._disable_dynamo
11811200
def PeftModelForCausalLM_fast_forward(
11821201
self,
1183-
input_ids=None,
1184-
causal_mask=None,
1185-
attention_mask=None,
1186-
inputs_embeds=None,
1187-
labels=None,
1188-
output_attentions=None,
1189-
output_hidden_states=None,
1190-
return_dict=None,
1191-
task_ids=None,
1192-
num_logits_to_keep=0,
1202+
input_ids = None,
1203+
causal_mask = None,
1204+
attention_mask = None,
1205+
inputs_embeds = None,
1206+
labels = None,
1207+
output_attentions = None,
1208+
output_hidden_states = None,
1209+
return_dict = None,
1210+
task_ids = None,
1211+
num_logits_to_keep = 0,
1212+
logits_to_keep = 0,
11931213
**kwargs,
11941214
):
11951215
return self.base_model(
1196-
input_ids=input_ids,
1197-
causal_mask=causal_mask,
1198-
attention_mask=attention_mask,
1199-
inputs_embeds=inputs_embeds,
1200-
labels=labels,
1201-
output_attentions=output_attentions,
1202-
output_hidden_states=output_hidden_states,
1203-
return_dict=return_dict,
1204-
num_logits_to_keep=num_logits_to_keep,
1216+
input_ids = input_ids,
1217+
causal_mask = causal_mask,
1218+
attention_mask = attention_mask,
1219+
inputs_embeds = inputs_embeds,
1220+
labels = labels,
1221+
output_attentions = output_attentions,
1222+
output_hidden_states = output_hidden_states,
1223+
return_dict = return_dict,
1224+
num_logits_to_keep = num_logits_to_keep,
1225+
logits_to_keep = logits_to_keep,
12051226
**kwargs,
12061227
)
12071228
pass
@@ -1694,9 +1715,9 @@ def from_pretrained(
16941715
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
16951716
logger.warning_once("Device does not support bfloat16. Will change to float16.")
16961717
dtype = torch.float16
1697-
elif dtype == torch.float16 and SUPPORTS_BFLOAT16:
1698-
logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.")
1699-
dtype = torch.bfloat16
1718+
# elif dtype == torch.float16 and SUPPORTS_BFLOAT16:
1719+
# logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.")
1720+
# dtype = torch.bfloat16
17001721

17011722
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32)
17021723

unsloth/models/loader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,14 @@
2424
from .loader_utils import get_model_name
2525
import os, contextlib, sys
2626
try:
27-
from huggingface_hub.utils import get_token
27+
from huggingface_hub import get_token
2828
except:
29-
# Old HF Hub versions <= 0.0.25
30-
from huggingface_hub.utils._token import get_token
29+
try:
30+
from huggingface_hub.utils import get_token
31+
except:
32+
# For older versions of huggingface_hub
33+
from huggingface_hub.utils._token import get_token
34+
pass
3135
pass
3236
from huggingface_hub import HfFileSystem
3337
import importlib.util

unsloth/models/mapper.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -601,11 +601,6 @@
601601
"Qwen/Qwen2.5-VL-72B-Instruct",
602602
"unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit",
603603
),
604-
"unsloth/DeepHermes-3-Llama-3-8B-Preview-unsloth-bnb-4bit" : (
605-
"unsloth/DeepHermes-3-Llama-3-8B-Preview",
606-
"NousResearch/DeepHermes-3-Llama-3-8B-Preview",
607-
"unsloth/DeepHermes-3-Llama-3-8B-Preview-bnb-4bit",
608-
),
609604
"unsloth/DeepScaleR-1.5B-Preview-unsloth-bnb-4bit" : (
610605
"unsloth/DeepHermes-3-Llama-3-8B-Preview",
611606
"agentica-org/DeepScaleR-1.5B-Preview",

0 commit comments

Comments
 (0)