diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index b92b822c1d19..adb9b08a6573 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -80,6 +80,13 @@ def flash_attn_supports_fp8() -> bool: ) +def flash_attn_supports_sinks() -> bool: + if current_platform.is_xpu(): + return True + else: + return get_flash_attn_version() == 3 + + def flash_attn_supports_mla(): from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index bf34ec0f3899..ca6c9502d405 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -142,6 +142,9 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: else: logger.info_once("Using Triton backend") return Mxfp4Backend.TRITON + elif current_platform.is_xpu(): + logger.info_once("Using ipex marlin backend on XPU") + return Mxfp4Backend.MARLIN elif current_platform.is_rocm() and has_triton_kernels(): logger.info_once("Using Triton backend") return Mxfp4Backend.TRITON @@ -188,7 +191,10 @@ def get_quant_method( return UnquantizedLinearMethod() raise NotImplementedError("Mxfp4 linear layer is not implemented") elif isinstance(layer, FusedMoE): - return Mxfp4MoEMethod(layer.moe_config) + if current_platform.is_xpu(): + return IpexMxfp4MoEMethod(layer.moe_config) + else: + return Mxfp4MoEMethod(layer.moe_config) elif isinstance(layer, Attention): raise NotImplementedError("Mxfp4 attention layer is not implemented") return None @@ -247,7 +253,10 @@ def create_weights( intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128 ) - hidden_size = round_up(hidden_size, 256) + if current_platform.is_xpu(): + hidden_size = round_up(hidden_size, 128) + else: + hidden_size = round_up(hidden_size, 256) layer.params_dtype = params_dtype layer.num_experts = num_experts @@ -1146,3 +1155,84 @@ def apply( ) else: raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") + + +class IpexMxfp4MoEMethod(Mxfp4MoEMethod): + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) + self.moe_config = moe_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + super().create_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, + ) + self.original_hidden_size = hidden_size + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + import intel_extension_for_pytorch as ipex + + layer.w13_weight.data = layer.w13_weight.data.view(torch.int32) + layer.w2_weight.data = layer.w2_weight.data.view(torch.int32) + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + w1_scale_inv=layer.w13_weight_scale, + w2_scale_inv=layer.w2_weight_scale, + w13_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + is_mxfp4=True, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor: + assert activation == "swigluoai", ( + "Only swiglu_oai activation is supported for IPEX MXFP4 MoE" + ) # noqa: + hidden_size_pad = round_up(self.original_hidden_size, 128) + x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1))) + hidden_states = layer.ipex_fusion( + x_pad, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + activation="swiglu_oai", + ) + hidden_states = hidden_states[..., : self.original_hidden_size].contiguous() + return hidden_states diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 44f6824b5212..bc7c94973541 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -329,9 +329,6 @@ def _load_weights_mxfp4( if is_pp_missing_parameter(name, self): continue - # FIXME(woosuk): Remove this after testing. - weight = weight.cuda() - if ".w13_weight_scale" in name: # Handle MLP gate and up projection weights scale if use_ep: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 07f9ef173b4e..a41d318c3c2c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -27,6 +27,7 @@ if is_flash_attn_varlen_func_available(): from vllm.attention.utils.fa_utils import ( + flash_attn_supports_sinks, flash_attn_varlen_func, get_scheduler_metadata, reshape_and_cache_flash, @@ -497,7 +498,7 @@ def __init__( self.sinks = sinks if self.sinks is not None: - assert self.vllm_flash_attn_version == 3, ( + assert flash_attn_supports_sinks(), ( "Sinks are only supported in FlashAttention 3" ) assert self.sinks.shape[0] == num_heads, (