Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,29 +150,28 @@ The blog article is also available: [Supercharge Your AIGC Experience: Leverage

### 1. Install from pip

We set `diffusers` and `flash_attn` as two optional installation requirements.

About `diffusers` version:
- If you only use the USP interface, `diffusers` is not required. Models are typically released as `nn.Module`
first, before being integrated into diffusers. xDiT sometimes is applied as an USP plugin to existing projects.
- Different models may require different diffusers versions. Model implementations can vary between diffusers versions, especially for latest models, which affects parallel processing. When encountering model execution errors, you may need to try several recent diffusers versions.
- While we specify a diffusers version in `setup.py`, newer models may require later versions or even need to be installed from main branch.
We set `flash_attn` as optional installation requirement.

About `flash_attn` version:
- Without `flash_attn` installed, xDiT falls back to a PyTorch implementation of ring attention, which helps NPU users with compatibility
- However, not using `flash_attn` on GPUs may result in suboptimal performance. For best GPU performance, we strongly recommend installing `flash_attn`.

About `diffusers` version:
- Different models may require different diffusers versions. Model implementations can vary between diffusers versions, especially for latest models, which affects parallel processing. When encountering model execution errors, you may need to try several recent diffusers versions.
- While we specify a diffusers version in `setup.py`, newer models may require later versions or even need to be installed from main branch.
- Limited list of validated diffusers versions can be seen [here](#6-limitations)

```
pip install xfuser # Basic installation
pip install "xfuser[diffusers,flash-attn]" # With both diffusers and flash attention
pip install "xfuser[flash-attn]" # With flash attention
```

### 2. Install from source

```
pip install -e .
# Or optionally, with diffusers
pip install -e ".[diffusers,flash-attn]"
# Or optionally, with flash attention
pip install -e ".[flash-attn]"
```

Note that we use two self-maintained packages:
Expand Down Expand Up @@ -226,6 +225,15 @@ You can also launch an HTTP service to generate images with xDiT.

### 6. Limitations

#### Diffusers version

Below is a list of validated diffusers version requirements. If the model is not in the list, you may need to try several diffusers versions to find a working configuration.

| Model Name | Diffusers version |
| --- | --- |
| [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) | >= 0.35.2 |


#### HunyuanVideo

- Supports `diffusers<=0.32.2` (breaking commit diffusers @ [8907a70](https://github.com/huggingface/diffusers/commit/8907a70a366c96b2322656f57b24e442ea392c7b))
Expand Down
7 changes: 7 additions & 0 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import logging
import time
import torch
import diffusers
import torch.distributed
from xfuser.config.diffusers import has_valid_diffusers_version, get_minimum_diffusers_version

if not has_valid_diffusers_version("flux"):
minimum_diffusers_version = get_minimum_diffusers_version("flux")
raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Flux.")

from transformers import T5EncoderModel
from xfuser import xFuserFluxPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
Expand Down
25 changes: 16 additions & 9 deletions examples/flux_usp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py

import functools
from typing import List, Optional

import logging
import time
import torch
from xfuser.config.diffusers import has_valid_diffusers_version, get_minimum_diffusers_version
from typing import List, Optional

if not has_valid_diffusers_version("flux"):
minimum_diffusers_version = get_minimum_diffusers_version("flux")
raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Flux.")

from diffusers import DiffusionPipeline, FluxPipeline

from xfuser import xFuserArgs
Expand All @@ -27,7 +33,7 @@
get_pipeline_parallel_world_size,
)

from xfuser.model_executor.layers.attention_processor import xFuserFluxAttnProcessor2_0
from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxAttnProcessor

def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer
Expand All @@ -50,7 +56,7 @@ def new_forward(
get_runtime_state().split_text_embed_in_sp = False
else:
get_runtime_state().split_text_embed_in_sp = True

if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]:
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
Expand All @@ -61,10 +67,8 @@ def new_forward(
img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
if get_runtime_state().split_text_embed_in_sp:
txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]

for block in transformer.transformer_blocks + transformer.single_transformer_blocks:
block.attn.processor = xFuserFluxAttnProcessor2_0()



output = original_forward(
hidden_states,
encoder_hidden_states,
Expand All @@ -86,6 +90,9 @@ def new_forward(
new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward

for block in transformer.transformer_blocks + transformer.single_transformer_blocks:
block.attn.processor = xFuserFluxAttnProcessor()


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
Expand Down Expand Up @@ -119,7 +126,7 @@ def main():
max_condition_sequence_length=512,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)

parallelize_transformer(pipe)

if engine_config.runtime_config.use_torch_compile:
Expand All @@ -139,7 +146,7 @@ def main():

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
height=input_config.height,
width=input_config.width,
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ def get_cuda_version():
"distvae",
"yunchang>=0.6.0",
"einops",
"diffusers>=0.33.0",
],
extras_require={
"diffusers": [
"diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
],
"flash-attn": [
"flash-attn>=2.6.0", # NOTE: flash-attn is necessary if ring_degree > 1
],
Expand Down
16 changes: 16 additions & 0 deletions xfuser/config/diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import diffusers
from packaging import version

DEFAULT_MINIMUM_DIFFUSERS_VERSION = "0.33.0"
MINIMUM_DIFFUSERS_VERSIONS = {
"flux": "0.35.2",
}

def has_valid_diffusers_version(model_name: str|None = None) -> bool:
diffusers_version = diffusers.__version__
minimum_diffusers_version = MINIMUM_DIFFUSERS_VERSIONS.get(model_name, DEFAULT_MINIMUM_DIFFUSERS_VERSION)
return version.parse(diffusers_version) >= version.parse(minimum_diffusers_version)


def get_minimum_diffusers_version(model_name: str|None = None) -> str:
return MINIMUM_DIFFUSERS_VERSIONS.get(model_name, DEFAULT_MINIMUM_DIFFUSERS_VERSION)
10 changes: 6 additions & 4 deletions xfuser/model_executor/cache/diffusers_adapters/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
adapted from https://github.com/ali-vilab/TeaCache.git
adapted from https://github.com/chengzeyi/ParaAttention.git
"""
from xfuser.config.diffusers import has_valid_diffusers_version
from typing import Type, Dict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper

TRANSFORMER_ADAPTER_REGISTRY: Dict[Type, str] = {}

def register_transformer_adapter(transformer_class: Type, adapter_name: str):
TRANSFORMER_ADAPTER_REGISTRY[transformer_class] = adapter_name

register_transformer_adapter(FluxTransformer2DModel, "flux")
register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux")
if has_valid_diffusers_version("flux"):
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper
register_transformer_adapter(FluxTransformer2DModel, "flux")
register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux")

2 changes: 2 additions & 0 deletions xfuser/model_executor/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from .register import xFuserLayerWrappersRegister
from .base_layer import xFuserLayerBaseWrapper
from .attention_processor import xFuserAttentionWrapper
from .attention_processor import xFuserAttentionBaseWrapper
from .conv import xFuserConv2dWrapper
from .embeddings import xFuserPatchEmbedWrapper
from .feedforward import xFuserFeedForwardWrapper

__all__ = [
"xFuserLayerWrappersRegister",
"xFuserLayerBaseWrapper",
"xFuserAttentionBaseWrapper",
"xFuserAttentionWrapper",
"xFuserConv2dWrapper",
"xFuserPatchEmbedWrapper",
Expand Down
Loading