Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Check out the accompanying blog post [here](https://pytorch.org/blog/presenting-

**Updates**

**July 1, 2025**: This repository now supports AMD MI300X GPUs using AITER kernels [(PR)](https://github.com/huggingface/flux-fast/pull/10). The README has been updated to provide instructions on how to run on AMD GPUs.

**June 28, 2025**: This repository now supports [Flux.1 Kontext Dev](https://hf.co/black-forest-labs/FLUX.1-Kontext-dev). We enabled ~2.5x speedup on it. Check out [this section](#flux1-kontext-dev) for more details.

## Results
Expand Down Expand Up @@ -73,29 +75,47 @@ Here are some example outputs with Flux.1-Schnell for prompt `"A cat playing wit
We rely primarily on pure PyTorch for the optimizations. Currently, a relatively recent nightly version of PyTorch is required.

The numbers reported here were gathered using:

For NVIDIA:
* `torch==2.8.0.dev20250605+cu126` - note that we rely on some fixes since 2.7
* `torchao==0.12.0.dev20250610+cu126` - note that we rely on a fix in the 06/10 nightly
* `diffusers` - with [this fix](https://github.com/huggingface/diffusers/pull/11696) included
* `flash_attn_3==3.0.0b1`

To install deps:
For AMD:
* `torch==2.8.0.dev20250605+rocm6.4` - note that we rely on some fixes since 2.7
* `torchao==0.12.0.dev20250610+rocm6.4` - note that we rely on a fix in the 06/10 nightly
* `diffusers` - with [this fix](https://github.com/huggingface/diffusers/pull/11696) included
* `aiter-0.1.4.dev17+gd0384d4`

To install deps on NVIDIA:
```
pip install -U huggingface_hub[hf_xet] accelerate transformers
pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --pre torchao==0.12.0.dev20250609+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
```

To install flash attention v3, follow the instructions in https://github.com/Dao-AILab/flash-attention#flashattention-3-beta-release.
(For NVIDIA) To install flash attention v3, follow the instructions in https://github.com/Dao-AILab/flash-attention#flashattention-3-beta-release.

To install deps on AMD:
```
pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4
pip install --pre torchao==0.12.0.dev20250609+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4
pip install git+https://github.com/ROCm/aiter
```

For hardware, we used a 96GB 700W H100 GPU. Some of the optimizations applied (BFloat16, torch.compile, Combining q,k,v projections, dynamic float8 quantization) are available on CPU as well.
(For AMD) Instead of flash attention v3, we use (AITER)[https://github.com/ROCm/aiter]. It provides the required fp8 MHA kernels

For hardware, we used a 96GB 700W H100 GPU and 192GB MI300X GPU. Some of the optimizations applied (BFloat16, torch.compile, Combining q,k,v projections, dynamic float8 quantization) are available on CPU as well.

## Run the optimized pipeline

On NVIDIA:
```sh
python gen_image.py --prompt "An astronaut standing next to a giant lemon" --output-file output.png --use-cached-model
```

This will include all optimizations and will attempt to use pre-cached binary models
generated via `torch.export` + AOTI. To generate these binaries for subsequent runs, run
the above command without the `--use-cached-model` flag.
Expand All @@ -108,6 +128,13 @@ the above command without the `--use-cached-model` flag.
> different environment than the one present at runtime. The PyTorch Compiler team is working on
> solutions for more portable binaries / artifact caching.

On AMD:
```sh
python gen_image.py --prompt "A cat playing with a ball of yarn" --output-file output.png --compile_export_mode compile
```
Currently, only torch.export is not working as expected. Instead, use `torch.compile` as shown in the above command.


## Benchmarking
[`run_benchmark.py`](./run_benchmark.py) is the main script for benchmarking the different optimization techniques.
Usage:
Expand Down Expand Up @@ -326,7 +353,7 @@ image = pipe(prompt, num_inference_steps=4).images[0]
</details>

<details>
<summary>Flash Attention V3</summary>
<summary>Flash Attention V3 / aiter</summary>

Flash Attention V3 is substantially faster on H100s than the previous iteration FA2, due
in large part to float8 support. As this kernel isn't quite available yet within PyTorch Core, we implement a custom
Expand All @@ -335,6 +362,8 @@ image = pipe(prompt, num_inference_steps=4).images[0]
the op integrates well with `torch.compile` / `torch.export`. Inputs are converted to float8 in an unscaled fashion before
kernel invocation and outputs are converted back to the original dtype on the way out.

On AMD GPUs, we use [`aiter`](https://github.com/ROCm/aiter) instead, which also provides fp8 MHA kernels.

```python
from diffusers import FluxPipeline

Expand Down
56 changes: 39 additions & 17 deletions utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from PIL import Image
import inspect

def is_hip():
return torch.version.hip is not None


@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
def flash_attn_func(
Expand Down Expand Up @@ -34,11 +37,12 @@ def flash_attn_func(
else:
window_size = tuple(window_size)

import flash_attn_interface

dtype = torch.float8_e4m3fn
if is_hip():
from aiter.ops.triton.mha import flash_attn_fp8_func as flash_attn_interface_func
else:
from flash_attn.flash_attn_interface import flash_attn_interface_func

sig = inspect.signature(flash_attn_interface.flash_attn_func)
sig = inspect.signature(flash_attn_interface_func)
accepted = set(sig.parameters)
all_kwargs = {
"softmax_scale": softmax_scale,
Expand All @@ -57,11 +61,19 @@ def flash_attn_func(
}
kwargs = {k: v for k, v in all_kwargs.items() if k in accepted}

outputs = flash_attn_interface.flash_attn_func(
q.to(dtype), k.to(dtype), v.to(dtype), **kwargs,
)
return outputs[0]
if is_hip():
# For AMD, AITER fp8 kernels take in fp32 inputs and converts it to fp8 by itself
# So we don't need to convert to fp8 here
outputs = flash_attn_interface_func(
q, k, v, **kwargs,
)
else:
dtype = torch.float8_e4m3fn
outputs = flash_attn_interface_func(
q.to(dtype), k.to(dtype), v.to(dtype), **kwargs,
)[0]

return outputs.contiguous().to(torch.bfloat16) if is_hip() else outputs

@flash_attn_func.register_fake
def _(q, k, v, **kwargs):
Expand All @@ -71,18 +83,26 @@ def _(q, k, v, **kwargs):
meta_q = torch.empty_like(q).contiguous()
return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)


# Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
class FlashFusedFluxAttnProcessor3_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

def __init__(self):
try:
import flash_attn_interface
except ImportError:
raise ImportError(
"flash_attention v3 package is required to be installed"
)

if is_hip():
try:
from aiter.ops.triton.mha import flash_attn_fp8_func as flash_attn_interface_func
except ImportError:
raise ImportError(
"aiter is required to be installed"
)
else:
try:
from flash_attn.flash_attn_interface import flash_attn_interface_func
except ImportError:
raise ImportError(
"flash_attention v3 package is required to be installed"
)

def __call__(
self,
Expand Down Expand Up @@ -214,11 +234,13 @@ def wrapped(*args, **kwargs):
def use_compile(pipeline):
# Compile the compute-intensive portions of the model: denoising transformer / decoder
is_kontext = "Kontext" in pipeline.__class__.__name__
# For AMD MI300X w/ the AITER kernels, the default dynamic=None is not working as expected, giving black results.
# Therefore, we use dynamic=True for AMD only. This leads to a small perf penalty, but should be fixed eventually.
pipeline.transformer = torch.compile(
pipeline.transformer, mode="max-autotune", fullgraph=True
pipeline.transformer, mode="max-autotune", fullgraph=True, dynamic=True if is_hip() else None
)
pipeline.vae.decode = torch.compile(
pipeline.vae.decode, mode="max-autotune", fullgraph=True
pipeline.vae.decode, mode="max-autotune", fullgraph=True, dynamic=True if is_hip() else None
)

# warmup for a few iterations (`num_inference_steps` shouldn't matter)
Expand Down