diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 391cab0b1e..259d8155b4 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -23,6 +23,7 @@ th { | `QwenImageEditPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` | |`ZImagePipeline` | Z-Image | `Tongyi-MAI/Z-Image-Turbo` | | `WanPipeline` | Wan2.2 | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | +|`LongcatImagePipeline` | LongCat-Image | `meituan-longcat/LongCat-Image` | ## List of Supported Models for NPU diff --git a/vllm_omni/diffusion/models/longcat_image/__init__.py b/vllm_omni/diffusion/models/longcat_image/__init__.py new file mode 100644 index 0000000000..105f51a261 --- /dev/null +++ b/vllm_omni/diffusion/models/longcat_image/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel +from vllm_omni.diffusion.models.longcat_image.pipeline_longcat_image import ( + LongCatImagePipeline, + get_longcat_image_post_process_func, +) + +__all__ = [ + "LongCatImagePipeline", + "LongCatImageTransformer2DModel", + "get_longcat_image_post_process_func", +] diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py new file mode 100644 index 0000000000..2d282c2f9c --- /dev/null +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -0,0 +1,527 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.utils.platform_utils import is_npu + +logger = init_logger(__name__) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, dim_out: int | None = None, mult: int = 4, bias: bool = True): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + self.w_in = ReplicatedLinear(dim, inner_dim, bias=bias, return_bias=False) + self.act = get_act_fn("gelu_pytorch_tanh") + self.w_out = ReplicatedLinear(inner_dim, dim_out, bias=bias, return_bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.w_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.w_out(hidden_states) + return hidden_states + + +class LongCatImageAttention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: bool | None = None, + pre_only: bool = False, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + + # Fused QKV projection using vLLM's optimized layer + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + disable_tp=True, + bias=bias, + ) + + if not self.pre_only: + self.to_out = torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + if self.added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + + self.add_kv_proj = QKVParallelLinear( + hidden_size=self.added_kv_proj_dim, + head_size=self.head_dim, + total_num_heads=self.heads, + disable_tp=True, + bias=added_proj_bias, + ) + + self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias) + + self.attn = Attention( + num_heads=heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.to_qkv(hidden_states) + + query, key, value = qkv.chunk(3, dim=-1) + + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + if self.added_kv_proj_dim is not None: + encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) + encoder_query, encoder_key, encoder_value = encoder_qkv.chunk(3, dim=-1) + + encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.heads, -1)) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = self.attn( + query, + key, + value, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = self.to_out(hidden_states) + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class LongCatImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = LongCatImageAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim) + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class LongCatImagePosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: list[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class LongCatImageTimestepEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + return timesteps_emb + + +class LongCatImageSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + self.attn = LongCatImageAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +class LongCatImageTransformer2DModel(nn.Module): + """ + The Transformer model introduced in Flux. + """ + + def __init__( + self, + od_config: OmniDiffusionConfig, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + pooled_projection_dim: int = 3584, + axes_dims_rope: list[int] = [16, 56, 56], + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.pooled_projection_dim = pooled_projection_dim + + self.pos_embed = LongCatImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) + + self.time_embed = LongCatImageTimestepEmbeddings(embedding_dim=self.inner_dim) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + LongCatImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + LongCatImageSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + self.use_checkpoint = [True] * num_layers + self.use_single_checkpoint = [True] * num_single_layers + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + return_dict: bool = True, + ) -> torch.FloatTensor | Transformer2DModelOutput: + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + + temb = self.time_embed(timestep, hidden_states.dtype) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + ids = torch.cat((txt_ids, img_ids), dim=0) + + if is_npu(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) + else: + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_single_checkpoint[index_block]: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # self attn + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + # cross attn + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if ".to_out.0" in name: + name = name.replace(".to_out.0", ".to_out") + # Handle FeedForward parameter mapping + if ".ff.net." in name: + # Map .ff.net.0.proj -> .ff.w_in + if ".net.0.proj" in name: + name = name.replace(".net.0.proj", ".w_in") + # Map .ff.net.2 -> .ff.w_out + elif ".net.2" in name: + name = name.replace(".net.2", ".w_out") + # Handle FeedForward context parameters + if ".ff_context.net." in name: + # Map .ff_context.net.0.proj -> .ff_context.w_in + if ".net.0.proj" in name: + name = name.replace(".net.0.proj", ".w_in") + # Map .ff_context.net.2 -> .ff_context.w_out + elif ".net.2" in name: + name = name.replace(".net.2", ".w_out") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py new file mode 100644 index 0000000000..5203ff095f --- /dev/null +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -0,0 +1,647 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import inspect +import json +import os +import re +from collections.abc import Iterable +from typing import Any + +import numpy as np +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.longcat_image.longcat_image_transformer import LongCatImageTransformer2DModel +from vllm_omni.diffusion.models.longcat_image.system_prompt import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = init_logger(__name__) + + +def get_longcat_image_post_process_func( + od_config: OmniDiffusionConfig, +): + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + def post_process_func( + images: torch.Tensor, + ): + return image_processor.postprocess(images) + + return post_process_func + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters + defined by single or double quote pairs. + + Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." + >>> print(split_quotation(prompt_en)) + >>> # output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None) -> torch.Tensor: + if type == "text": + assert num_token + if height or width: + logger.warning('Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == "image": + assert height and width + if num_token: + logger.warning('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) + else: + raise KeyError(f'Unknown type {type}, only support "text" or "image".') + # pos_ids = pos_ids[None, :].repeat(batch_size, 1, 1) + return pos_ids + + +def retrieve_timesteps( + scheduler: SchedulerMixin, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def get_prompt_language(prompt): + pattern = re.compile(r"[\u4e00-\u9fff]") + if bool(pattern.search(prompt)): + return "zh" + return "en" + + +class LongCatImagePipeline( + nn.Module, +): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self.device = get_local_device() + model = od_config.model + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + + self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.text_processor = Qwen2VLProcessor.from_pretrained( + model, subfolder="tokenizer", local_files_only=local_files_only + ) + self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + self.transformer = LongCatImageTransformer2DModel(od_config=od_config) + self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + self.prompt_template_encode_prefix = ( + "<|im_start|>system\n" + "As an image captioning expert, generate a descriptive text prompt based on an image content," + " suitable for input to a text-to-image model.<|im_end|>\n" + "<|im_start|>user\n" + ) + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def rewire_prompt(self, prompt, device): + prompt = [prompt] if isinstance(prompt, str) else prompt + all_text = [] + for each_prompt in prompt: + language = get_prompt_language(each_prompt) + if language == "zh": + question = SYSTEM_PROMPT_ZH + f"\n用户输入为:{each_prompt}\n改写后的prompt为:" + else: + question = SYSTEM_PROMPT_EN + f"\nUser Input: {each_prompt}\nRewritten prompt:" + message = [ + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + ], + } + ] + text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + all_text.append(text) + + inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device) + + self.text_encoder.to(device) + generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + output_text = self.text_processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + return output_text + + def _encode_prompt(self, prompt: list[str]) -> torch.Tensor: + batch_all_tokens = [] + + for each_prompt in prompt: + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(each_prompt): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f"{self.tokenizer_max_length} input token nums : {len(all_tokens)}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + batch_all_tokens.append(all_tokens) + + text_tokens_and_mask = self.tokenizer.pad( + {"input_ids": batch_all_tokens}, + max_length=self.tokenizer_max_length, + padding="max_length", + return_attention_mask=True, + return_tensors="pt", + ) + + prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + prefix_len = len(prefix_tokens) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + batch_size = text_tokens_and_mask.input_ids.size(0) + prefix_tokens_batch = prefix_tokens.unsqueeze(0).expand(batch_size, -1) + suffix_tokens_batch = suffix_tokens.unsqueeze(0).expand(batch_size, -1) + prefix_mask_batch = prefix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + suffix_mask_batch = suffix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + + input_ids = torch.cat((prefix_tokens_batch, text_tokens_and_mask.input_ids, suffix_tokens_batch), dim=-1) + attention_mask = torch.cat((prefix_mask_batch, text_tokens_and_mask.attention_mask, suffix_mask_batch), dim=-1) + + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + prompt_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if prompt_embeds is None and prompt is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`.") + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt_embeds is None: + prompt_embeds = self._encode_prompt(prompt) + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds.to(self.device), text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(self.tokenizer_max_length, self.tokenizer_max_length), + height=height // 2, + width=width // 2, + ).to(device) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device) + latents = latents.to(dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, latent_image_ids + + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + "`height` and `width` have to be divisible by " + f"{self.vae_scale_factor * 2} but are {height} and {width}. " + "Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + enable_cfg_renorm: bool | None = True, + cfg_renorm_min: float | None = 0.0, + enable_prompt_rewrite: bool | None = True, + ) -> DiffusionOutput: + prompt = req.prompt if req.prompt is not None else prompt + negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt + + height = req.height or height or self.default_sample_size * self.vae_scale_factor + width = req.width or width or self.default_sample_size * self.vae_scale_factor + num_inference_steps = req.num_inference_steps or num_inference_steps + generator = req.generator or generator + guidance_scale = req.guidance_scale if getattr(req, "guidance_scale", None) is not None else guidance_scale + num_images_per_prompt = getattr(req, "num_outputs_per_prompt", None) or num_images_per_prompt + enable_prompt_rewrite = getattr(req, "enable_prompt_rewrite", None) or enable_prompt_rewrite + enable_cfg_renorm = getattr(req, "enable_cfg_renorm", None) or enable_cfg_renorm + cfg_renorm_min = getattr(req, "cfg_renorm_min", None) or cfg_renorm_min + prompt_embeds = getattr(req, "prompt_embeds", None) or prompt_embeds + negative_prompt_embeds = getattr(req, "negative_prompt_embeds", None) or negative_prompt_embeds + + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self.device + if enable_prompt_rewrite and prompt is not None: + prompt = self.rewire_prompt(prompt if isinstance(prompt, list) else [prompt], device) + + negative_prompt = "" if negative_prompt is None else negative_prompt + + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + if self.do_classifier_free_guidance: + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + prompt_embeds = prompt_embeds.to(device) + if self.do_classifier_free_guidance: + negative_prompt_embeds = negative_prompt_embeds.to(device) + + # 6. Denoising loop + for i, t in enumerate(timesteps): + if self._interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred_text = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond) + + if enable_cfg_renorm: + cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + noise_pred = noise_pred * scale + else: + noise_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights using AutoWeightsLoader for vLLM integration.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/longcat_image/system_prompt.py b/vllm_omni/diffusion/models/longcat_image/system_prompt.py new file mode 100644 index 0000000000..07a852d264 --- /dev/null +++ b/vllm_omni/diffusion/models/longcat_image/system_prompt.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copied from https://github.com/meituan-longcat/LongCat-Image/blob/main/longcat_image/pipelines/pipeline_longcat_image.py#L53 + +SYSTEM_PROMPT_EN = """ + +You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in +understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's +understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all +information from the user's original prompt without deleting or distorting any details. Specific requirements are as +follows: + +1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use + coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as + concise as possible. + +2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields + English output. The rewritten token count should not exceed 512. + +3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the + original prompt, such as lighting and textures. + +4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography + style**. If the user specifies a style, retain the user's style. + +5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge to + convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a + giraffe"). + +6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% + OFF"`). + +7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no specific + text content is specified, you need to infer appropriate text content and enclose it in double quotes. For example, + if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer with the + image title 'Grassland'." + +8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For + example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all. + +9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**. + +Here are examples of rewrites for different types of prompts: # Examples (Few-Shot Learning) + + 1. User Input: An animal with nine lives. + + Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home + environment with light from the window filtering through curtains, creating a warm light and shadow effect. The + shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits + the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image. + + 2. User Input: Create an anime-style tourism flyer with a grassland theme. + + Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped + rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her + left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs covering + her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To the girl's + left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The grass extends + into the distance, forming rolling green hills that fade in color as they recede. The sky occupies the upper half of + the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is a line of text in + italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, and yellow, fluid + lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere. + + 3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer. + + Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and + left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, + golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two + transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls + scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, + and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, + accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing + strong visual appeal. + + 4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident + posture. + + Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her + shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long + eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She + has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. Her + skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a black + spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and metal + bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible knitting + patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a relaxed + posture. The background is a pure dark brown without extra decoration, making the figure the absolute focus. The + figure is located in the center of the frame. Light enters from the upper right, creating bright spots on her left + cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional and soft tone. + Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are dominated by warm + tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. The overall style is + natural, elegant, and artistic. + + 5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should + include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting. + + Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage + precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark + soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with + green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, + stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden + light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches + and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under a + clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into the + tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the orchard, + with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a realistic style, + focusing on details and harmonious colors to showcase the natural beauty and development of the apple's life cycle. + + 6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate + a four-color rainbow based on this rule. The color order from top to bottom is 3142. + + Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as + purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the number + "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the bottom + green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast with the + background colors to ensure good readability. The stripes have high color saturation and a slight texture. The + overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing the + numerical information. The image is high definition, with accurate colors and a consistent style, offering strong + visual appeal. + + 7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a + Chinese garden. + + Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with + traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the + stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo + forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a + realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the + stone tablet and the classical beauty of the garden. + +# Output Format Please directly output the rewritten and optimized Prompt content. Do not include any explanatory +language or JSON formatting, and do not add opening or closing quotes yourself.""" + + +SYSTEM_PROMPT_ZH = """ +你是一名文生图模型的prompt engineering专家。由于文生图模型对用户prompt的理解能力有限, +你需要识别用户输入的核心主题和意图,并通过优化改写提升模型的理解准确性和生成质量。 +改写必须严格保留用户原始prompt的所有信息,不得删减或曲解任何细节。 + +具体要求如下: + +1. 改写不能影响用户原始prompt里表达的任何信息,改写后的prompt应该使用连贯的自然语言表达, + 不要出现低信息量的冗余描述,尽可能保持改写后prompt长度精简。 +2. 请确保输入和输出的语言类型一致,中文输入中文输出,英文输入英文输出,改写后的token数量不要 + 超过512个。 +3. 改写后的描述应当进一步完善原始prompt中出现的主体特征、美学技巧,如打光、纹理等。 +4. 如果原始prompt没有指定图片风格时,确保改写后的prompt采用真实摄影风格;如果用户指定了图片风 + 格,则保留用户风格。 +5. 当原始prompt需要推理才能明确用户意图时,根据世界知识进行适当逻辑推理,将模糊抽象描述转化为 + 具体事物(例:将“最高的动物”转化为“一头长颈鹿”)。 +6. 当原始prompt需要生成文字时,请使用双引号圈定文字部分,例:"限时5折"。 +7. 当原始prompt需要生成网页、logo、UI、海报等文字场景时,且没有指定具体文字内容,需要推断合适 + 的文字内容,并使用双引号圈定,如用户输入“一个以草原为主题的旅游宣传单”,应改写为“一个旅游 + 宣传单,图片标题为"草原"”。 +8. 当原始prompt中存在否定词时,需要确保改写后的prompt不存在被否定的对象,如“没有船的湖边”,改 + 写后的prompt不能出现“船”这个词汇。 +9. 除非用户指定生成品牌logo,否则不要增加额外的品牌logo。 +10. 除了用户明确要求书写的文字内容外,禁止增加任何额外的文字内容。 + +以下是针对不同类型prompt改写的示例: + +# Examples (Few-Shot Learning) + + 1. 用户输入: 九条命的动物。 + + 改写输出: + 一只猫,被柔和的阳光笼罩着,毛发柔软而富有光泽。背景是一个舒适的家居环境,窗外的光线透过 + 窗帘,形成温馨的光影效果。镜头采用中距离视角,突出猫悠闲舒展的姿态。光线巧妙地打在猫的脸 + 部,强调它灵动的眼睛和精致的胡须,增加画面的层次感与亲和力。 + + 2. 用户输入: 制作一个动画风格的旅游宣传单,以草原为主题。 + + 改写输出: + 画面中央偏右下角,一个短发女孩侧身坐在灰色的不规则形状岩石上,她穿着白色短袖连衣裙和棕色 + 平底鞋,左手拿着一束白色小花,面带微笑,双腿自然垂下。女孩的头发为深棕色,齐肩短发,刘海 + 覆盖额头,眼睛呈棕色,嘴巴微张。岩石表面有深浅不一的纹理。女孩的左侧和前方是茂盛的草地, + 草叶细长,呈黄绿色,部分草叶在阳光下泛着金色的光芒,仿佛被阳光照亮。草地向远处延伸,形成 + 连绵起伏的绿色山丘,山丘的颜色由近及远逐渐变浅。天空占据了画面的上半部分,呈淡蓝色,点缀 + 着几朵白色蓬松的云彩。画面的左上角有一行文字,文字内容是斜体、深绿色的“Explore Nature's + Peace”。色彩以绿色、蓝色和黄色为主,线条流畅,光影明暗对比明显,营造出一种宁静、舒适的氛 + 围。 + + 3. 用户输入: 一张以红色为背景的圣诞节促销海报,主要宣传奶茶买一送一的优惠活动。 + + 改写输出: + 海报整体呈现红色调,上方和左侧点缀着白色雪花图案,右上方有一束冬青叶和红色浆果,以及一个 + 松果。海报中央偏上位置,金色立体字样“圣诞节 暖心回馈”居中排列,和红色粗体字“买1送1”。海 + 报下方,两个装满珍珠奶茶的透明杯子并排摆放,杯中奶茶呈浅棕色,底部和中间散布着深棕色珍珠。 + 杯子下方,堆积着白色雪花,雪花上装饰着松枝、红色浆果和松果。右下角隐约可见一棵模糊的圣诞 + 树。图片清晰度高,文字内容准确,整体设计风格统一,圣诞主题突出,排版布局合理,具有较强的 + 视觉吸引力。 + + 4. 用户输入: 一位女性在室内以自然光线拍摄,她面带微笑,双臂交叉,展现出轻松自信的姿态。 + + 改写输出: + 画面中是一位年轻的亚洲女性,她拥有深棕色的长发,发丝自然地垂落在双肩,部分发丝被光线照亮, + 呈现出柔和的光泽。她的五官精致,眉毛修长,眼睛明亮有神,瞳孔呈深棕色,眼神直视镜头,流露 + 出平和与自信。鼻梁挺拔,嘴唇丰满,涂有裸色系唇膏,嘴角微微上扬,展现出浅浅的微笑。她的肤色 + 白皙,脸颊和锁骨处被暖色调的光线照亮,呈现出健康的红润感。她穿着一件黑色的细吊带背心,肩带 + 纤细,露出优美的锁骨线条。脖颈上佩戴着一条金色的细项链,项链由小珠子和几个细长的金属条组 + 成,在光线下闪烁着光泽。她的外搭是一件米黄色的针织开衫,材质柔软,袖子部分有明显的针织纹理。 + 她双臂交叉在胸前,双手被开衫的袖子覆盖,姿态放松。背景是纯粹的深棕色,没有多余的装饰,使得 + 人物成为画面的绝对焦点。人物位于画面中央。光线从画面的右上方射入,在人物的左侧脸颊、脖颈 + 和锁骨处形成明亮的光斑,右侧则略显阴影,营造出立体感和柔和的影调。图像细节清晰,人物的皮肤 + 纹理、发丝以及衣物材质都得到了很好的展现。色彩以暖色调为主,米黄色和深棕色的搭配营造出温馨 + 舒适的氛围。整体呈现出一种自然、优雅且富有亲和力的艺术风格。 + + 5. 用户输入: 创作一系列图片,展现苹果从种子到结果的生长过程。该系列图片应包含以下四个阶段: + 1. 播种,2. 幼苗生长,3. 植物成熟,4. 果实采摘。 + + 改写输出: + 一个4宫格的精美插图,描绘苹果的生长过程,精确清晰地捕捉每个阶段。1.“播种”:特写镜头,一只 + 手轻轻地将一颗小小的苹果种子放入肥沃的深色土壤中,土壤的纹理和种子光滑的表面清晰可见。背景 + 是花园的柔焦画面,点缀着绿色的树叶和透过树叶洒下的阳光。2.“幼苗生长”:一棵幼小的苹果树苗破 + 土而出,嫩绿的叶子向天空舒展。场景设定在一个生机勃勃的花园中,温暖的金光照亮了它。幼苗的纤 + 细结构。3.“植物的成熟”:一棵成熟的苹果树,枝繁叶茂,挂满了嫩绿的叶子和正在萌发的小苹果。背 + 景是一片生机勃勃的果园,湛蓝的天空下,斑驳的阳光营造出宁静祥和的氛围。4.“采摘果实”:一只手 + 伸向树上,摘下一个成熟的红苹果,苹果光滑的果皮在阳光下闪闪发光。画面展现了果园的丰收景象, + 背景中摆放着一篮篮的苹果,给人一种圆满满足的感觉。每幅插图都采用写实风格,注重细节,色彩和 + 谐,展现了苹果生命周期的自然之美和发展过程。 + + 6. 用户输入: 如果1代表红色,2代表绿色,3代表紫色,4代表黄色,请按照此规则生成四色彩虹。它的 + 颜色顺序从上到下是3142。 + + 改写输出: + 图片由四个水平排列的彩色条纹组成,从上到下依次为紫色、红色、黄色和绿色。每个条纹上都居中放 + 置一个白色数字。最上方的紫色条纹上是数字“3”,其下方红色条纹上是数字“1”,再下方黄色条纹上 + 是数字“4”,最下方的绿色条纹上是数字“2”。所有数字均采用无衬线字体,颜色为纯白色,与背景色 + 形成鲜明对比,确保了良好的可读性。条纹的颜色饱和度高,且带有轻微的纹理感,整体排版简洁明了, + 视觉效果清晰,没有多余的装饰元素,强调了数字信息本身。图片整体清晰度高,色彩准确,风格一致, + 具有较强的视觉吸引力。 + + 7. 用户输入: 石碑上刻着“关关雎鸠,在河之洲”,自然光照,背景是中式园林。 + + 改写输出: + 一块古老的石碑上刻着“关关雎鸠,在河之洲”,石碑表面布满岁月的痕迹,字迹清晰而深刻。自然光线 + 从上方洒下,柔和地照亮石碑的每一个细节,增强了其历史感。背景是一座典雅的中式园林,园林中有 + 翠绿的竹林、蜿蜒的小径和静谧的水池,营造出一种宁静而悠远的氛围。整体画面采用写实风格,细节 + 丰富,光影效果自然,突出了石碑的文化底蕴和园林的古典美。 + +# 输出格式 请直接输出改写优化后的 Prompt 内容,不要包含任何解释性语言或 JSON 格式,不要自行添加开头 +或结尾的引号。 +""" diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index f473332096..05f07f7fb5 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -29,6 +29,11 @@ "pipeline_wan2_2", "Wan22Pipeline", ), + "LongCatImagePipeline": ( + "longcat_image", + "pipeline_longcat_image", + "LongCatImagePipeline", + ), } @@ -67,6 +72,7 @@ def initialize_model( "QwenImageEditPipeline": "get_qwen_image_edit_post_process_func", "ZImagePipeline": "get_post_process_func", "WanPipeline": "get_wan22_post_process_func", + "LongCatImagePipeline": "get_longcat_image_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = {