Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def flash_attn_supports_fp8() -> bool:
)


def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu():
return True
else:
return get_flash_attn_version() == 3


def flash_attn_supports_mla():
from vllm.platforms import current_platform

Expand Down
94 changes: 92 additions & 2 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
else:
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
elif current_platform.is_xpu():
logger.info_once("Using ipex marlin backend on XPU")
return Mxfp4Backend.MARLIN
elif current_platform.is_rocm() and has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
Expand Down Expand Up @@ -188,7 +191,10 @@ def get_quant_method(
return UnquantizedLinearMethod()
raise NotImplementedError("Mxfp4 linear layer is not implemented")
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config)
if current_platform.is_xpu():
return IpexMxfp4MoEMethod(layer.moe_config)
else:
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
Expand Down Expand Up @@ -247,7 +253,10 @@ def create_weights(
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128
)
hidden_size = round_up(hidden_size, 256)
if current_platform.is_xpu():
hidden_size = round_up(hidden_size, 128)
else:
hidden_size = round_up(hidden_size, 256)

layer.params_dtype = params_dtype
layer.num_experts = num_experts
Expand Down Expand Up @@ -1146,3 +1155,84 @@ def apply(
)
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")


class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self.moe_config = moe_config

def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
self.original_hidden_size = hidden_size

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
import intel_extension_for_pytorch as ipex

layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
w1_scale_inv=layer.w13_weight_scale,
w2_scale_inv=layer.w2_weight_scale,
w13_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
is_mxfp4=True,
)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
assert activation == "swigluoai", (
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
) # noqa:
hidden_size_pad = round_up(self.original_hidden_size, 128)
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
hidden_states = layer.ipex_fusion(
x_pad,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
activation="swiglu_oai",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The activation parameter is ignored, and a hardcoded value "swiglu_oai" is used. This appears to be a typo for "swigluoai", which is the value passed into this method for the gpt-oss model. This could lead to a runtime error or incorrect behavior.

Please use the activation parameter that is passed into the method to avoid this issue and make the implementation more general.

Suggested change
activation="swiglu_oai",
activation=activation,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little difference in naming, vllm use swigluoai, ipex interface accept swiglu_oai, I prefer to keep this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jikunshang can you add an assert that activation is swiglu_oai then?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, please take a look again:)

)
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
return hidden_states
3 changes: 0 additions & 3 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,6 @@ def _load_weights_mxfp4(
if is_pp_missing_parameter(name, self):
continue

# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()

if ".w13_weight_scale" in name:
# Handle MLP gate and up projection weights scale
if use_ep:
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import (
flash_attn_supports_sinks,
flash_attn_varlen_func,
get_scheduler_metadata,
reshape_and_cache_flash,
Expand Down Expand Up @@ -497,7 +498,7 @@ def __init__(

self.sinks = sinks
if self.sinks is not None:
assert self.vllm_flash_attn_version == 3, (
assert flash_attn_supports_sinks(), (
"Sinks are only supported in FlashAttention 3"
)
assert self.sinks.shape[0] == num_heads, (
Expand Down