Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,22 @@
user_end,
video,
)
from sglang.global_config import global_config
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.choices import (
greedy_token_selection,
token_length_normalized,
unconditional_likelihood_normalized,
)
from sglang.utils import LazyImport
from sglang.version import __version__

ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")

# Other configs
from sglang.global_config import global_config
from sglang.version import __version__

__all__ = [
"Engine",
"Runtime",
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,6 @@ def sample_random_requests(

# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
print(
"If you do not want to randomly sample from a dataset,"
" please use --dataset-name random-ids."
)
dataset_path = download_and_cache_file(SHAREGPT_URL)

# Load the dataset.
Expand Down
Empty file removed python/sglang/lang/__init__.py
Empty file.
4 changes: 0 additions & 4 deletions python/sglang/lang/backend/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from typing import List, Optional, Union

import numpy as np

from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/lang/backend/base_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union

from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/lang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import time
import warnings
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union

import numpy as np

Expand Down
1 change: 0 additions & 1 deletion python/sglang/lang/backend/vertexai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import warnings
from typing import Optional

from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
Expand Down
8 changes: 1 addition & 7 deletions python/sglang/lang/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@

from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import (
SglArgument,
SglConstantText,
SglExpr,
SglSamplingParams,
SglVariable,
)
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable


def compile_func(function, backend):
Expand Down
10 changes: 3 additions & 7 deletions python/sglang/lang/tracer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
"""Tracing a program."""

import uuid
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFork,
SglFunction,
SglGen,
SglGetForkItem,
SglRoleBegin,
Expand Down Expand Up @@ -230,8 +226,8 @@ def _execute_role_end(self, expr: SglRoleEnd):
self.cur_role = None

def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(name, source=self.last_node)
self.variables[name] = new_node
new_node = SglVariable(expr.name, source=self.last_node)
self.variables[expr.name] = new_node

def get_var(self, name):
ret = self.arguments.get(name, None)
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import logging
import os
from typing import List, Tuple

import torch
import torch.library

from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu

Expand Down
62 changes: 0 additions & 62 deletions python/sglang/srt/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,65 +42,3 @@ def dispatch_forward(self):
return self.forward_hip
else:
return self.forward_native


if _is_cuda:
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8

def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.

Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor

Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization

Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)

if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static

return output, scale
3 changes: 1 addition & 2 deletions python/sglang/srt/entrypoints/verl_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
from PIL.Image import Image
from torch.distributed.tensor import DeviceMesh, DTensor

from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.server import Engine
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj


Expand Down
14 changes: 6 additions & 8 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,19 @@
import torch.nn as nn
import torch.nn.functional as F

from sglang.srt.utils import is_cuda_available

_is_cuda = is_cuda_available()

if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul

from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import is_cuda_available, set_weight_attrs

_is_cuda = is_cuda_available()

if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
import torch.nn as nn

from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available

_is_cuda = is_cuda_available()
Expand All @@ -31,7 +32,6 @@
rmsnorm,
)

from sglang.srt.custom_op import CustomOp

logger = logging.getLogger(__name__)

Expand Down
38 changes: 12 additions & 26 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, List, Optional, Tuple

import torch
from torch.nn import Module

try:
from deep_gemm import (
Expand All @@ -13,8 +14,6 @@
except ImportError:
use_deep_gemm = False

from torch.nn import Module

from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
Expand All @@ -37,22 +36,17 @@
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs

_is_cuda = is_cuda()
_is_hip = is_hip()

if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
if _is_hip:
from vllm._custom_ops import scaled_fp8_quant

logger = logging.getLogger(__name__)

_is_hip = is_hip()

_buffer = None


class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None
Expand Down Expand Up @@ -740,20 +734,12 @@ def process_weights_after_loading(self, layer: Module) -> None:
)

for expert in range(layer.num_experts_per_partition):
if _is_cuda:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
else:
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return
Expand Down
31 changes: 12 additions & 19 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import triton.language as tl

from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
Expand All @@ -22,28 +23,25 @@
)

_is_hip = is_hip()


logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0

enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)

_is_cuda = is_cuda()

if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul

from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant

if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size


logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)


@triton.jit
def write_zeros_to_output(
c_ptr,
Expand Down Expand Up @@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
# activation tensor-wise fp8 quantization, dynamic or static
padded_size = padding_size
# activations apply per-token quantization when weights apply per-channel quantization by default
if _is_cuda:
A, A_scale = sgl_scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else:
A, A_scale = vllm_ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
A, A_scale = scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_channel_quant
)
else:
# activation block-wise fp8 quantization
assert len(block_shape) == 2
Expand Down
Loading
Loading