Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,15 @@ def cutlass_scaled_mm(a: torch.Tensor,
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype

m = a.shape[0]
n = b.shape[1]

if current_platform.is_rocm():
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
if current_platform.is_rocm() or not cutlass_compatible_b:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this could change behavior for CUDA since it will now run the triton kernel if the cutlass kernel isn't satisfied

triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
Expand Down
28 changes: 27 additions & 1 deletion vllm/model_executor/layers/quantization/utils/int8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ def block_dequant(
return x_dq_block


if current_platform.is_rocm():
from triton.language import core

# NOTE: This can be removed when hip.libdevice.round() is available.
@core.extern
def round_f32(arg0, _builder=None):
return core.extern_elementwise("",
"", [arg0], {
(core.dtype("fp32"), ):
("llvm.round", core.dtype("fp32")),
(core.dtype("fp64"), ):
("llvm.round", core.dtype("fp64")),
},
is_pure=True,
_builder=_builder)

@triton.jit
def round_int8(x):
return round_f32(x).to(tl.int8)
else:

@triton.jit
def round_int8(x):
return tl.extra.cuda.libdevice.round(x).to(tl.int8)


@triton.jit
def _per_token_quant_int8(
x_ptr,
Expand All @@ -106,7 +132,7 @@ def _per_token_quant_int8(
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
x_q = round_int8(x_q)

tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
Expand Down