Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vllm_hpu_extension/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from vllm_hpu_extension.environment import get_environment
from vllm_hpu_extension.kernels import fsdpa
from vllm_hpu_extension.kernels import block_softmax_adjustment, block_softmax


detected = None
Expand Down Expand Up @@ -160,6 +161,15 @@ def enabled_flags():
& ModelType("llama")
& Not(EnvFlag("VLLM_PROMPT_USE_FUSEDSDPA", "false"))
& EnvFlag("VLLM_PROMPT_USE_FLEX_ATTENTION", "false")),
"fused_block_softmax_adjustment": (Not(Hardware("cpu"))
& VersionRange(">=1.22.0.494")
& Kernel(block_softmax_adjustment)
& EnvFlag("VLLM_FUSED_BLOCK_SOFTMAX_ADJUSTMENT",
Not(ModelType('qwen2')) & Hardware("gaudi3"))),
"fused_block_softmax": (Not(Hardware("cpu"))
& VersionRange(">=1.22.0.494")
& Kernel(block_softmax)
& EnvFlag("VLLM_FUSED_BLOCK_SOFTMAX", "false")),
}
environment = get_environment()
detected = Flags(supported_flags, environment)
Expand Down
50 changes: 34 additions & 16 deletions vllm_hpu_extension/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,42 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

from .utils import logger
from functools import cache


@cache
def _kernel(name):
def loader(fn):
@cache
def loader_impl():
try:
print("Load", name, fn)
return fn()
except (ImportError, AttributeError):
from .utils import logger
logger().warning(f"Could not import HPU {name} kernel. "
"vLLM will use native implementation")
return loader_impl
return loader


@_kernel("FusedSDPA")
def fsdpa():
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
return FusedSDPA
except ImportError:
logger().warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")

@cache
from habana_frameworks.torch.hpex.kernels import FusedSDPA
return FusedSDPA


@_kernel("FusedRMSNorm")
def rms_norm():
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
return FusedRMSNorm
except ImportError:
logger().warning("Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
return FusedRMSNorm


@_kernel("block_softmax_adjustment")
def block_softmax_adjustment():
import torch
return torch.ops.hpu.block_softmax_adjustment

@_kernel("block_softmax")
def block_softmax():
import torch
return torch.ops.hpu.block_softmax
83 changes: 54 additions & 29 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,35 +61,60 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s
# We can return to native dtype after we renormalize and calculate the adjustments

# Normalize the attention scores and cast attn to native dtype
block_max = attn.amax(dim=-1, keepdim=True)
if 'fused_block_softmax' in enabled_flags() and attn.dim() == 5:
print('INFO: run with fused_block_softmax ====================')
attn_shape = attn.shape
block_bias = torch.zeros(attn_shape[0], 1, 1, attn_shape[3], attn_shape[4],
device=attn.device,
dtype=attn.dtype) # torch.ops.hpu.block_softmax can't take None block_bias
attn, block_max, block_sums = torch.ops.hpu.block_softmax(attn, block_bias, block_groups)
# To make block_max and block_sums same output shape with none-fused block softmax
block_max = block_max.view(block_max.shape[0], 1, block_max.shape[1], 1, 1)
block_sums = block_sums.view(block_sums.shape[0], 1, block_sums.shape[1], 1, 1)
else:
print('INFO: run with normal block softmax ====================')
block_max = attn.amax(dim=-1, keepdim=True)
attn = attn.sub(block_max)
attn = attn.exp()
if attn.dtype == torch.float32:
attn = attn.to(value.dtype)
block_sums = attn.sum(dim=-1, keepdim=True)

adjustment_target_shape = block_max.shape
attn = attn.sub(block_max)
attn = attn.exp()
attn = attn.to(value.dtype)
block_sums = attn.sum(dim=-1, keepdim=True)

print(f'INFO: {attn.shape=}, {block_max.shape=}, {block_sums.shape=}')
attn = matmul_av_op(attn, value)
block_max = block_max.squeeze()
block_sums = block_sums.squeeze()

# Calculate maximum of blocks that belong to the same sequences
# and cast adjustments to native dtype
group_max = grouped_max(block_max, batch_size, block_groups)
block_adjustment = (block_max - group_max).exp()
block_adjustment = block_adjustment.to(value.dtype)
sum_adjusted = block_sums.mul(block_adjustment)

# Sum block's sums that belongs to the same sequences
group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op)
group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op)
sum_adjusted = sum_adjusted.view(*adjustment_target_shape)
group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape)
block_adjustment = block_adjustment.view(*adjustment_target_shape)

# For stability in case some of the sums have been zeroed out during block aggretation
group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted)

# Post processing for the attention scores
rescale = block_adjustment.div(group_sum_adjusted)

if 'fused_block_softmax_adjustment' in enabled_flags() and block_max.dtype != torch.float16:
print('INFO: run with fused_block_softmax_adjustment ====================')
rescale = torch.ops.hpu.block_softmax_adjustment(block_max,
block_sums.to(block_max.dtype),
block_groups,
batch_size).to(attn.dtype)
else:
print('INFO: run with normal block softmax adjustment ====================')
block_max = block_max.squeeze()
block_sums = block_sums.squeeze()

# Calculate maximum of blocks that belong to the same sequences
# and cast adjustments to native dtype
group_max = grouped_max(block_max, batch_size, block_groups)
block_adjustment = (block_max - group_max).exp()
if block_adjustment.dtype == torch.float32:
block_adjustment = block_adjustment.to(value.dtype)
sum_adjusted = block_sums.mul(block_adjustment)

# Sum block's sums that belongs to the same sequences
group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op)
group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op)
sum_adjusted = sum_adjusted.view(*adjustment_target_shape)
group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape)
block_adjustment = block_adjustment.view(*adjustment_target_shape)

# For stability in case some of the sums have been zeroed out during block aggretation
group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted)
# Post processing for the attention scores
rescale = block_adjustment.div(group_sum_adjusted)
attn = attn.mul(rescale)
return attn

Expand Down Expand Up @@ -405,8 +430,8 @@ def forward(self, hidden_states, score, topk):
htorch.core.mark_step()
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)

Expand Down