From c103a86149c617dd063f26fdb2fa6c2ce26373f6 Mon Sep 17 00:00:00 2001 From: Bob Zhu Date: Wed, 16 Jul 2025 13:44:56 +0800 Subject: [PATCH 1/3] Introduce block_softmax_adjustment kernel Cherry-pick from https://github.com/HabanaAI/vllm-hpu-extension/pull/263 --- vllm_hpu_extension/flags.py | 6 ++++ vllm_hpu_extension/kernels.py | 45 +++++++++++++++++---------- vllm_hpu_extension/ops.py | 58 ++++++++++++++++++++--------------- 3 files changed, 68 insertions(+), 41 deletions(-) diff --git a/vllm_hpu_extension/flags.py b/vllm_hpu_extension/flags.py index 3dc9f885e..1cd94496c 100644 --- a/vllm_hpu_extension/flags.py +++ b/vllm_hpu_extension/flags.py @@ -11,6 +11,7 @@ from vllm_hpu_extension.environment import get_environment from vllm_hpu_extension.kernels import fsdpa +from vllm_hpu_extension.kernels import block_softmax_adjustment detected = None @@ -160,6 +161,11 @@ def enabled_flags(): & ModelType("llama") & Not(EnvFlag("VLLM_PROMPT_USE_FUSEDSDPA", "false")) & EnvFlag("VLLM_PROMPT_USE_FLEX_ATTENTION", "false")), + "fused_block_softmax_adjustment": (Not(Hardware("cpu")) + & VersionRange(">=1.22.0.101") + & Kernel(block_softmax_adjustment) + & EnvFlag("VLLM_FUSED_BLOCK_SOFTMAX_ADJUSTMENT", + Not(ModelType('qwen2')) & Hardware("gaudi3"))), } environment = get_environment() detected = Flags(supported_flags, environment) diff --git a/vllm_hpu_extension/kernels.py b/vllm_hpu_extension/kernels.py index 77164985a..5aadd3c1f 100644 --- a/vllm_hpu_extension/kernels.py +++ b/vllm_hpu_extension/kernels.py @@ -5,24 +5,37 @@ # LICENSE file in the root directory of this source tree. ############################################################################### -from .utils import logger from functools import cache -@cache +def _kernel(name): + def loader(fn): + @cache + def loader_impl(): + try: + print("Load", name, fn) + return fn() + except (ImportError, AttributeError): + from .utils import logger + logger().warning(f"Could not import HPU {name} kernel. " + "vLLM will use native implementation") + return loader_impl + return loader + + +@_kernel("FusedSDPA") def fsdpa(): - try: - from habana_frameworks.torch.hpex.kernels import FusedSDPA - return FusedSDPA - except ImportError: - logger().warning("Could not import HPU FusedSDPA kernel. " - "vLLM will use native implementation.") - -@cache + from habana_frameworks.torch.hpex.kernels import FusedSDPA + return FusedSDPA + + +@_kernel("FusedRMSNorm") def rms_norm(): - try: - from habana_frameworks.torch.hpex.normalization import FusedRMSNorm - return FusedRMSNorm - except ImportError: - logger().warning("Could not import HPU FusedRMSNorm kernel. " - "vLLM will use forward_native implementation of RMSNorm.") + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm + return FusedRMSNorm + + +@_kernel("block_softmax_adjustment") +def block_softmax_adjustment(): + import torch + return torch.ops.hpu.block_softmax_adjustment diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index dbc7c4457..d17885fbf 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -65,31 +65,39 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s adjustment_target_shape = block_max.shape attn = attn.sub(block_max) attn = attn.exp() - attn = attn.to(value.dtype) + if attn.dtype == torch.float32: + attn = attn.to(value.dtype) block_sums = attn.sum(dim=-1, keepdim=True) attn = matmul_av_op(attn, value) - block_max = block_max.squeeze() - block_sums = block_sums.squeeze() - - # Calculate maximum of blocks that belong to the same sequences - # and cast adjustments to native dtype - group_max = grouped_max(block_max, batch_size, block_groups) - block_adjustment = (block_max - group_max).exp() - block_adjustment = block_adjustment.to(value.dtype) - sum_adjusted = block_sums.mul(block_adjustment) - - # Sum block's sums that belongs to the same sequences - group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op) - group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op) - sum_adjusted = sum_adjusted.view(*adjustment_target_shape) - group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape) - block_adjustment = block_adjustment.view(*adjustment_target_shape) - - # For stability in case some of the sums have been zeroed out during block aggretation - group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted) - - # Post processing for the attention scores - rescale = block_adjustment.div(group_sum_adjusted) + + if 'fused_block_softmax_adjustment' in enabled_flags() and block_max.dtype != torch.float16: + rescale = torch.ops.hpu.block_softmax_adjustment(block_max, + block_sums.to(block_max.dtype), + block_groups, + batch_size).to(attn.dtype) + else: + block_max = block_max.squeeze() + block_sums = block_sums.squeeze() + + # Calculate maximum of blocks that belong to the same sequences + # and cast adjustments to native dtype + group_max = grouped_max(block_max, batch_size, block_groups) + block_adjustment = (block_max - group_max).exp() + if block_adjustment.dtype == torch.float32: + block_adjustment = block_adjustment.to(value.dtype) + sum_adjusted = block_sums.mul(block_adjustment) + + # Sum block's sums that belongs to the same sequences + group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op) + group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op) + sum_adjusted = sum_adjusted.view(*adjustment_target_shape) + group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape) + block_adjustment = block_adjustment.view(*adjustment_target_shape) + + # For stability in case some of the sums have been zeroed out during block aggretation + group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted) + # Post processing for the attention scores + rescale = block_adjustment.div(group_sum_adjusted) attn = attn.mul(rescale) return attn @@ -405,8 +413,8 @@ def forward(self, hidden_states, score, topk): htorch.core.mark_step() routing_weights = F.softmax(score, dim=1, dtype=torch.float32) routing_weights, selected_experts = torch.topk(routing_weights, - topk, - dim=-1) + topk, + dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) From af6ea07060e23c1476160db71f052fea4905b770 Mon Sep 17 00:00:00 2001 From: Bob Zhu Date: Wed, 16 Jul 2025 14:19:26 +0800 Subject: [PATCH 2/3] Add block_softmax kernel support --- vllm_hpu_extension/flags.py | 8 ++++++-- vllm_hpu_extension/kernels.py | 5 +++++ vllm_hpu_extension/ops.py | 27 ++++++++++++++++++++------- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/vllm_hpu_extension/flags.py b/vllm_hpu_extension/flags.py index 1cd94496c..721461608 100644 --- a/vllm_hpu_extension/flags.py +++ b/vllm_hpu_extension/flags.py @@ -11,7 +11,7 @@ from vllm_hpu_extension.environment import get_environment from vllm_hpu_extension.kernels import fsdpa -from vllm_hpu_extension.kernels import block_softmax_adjustment +from vllm_hpu_extension.kernels import block_softmax_adjustment, block_softmax detected = None @@ -162,10 +162,14 @@ def enabled_flags(): & Not(EnvFlag("VLLM_PROMPT_USE_FUSEDSDPA", "false")) & EnvFlag("VLLM_PROMPT_USE_FLEX_ATTENTION", "false")), "fused_block_softmax_adjustment": (Not(Hardware("cpu")) - & VersionRange(">=1.22.0.101") + & VersionRange(">=1.22.0.494") & Kernel(block_softmax_adjustment) & EnvFlag("VLLM_FUSED_BLOCK_SOFTMAX_ADJUSTMENT", Not(ModelType('qwen2')) & Hardware("gaudi3"))), + "fused_block_softmax": (Not(Hardware("cpu")) + & VersionRange(">=1.22.0.494") + & Kernel(block_softmax) + & EnvFlag("VLLM_FUSED_BLOCK_SOFTMAX", "false")), } environment = get_environment() detected = Flags(supported_flags, environment) diff --git a/vllm_hpu_extension/kernels.py b/vllm_hpu_extension/kernels.py index 5aadd3c1f..701be4413 100644 --- a/vllm_hpu_extension/kernels.py +++ b/vllm_hpu_extension/kernels.py @@ -39,3 +39,8 @@ def rms_norm(): def block_softmax_adjustment(): import torch return torch.ops.hpu.block_softmax_adjustment + +@_kernel("block_softmax") +def block_softmax(): + import torch + return torch.ops.hpu.block_softmax \ No newline at end of file diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index d17885fbf..4a3b8f5f2 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -61,21 +61,34 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s # We can return to native dtype after we renormalize and calculate the adjustments # Normalize the attention scores and cast attn to native dtype - block_max = attn.amax(dim=-1, keepdim=True) - adjustment_target_shape = block_max.shape - attn = attn.sub(block_max) - attn = attn.exp() - if attn.dtype == torch.float32: - attn = attn.to(value.dtype) - block_sums = attn.sum(dim=-1, keepdim=True) + if 'fused_block_softmax' in enabled_flags() and attn.dim() == 5: + print('INFO: run with fused_block_softmax ====================') + block_bias = torch.zeros_like(attn) # torch.ops.hpu.block_softmax can't take None block_bias + attn, block_max, block_sums = torch.ops.hpu.block_softmax(attn, block_bias, block_groups) + # To make block_max and block_sums same output shape with none-fused block softmax + block_max = block_max.view(block_max.shape[0], 1, block_max.shape[1], 1, 1) + block_sums = block_sums.view(block_sums.shape[0], 1, block_sums.shape[1], 1, 1) + else: + print('INFO: run with normal block softmax ====================') + block_max = attn.amax(dim=-1, keepdim=True) + adjustment_target_shape = block_max.shape + attn = attn.sub(block_max) + attn = attn.exp() + if attn.dtype == torch.float32: + attn = attn.to(value.dtype) + block_sums = attn.sum(dim=-1, keepdim=True) + + print(f'INFO: {attn.shape=}, {block_max.shape=}, {block_sums.shape=}') attn = matmul_av_op(attn, value) if 'fused_block_softmax_adjustment' in enabled_flags() and block_max.dtype != torch.float16: + print('INFO: run with fused_block_softmax_adjustment ====================') rescale = torch.ops.hpu.block_softmax_adjustment(block_max, block_sums.to(block_max.dtype), block_groups, batch_size).to(attn.dtype) else: + print('INFO: run with normal block softmax adjustment ====================') block_max = block_max.squeeze() block_sums = block_sums.squeeze() From 7689041c9ad684ab786dbfa282aff9cf2058911e Mon Sep 17 00:00:00 2001 From: Bob Zhu Date: Wed, 16 Jul 2025 14:36:30 +0800 Subject: [PATCH 3/3] Fix the GC error due to block_bias shape mismatch --- vllm_hpu_extension/ops.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index 4a3b8f5f2..3556a655d 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -63,7 +63,10 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s # Normalize the attention scores and cast attn to native dtype if 'fused_block_softmax' in enabled_flags() and attn.dim() == 5: print('INFO: run with fused_block_softmax ====================') - block_bias = torch.zeros_like(attn) # torch.ops.hpu.block_softmax can't take None block_bias + attn_shape = attn.shape + block_bias = torch.zeros(attn_shape[0], 1, 1, attn_shape[3], attn_shape[4], + device=attn.device, + dtype=attn.dtype) # torch.ops.hpu.block_softmax can't take None block_bias attn, block_max, block_sums = torch.ops.hpu.block_softmax(attn, block_bias, block_groups) # To make block_max and block_sums same output shape with none-fused block softmax block_max = block_max.view(block_max.shape[0], 1, block_max.shape[1], 1, 1) @@ -71,13 +74,14 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s else: print('INFO: run with normal block softmax ====================') block_max = attn.amax(dim=-1, keepdim=True) - adjustment_target_shape = block_max.shape attn = attn.sub(block_max) attn = attn.exp() if attn.dtype == torch.float32: attn = attn.to(value.dtype) block_sums = attn.sum(dim=-1, keepdim=True) + adjustment_target_shape = block_max.shape + print(f'INFO: {attn.shape=}, {block_max.shape=}, {block_sums.shape=}') attn = matmul_av_op(attn, value)