@@ -103,12 +103,41 @@ def default_unquantized_gemm(
103103 return torch .nn .functional .linear (x , weight , bias )
104104
105105
106+ def use_aiter_triton_gemm (n , m , k , dtype ):
107+ if (
108+ envs .VLLM_ROCM_USE_AITER == 0
109+ # MI300's - fp8nuz=True
110+ or current_platform .is_fp8_fnuz ()
111+ or dtype not in [torch .float16 , torch .bfloat16 ]
112+ ):
113+ return False
114+
115+ # use hipblaslt for the larger GEMMs
116+ if n > 2048 and m > 512 :
117+ return False
118+ return (
119+ (m == 5120 and k == 2880 )
120+ or (m == 2880 and k == 4096 )
121+ or (m == 128 and k == 2880 )
122+ or (m == 640 and k == 2880 )
123+ or (m == 2880 and k == 512 )
124+ )
125+
126+
106127def rocm_unquantized_gemm_impl (
107128 x : torch .Tensor , weight : torch .Tensor , bias : torch .Tensor | None = None
108129) -> torch .Tensor :
109130 from vllm .platforms .rocm import on_gfx9
110131
132+ n = x .numel () / x .size (- 1 )
133+ m = weight .shape [0 ]
111134 k = weight .shape [1 ]
135+
136+ if use_aiter_triton_gemm (n , m , k , x .dtype ):
137+ from aiter .ops .triton .gemm_a16w16 import gemm_a16w16
138+
139+ return gemm_a16w16 (x , weight , bias )
140+
112141 use_skinny = (
113142 envs .VLLM_ROCM_USE_SKINNY_GEMM
114143 and on_gfx9 ()
@@ -120,11 +149,8 @@ def rocm_unquantized_gemm_impl(
120149 return torch .nn .functional .linear (x , weight , bias )
121150
122151 x_view = x .reshape (- 1 , x .size (- 1 ))
123- n = x_view .shape [0 ]
124- m = weight .shape [0 ]
125- cu_count = current_platform .get_cu_count ()
126-
127152 if m > 8 and 0 < n <= 4 :
153+ cu_count = current_platform .get_cu_count ()
128154 out = ops .wvSplitK (weight , x_view , cu_count , bias )
129155 return out .reshape (* x .shape [:- 1 ], weight .shape [0 ])
130156 elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None :
@@ -133,7 +159,7 @@ def rocm_unquantized_gemm_impl(
133159 return torch .nn .functional .linear (x , weight , bias )
134160
135161
136- def rocm_unquantized_gemm_impl_fake (
162+ def rocm_unquantized_gemm_fake (
137163 x : torch .Tensor , weight : torch .Tensor , bias : torch .Tensor | None = None
138164) -> torch .Tensor :
139165 return x .new_empty ((* x .shape [:- 1 ], weight .shape [0 ]))
@@ -145,13 +171,13 @@ def rocm_unquantized_gemm(
145171 weight : torch .Tensor ,
146172 bias : torch .Tensor | None = None ,
147173) -> torch .Tensor :
148- return torch .ops .vllm .rocm_unquantized_gemm_impl (x , weight , bias )
174+ return torch .ops .vllm .rocm_unquantized_gemm (x , weight , bias )
149175
150176
151177direct_register_custom_op (
152- op_name = "rocm_unquantized_gemm_impl " ,
178+ op_name = "rocm_unquantized_gemm " ,
153179 op_func = rocm_unquantized_gemm_impl ,
154- fake_impl = rocm_unquantized_gemm_impl_fake ,
180+ fake_impl = rocm_unquantized_gemm_fake ,
155181)
156182
157183
0 commit comments