Skip to content

Commit 2d977a7

Browse files
maleksan85Aleksandr Malyshev
andauthored
[ROCm] gemm_a16w16 upstreaming (vllm-project#26969)
Signed-off-by: Aleksandr Malyshev <[email protected]> Co-authored-by: Aleksandr Malyshev <[email protected]>
1 parent 1fb4217 commit 2d977a7

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

vllm/model_executor/layers/utils.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,41 @@ def default_unquantized_gemm(
103103
return torch.nn.functional.linear(x, weight, bias)
104104

105105

106+
def use_aiter_triton_gemm(n, m, k, dtype):
107+
if (
108+
envs.VLLM_ROCM_USE_AITER == 0
109+
# MI300's - fp8nuz=True
110+
or current_platform.is_fp8_fnuz()
111+
or dtype not in [torch.float16, torch.bfloat16]
112+
):
113+
return False
114+
115+
# use hipblaslt for the larger GEMMs
116+
if n > 2048 and m > 512:
117+
return False
118+
return (
119+
(m == 5120 and k == 2880)
120+
or (m == 2880 and k == 4096)
121+
or (m == 128 and k == 2880)
122+
or (m == 640 and k == 2880)
123+
or (m == 2880 and k == 512)
124+
)
125+
126+
106127
def rocm_unquantized_gemm_impl(
107128
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
108129
) -> torch.Tensor:
109130
from vllm.platforms.rocm import on_gfx9
110131

132+
n = x.numel() / x.size(-1)
133+
m = weight.shape[0]
111134
k = weight.shape[1]
135+
136+
if use_aiter_triton_gemm(n, m, k, x.dtype):
137+
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
138+
139+
return gemm_a16w16(x, weight, bias)
140+
112141
use_skinny = (
113142
envs.VLLM_ROCM_USE_SKINNY_GEMM
114143
and on_gfx9()
@@ -120,11 +149,8 @@ def rocm_unquantized_gemm_impl(
120149
return torch.nn.functional.linear(x, weight, bias)
121150

122151
x_view = x.reshape(-1, x.size(-1))
123-
n = x_view.shape[0]
124-
m = weight.shape[0]
125-
cu_count = current_platform.get_cu_count()
126-
127152
if m > 8 and 0 < n <= 4:
153+
cu_count = current_platform.get_cu_count()
128154
out = ops.wvSplitK(weight, x_view, cu_count, bias)
129155
return out.reshape(*x.shape[:-1], weight.shape[0])
130156
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
@@ -133,7 +159,7 @@ def rocm_unquantized_gemm_impl(
133159
return torch.nn.functional.linear(x, weight, bias)
134160

135161

136-
def rocm_unquantized_gemm_impl_fake(
162+
def rocm_unquantized_gemm_fake(
137163
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
138164
) -> torch.Tensor:
139165
return x.new_empty((*x.shape[:-1], weight.shape[0]))
@@ -145,13 +171,13 @@ def rocm_unquantized_gemm(
145171
weight: torch.Tensor,
146172
bias: torch.Tensor | None = None,
147173
) -> torch.Tensor:
148-
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
174+
return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)
149175

150176

151177
direct_register_custom_op(
152-
op_name="rocm_unquantized_gemm_impl",
178+
op_name="rocm_unquantized_gemm",
153179
op_func=rocm_unquantized_gemm_impl,
154-
fake_impl=rocm_unquantized_gemm_impl_fake,
180+
fake_impl=rocm_unquantized_gemm_fake,
155181
)
156182

157183

vllm/model_executor/models/gpt_oss.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2626
from vllm.model_executor.layers.quantization import QuantizationConfig
2727
from vllm.model_executor.layers.rotary_embedding import get_rope
28+
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
2829
from vllm.model_executor.layers.vocab_parallel_embedding import (
2930
ParallelLMHead,
3031
VocabParallelEmbedding,
3132
)
3233
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3334
from vllm.model_executor.models.utils import sequence_parallel_chunk
35+
from vllm.platforms import current_platform
3436
from vllm.sequence import IntermediateTensors
3537
from vllm.utils.math_utils import cdiv
3638

@@ -153,6 +155,7 @@ def __init__(
153155

154156
self.layer_idx = layer_idx
155157
self.num_experts = config.num_local_experts
158+
self.hidden_size = config.hidden_size
156159
self.experts_per_token = config.num_experts_per_tok
157160
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
158161
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
@@ -177,7 +180,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
177180
if self.is_sequence_parallel:
178181
x = sequence_parallel_chunk(x)
179182

180-
g = self.router(x)
183+
if current_platform.is_rocm():
184+
g = rocm_unquantized_gemm(
185+
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
186+
)
187+
else:
188+
g = self.router(x)
181189
x = self.experts(hidden_states=x, router_logits=g)
182190

183191
if self.is_sequence_parallel:

0 commit comments

Comments
 (0)