From 9a99d349405ca99ee70f3244192868559917576c Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Mon, 26 May 2025 16:54:01 +0300 Subject: [PATCH 1/4] Enable triangular softmax with merged prefill --- vllm_hpu_extension/flags.py | 5 +++++ vllm_hpu_extension/ops.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm_hpu_extension/flags.py b/vllm_hpu_extension/flags.py index 1c3e1ba7c..686ba4323 100644 --- a/vllm_hpu_extension/flags.py +++ b/vllm_hpu_extension/flags.py @@ -156,6 +156,11 @@ def enabled_flags(): & Kernel(fsdpa) & EnvFlag("VLLM_PROMPT_USE_FUSEDSDPA", Not(ModelType('mllama')))), + "merged_prefill_traingular_mask": (Not(Hardware("cpu")) + & Kernel(fsdpa) + & VersionRange(">=1.22.0.294") + & EnvFlag("VLLM_PROMPT_USE_FUSEDSDPA", + Not(ModelType('mllama')))), "compile_one_hot": (VersionRange(">=1.20.0.370") & Not(EnvFlag("PT_HPU_LAZY_MODE", "1"))), "flex_attention": (Not(Hardware("cpu")) & Not(EnvFlag("PT_HPU_LAZY_MODE", "1")) & ModelType("llama") diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index 2453afeca..ac83aa4b9 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -274,8 +274,9 @@ def _fsdpa_prompt_attention( assert attn_bias is not None or valid_seq_lengths is not None, \ 'Either attn_bias or valid_seq_lengths must be != None' if is_causal and attn_bias is not None: - # TODO: causal + attn_bias is not yet supported - is_causal = False + if 'merged_prefill_triangular_mask' not in enabled_flags(): + is_causal = False + # TODO: valid_seq_lenghts is not yet supported for causal with attn_bias valid_seq_lengths = None attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, From b6f7e949325c15bd6c27ec4b7c4658f2e246bb04 Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Tue, 3 Jun 2025 09:21:30 +0200 Subject: [PATCH 2/4] Change to generic name --- vllm_hpu_extension/flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_hpu_extension/flags.py b/vllm_hpu_extension/flags.py index 686ba4323..79a46e69d 100644 --- a/vllm_hpu_extension/flags.py +++ b/vllm_hpu_extension/flags.py @@ -156,7 +156,7 @@ def enabled_flags(): & Kernel(fsdpa) & EnvFlag("VLLM_PROMPT_USE_FUSEDSDPA", Not(ModelType('mllama')))), - "merged_prefill_traingular_mask": (Not(Hardware("cpu")) + "triangular_mask": (Not(Hardware("cpu")) & Kernel(fsdpa) & VersionRange(">=1.22.0.294") & EnvFlag("VLLM_PROMPT_USE_FUSEDSDPA", From 0178798765517973e771433019c45c1282e965cc Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Tue, 3 Jun 2025 09:21:56 +0200 Subject: [PATCH 3/4] Change to generic name --- vllm_hpu_extension/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index ac83aa4b9..dd25b8360 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -274,7 +274,7 @@ def _fsdpa_prompt_attention( assert attn_bias is not None or valid_seq_lengths is not None, \ 'Either attn_bias or valid_seq_lengths must be != None' if is_causal and attn_bias is not None: - if 'merged_prefill_triangular_mask' not in enabled_flags(): + if 'triangular_mask' not in enabled_flags(): is_causal = False # TODO: valid_seq_lenghts is not yet supported for causal with attn_bias valid_seq_lengths = None From 06756a9e9911b3f8218af15551dfa99a8cdb5a56 Mon Sep 17 00:00:00 2001 From: Kamil Kaczor Date: Mon, 14 Jul 2025 12:43:37 +0200 Subject: [PATCH 4/4] Enable valid_seq_len in triangular merged --- vllm_hpu_extension/ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index dd25b8360..b6bb198da 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -276,8 +276,7 @@ def _fsdpa_prompt_attention( if is_causal and attn_bias is not None: if 'triangular_mask' not in enabled_flags(): is_causal = False - # TODO: valid_seq_lenghts is not yet supported for causal with attn_bias - valid_seq_lengths = None + valid_seq_lengths = torch.sum(valid_seq_lengths) attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, 'right')