File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2626 method_has_implemented_embedding ,
2727)
2828from sglang .srt .layers .quantization .unquant import UnquantizedEmbeddingMethod
29- from sglang .srt .utils import cpu_has_amx_support , is_cpu , set_weight_attrs
29+ from sglang .srt .utils import (
30+ cpu_has_amx_support ,
31+ get_compiler_backend ,
32+ is_cpu ,
33+ set_weight_attrs ,
34+ )
3035
3136DEFAULT_VOCAB_PADDING_SIZE = 64
3237
@@ -117,7 +122,7 @@ def __post_init__(self):
117122 assert self .num_added_elements <= self .num_added_elements_padded
118123
119124
120- @torch .jit . script
125+ @torch .compile ( dynamic = True , backend = get_compiler_backend ())
121126def get_masked_input_and_mask (
122127 input_ : torch .Tensor ,
123128 org_vocab_start_index : int ,
@@ -126,7 +131,7 @@ def get_masked_input_and_mask(
126131 added_vocab_start_index : int ,
127132 added_vocab_end_index : int ,
128133) -> Tuple [torch .Tensor , torch .Tensor ]:
129- # torch.jit.script will fuse all of the pointwise ops below
134+ # torch.compile will fuse all of the pointwise ops below
130135 # into a single kernel, making it very fast
131136 org_vocab_mask = (input_ >= org_vocab_start_index ) & (input_ < org_vocab_end_index )
132137 added_vocab_mask = (input_ >= added_vocab_start_index ) & (
You can’t perform that action at this time.
0 commit comments