diff --git a/README.md b/README.md index c3b18547..7e7de4b8 100644 --- a/README.md +++ b/README.md @@ -150,29 +150,28 @@ The blog article is also available: [Supercharge Your AIGC Experience: Leverage ### 1. Install from pip -We set `diffusers` and `flash_attn` as two optional installation requirements. - -About `diffusers` version: -- If you only use the USP interface, `diffusers` is not required. Models are typically released as `nn.Module` - first, before being integrated into diffusers. xDiT sometimes is applied as an USP plugin to existing projects. -- Different models may require different diffusers versions. Model implementations can vary between diffusers versions, especially for latest models, which affects parallel processing. When encountering model execution errors, you may need to try several recent diffusers versions. -- While we specify a diffusers version in `setup.py`, newer models may require later versions or even need to be installed from main branch. +We set `flash_attn` as optional installation requirement. About `flash_attn` version: - Without `flash_attn` installed, xDiT falls back to a PyTorch implementation of ring attention, which helps NPU users with compatibility - However, not using `flash_attn` on GPUs may result in suboptimal performance. For best GPU performance, we strongly recommend installing `flash_attn`. +About `diffusers` version: +- Different models may require different diffusers versions. Model implementations can vary between diffusers versions, especially for latest models, which affects parallel processing. When encountering model execution errors, you may need to try several recent diffusers versions. +- While we specify a diffusers version in `setup.py`, newer models may require later versions or even need to be installed from main branch. +- Limited list of validated diffusers versions can be seen [here](#6-limitations) + ``` pip install xfuser # Basic installation -pip install "xfuser[diffusers,flash-attn]" # With both diffusers and flash attention +pip install "xfuser[flash-attn]" # With flash attention ``` ### 2. Install from source ``` pip install -e . -# Or optionally, with diffusers -pip install -e ".[diffusers,flash-attn]" +# Or optionally, with flash attention +pip install -e ".[flash-attn]" ``` Note that we use two self-maintained packages: @@ -226,6 +225,15 @@ You can also launch an HTTP service to generate images with xDiT. ### 6. Limitations +#### Diffusers version + +Below is a list of validated diffusers version requirements. If the model is not in the list, you may need to try several diffusers versions to find a working configuration. + +| Model Name | Diffusers version | +| --- | --- | +| [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) | >= 0.35.2 | + + #### HunyuanVideo - Supports `diffusers<=0.32.2` (breaking commit diffusers @ [8907a70](https://github.com/huggingface/diffusers/commit/8907a70a366c96b2322656f57b24e442ea392c7b)) diff --git a/examples/flux_example.py b/examples/flux_example.py index 9e0ac849..8d85b0fa 100644 --- a/examples/flux_example.py +++ b/examples/flux_example.py @@ -1,7 +1,14 @@ import logging import time import torch +import diffusers import torch.distributed +from xfuser.config.diffusers import has_valid_diffusers_version, get_minimum_diffusers_version + +if not has_valid_diffusers_version("flux"): + minimum_diffusers_version = get_minimum_diffusers_version("flux") + raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Flux.") + from transformers import T5EncoderModel from xfuser import xFuserFluxPipeline, xFuserArgs from xfuser.config import FlexibleArgumentParser diff --git a/examples/flux_usp_example.py b/examples/flux_usp_example.py index 1f3287cd..7c3fcd6d 100644 --- a/examples/flux_usp_example.py +++ b/examples/flux_usp_example.py @@ -2,11 +2,17 @@ # from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py import functools -from typing import List, Optional import logging import time import torch +from xfuser.config.diffusers import has_valid_diffusers_version, get_minimum_diffusers_version +from typing import List, Optional + +if not has_valid_diffusers_version("flux"): + minimum_diffusers_version = get_minimum_diffusers_version("flux") + raise ImportError(f"Please install diffusers>={minimum_diffusers_version} to use Flux.") + from diffusers import DiffusionPipeline, FluxPipeline from xfuser import xFuserArgs @@ -27,7 +33,7 @@ get_pipeline_parallel_world_size, ) -from xfuser.model_executor.layers.attention_processor import xFuserFluxAttnProcessor2_0 +from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxAttnProcessor def parallelize_transformer(pipe: DiffusionPipeline): transformer = pipe.transformer @@ -50,7 +56,7 @@ def new_forward( get_runtime_state().split_text_embed_in_sp = False else: get_runtime_state().split_text_embed_in_sp = True - + if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] @@ -61,10 +67,8 @@ def new_forward( img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] if get_runtime_state().split_text_embed_in_sp: txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] - - for block in transformer.transformer_blocks + transformer.single_transformer_blocks: - block.attn.processor = xFuserFluxAttnProcessor2_0() - + + output = original_forward( hidden_states, encoder_hidden_states, @@ -86,6 +90,9 @@ def new_forward( new_forward = new_forward.__get__(transformer) transformer.forward = new_forward + for block in transformer.transformer_blocks + transformer.single_transformer_blocks: + block.attn.processor = xFuserFluxAttnProcessor() + def main(): parser = FlexibleArgumentParser(description="xFuser Arguments") @@ -119,7 +126,7 @@ def main(): max_condition_sequence_length=512, split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, ) - + parallelize_transformer(pipe) if engine_config.runtime_config.use_torch_compile: @@ -139,7 +146,7 @@ def main(): torch.cuda.reset_peak_memory_stats() start_time = time.time() - + output = pipe( height=input_config.height, width=input_config.width, diff --git a/setup.py b/setup.py index eb822a52..fdbce167 100644 --- a/setup.py +++ b/setup.py @@ -34,11 +34,9 @@ def get_cuda_version(): "distvae", "yunchang>=0.6.0", "einops", + "diffusers>=0.33.0", ], extras_require={ - "diffusers": [ - "diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux - ], "flash-attn": [ "flash-attn>=2.6.0", # NOTE: flash-attn is necessary if ring_degree > 1 ], diff --git a/xfuser/config/diffusers.py b/xfuser/config/diffusers.py new file mode 100644 index 00000000..5c13c53c --- /dev/null +++ b/xfuser/config/diffusers.py @@ -0,0 +1,16 @@ +import diffusers +from packaging import version + +DEFAULT_MINIMUM_DIFFUSERS_VERSION = "0.33.0" +MINIMUM_DIFFUSERS_VERSIONS = { + "flux": "0.35.2", +} + +def has_valid_diffusers_version(model_name: str|None = None) -> bool: + diffusers_version = diffusers.__version__ + minimum_diffusers_version = MINIMUM_DIFFUSERS_VERSIONS.get(model_name, DEFAULT_MINIMUM_DIFFUSERS_VERSION) + return version.parse(diffusers_version) >= version.parse(minimum_diffusers_version) + + +def get_minimum_diffusers_version(model_name: str|None = None) -> str: + return MINIMUM_DIFFUSERS_VERSIONS.get(model_name, DEFAULT_MINIMUM_DIFFUSERS_VERSION) \ No newline at end of file diff --git a/xfuser/model_executor/cache/diffusers_adapters/registry.py b/xfuser/model_executor/cache/diffusers_adapters/registry.py index 2232832a..389345c5 100644 --- a/xfuser/model_executor/cache/diffusers_adapters/registry.py +++ b/xfuser/model_executor/cache/diffusers_adapters/registry.py @@ -2,15 +2,17 @@ adapted from https://github.com/ali-vilab/TeaCache.git adapted from https://github.com/chengzeyi/ParaAttention.git """ +from xfuser.config.diffusers import has_valid_diffusers_version from typing import Type, Dict -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel -from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper TRANSFORMER_ADAPTER_REGISTRY: Dict[Type, str] = {} def register_transformer_adapter(transformer_class: Type, adapter_name: str): TRANSFORMER_ADAPTER_REGISTRY[transformer_class] = adapter_name -register_transformer_adapter(FluxTransformer2DModel, "flux") -register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux") +if has_valid_diffusers_version("flux"): + from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel + from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper + register_transformer_adapter(FluxTransformer2DModel, "flux") + register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux") diff --git a/xfuser/model_executor/layers/__init__.py b/xfuser/model_executor/layers/__init__.py index 664a2d6a..cca6367f 100644 --- a/xfuser/model_executor/layers/__init__.py +++ b/xfuser/model_executor/layers/__init__.py @@ -1,6 +1,7 @@ from .register import xFuserLayerWrappersRegister from .base_layer import xFuserLayerBaseWrapper from .attention_processor import xFuserAttentionWrapper +from .attention_processor import xFuserAttentionBaseWrapper from .conv import xFuserConv2dWrapper from .embeddings import xFuserPatchEmbedWrapper from .feedforward import xFuserFeedForwardWrapper @@ -8,6 +9,7 @@ __all__ = [ "xFuserLayerWrappersRegister", "xFuserLayerBaseWrapper", + "xFuserAttentionBaseWrapper", "xFuserAttentionWrapper", "xFuserConv2dWrapper", "xFuserPatchEmbedWrapper", diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 0f410e23..cb6fdc8e 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -11,7 +11,6 @@ from diffusers.models.attention_processor import ( AttnProcessor2_0, JointAttnProcessor2_0, - FluxAttnProcessor2_0, HunyuanAttnProcessor2_0, CogVideoXAttnProcessor2_0, SanaLinearAttnProcessor2_0, @@ -222,7 +221,7 @@ def __init__(self): HAS_LONG_CTX_ATTN and use_long_ctx_attn_kvcache and get_sequence_parallel_world_size() > 1 - ) + ) set_hybrid_seq_parallel_attn(self, self.use_long_ctx_attn_kvcache) if get_fast_attn_enable(): @@ -450,7 +449,7 @@ def __call__( query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - + inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -620,199 +619,6 @@ def __call__( return hidden_states -@xFuserAttentionProcessorRegister.register(FluxAttnProcessor2_0) -class xFuserFluxAttnProcessor2_0(FluxAttnProcessor2_0): - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - super().__init__() - use_long_ctx_attn_kvcache = True - self.use_long_ctx_attn_kvcache = ( - HAS_LONG_CTX_ATTN - and use_long_ctx_attn_kvcache - and get_sequence_parallel_world_size() > 1 - ) - set_hybrid_seq_parallel_attn(self, self.use_long_ctx_attn_kvcache) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - batch_size, _, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q( - encoder_hidden_states_query_proj - ) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k( - encoder_hidden_states_key_proj - ) - - num_encoder_hidden_states_tokens = encoder_hidden_states_query_proj.shape[2] - num_query_tokens = query.shape[2] - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - else: - num_encoder_hidden_states_tokens = ( - get_runtime_state().max_condition_sequence_length - ) - num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens - - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - #! ---------------------------------------- KV CACHE ---------------------------------------- - if ( - get_runtime_state().num_pipeline_patch > 1 - and not self.use_long_ctx_attn_kvcache - ): - encoder_hidden_states_key_proj, key = key.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=2 - ) - encoder_hidden_states_value_proj, value = value.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=2 - ) - key, value = get_cache_manager().update_and_get_kv_cache( - new_kv=[key, value], - layer=attn, - slice_dim=2, - layer_type="attn", - ) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - #! ---------------------------------------- KV CACHE ---------------------------------------- - - #! ---------------------------------------- ATTENTION ---------------------------------------- - if ( - get_pipeline_parallel_world_size() == 1 - and get_runtime_state().split_text_embed_in_sp - ): - hidden_states = USP(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - if get_runtime_state().split_text_embed_in_sp: - encoder_hidden_states_query_proj = None - encoder_hidden_states_key_proj = None - encoder_hidden_states_value_proj = None - else: - encoder_hidden_states_query_proj, query = query.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 - ) - encoder_hidden_states_key_proj, key = key.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 - ) - encoder_hidden_states_value_proj, value = value.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 - ) - hidden_states = self.hybrid_seq_parallel_attn( - attn if get_runtime_state().num_pipeline_patch > 1 else None, - query, - key, - value, - dropout_p=0.0, - causal=False, - joint_tensor_query=encoder_hidden_states_query_proj, - joint_tensor_key=encoder_hidden_states_key_proj, - joint_tensor_value=encoder_hidden_states_value_proj, - joint_strategy="front", - ) - hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) - - else: - if HAS_FLASH_ATTN: - from flash_attn import flash_attn_func - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - hidden_states = flash_attn_func( - query, key, value, dropout_p=0.0, causal=False - ) - hidden_states = hidden_states.reshape( - batch_size, -1, attn.heads * head_dim - ) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - #! ---------------------------------------- ATTENTION ---------------------------------------- - - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - @xFuserAttentionProcessorRegister.register(HunyuanAttnProcessor2_0) class xFuserHunyuanAttnProcessor2_0(HunyuanAttnProcessor2_0): def __init__(self): @@ -1510,7 +1316,7 @@ def __call__( class xFuserSanaAttnProcessor2_0(SanaAttnProcessor2_0): def __init__(self): super().__init__() - + def __call__( self, attn: Attention, @@ -1627,7 +1433,7 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - + batch_size = encoder_hidden_states.size(0) query = attn.to_q(hidden_states) @@ -1638,7 +1444,7 @@ def __call__( query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) - + inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads diff --git a/xfuser/model_executor/models/base_model.py b/xfuser/model_executor/models/base_model.py index 3d3315be..5762c9a2 100644 --- a/xfuser/model_executor/models/base_model.py +++ b/xfuser/model_executor/models/base_model.py @@ -120,7 +120,7 @@ def _register_cache( self, ): for layer in self.wrapped_layers: - if isinstance(layer, xFuserAttentionWrapper): + if isinstance(layer, xFuserAttentionBaseWrapper): # if getattr(layer.processor, 'use_long_ctx_attn_kvcache', False): # TODO(Eigensystem): remove use_long_ctx_attn_kvcache flag if get_sequence_parallel_world_size() == 1 or not getattr( diff --git a/xfuser/model_executor/models/transformers/__init__.py b/xfuser/model_executor/models/transformers/__init__.py index 96e3ae8e..8258945c 100644 --- a/xfuser/model_executor/models/transformers/__init__.py +++ b/xfuser/model_executor/models/transformers/__init__.py @@ -1,8 +1,8 @@ +from xfuser.config.diffusers import has_valid_diffusers_version from .register import xFuserTransformerWrappersRegister from .base_transformer import xFuserTransformerBaseWrapper from .pixart_transformer_2d import xFuserPixArtTransformer2DWrapper from .transformer_sd3 import xFuserSD3Transformer2DWrapper -from .transformer_flux import xFuserFluxTransformer2DWrapper from .latte_transformer_3d import xFuserLatteTransformer3DWrapper from .hunyuan_transformer_2d import xFuserHunyuanDiT2DWrapper from .cogvideox_transformer_3d import xFuserCogVideoXTransformer3DWrapper @@ -14,10 +14,14 @@ "xFuserTransformerBaseWrapper", "xFuserPixArtTransformer2DWrapper", "xFuserSD3Transformer2DWrapper", - "xFuserFluxTransformer2DWrapper", "xFuserLatteTransformer3DWrapper", "xFuserCogVideoXTransformer3DWrapper", "xFuserHunyuanDiT2DWrapper", "xFuserConsisIDTransformer3DWrapper", "xFuserSanaTransformer2DWrapper" -] \ No newline at end of file +] + +# Gating some imports based on diffusers version, as they import part of diffusers +if has_valid_diffusers_version("flux"): + from .transformer_flux import xFuserFluxTransformer2DWrapper + __all__.append("xFuserFluxTransformer2DWrapper") \ No newline at end of file diff --git a/xfuser/model_executor/models/transformers/transformer_flux.py b/xfuser/model_executor/models/transformers/transformer_flux.py index 91869552..00096e14 100644 --- a/xfuser/model_executor/models/transformers/transformer_flux.py +++ b/xfuser/model_executor/models/transformers/transformer_flux.py @@ -1,10 +1,16 @@ -from typing import Optional, Dict, Any, Union +import inspect import torch import torch.distributed import torch.nn as nn +from typing import Optional, Dict, Any, Union -from diffusers.models.embeddings import PatchEmbed -from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +#from diffusers.models.embeddings import PatchEmbed +from diffusers.models.transformers.transformer_flux import ( + _get_qkv_projections, + FluxTransformer2DModel, + FluxAttnProcessor, + FluxAttention, +) from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput from diffusers.utils import ( is_torch_version, @@ -12,23 +18,197 @@ USE_PEFT_BACKEND, unscale_lora_layers, ) +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import apply_rotary_emb from xfuser.core.distributed.parallel_state import ( get_tensor_model_parallel_world_size, is_pipeline_first_stage, is_pipeline_last_stage, ) +from xfuser.core.distributed import ( + get_sequence_parallel_world_size, +) + +from xfuser.core.cache_manager.cache_manager import get_cache_manager from xfuser.core.distributed.runtime_state import get_runtime_state + from xfuser.logger import init_logger +from xfuser.envs import PACKAGES_CHECKER from xfuser.model_executor.models.transformers.register import ( xFuserTransformerWrappersRegister, ) from xfuser.model_executor.models.transformers.base_transformer import ( xFuserTransformerBaseWrapper, ) +from xfuser.model_executor.layers import xFuserLayerWrappersRegister +from xfuser.model_executor.layers.attention_processor import ( + set_hybrid_seq_parallel_attn, + xFuserAttentionBaseWrapper, + xFuserAttentionProcessorRegister +) +from xfuser.model_executor.layers.usp import USP logger = init_logger(__name__) -from diffusers.models.attention import FeedForward + +env_info = PACKAGES_CHECKER.get_packages_info() +HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"] + + +@xFuserLayerWrappersRegister.register(FluxAttention) +class xFuserFluxAttentionWrapper(xFuserAttentionBaseWrapper): + def __init__( + self, + attention: FluxAttention, + ): + super().__init__(attention=attention) + self.processor = xFuserAttentionProcessorRegister.get_processor( + attention.processor + )() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + +@xFuserAttentionProcessorRegister.register(FluxAttnProcessor) +class xFuserFluxAttnProcessor(FluxAttnProcessor): + + def __init__(self): + super().__init__() + use_long_ctx_attn_kvcache = True + self.use_long_ctx_attn_kvcache = ( + HAS_LONG_CTX_ATTN + and use_long_ctx_attn_kvcache + and get_sequence_parallel_world_size() > 1 + ) + set_hybrid_seq_parallel_attn(self, self.use_long_ctx_attn_kvcache) + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + num_encoder_hidden_states_tokens = encoder_query.shape[1] + num_query_tokens = query.shape[1] + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + else: + num_encoder_hidden_states_tokens = ( + get_runtime_state().max_condition_sequence_length + ) + num_query_tokens = query.shape[1] - num_encoder_hidden_states_tokens + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + if ( + get_runtime_state().num_pipeline_patch > 1 + and not self.use_long_ctx_attn_kvcache + ): + encoder_hidden_states_key_proj, key = key.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + encoder_hidden_states_value_proj, value = value.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + key, value = get_cache_manager().update_and_get_kv_cache( + new_kv=[key, value], + layer=attn, + slice_dim=1, + layer_type="attn", + ) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + + uses_pipeline_parallelism = get_runtime_state().num_pipeline_patch > 1 + if not uses_pipeline_parallelism: + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + hidden_states = USP(query, key, value) + hidden_states = hidden_states.transpose(1, 2) + else: + if get_runtime_state().split_text_embed_in_sp: + encoder_hidden_states_query_proj = None + encoder_hidden_states_key_proj = None + encoder_hidden_states_value_proj = None + else: + encoder_hidden_states_query_proj, query = query.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + encoder_hidden_states_key_proj, key = key.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + encoder_hidden_states_value_proj, value = value.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + hidden_states = self.hybrid_seq_parallel_attn( + attn if get_runtime_state().num_pipeline_patch > 1 else None, + query, + key, + value, + dropout_p=0.0, + causal=False, + joint_tensor_query=encoder_hidden_states_query_proj, + joint_tensor_key=encoder_hidden_states_key_proj, + joint_tensor_value=encoder_hidden_states_value_proj, + joint_strategy="front", + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + + @xFuserTransformerWrappersRegister.register(FluxTransformer2DModel) @@ -178,7 +358,6 @@ def custom_forward(*inputs): # hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] # if self.stage_info.after_flags["transformer_blocks"]: - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: @@ -195,17 +374,19 @@ def custom_forward(*inputs): ckpt_kwargs: Dict[str, Any] = ( {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ) - hidden_states = torch.utils.checkpoint.checkpoint( + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, + encoder_hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) @@ -219,6 +400,8 @@ def custom_forward(*inputs): # + controlnet_single_block_samples[index_block // interval_control] # ) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...] hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]