Skip to content

Conversation

@djsaunde
Copy link
Collaborator

@djsaunde djsaunde commented Oct 29, 2025

This PR adds sample packing support. It uses TRL's SFTConfig packing=True and padding_free=True args to pack the sequences, and we compute packed_seq_lengths metadata and thread it through the model forward pass. This metadata is used to create block causal masks for SDPA and xformers attention, and is passed to the flash attention varlen API which handles the block causal masking itself under the hood (we need to do this ourselves because of our custom forward pass, whereas TRL handles the sequence length metadata internally in their trainer).

I added a few unit tests. I also wrote a quick bash script for smoke testing some common model architectures: gist, which runs.

Below is a comparison of short unsloth/qwen2.5-0.5b training runs. The losses don't match because we're seeing more / different samples on each step. But the scale and trend match, which is the important bit.

image

Commands:

No sample packing:

python unsloth-cli.py --model_name unsloth/qwen2.5-0.5b --dataset yahma/alpaca-cleaned --per_device_train_batch_size 8 --max_steps 50 --max_seq_length 2048

Sample packing:

python unsloth-cli.py --model_name unsloth/qwen2.5-0.5b --dataset yahma/alpaca-cleaned --per_device_train_batch_size 1 --max_steps 50 --max_seq_length 2048 --sample_packing

Note that we use --per_device_train_batch_size 1 in the latter case since we are packing multiple examples into a single [1, max_seq_length] tensor.

The benefit of this approach is that we're able to discard a lot of zero padding, and therefore get higher token/s training throughput. The below plot shows that we're able to get through our dataset ~20% faster. These gains depend on the dataset and configured --max_seq_length; if we increase this we generally get better packing efficiency => higher throughput.

image

I manually tested on SDPA and flash attention, but I still need to test xformers attention since I couldn't get it to build for blackwell.

TODO

  • test xformers attention

@djsaunde djsaunde self-assigned this Oct 29, 2025
@djsaunde djsaunde changed the title Packing sample packing Oct 29, 2025
@djsaunde
Copy link
Collaborator Author

Follow up: DRY up attention code. We re-implement a big if / else block for selecting / running the attention per modeling file. We can factor this out into a separate module and call a helper function. CC @Datta0

@djsaunde djsaunde force-pushed the packing branch 2 times, most recently from c07d6bd to c23f676 Compare October 30, 2025 18:22
@djsaunde
Copy link
Collaborator Author

I added support for passing position IDs to RoPE (needed for correctness, just like attention), and a (fused QK) triton kernel for the RoPE embedding (similar to what exists currently for the non-packing case).

Benchmarks show we're competitive to the triton kernel for the non-packing case while numerical ~match and significantly beat the torch slow path:

RoPE kernel benchmark sweep (microseconds per call)

seqlen varlen dense old new speedup max abs Δ mean abs Δ
256 False 198.501
256 True 429.066 223.670 1.918 4.768e-07 1.136e-08
512 False 413.377
512 True 1149.956 566.851 2.029 4.768e-07 1.170e-08
1024 False 1113.990
1024 True 2784.808 1140.053 2.443 4.768e-07 1.187e-08
2048 False 2341.204
2048 True 5525.063 2372.505 2.329 4.768e-07 1.214e-08
4096 False 4675.885
4096 True 11354.554 4681.061 2.426 4.768e-07 1.239e-08
8192 False 9285.158
8192 True 21901.080 9323.563 2.349 4.768e-07 1.256e-08

@shimmyshimmer shimmyshimmer changed the title sample packing Uncontaminated packing Oct 30, 2025
@shimmyshimmer shimmyshimmer changed the title Uncontaminated packing Uncontaminated Sample Packing Oct 30, 2025
@djsaunde djsaunde changed the title Uncontaminated Sample Packing sample packing Oct 31, 2025
@djsaunde djsaunde changed the title sample packing Uncontaminated Sample Packing Oct 31, 2025
@djsaunde
Copy link
Collaborator Author

djsaunde commented Oct 31, 2025

I added helpers for attention backend selection / running that each of fast_forward methods call (+ units tests) as requested. This removed a lot of duplicate if / elif / else codeblocks in favor of a single attention_dispatch.py module.

@djsaunde
Copy link
Collaborator Author

djsaunde commented Nov 2, 2025

Tested and pushed a fix for xformers attention, this PR should be good to go now.

One open question: should we make sample packing the default for pretrain / SFT workloads? It should always work and provides better throughput than without. It's a bit of a shift though; it reshapes samples to [1, max_seq_length] so we discard per_device_train_batch_size in favor of just changing max_seq_length.

One option is just to reshape so samples have shape [1, max_seq_length * per_device_batch_size]. This allows us to keep per_device_batch_size > 1, but it's probably a little confusing for the user.

Another option is to strongly recommend using sample packing in a logged message on the command line (if not already enabled).

We can also explore this in a follow up PR if we don't want to make a decision now.

@djsaunde
Copy link
Collaborator Author

djsaunde commented Nov 4, 2025

I added some utils and updated the CLI to work OOTB with DDP. Just use torchrun --nproc_per_node=N or accelerate launch on a multi-GPU machine and it should just work.

These utils should be reusable in our notebooks / scripts too!

PS: DDP working relies on removing the @torch.compile decorator from unsloth_zoo/patch_torch_functions.py::cross_entropy as it results in a double compile somehow. I think @danielhanchen is fixing this.

@danielhanchen
Copy link
Contributor

danielhanchen commented Nov 4, 2025

Tested and pushed a fix for xformers attention, this PR should be good to go now.

One open question: should we make sample packing the default for pretrain / SFT workloads? It should always work and provides better throughput than without. It's a bit of a shift though; it reshapes samples to [1, max_seq_length] so we discard per_device_train_batch_size in favor of just changing max_seq_length.

One option is just to reshape so samples have shape [1, max_seq_length * per_device_batch_size]. This allows us to keep per_device_batch_size > 1, but it's probably a little confusing for the user.

Another option is to strongly recommend using sample packing in a logged message on the command line (if not already enabled).

We can also explore this in a follow up PR if we don't want to make a decision now.

Yes the goal is to allow the padding free collator then it auto gets a perf boost :) We can do this for the next PR if that helps

I also fixed the torch.compile issue for CE (verifying now)

@djsaunde
Copy link
Collaborator Author

djsaunde commented Nov 5, 2025

disabled the batch_size == 1 check; now if the user passes in batch_size > 1, we take advantage of TRL's logic to flatten from (batch_size, max_seq_length) to (1, total_tokens) (where total_tokens <= batch_size * max_seq_length). This makes it easier for folks to use without changing their batch_size / max_seq_length config.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants