Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions python/sglang/srt/custom_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from typing import Optional

import torch
from torch import nn

from sglang.srt.utils import is_cuda, is_hip
Expand All @@ -14,6 +11,26 @@ def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()

def enter_torch_compile(self, num_tokens: int):
# NOTE: Temporarily workaround MoE
if "FusedMoE" in self.__class__.__name__:
if num_tokens == 1:
from sglang.srt.layers.moe.fused_moe_native import (
fused_moe_forward_native,
)

# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
self._forward_method = fused_moe_forward_native
else:
self._forward_method = self.forward_native
self.is_torch_compile = True

def leave_torch_compile(self):
self._forward_method = self.forward_cuda
self.is_torch_compile = False

# Please do not override this method, because `self._forward_method` can change when in torch compile mode
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)

Expand Down
14 changes: 2 additions & 12 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
Expand Down Expand Up @@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
if reverse:
sub._forward_method = sub.forward_cuda
setattr(sub, "is_torch_compile", False)
sub.leave_torch_compile()
else:
# NOTE: Temporarily workaround MoE
if "FusedMoE" in sub.__class__.__name__:
if num_tokens == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
sub._forward_method = fused_moe_forward_native
else:
sub._forward_method = sub.forward_native
setattr(sub, "is_torch_compile", True)
sub.enter_torch_compile(num_tokens=num_tokens)
if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse, num_tokens)

Expand Down
Loading