Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/compile_deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
multiprocessing.set_start_method("spawn", force=True)

# Reduce warning
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
# Force enable deep gemm
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
Expand Down
25 changes: 16 additions & 9 deletions python/sglang/srt/layers/quantization/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
)
_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_DO_COMPILE_ALL = True
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")

# Force redirect deep_gemm cache_dir
os.environ["DG_CACHE_DIR"] = os.getenv(
Expand All @@ -46,7 +47,8 @@

def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
global _BUILTIN_M_LIST
global _DO_COMPILE
global _DO_COMPILE_ALL
global _IS_FIRST_RANK_ON_NODE

# Generate m_max
m_max = 1024 * 16
Expand All @@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
m_max = min(1024 * 128, m_max)
_BUILTIN_M_LIST = list(range(1, m_max + 1))

# Check if is the first rank on node
_DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
_IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id

# Check if is the first rank on node.
# Default each rank will try compile all Ms to
# load all symbols at the launch stages.
# Avoid loading symbols at the serving stages.
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE


class DeepGemmKernelType(IntEnum):
Expand Down Expand Up @@ -89,7 +96,7 @@ class DeepGemmKernelHelper:


def _compile_warning_1():
if not _IN_PRE_COMPILE_STAGE:
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
logger.warning(
"Entering DeepGEMM JIT Pre-Complie session. "
"And it may takes a long time(Typically 10-20 mins) "
Expand Down Expand Up @@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
query_key = (kernel_type, n, k, num_groups)
if (
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
and _DO_COMPILE
and _DO_COMPILE_ALL
and _INITIALIZATION_DICT.get(query_key) is None
):
_INITIALIZATION_DICT[query_key] = True
Expand All @@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
logger.info(
f"Try DeepGEMM JIT Compiling for "
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
)

# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
Expand Down Expand Up @@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(

@contextmanager
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
if _IN_PRE_COMPILE_STAGE:
if _IN_PRECOMPILE_STAGE:
yield
return

Expand Down
Loading