diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index 84717a165..43a2e98a5 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -688,7 +688,7 @@ def forward( **extra_kwargs, ) -> torch.Tensor: # specify attention type for static batching - extra_kwargs['attn_name'] = "sdpa_causal" + extra_kwargs['attn_name'] = self.attention_name if envs_spyre.VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS: # In order to calculate prompt logprobs, we have to return the