Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ For multi-node distributed training, [this guide](/documentation/DISTRIBUTED.md)
## Features

- Multi-GPU training
- New token-wise dropout techniques like [TREAD](/documentation/TREAD.md) for speeding up Flux training, including Kontext
- Image, video, and caption features (embeds) are cached to the hard drive in advance, so that training runs faster and with less memory consumption
- Aspect bucketing: support for a variety of image/video sizes and aspect ratios, enabling widescreen and portrait training.
- Refiner LoRA or full u-net training for SDXL
Expand Down
286 changes: 286 additions & 0 deletions documentation/TREAD.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# TREAD Training Documentation

> ⚠️ **Experimental Feature**: TREAD support in SimpleTuner is newly implemented. While functional, optimal configurations are still being explored and some behaviors may change in future releases.

## Overview

TREAD (Token Routing for Efficient Architecture-agnostic Diffusion Training) is a training acceleration method that speeds up diffusion model training by intelligently routing tokens through transformer layers. By selectively processing only the most important tokens during certain layers, TREAD can significantly reduce computational costs while maintaining model quality.

Based on the research by [Krause et al. (2025)](https://arxiv.org/abs/2501.04765), TREAD achieves training speedups by:
- Dynamically selecting which tokens to process in each transformer layer
- Maintaining gradient flow through all tokens via skip connections
- Using importance-based routing decisions

The speedup is directly proportional to the `selection_ratio` - the closer to 1.0, the more tokens are dropped and the faster training becomes.

## How TREAD Works

### Core Concept

During training, TREAD:
1. **Routes tokens** - For specified transformer layers, it selects a subset of tokens to process based on their importance
2. **Processes subset** - Only the selected tokens go through the expensive attention and MLP operations
3. **Restores full sequence** - After processing, the full token sequence is restored with gradients flowing to all tokens

### Token Selection

Tokens are selected based on their L1-norm (importance score), with optional randomization for exploration:
- Higher importance tokens are more likely to be kept
- A mix of importance-based and random selection prevents overfitting to specific patterns
- Force-keep masks can ensure certain tokens (like masked regions) are never dropped

## Configuration

### Basic Setup

To enable TREAD training in SimpleTuner, add the following to your configuration:

```json
{
"tread_config": {
"routes": [
{
"selection_ratio": 0.5,
"start_layer_idx": 2,
"end_layer_idx": 5
}
]
}
}
```

### Route Configuration

Each route defines a window where token routing is active:
- `selection_ratio`: Fraction of tokens to drop (0.5 = keep 50% of tokens)
- `start_layer_idx`: First layer where routing begins (0-indexed)
- `end_layer_idx`: Last layer where routing is active

Negative indices are supported: `-1` refers to the last layer.

### Advanced Example

Multiple routing windows with different selection ratios:

```json
{
"tread_config": {
"routes": [
{
"selection_ratio": 0.3,
"start_layer_idx": 1,
"end_layer_idx": 3
},
{
"selection_ratio": 0.5,
"start_layer_idx": 4,
"end_layer_idx": 8
},
{
"selection_ratio": 0.7,
"start_layer_idx": -4,
"end_layer_idx": -1
}
]
}
}
```

## Compatibility

### Supported Models
- **FLUX** - Currently the only supported model family
- Future support planned for other diffusion transformers

### Works Well With
- **Masked Loss Training** - TREAD automatically preserves masked regions when combined with mask/segmentation conditioning
- **Multi-GPU Training** - Compatible with distributed training setups
- **Quantized Training** - Can be used with int8/int4/NF4 quantization

### Limitations
- Only active during training (not inference)
- Requires gradient computation (won't work in eval mode)
- Currently FLUX-specific implementation

## Performance Considerations

### Speed Benefits
- Training speedup is proportional to `selection_ratio` (closer to 1.0 = more tokens dropped = faster training)
- **Biggest speedups occur at 1024x1024 resolution and above** due to attention's O(n²) complexity
- Typically 20-40% speedup, but results vary based on configuration
- With masked loss training, speedup is reduced as masked tokens cannot be dropped

### Quality Trade-offs
- **Higher token dropping leads to higher initial loss** when starting LoRA/LoKr training
- The loss tends to correct fairly rapidly and images normalize quickly
- This may be the network adjusting to fewer tokens in intermediary layers
- Conservative ratios (0.3-0.5) typically maintain quality
- Aggressive ratios (>0.7) may impact convergence

### LoRA-specific Considerations
- Some users report slight slowdowns with LoRA training
- This may be configuration-dependent - optimal routing configs need more exploration
- Initial loss spike is more noticeable with LoRA/LoKr than full fine-tuning

### Recommended Settings

For balanced speed/quality:
```json
{
"routes": [
{"selection_ratio": 0.5, "start_layer_idx": 2, "end_layer_idx": -2}
]
}
```

For maximum speed (expect initial loss spike):
```json
{
"routes": [
{"selection_ratio": 0.7, "start_layer_idx": 1, "end_layer_idx": -1}
]
}
```

For high-resolution training (1024px+):
```json
{
"routes": [
{"selection_ratio": 0.6, "start_layer_idx": 2, "end_layer_idx": -3}
]
}
```

## Technical Details

### Router Implementation

The TREAD router (`TREADRouter` class) handles:
- Token importance calculation via L1-norm
- Permutation generation for efficient routing
- Gradient-preserving token restoration

### Integration with Attention

TREAD modifies the rotary position embeddings (RoPE) to match the routed sequence:
- Text tokens maintain original positions
- Image tokens use shuffled/sliced positions
- Ensures positional consistency during routing
- **Note**: The RoPE implementation for FLUX may not be 100% correct but appears functional in practice

### Masked Loss Compatibility

When using masked loss training:
- Tokens within the mask are automatically force-kept
- Prevents important training signal from being dropped
- Activated via `conditioning_type` in ["mask", "segmentation"]
- **Note**: This reduces speedup as more tokens must be processed

## Known Issues and Limitations

### Implementation Status
- **Experimental feature** - TREAD support is newly implemented and may have undiscovered issues
- **RoPE handling** - The rotary position embedding implementation for token routing may not be perfectly correct
- **Limited testing** - Optimal routing configurations haven't been extensively explored

### Training Behavior
- **Initial loss spike** - When starting LoRA/LoKr training with TREAD, expect higher initial loss that corrects rapidly
- **LoRA performance** - Some configurations may show slight slowdowns with LoRA training
- **Configuration sensitivity** - Performance highly depends on routing configuration choices

### Known Bugs (Fixed)
- Masked loss training was broken in earlier versions but has been fixed with proper model flavor checking (`kontext` guard)

## Troubleshooting

### Common Issues

**"TREAD training requires you to configure the routes"**
- Ensure `tread_config` includes a `routes` array
- Each route needs `selection_ratio`, `start_layer_idx`, and `end_layer_idx`

**Slower training than expected**
- Verify routes cover meaningful layer ranges
- Consider more aggressive selection ratios
- Check that gradient checkpointing isn't conflicting
- For LoRA training, some slowdown is expected - try different routing configs

**High initial loss with LoRA/LoKr**
- This is expected behavior - the network needs to adapt to fewer tokens
- Loss typically corrects within a few hundred steps
- If loss doesn't improve, reduce `selection_ratio` (keep more tokens)

**Quality degradation**
- Reduce selection ratios (keep more tokens)
- Avoid routing in early layers (0-2) or final layers
- Ensure sufficient training data for the increased efficiency

## Practical Examples

### High-Resolution Training (1024px+)
For maximum benefit when training at high resolutions:
```json
{
"tread_config": {
"routes": [
{"selection_ratio": 0.6, "start_layer_idx": 2, "end_layer_idx": -3}
]
}
}
```

### LoRA Fine-tuning
Conservative config to minimize initial loss spike:
```json
{
"tread_config": {
"routes": [
{"selection_ratio": 0.4, "start_layer_idx": 3, "end_layer_idx": -4}
]
}
}
```

### Masked Loss Training
When training with masks, tokens in masked regions are preserved:
```json
{
"tread_config": {
"routes": [
{"selection_ratio": 0.7, "start_layer_idx": 2, "end_layer_idx": -2}
]
}
}
```
Note: Actual speedup will be less than 0.7 suggests due to forced token preservation.

## Future Work

As TREAD support in SimpleTuner is newly implemented, there are several areas for future improvement:

- **Configuration optimization** - More testing needed to find optimal routing configurations for different use cases
- **LoRA performance** - Investigation into why some LoRA configurations show slowdowns
- **RoPE implementation** - Refinement of the rotary position embedding handling for better correctness
- **Extended model support** - Implementation for other diffusion transformer architectures beyond Flux
- **Automated configuration** - Tools to automatically determine optimal routing based on model and dataset characteristics

Community contributions and testing results are welcome to help improve TREAD support.

## References

- [TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training](https://arxiv.org/abs/2501.04765)
- [SimpleTuner Flux Documentation](/documentation/quickstart/FLUX.md#tread-training)

## Citation

```bibtex
@misc{krause2025treadtokenroutingefficient,
title={TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training},
author={Felix Krause and Timy Phan and Vincent Tao Hu and Björn Ommer},
year={2025},
eprint={2501.04765},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2501.04765},
}
```
56 changes: 55 additions & 1 deletion documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,60 @@ If VRAM is not a concern (eg. 48G or greater) then int8 with torch.compile is yo

If you are training a subject or style and would like to mask one or the other, see the [masked loss training](/documentation/DREAMBOOTH.md#masked-loss) section of the Dreambooth guide.

### TREAD training

> ⚠️ **Experimental**: TREAD is a newly implemented feature. While functional, optimal configurations are still being explored.

[TREAD](/documentation/TREAD.md) (paper) stands for **T**oken **R**outing for **E**fficient **A**rchitecture-agnostic **D**iffusion. It is a method that can accelerate Flux training by intelligently routing tokens through transformer layers. The speedup is proportional to how many tokens you drop.

#### Quick setup

Add this to your `config.json`:

```json
{
"tread_config": {
"routes": [
{
"selection_ratio": 0.5,
"start_layer_idx": 2,
"end_layer_idx": -2
}
]
}
}
```

This configuration will:
- Keep only 50% of image tokens during layers 2 through second-to-last
- Text tokens are never dropped
- Training speedup of ~25% with minimal quality impact

#### Key points

- **Currently Flux-only** - TREAD is only implemented for Flux models
- **Best at high resolutions** - Biggest speedups at 1024x1024+ due to attention's O(n²) complexity
- **Compatible with masked loss** - Masked regions are automatically preserved (but this reduces speedup)
- **Works with quantization** - Can be combined with int8/int4/NF4 training
- **Expect initial loss spike** - When starting LoRA/LoKr training, loss will be higher initially but corrects quickly

#### Tuning tips

- **Conservative (quality-focused)**: Use `selection_ratio` of 0.3-0.5
- **Aggressive (speed-focused)**: Use `selection_ratio` of 0.6-0.8
- **Avoid early/late layers**: Don't route in layers 0-1 or the final layer
- **For LoRA training**: May see slight slowdowns - experiment with different configs
- **Higher resolution = better speedup**: Most beneficial at 1024px and above

#### Known behavior

- The more tokens dropped (higher `selection_ratio`), the faster training but higher initial loss
- LoRA/LoKr training shows an initial loss spike that rapidly corrects as the network adapts
- Some LoRA configurations may train slightly slower - optimal configs still being explored
- The RoPE (rotary position embedding) implementation is functional but may not be 100% correct

For detailed configuration options and troubleshooting, see the [full TREAD documentation](/documentation/TREAD.md).

### Classifier-free guidance

#### Problem
Expand Down Expand Up @@ -556,4 +610,4 @@ The users of [Terminus Research](https://huggingface.co/terminusresearch) who wo

[Lambda Labs](https://lambdalabs.com) for generous compute allocations that were used for tests and verifications for large scale training runs

Especially [@JimmyCarter](https://huggingface.co/jimmycarter) and [@kaibioinfo](https://github.com/kaibioinfo) for coming up with some of the best ideas and putting them into action, offering pull requests and running exhaustive tests for analysis - even daring to use _their own faces_ for DreamBooth experimentation.
Especially [@JimmyCarter](https://huggingface.co/jimmycarter) (incl TREAD addition) and [@kaibioinfo](https://github.com/kaibioinfo) for coming up with some of the best ideas and putting them into action, offering pull requests and running exhaustive tests for analysis - even daring to use _their own faces_ for DreamBooth experimentation.
Loading