Skip to content

Commit 6606718

Browse files
YyWangCSMahmoudAshraf97
authored andcommitted
Replace torch.jit.script with torch.compile in get_masked_input_and_mask to fix benchmark underreporting (sgl-project#8733)
1 parent b1cc296 commit 6606718

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

python/sglang/srt/layers/vocab_parallel_embedding.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
method_has_implemented_embedding,
2727
)
2828
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
29-
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
29+
from sglang.srt.utils import (
30+
cpu_has_amx_support,
31+
get_compiler_backend,
32+
is_cpu,
33+
set_weight_attrs,
34+
)
3035

3136
DEFAULT_VOCAB_PADDING_SIZE = 64
3237

@@ -117,7 +122,7 @@ def __post_init__(self):
117122
assert self.num_added_elements <= self.num_added_elements_padded
118123

119124

120-
@torch.jit.script
125+
@torch.compile(dynamic=True, backend=get_compiler_backend())
121126
def get_masked_input_and_mask(
122127
input_: torch.Tensor,
123128
org_vocab_start_index: int,
@@ -126,7 +131,7 @@ def get_masked_input_and_mask(
126131
added_vocab_start_index: int,
127132
added_vocab_end_index: int,
128133
) -> Tuple[torch.Tensor, torch.Tensor]:
129-
# torch.jit.script will fuse all of the pointwise ops below
134+
# torch.compile will fuse all of the pointwise ops below
130135
# into a single kernel, making it very fast
131136
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
132137
added_vocab_mask = (input_ >= added_vocab_start_index) & (

0 commit comments

Comments
 (0)