diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 77ef0bedb4..08614c09a3 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -52,6 +52,7 @@ from .jit.gemm import ( gen_gemm_module, gen_gemm_sm90_module, + gen_fp8_blockscale_gemm_sm90_module, gen_gemm_sm100_module, gen_gemm_sm100_module_cutlass_fp4, gen_gemm_sm100_module_cutlass_fp8, @@ -476,6 +477,8 @@ def gen_all_modules( jit_specs.append(gen_gemm_module()) if has_sm90: jit_specs.append(gen_gemm_sm90_module()) + # fp8 blockscale GEMM (SM90) + jit_specs.append(gen_fp8_blockscale_gemm_sm90_module()) jit_specs.append(gen_fp4_quantization_sm90_module()) jit_specs.append(gen_cutlass_fused_moe_sm90_module()) if has_sm100: