1616 import deep_gemm
1717 from deep_gemm import get_num_sms
1818 from deep_gemm .jit_kernels .gemm import get_best_configs
19- from deep_gemm .jit_kernels .gemm import includes as deep_gemm_includes
20- from deep_gemm .jit_kernels .gemm import template as deep_gemm_gemm_template
21- from deep_gemm .jit_kernels .m_grouped_gemm import (
22- template as deep_gemm_grouped_gemm_template ,
23- )
19+ from deep_gemm .jit_kernels .runtime import FP8GemmRuntime , GemmType
2420 from deep_gemm .jit_kernels .tuner import jit_tuner
2521
2622 sm_version = get_device_sm ()
@@ -45,10 +41,15 @@ def get_enable_jit_deepgemm():
4541_IN_PRECOMPILE_STAGE = get_bool_env_var ("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE" , "false" )
4642
4743# Force redirect deep_gemm cache_dir
48- os .environ ["DG_CACHE_DIR " ] = os .getenv (
49- "SGL_DG_CACHE_DIR" , os .path .expanduser ("~" ) + "/ .cache/ deep_gemm"
44+ os .environ ["DG_JIT_CACHE_DIR " ] = os .getenv (
45+ "SGL_DG_CACHE_DIR" , os .path .join ( os . path . expanduser ("~" ), " .cache" , " deep_gemm")
5046)
5147
48+ # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
49+ # NVRTC may have performance loss with some cases.
50+ # And NVCC JIT speed is also 9x faster in the ref commit
51+ os .environ ["DG_JIT_USE_NVRTC" ] = os .getenv ("SGL_DG_USE_NVRTC" , "0" )
52+
5253
5354def update_deep_gemm_config (gpu_id : int , server_args : ServerArgs ):
5455 global _BUILTIN_M_LIST
@@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
130131 num_groups : int ,
131132 config : Tuple [int , int , int , int , Tuple [int , bool ], Tuple [int , int , int ]],
132133) -> None :
133- # Auto-tuning with compilation
134- global deep_gemm_includes , deep_gemm_grouped_gemm_template
135- _ , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
136- _ = jit_tuner .compile_and_tune (
134+ num_sms , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
135+ block_k = 128
136+ num_tma_threads = 128
137+ num_math_threads_per_group = 128
138+ kwargs = {
139+ "NUM_TMA_THREADS" : num_tma_threads ,
140+ "NUM_MATH_THREADS_PER_GROUP" : num_math_threads_per_group ,
141+ "BLOCK_K" : block_k ,
142+ "NUM_SMS" : num_sms ,
143+ "SMEM_SIZE" : smem_config [0 ],
144+ }
145+ _ , _ = jit_tuner .compile_and_tune (
137146 name = "m_grouped_gemm_fp8_fp8_bf16_nt" ,
138147 keys = {
139148 "N" : n ,
@@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
146155 "NUM_STAGES" : num_stages ,
147156 "NUM_TMA_MULTICAST" : tma_multicast_config [0 ],
148157 "IS_TMA_MULTICAST_ON_A" : tma_multicast_config [1 ],
149- "GEMM_TYPE" : " GroupedMasked" ,
158+ "GEMM_TYPE" : GemmType . GroupedMasked ,
150159 },
151160 space = (),
152- includes = deep_gemm_includes ,
153- arg_defs = (
154- ("lhs" , torch .float8_e4m3fn ),
155- ("lhs_scales" , torch .float ),
156- ("rhs" , torch .float8_e4m3fn ),
157- ("rhs_scales" , torch .float ),
158- ("out" , torch .bfloat16 ),
159- ("grouped_layout" , torch .int32 ),
160- ("m" , int ),
161- ("stream" , torch .cuda .Stream ),
162- ("num_sms" , int ),
163- ("smem_size" , int ),
164- ),
165- template = deep_gemm_grouped_gemm_template ,
166- args = [],
161+ kwargs = kwargs ,
162+ runtime_cls = FP8GemmRuntime ,
167163 )
168164
169165
@@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
173169 num_groups : int ,
174170 config : Tuple [int , int , int , int , Tuple [int , bool ], Tuple [int , int , int ]],
175171) -> None :
176- global deep_gemm_includes , deep_gemm_grouped_gemm_template
177- _ , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
178- _ = jit_tuner .compile_and_tune (
172+ num_sms , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
173+ block_k = 128
174+ num_tma_threads = 128
175+ num_math_threads_per_group = 128
176+ kwargs = {
177+ "NUM_TMA_THREADS" : num_tma_threads ,
178+ "NUM_MATH_THREADS_PER_GROUP" : num_math_threads_per_group ,
179+ "BLOCK_K" : block_k ,
180+ "NUM_SMS" : num_sms ,
181+ "SMEM_SIZE" : smem_config [0 ],
182+ }
183+ _ , _ = jit_tuner .compile_and_tune (
179184 name = "m_grouped_gemm_fp8_fp8_bf16_nt" ,
180185 keys = {
181186 "N" : n ,
@@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
188193 "NUM_STAGES" : num_stages ,
189194 "NUM_TMA_MULTICAST" : tma_multicast_config [0 ],
190195 "IS_TMA_MULTICAST_ON_A" : tma_multicast_config [1 ],
191- "GEMM_TYPE" : " GroupedContiguous" ,
196+ "GEMM_TYPE" : GemmType . GroupedContiguous ,
192197 },
193198 space = (),
194- includes = deep_gemm_includes ,
195- arg_defs = (
196- ("lhs" , torch .float8_e4m3fn ),
197- ("lhs_scales" , torch .float ),
198- ("rhs" , torch .float8_e4m3fn ),
199- ("rhs_scales" , torch .float ),
200- ("out" , torch .bfloat16 ),
201- ("grouped_layout" , torch .int32 ),
202- ("m" , int ),
203- ("num_groups" , int ),
204- ("stream" , torch .cuda .Stream ),
205- ("num_sms" , int ),
206- ("smem_size" , int ),
207- ),
208- template = deep_gemm_grouped_gemm_template ,
209- args = [],
199+ kwargs = kwargs ,
200+ runtime_cls = FP8GemmRuntime ,
210201 )
211202
212203
@@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one(
216207 _ : int , # _ is a dummy parameter to align with other interfaces
217208 config : Tuple [int , int , int , int , Tuple [int , bool ], Tuple [int , int , int ]],
218209) -> None :
219- global deep_gemm_includes , deep_gemm_gemm_template
220- _ , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
221- _ = jit_tuner .compile_and_tune (
210+ num_sms , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
211+ block_k = 128
212+ num_tma_threads = 128
213+ num_math_threads_per_group = 128
214+ kwargs = {
215+ "GEMM_TYPE" : GemmType .Normal ,
216+ "NUM_TMA_THREADS" : num_tma_threads ,
217+ "NUM_MATH_THREADS_PER_GROUP" : num_math_threads_per_group ,
218+ "NUM_GROUPS" : 1 ,
219+ "BLOCK_K" : block_k ,
220+ "NUM_SMS" : num_sms ,
221+ "SMEM_SIZE" : smem_config [0 ],
222+ }
223+ _ , _ = jit_tuner .compile_and_tune (
222224 name = "gemm_fp8_fp8_bf16_nt" ,
223225 keys = {
224226 "N" : n ,
@@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one(
232234 "IS_TMA_MULTICAST_ON_A" : tma_multicast_config [1 ],
233235 },
234236 space = (),
235- includes = deep_gemm_includes ,
236- arg_defs = (
237- ("lhs" , torch .float8_e4m3fn ),
238- ("lhs_scales" , torch .float ),
239- ("rhs" , torch .float8_e4m3fn ),
240- ("rhs_scales" , torch .float ),
241- ("out" , torch .bfloat16 ),
242- ("m" , int ),
243- ("stream" , torch .cuda .Stream ),
244- ("num_sms" , int ),
245- ("smem_size" , int ),
246- ),
247- template = deep_gemm_gemm_template ,
248- args = [],
237+ kwargs = kwargs ,
238+ runtime_cls = FP8GemmRuntime ,
249239 )
250240
251241
@@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
373363
374364 from deep_gemm .jit .runtime import RuntimeCache
375365
376- origin_func = RuntimeCache .__getitem__
366+ origin_func = RuntimeCache .get
377367
378368 def __patched_func (self , * args , ** kwargs ):
379369 ret = origin_func (self , * args , ** kwargs )
@@ -385,6 +375,6 @@ def __patched_func(self, *args, **kwargs):
385375 )
386376 return ret
387377
388- RuntimeCache .__getitem__ = __patched_func
378+ RuntimeCache .get = __patched_func
389379 yield
390- RuntimeCache .__getitem__ = origin_func
380+ RuntimeCache .get = origin_func
0 commit comments