Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions vllm_omni/diffusion/layers/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,29 @@ def forward_native(
sin,
interleaved=self.interleaved,
)


def apply_rope_to_qk(
rope: RotaryEmbedding,
query: torch.Tensor,
key: torch.Tensor,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary positional embeddings to query and key tensors.

Args:
rope: RotaryEmbedding instance for applying position embeddings
query: Query tensor [B, S, H, D]
key: Key tensor [B, S, H, D]
image_rotary_emb: Tuple of (cos, sin) tensors or None

Returns:
Tuple of (query, key) with RoPE applied if rotary embeddings provided
"""
if image_rotary_emb is not None:
cos, sin = image_rotary_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query = rope(query, cos, sin)
key = rope(key, cos, sin)
return query, key
9 changes: 2 additions & 7 deletions vllm_omni/diffusion/models/flux/flux_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
from vllm_omni.diffusion.layers.rope import RotaryEmbedding, apply_rope_to_qk

logger = init_logger(__name__)

Expand Down Expand Up @@ -224,12 +224,7 @@ def forward(
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)

if image_rotary_emb is not None:
cos, sin = image_rotary_emb # [S, D/2]
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query = self.rope(query, cos, sin)
key = self.rope(key, cos, sin)
query, key = apply_rope_to_qk(self.rope, query, key, image_rotary_emb) # [S, D/2]

hidden_states = self.attn(
query,
Expand Down
Loading
Loading