From db8fa35ddfe5f3869653247274fa8abe1512249a Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Thu, 9 Oct 2025 12:15:20 +0000 Subject: [PATCH 01/15] Parallelize new FluxAttention and FluxAttnProcessor --- .../layers/attention_processor.py | 10 +- .../models/transformers/transformer_flux.py | 158 +++++++++++++++++- 2 files changed, 157 insertions(+), 11 deletions(-) diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 0f410e23..080119aa 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -222,7 +222,7 @@ def __init__(self): HAS_LONG_CTX_ATTN and use_long_ctx_attn_kvcache and get_sequence_parallel_world_size() > 1 - ) + ) set_hybrid_seq_parallel_attn(self, self.use_long_ctx_attn_kvcache) if get_fast_attn_enable(): @@ -450,7 +450,7 @@ def __call__( query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - + inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1510,7 +1510,7 @@ def __call__( class xFuserSanaAttnProcessor2_0(SanaAttnProcessor2_0): def __init__(self): super().__init__() - + def __call__( self, attn: Attention, @@ -1627,7 +1627,7 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - + batch_size = encoder_hidden_states.size(0) query = attn.to_q(hidden_states) @@ -1638,7 +1638,7 @@ def __call__( query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) - + inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads diff --git a/xfuser/model_executor/models/transformers/transformer_flux.py b/xfuser/model_executor/models/transformers/transformer_flux.py index 91869552..b266e059 100644 --- a/xfuser/model_executor/models/transformers/transformer_flux.py +++ b/xfuser/model_executor/models/transformers/transformer_flux.py @@ -1,10 +1,15 @@ -from typing import Optional, Dict, Any, Union +import inspect import torch import torch.distributed import torch.nn as nn +from typing import Optional, Dict, Any, Union -from diffusers.models.embeddings import PatchEmbed -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +#from diffusers.models.embeddings import PatchEmbed +from diffusers.models.transformers.transformer_flux import ( + FluxTransformer2DModel, + FluxAttnProcessor, + FluxAttention, +) from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput from diffusers.utils import ( is_torch_version, @@ -12,26 +17,167 @@ USE_PEFT_BACKEND, unscale_lora_layers, ) +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import apply_rotary_emb from xfuser.core.distributed.parallel_state import ( get_tensor_model_parallel_world_size, is_pipeline_first_stage, is_pipeline_last_stage, ) -from xfuser.core.distributed.runtime_state import get_runtime_state +from xfuser.core.distributed import ( + get_sequence_parallel_world_size, +) + from xfuser.logger import init_logger +from xfuser.envs import PACKAGES_CHECKER from xfuser.model_executor.models.transformers.register import ( xFuserTransformerWrappersRegister, ) from xfuser.model_executor.models.transformers.base_transformer import ( xFuserTransformerBaseWrapper, ) +from xfuser.model_executor.layers import xFuserLayerWrappersRegister +from xfuser.model_executor.layers.attention_processor import ( + set_hybrid_seq_parallel_attn, + xFuserAttentionBaseWrapper, + xFuserAttentionProcessorRegister +) +from xfuser.model_executor.layers.usp import USP logger = init_logger(__name__) -from diffusers.models.attention import FeedForward + +env_info = PACKAGES_CHECKER.get_packages_info() +HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"] + + +def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + +@xFuserLayerWrappersRegister.register(FluxAttention) +class xFuserFluxAttentionWrapper(xFuserAttentionBaseWrapper): + def __init__( + self, + attention: FluxAttention, + ): + super().__init__(attention=attention) + self.processor = xFuserAttentionProcessorRegister.get_processor( + attention.processor + )() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + +xFuserAttentionProcessorRegister.register(FluxAttnProcessor) +class xFuserFluxAttnProcessor(FluxAttnProcessor): + + def __init__(self): + super().__init__() + use_long_ctx_attn_kvcache = True + self.use_long_ctx_attn_kvcache = ( + HAS_LONG_CTX_ATTN + and use_long_ctx_attn_kvcache + and get_sequence_parallel_world_size() > 1 + ) + set_hybrid_seq_parallel_attn(self, self.use_long_ctx_attn_kvcache) + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = self.hybrid_seq_parallel_attn(None, query, key, value) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + + -@xFuserTransformerWrappersRegister.register(FluxTransformer2DModel) +#@xFuserTransformerWrappersRegister.register(FluxTransformer2DModel) class xFuserFluxTransformer2DWrapper(xFuserTransformerBaseWrapper): def __init__( self, From 219c1c61bb6999190880dcb80718e6f56df14b64 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Fri, 10 Oct 2025 11:42:07 +0000 Subject: [PATCH 02/15] Fix Flux transformer for flux_example.py --- .../models/transformers/transformer_flux.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/xfuser/model_executor/models/transformers/transformer_flux.py b/xfuser/model_executor/models/transformers/transformer_flux.py index b266e059..bfc0d3a0 100644 --- a/xfuser/model_executor/models/transformers/transformer_flux.py +++ b/xfuser/model_executor/models/transformers/transformer_flux.py @@ -109,7 +109,7 @@ def forward( kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) -xFuserAttentionProcessorRegister.register(FluxAttnProcessor) +@xFuserAttentionProcessorRegister.register(FluxAttnProcessor) class xFuserFluxAttnProcessor(FluxAttnProcessor): def __init__(self): @@ -177,7 +177,7 @@ def __call__( -#@xFuserTransformerWrappersRegister.register(FluxTransformer2DModel) +@xFuserTransformerWrappersRegister.register(FluxTransformer2DModel) class xFuserFluxTransformer2DWrapper(xFuserTransformerBaseWrapper): def __init__( self, @@ -324,7 +324,6 @@ def custom_forward(*inputs): # hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] # if self.stage_info.after_flags["transformer_blocks"]: - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: @@ -341,17 +340,19 @@ def custom_forward(*inputs): ckpt_kwargs: Dict[str, Any] = ( {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ) - hidden_states = torch.utils.checkpoint.checkpoint( + encoder_hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, + encoder_hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) @@ -365,6 +366,8 @@ def custom_forward(*inputs): # + controlnet_single_block_samples[index_block // interval_control] # ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...] hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] From 3508a913bd47ecfe13cc55fbd6c08891cc60d3f1 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Mon, 13 Oct 2025 07:51:25 +0000 Subject: [PATCH 03/15] Import qkv projection functions instead of copying them from diffusers --- .../models/transformers/transformer_flux.py | 30 +------------------ 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/xfuser/model_executor/models/transformers/transformer_flux.py b/xfuser/model_executor/models/transformers/transformer_flux.py index bfc0d3a0..805fa59b 100644 --- a/xfuser/model_executor/models/transformers/transformer_flux.py +++ b/xfuser/model_executor/models/transformers/transformer_flux.py @@ -6,6 +6,7 @@ #from diffusers.models.embeddings import PatchEmbed from diffusers.models.transformers.transformer_flux import ( + _get_qkv_projections, FluxTransformer2DModel, FluxAttnProcessor, FluxAttention, @@ -51,35 +52,6 @@ HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"] -def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - - return query, key, value, encoder_query, encoder_key, encoder_value - - -def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): - query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) - - encoder_query = encoder_key = encoder_value = (None,) - if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) - - return query, key, value, encoder_query, encoder_key, encoder_value - - -def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): - if attn.fused_projections: - return _get_fused_projections(attn, hidden_states, encoder_hidden_states) - return _get_projections(attn, hidden_states, encoder_hidden_states) - @xFuserLayerWrappersRegister.register(FluxAttention) class xFuserFluxAttentionWrapper(xFuserAttentionBaseWrapper): def __init__( From 1f655c848fcc93dbac4b3f75f51192f81108019d Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Thu, 23 Oct 2025 11:52:06 +0000 Subject: [PATCH 04/15] Add diffusers version gating for Flux --- examples/flux_example.py | 6 ++++++ examples/flux_usp_example.py | 8 +++++++- .../cache/diffusers_adapters/registry.py | 11 +++++++---- xfuser/model_executor/models/transformers/__init__.py | 11 ++++++++--- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/examples/flux_example.py b/examples/flux_example.py index 9e0ac849..3aafe183 100644 --- a/examples/flux_example.py +++ b/examples/flux_example.py @@ -1,7 +1,13 @@ import logging import time import torch +import diffusers import torch.distributed +from packaging import version + +if version.parse(diffusers.__version__) < version.parse("0.35.2"): + raise ImportError("Please install diffusers>=0.35.2 to use Flux.") + from transformers import T5EncoderModel from xfuser import xFuserFluxPipeline, xFuserArgs from xfuser.config import FlexibleArgumentParser diff --git a/examples/flux_usp_example.py b/examples/flux_usp_example.py index 1f3287cd..b75c8139 100644 --- a/examples/flux_usp_example.py +++ b/examples/flux_usp_example.py @@ -2,11 +2,17 @@ # from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py import functools -from typing import List, Optional import logging import time import torch +import diffusers +from packaging import version +from typing import List, Optional + +if version.parse(diffusers.__version__) < version.parse("0.35.2"): + raise ImportError("Please install diffusers>=0.35.2 to use Flux.") + from diffusers import DiffusionPipeline, FluxPipeline from xfuser import xFuserArgs diff --git a/xfuser/model_executor/cache/diffusers_adapters/registry.py b/xfuser/model_executor/cache/diffusers_adapters/registry.py index 2232832a..4ad9a156 100644 --- a/xfuser/model_executor/cache/diffusers_adapters/registry.py +++ b/xfuser/model_executor/cache/diffusers_adapters/registry.py @@ -2,15 +2,18 @@ adapted from https://github.com/ali-vilab/TeaCache.git adapted from https://github.com/chengzeyi/ParaAttention.git """ +import diffusers +from packaging import version from typing import Type, Dict -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper TRANSFORMER_ADAPTER_REGISTRY: Dict[Type, str] = {} def register_transformer_adapter(transformer_class: Type, adapter_name: str): TRANSFORMER_ADAPTER_REGISTRY[transformer_class] = adapter_name -register_transformer_adapter(FluxTransformer2DModel, "flux") -register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux") +if version.parse(diffusers.__version__) >= version.parse("0.35.2"): + from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel + from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper + register_transformer_adapter(FluxTransformer2DModel, "flux") + register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux") diff --git a/xfuser/model_executor/models/transformers/__init__.py b/xfuser/model_executor/models/transformers/__init__.py index 96e3ae8e..ac16e1eb 100644 --- a/xfuser/model_executor/models/transformers/__init__.py +++ b/xfuser/model_executor/models/transformers/__init__.py @@ -1,8 +1,9 @@ +import diffusers +from packaging import version from .register import xFuserTransformerWrappersRegister from .base_transformer import xFuserTransformerBaseWrapper from .pixart_transformer_2d import xFuserPixArtTransformer2DWrapper from .transformer_sd3 import xFuserSD3Transformer2DWrapper -from .transformer_flux import xFuserFluxTransformer2DWrapper from .latte_transformer_3d import xFuserLatteTransformer3DWrapper from .hunyuan_transformer_2d import xFuserHunyuanDiT2DWrapper from .cogvideox_transformer_3d import xFuserCogVideoXTransformer3DWrapper @@ -14,10 +15,14 @@ "xFuserTransformerBaseWrapper", "xFuserPixArtTransformer2DWrapper", "xFuserSD3Transformer2DWrapper", - "xFuserFluxTransformer2DWrapper", "xFuserLatteTransformer3DWrapper", "xFuserCogVideoXTransformer3DWrapper", "xFuserHunyuanDiT2DWrapper", "xFuserConsisIDTransformer3DWrapper", "xFuserSanaTransformer2DWrapper" -] \ No newline at end of file +] + +# Gating some imports based on diffusers version, as they import part of diffusers +if version.parse(diffusers.__version__) >= version.parse("0.35.2"): + from .transformer_flux import xFuserFluxTransformer2DWrapper + __all__.append("xFuserFluxTransformer2DWrapper") \ No newline at end of file From b519f01ea5fe23575fe62b9c7a00f56857ca17a5 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Thu, 23 Oct 2025 11:52:49 +0000 Subject: [PATCH 05/15] Use new Flux attn processor in flux_usp example as well --- examples/flux_usp_example.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/flux_usp_example.py b/examples/flux_usp_example.py index b75c8139..a0d1cf94 100644 --- a/examples/flux_usp_example.py +++ b/examples/flux_usp_example.py @@ -33,7 +33,7 @@ get_pipeline_parallel_world_size, ) -from xfuser.model_executor.layers.attention_processor import xFuserFluxAttnProcessor2_0 +from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxAttnProcessor def parallelize_transformer(pipe: DiffusionPipeline): transformer = pipe.transformer @@ -56,7 +56,7 @@ def new_forward( get_runtime_state().split_text_embed_in_sp = False else: get_runtime_state().split_text_embed_in_sp = True - + if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] @@ -67,10 +67,8 @@ def new_forward( img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] if get_runtime_state().split_text_embed_in_sp: txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] - - for block in transformer.transformer_blocks + transformer.single_transformer_blocks: - block.attn.processor = xFuserFluxAttnProcessor2_0() - + + output = original_forward( hidden_states, encoder_hidden_states, @@ -92,6 +90,9 @@ def new_forward( new_forward = new_forward.__get__(transformer) transformer.forward = new_forward + for block in transformer.transformer_blocks + transformer.single_transformer_blocks: + block.attn.processor = xFuserFluxAttnProcessor() + def main(): parser = FlexibleArgumentParser(description="xFuser Arguments") @@ -125,7 +126,7 @@ def main(): max_condition_sequence_length=512, split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, ) - + parallelize_transformer(pipe) if engine_config.runtime_config.use_torch_compile: @@ -145,7 +146,7 @@ def main(): torch.cuda.reset_peak_memory_stats() start_time = time.time() - + output = pipe( height=input_config.height, width=input_config.width, From 0a5e126a1b51116af2e8f3df78f834292b64ee57 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Thu, 23 Oct 2025 11:57:12 +0000 Subject: [PATCH 06/15] Add missing return parameter --- xfuser/model_executor/models/transformers/transformer_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xfuser/model_executor/models/transformers/transformer_flux.py b/xfuser/model_executor/models/transformers/transformer_flux.py index 805fa59b..5850f12d 100644 --- a/xfuser/model_executor/models/transformers/transformer_flux.py +++ b/xfuser/model_executor/models/transformers/transformer_flux.py @@ -312,7 +312,7 @@ def custom_forward(*inputs): ckpt_kwargs: Dict[str, Any] = ( {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ) - encoder_hidden_states = torch.utils.checkpoint.checkpoint( + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, From 775e89f893f17e16c1d47e1bff9f519822c910fa Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Thu, 23 Oct 2025 12:08:47 +0000 Subject: [PATCH 07/15] Refactor the diffusers version check --- examples/flux_example.py | 4 ++-- examples/flux_usp_example.py | 5 ++--- xfuser/model_executor/cache/diffusers_adapters/registry.py | 5 ++--- xfuser/model_executor/models/transformers/__init__.py | 5 ++--- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/examples/flux_example.py b/examples/flux_example.py index 3aafe183..f0fc9e49 100644 --- a/examples/flux_example.py +++ b/examples/flux_example.py @@ -3,9 +3,9 @@ import torch import diffusers import torch.distributed -from packaging import version +from xfuser.config.diffusers import has_valid_diffusers_version -if version.parse(diffusers.__version__) < version.parse("0.35.2"): +if not has_valid_diffusers_version("flux"): raise ImportError("Please install diffusers>=0.35.2 to use Flux.") from transformers import T5EncoderModel diff --git a/examples/flux_usp_example.py b/examples/flux_usp_example.py index a0d1cf94..4d3aea12 100644 --- a/examples/flux_usp_example.py +++ b/examples/flux_usp_example.py @@ -6,11 +6,10 @@ import logging import time import torch -import diffusers -from packaging import version +from xfuser.config.diffusers import has_valid_diffusers_version from typing import List, Optional -if version.parse(diffusers.__version__) < version.parse("0.35.2"): +if not has_valid_diffusers_version("flux"): raise ImportError("Please install diffusers>=0.35.2 to use Flux.") from diffusers import DiffusionPipeline, FluxPipeline diff --git a/xfuser/model_executor/cache/diffusers_adapters/registry.py b/xfuser/model_executor/cache/diffusers_adapters/registry.py index 4ad9a156..f3e33d53 100644 --- a/xfuser/model_executor/cache/diffusers_adapters/registry.py +++ b/xfuser/model_executor/cache/diffusers_adapters/registry.py @@ -2,8 +2,7 @@ adapted from https://github.com/ali-vilab/TeaCache.git adapted from https://github.com/chengzeyi/ParaAttention.git """ -import diffusers -from packaging import version +from config.diffusers import has_valid_diffusers_version from typing import Type, Dict TRANSFORMER_ADAPTER_REGISTRY: Dict[Type, str] = {} @@ -11,7 +10,7 @@ def register_transformer_adapter(transformer_class: Type, adapter_name: str): TRANSFORMER_ADAPTER_REGISTRY[transformer_class] = adapter_name -if version.parse(diffusers.__version__) >= version.parse("0.35.2"): +if has_valid_diffusers_version("flux"): from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper register_transformer_adapter(FluxTransformer2DModel, "flux") diff --git a/xfuser/model_executor/models/transformers/__init__.py b/xfuser/model_executor/models/transformers/__init__.py index ac16e1eb..5ed646b5 100644 --- a/xfuser/model_executor/models/transformers/__init__.py +++ b/xfuser/model_executor/models/transformers/__init__.py @@ -1,5 +1,4 @@ -import diffusers -from packaging import version +from config.diffusers import has_valid_diffusers_version from .register import xFuserTransformerWrappersRegister from .base_transformer import xFuserTransformerBaseWrapper from .pixart_transformer_2d import xFuserPixArtTransformer2DWrapper @@ -23,6 +22,6 @@ ] # Gating some imports based on diffusers version, as they import part of diffusers -if version.parse(diffusers.__version__) >= version.parse("0.35.2"): +if has_valid_diffusers_version("flux"): from .transformer_flux import xFuserFluxTransformer2DWrapper __all__.append("xFuserFluxTransformer2DWrapper") \ No newline at end of file From 81d6aa665188b72c41847b8f8bc6719449f2e174 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Thu, 23 Oct 2025 13:18:57 +0000 Subject: [PATCH 08/15] Fix typos --- xfuser/model_executor/cache/diffusers_adapters/registry.py | 2 +- xfuser/model_executor/models/transformers/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xfuser/model_executor/cache/diffusers_adapters/registry.py b/xfuser/model_executor/cache/diffusers_adapters/registry.py index f3e33d53..389345c5 100644 --- a/xfuser/model_executor/cache/diffusers_adapters/registry.py +++ b/xfuser/model_executor/cache/diffusers_adapters/registry.py @@ -2,7 +2,7 @@ adapted from https://github.com/ali-vilab/TeaCache.git adapted from https://github.com/chengzeyi/ParaAttention.git """ -from config.diffusers import has_valid_diffusers_version +from xfuser.config.diffusers import has_valid_diffusers_version from typing import Type, Dict TRANSFORMER_ADAPTER_REGISTRY: Dict[Type, str] = {} diff --git a/xfuser/model_executor/models/transformers/__init__.py b/xfuser/model_executor/models/transformers/__init__.py index 5ed646b5..8258945c 100644 --- a/xfuser/model_executor/models/transformers/__init__.py +++ b/xfuser/model_executor/models/transformers/__init__.py @@ -1,4 +1,4 @@ -from config.diffusers import has_valid_diffusers_version +from xfuser.config.diffusers import has_valid_diffusers_version from .register import xFuserTransformerWrappersRegister from .base_transformer import xFuserTransformerBaseWrapper from .pixart_transformer_2d import xFuserPixArtTransformer2DWrapper From aa1b242834040a8656993a44b583445050190a0e Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Thu, 23 Oct 2025 13:36:34 +0000 Subject: [PATCH 09/15] Add pipefusion support in the new attention processor as well --- xfuser/model_executor/models/base_model.py | 2 +- .../models/transformers/transformer_flux.py | 64 ++++++++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/xfuser/model_executor/models/base_model.py b/xfuser/model_executor/models/base_model.py index 3d3315be..9586ef3b 100644 --- a/xfuser/model_executor/models/base_model.py +++ b/xfuser/model_executor/models/base_model.py @@ -120,7 +120,7 @@ def _register_cache( self, ): for layer in self.wrapped_layers: - if isinstance(layer, xFuserAttentionWrapper): + if "AttentionWrapper" in layer._get_name(): # if getattr(layer.processor, 'use_long_ctx_attn_kvcache', False): # TODO(Eigensystem): remove use_long_ctx_attn_kvcache flag if get_sequence_parallel_world_size() == 1 or not getattr( diff --git a/xfuser/model_executor/models/transformers/transformer_flux.py b/xfuser/model_executor/models/transformers/transformer_flux.py index 5850f12d..00096e14 100644 --- a/xfuser/model_executor/models/transformers/transformer_flux.py +++ b/xfuser/model_executor/models/transformers/transformer_flux.py @@ -30,6 +30,9 @@ get_sequence_parallel_world_size, ) +from xfuser.core.cache_manager.cache_manager import get_cache_manager +from xfuser.core.distributed.runtime_state import get_runtime_state + from xfuser.logger import init_logger from xfuser.envs import PACKAGES_CHECKER from xfuser.model_executor.models.transformers.register import ( @@ -121,15 +124,74 @@ def __call__( encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) + num_encoder_hidden_states_tokens = encoder_query.shape[1] + num_query_tokens = query.shape[1] + query = torch.cat([encoder_query, query], dim=1) key = torch.cat([encoder_key, key], dim=1) value = torch.cat([encoder_value, value], dim=1) + else: + num_encoder_hidden_states_tokens = ( + get_runtime_state().max_condition_sequence_length + ) + num_query_tokens = query.shape[1] - num_encoder_hidden_states_tokens + if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - hidden_states = self.hybrid_seq_parallel_attn(None, query, key, value) + if ( + get_runtime_state().num_pipeline_patch > 1 + and not self.use_long_ctx_attn_kvcache + ): + encoder_hidden_states_key_proj, key = key.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + encoder_hidden_states_value_proj, value = value.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + key, value = get_cache_manager().update_and_get_kv_cache( + new_kv=[key, value], + layer=attn, + slice_dim=1, + layer_type="attn", + ) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + uses_pipeline_parallelism = get_runtime_state().num_pipeline_patch > 1 + if not uses_pipeline_parallelism: + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + hidden_states = USP(query, key, value) + hidden_states = hidden_states.transpose(1, 2) + else: + if get_runtime_state().split_text_embed_in_sp: + encoder_hidden_states_query_proj = None + encoder_hidden_states_key_proj = None + encoder_hidden_states_value_proj = None + else: + encoder_hidden_states_query_proj, query = query.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + encoder_hidden_states_key_proj, key = key.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + encoder_hidden_states_value_proj, value = value.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + hidden_states = self.hybrid_seq_parallel_attn( + attn if get_runtime_state().num_pipeline_patch > 1 else None, + query, + key, + value, + dropout_p=0.0, + causal=False, + joint_tensor_query=encoder_hidden_states_query_proj, + joint_tensor_key=encoder_hidden_states_key_proj, + joint_tensor_value=encoder_hidden_states_value_proj, + joint_strategy="front", + ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) From 9f13f23811e1d16862ab8b233ebd3e9b4e73c549 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Fri, 24 Oct 2025 10:49:24 +0000 Subject: [PATCH 10/15] Make diffusers a required library instead of optional --- setup.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/setup.py b/setup.py index eb822a52..fdbce167 100644 --- a/setup.py +++ b/setup.py @@ -34,11 +34,9 @@ def get_cuda_version(): "distvae", "yunchang>=0.6.0", "einops", + "diffusers>=0.33.0", ], extras_require={ - "diffusers": [ - "diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux - ], "flash-attn": [ "flash-attn>=2.6.0", # NOTE: flash-attn is necessary if ring_degree > 1 ], From 142172b09a3dbf199a0b6df5f85d2deac1bf919e Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Fri, 24 Oct 2025 10:49:56 +0000 Subject: [PATCH 11/15] Update readme with new list of supported diffusers versions --- README.md | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index c3b18547..7e7de4b8 100644 --- a/README.md +++ b/README.md @@ -150,29 +150,28 @@ The blog article is also available: [Supercharge Your AIGC Experience: Leverage ### 1. Install from pip -We set `diffusers` and `flash_attn` as two optional installation requirements. - -About `diffusers` version: -- If you only use the USP interface, `diffusers` is not required. Models are typically released as `nn.Module` - first, before being integrated into diffusers. xDiT sometimes is applied as an USP plugin to existing projects. -- Different models may require different diffusers versions. Model implementations can vary between diffusers versions, especially for latest models, which affects parallel processing. When encountering model execution errors, you may need to try several recent diffusers versions. -- While we specify a diffusers version in `setup.py`, newer models may require later versions or even need to be installed from main branch. +We set `flash_attn` as optional installation requirement. About `flash_attn` version: - Without `flash_attn` installed, xDiT falls back to a PyTorch implementation of ring attention, which helps NPU users with compatibility - However, not using `flash_attn` on GPUs may result in suboptimal performance. For best GPU performance, we strongly recommend installing `flash_attn`. +About `diffusers` version: +- Different models may require different diffusers versions. Model implementations can vary between diffusers versions, especially for latest models, which affects parallel processing. When encountering model execution errors, you may need to try several recent diffusers versions. +- While we specify a diffusers version in `setup.py`, newer models may require later versions or even need to be installed from main branch. +- Limited list of validated diffusers versions can be seen [here](#6-limitations) + ``` pip install xfuser # Basic installation -pip install "xfuser[diffusers,flash-attn]" # With both diffusers and flash attention +pip install "xfuser[flash-attn]" # With flash attention ``` ### 2. Install from source ``` pip install -e . -# Or optionally, with diffusers -pip install -e ".[diffusers,flash-attn]" +# Or optionally, with flash attention +pip install -e ".[flash-attn]" ``` Note that we use two self-maintained packages: @@ -226,6 +225,15 @@ You can also launch an HTTP service to generate images with xDiT. ### 6. Limitations +#### Diffusers version + +Below is a list of validated diffusers version requirements. If the model is not in the list, you may need to try several diffusers versions to find a working configuration. + +| Model Name | Diffusers version | +| --- | --- | +| [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) | >= 0.35.2 | + + #### HunyuanVideo - Supports `diffusers<=0.32.2` (breaking commit diffusers @ [8907a70](https://github.com/huggingface/diffusers/commit/8907a70a366c96b2322656f57b24e442ea392c7b)) From a7b2aeba3e2fff046265d628a030de1bc56fb25e Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Fri, 24 Oct 2025 11:19:57 +0000 Subject: [PATCH 12/15] Update base_model to wrap all attention wrappers, not just the global one --- xfuser/model_executor/layers/__init__.py | 2 ++ xfuser/model_executor/models/base_model.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/xfuser/model_executor/layers/__init__.py b/xfuser/model_executor/layers/__init__.py index 664a2d6a..cca6367f 100644 --- a/xfuser/model_executor/layers/__init__.py +++ b/xfuser/model_executor/layers/__init__.py @@ -1,6 +1,7 @@ from .register import xFuserLayerWrappersRegister from .base_layer import xFuserLayerBaseWrapper from .attention_processor import xFuserAttentionWrapper +from .attention_processor import xFuserAttentionBaseWrapper from .conv import xFuserConv2dWrapper from .embeddings import xFuserPatchEmbedWrapper from .feedforward import xFuserFeedForwardWrapper @@ -8,6 +9,7 @@ __all__ = [ "xFuserLayerWrappersRegister", "xFuserLayerBaseWrapper", + "xFuserAttentionBaseWrapper", "xFuserAttentionWrapper", "xFuserConv2dWrapper", "xFuserPatchEmbedWrapper", diff --git a/xfuser/model_executor/models/base_model.py b/xfuser/model_executor/models/base_model.py index 9586ef3b..5762c9a2 100644 --- a/xfuser/model_executor/models/base_model.py +++ b/xfuser/model_executor/models/base_model.py @@ -120,7 +120,7 @@ def _register_cache( self, ): for layer in self.wrapped_layers: - if "AttentionWrapper" in layer._get_name(): + if isinstance(layer, xFuserAttentionBaseWrapper): # if getattr(layer.processor, 'use_long_ctx_attn_kvcache', False): # TODO(Eigensystem): remove use_long_ctx_attn_kvcache flag if get_sequence_parallel_world_size() == 1 or not getattr( From 0579eff8449d3c612c2a0912c713abe7ac238631 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Fri, 24 Oct 2025 11:41:27 +0000 Subject: [PATCH 13/15] Add diffusers config file --- xfuser/config/diffusers.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 xfuser/config/diffusers.py diff --git a/xfuser/config/diffusers.py b/xfuser/config/diffusers.py new file mode 100644 index 00000000..6e64f88d --- /dev/null +++ b/xfuser/config/diffusers.py @@ -0,0 +1,12 @@ +import diffusers +from packaging import version + +DEFAULT_MINIMUM_DIFFUSERS_VERSION = "0.33.0" +MINIMUM_DIFFUSERS_VERSIONS = { + "flux": "0.35.2", +} + +def has_valid_diffusers_version(adapter_name: str|None = None) -> bool: + diffusers_version = diffusers.__version__ + minimum_diffusers_version = MINIMUM_DIFFUSERS_VERSIONS.get(adapter_name, DEFAULT_MINIMUM_DIFFUSERS_VERSION) + return version.parse(diffusers_version) >= version.parse(minimum_diffusers_version) \ No newline at end of file From d2bf3c354747c9bdcc30b356b181ffeb592ef344 Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Fri, 24 Oct 2025 11:42:09 +0000 Subject: [PATCH 14/15] Remove old Flux attention processor --- .../layers/attention_processor.py | 194 ------------------ 1 file changed, 194 deletions(-) diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 080119aa..cb6fdc8e 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -11,7 +11,6 @@ from diffusers.models.attention_processor import ( AttnProcessor2_0, JointAttnProcessor2_0, - FluxAttnProcessor2_0, HunyuanAttnProcessor2_0, CogVideoXAttnProcessor2_0, SanaLinearAttnProcessor2_0, @@ -620,199 +619,6 @@ def __call__( return hidden_states -@xFuserAttentionProcessorRegister.register(FluxAttnProcessor2_0) -class xFuserFluxAttnProcessor2_0(FluxAttnProcessor2_0): - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - super().__init__() - use_long_ctx_attn_kvcache = True - self.use_long_ctx_attn_kvcache = ( - HAS_LONG_CTX_ATTN - and use_long_ctx_attn_kvcache - and get_sequence_parallel_world_size() > 1 - ) - set_hybrid_seq_parallel_attn(self, self.use_long_ctx_attn_kvcache) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - batch_size, _, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q( - encoder_hidden_states_query_proj - ) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k( - encoder_hidden_states_key_proj - ) - - num_encoder_hidden_states_tokens = encoder_hidden_states_query_proj.shape[2] - num_query_tokens = query.shape[2] - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - else: - num_encoder_hidden_states_tokens = ( - get_runtime_state().max_condition_sequence_length - ) - num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens - - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - #! ---------------------------------------- KV CACHE ---------------------------------------- - if ( - get_runtime_state().num_pipeline_patch > 1 - and not self.use_long_ctx_attn_kvcache - ): - encoder_hidden_states_key_proj, key = key.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=2 - ) - encoder_hidden_states_value_proj, value = value.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=2 - ) - key, value = get_cache_manager().update_and_get_kv_cache( - new_kv=[key, value], - layer=attn, - slice_dim=2, - layer_type="attn", - ) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - #! ---------------------------------------- KV CACHE ---------------------------------------- - - #! ---------------------------------------- ATTENTION ---------------------------------------- - if ( - get_pipeline_parallel_world_size() == 1 - and get_runtime_state().split_text_embed_in_sp - ): - hidden_states = USP(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - if get_runtime_state().split_text_embed_in_sp: - encoder_hidden_states_query_proj = None - encoder_hidden_states_key_proj = None - encoder_hidden_states_value_proj = None - else: - encoder_hidden_states_query_proj, query = query.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 - ) - encoder_hidden_states_key_proj, key = key.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 - ) - encoder_hidden_states_value_proj, value = value.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 - ) - hidden_states = self.hybrid_seq_parallel_attn( - attn if get_runtime_state().num_pipeline_patch > 1 else None, - query, - key, - value, - dropout_p=0.0, - causal=False, - joint_tensor_query=encoder_hidden_states_query_proj, - joint_tensor_key=encoder_hidden_states_key_proj, - joint_tensor_value=encoder_hidden_states_value_proj, - joint_strategy="front", - ) - hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) - - else: - if HAS_FLASH_ATTN: - from flash_attn import flash_attn_func - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - hidden_states = flash_attn_func( - query, key, value, dropout_p=0.0, causal=False - ) - hidden_states = hidden_states.reshape( - batch_size, -1, attn.heads * head_dim - ) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - #! ---------------------------------------- ATTENTION ---------------------------------------- - - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - @xFuserAttentionProcessorRegister.register(HunyuanAttnProcessor2_0) class xFuserHunyuanAttnProcessor2_0(HunyuanAttnProcessor2_0): def __init__(self): From 7628a6b20a2e4d88155e490e4ac537a32d2a732b Mon Sep 17 00:00:00 2001 From: Aleksi Vesanto Date: Fri, 24 Oct 2025 12:38:03 +0000 Subject: [PATCH 15/15] Retrieve minimum versions from the config rather than use hardcoded values for error messages --- examples/flux_example.py | 5 +++-- examples/flux_usp_example.py | 5 +++-- xfuser/config/diffusers.py | 10 +++++++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/flux_example.py b/examples/flux_example.py index f0fc9e49..8d85b0fa 100644 --- a/examples/flux_example.py +++ b/examples/flux_example.py @@ -3,10 +3,11 @@ import torch import diffusers import torch.distributed -from xfuser.config.diffusers import has_valid_diffusers_version +from xfuser.config.diffusers import has_valid_diffusers_version, get_minimum_diffusers_version if not has_valid_diffusers_version("flux"): - raise ImportError("Please install diffusers>=0.35.2 to use Flux.") + minimum_diffusers_version = get_minimum_diffusers_version("flux") + raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Flux.") from transformers import T5EncoderModel from xfuser import xFuserFluxPipeline, xFuserArgs diff --git a/examples/flux_usp_example.py b/examples/flux_usp_example.py index 4d3aea12..7c3fcd6d 100644 --- a/examples/flux_usp_example.py +++ b/examples/flux_usp_example.py @@ -6,11 +6,12 @@ import logging import time import torch -from xfuser.config.diffusers import has_valid_diffusers_version +from xfuser.config.diffusers import has_valid_diffusers_version, get_minimum_diffusers_version from typing import List, Optional if not has_valid_diffusers_version("flux"): - raise ImportError("Please install diffusers>=0.35.2 to use Flux.") + minimum_diffusers_version = get_minimum_diffusers_version("flux") + raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Flux.") from diffusers import DiffusionPipeline, FluxPipeline diff --git a/xfuser/config/diffusers.py b/xfuser/config/diffusers.py index 6e64f88d..5c13c53c 100644 --- a/xfuser/config/diffusers.py +++ b/xfuser/config/diffusers.py @@ -6,7 +6,11 @@ "flux": "0.35.2", } -def has_valid_diffusers_version(adapter_name: str|None = None) -> bool: +def has_valid_diffusers_version(model_name: str|None = None) -> bool: diffusers_version = diffusers.__version__ - minimum_diffusers_version = MINIMUM_DIFFUSERS_VERSIONS.get(adapter_name, DEFAULT_MINIMUM_DIFFUSERS_VERSION) - return version.parse(diffusers_version) >= version.parse(minimum_diffusers_version) \ No newline at end of file + minimum_diffusers_version = MINIMUM_DIFFUSERS_VERSIONS.get(model_name, DEFAULT_MINIMUM_DIFFUSERS_VERSION) + return version.parse(diffusers_version) >= version.parse(minimum_diffusers_version) + + +def get_minimum_diffusers_version(model_name: str|None = None) -> str: + return MINIMUM_DIFFUSERS_VERSIONS.get(model_name, DEFAULT_MINIMUM_DIFFUSERS_VERSION) \ No newline at end of file