@@ -31,18 +31,35 @@ def benchmark_config(
3131 topk : int ,
3232 dtype : torch .dtype ,
3333 use_fp8 : bool ,
34+ use_int8 : bool ,
3435 num_iters : int = 100 ,
3536) -> float :
3637 init_dtype = torch .float16 if use_fp8 else dtype
3738 x = torch .randn (num_tokens , hidden_size , dtype = dtype )
38- w1 = torch .randn (num_experts ,
39- shard_intermediate_size ,
40- hidden_size ,
41- dtype = init_dtype )
42- w2 = torch .randn (num_experts ,
43- hidden_size ,
44- shard_intermediate_size // 2 ,
45- dtype = init_dtype )
39+ if use_int8 :
40+ w1 = torch .randint (- 127 ,
41+ 127 , (
42+ num_experts ,
43+ shard_intermediate_size ,
44+ hidden_size ,
45+ ),
46+ dtype = torch .int8 )
47+ w2 = torch .randint (- 127 ,
48+ 127 , (
49+ num_experts ,
50+ hidden_size ,
51+ shard_intermediate_size // 2 ,
52+ ),
53+ dtype = torch .int8 )
54+ else :
55+ w1 = torch .randn (num_experts ,
56+ shard_intermediate_size ,
57+ hidden_size ,
58+ dtype = init_dtype )
59+ w2 = torch .randn (num_experts ,
60+ hidden_size ,
61+ shard_intermediate_size // 2 ,
62+ dtype = init_dtype )
4663 gating_output = torch .randn (num_iters ,
4764 num_tokens ,
4865 num_experts ,
@@ -52,6 +69,10 @@ def benchmark_config(
5269 w2_scale = None
5370 a1_scale = None
5471 a2_scale = None
72+ if use_int8 :
73+ w1_scale = torch .randn ((num_experts , 2 * shard_intermediate_size ),
74+ dtype = torch .float32 )
75+ w2_scale = torch .randn ((hidden_size , num_experts ), dtype = torch .float32 )
5576 if use_fp8 :
5677 w1_scale = torch .randn (num_experts , dtype = torch .float32 )
5778 w2_scale = torch .randn (num_experts , dtype = torch .float32 )
@@ -77,6 +98,7 @@ def run():
7798 inplace = True ,
7899 override_config = config ,
79100 use_fp8 = use_fp8 ,
101+ use_int8 = use_int8 ,
80102 w1_scale = w1_scale ,
81103 w2_scale = w2_scale ,
82104 a1_scale = a1_scale ,
@@ -156,10 +178,11 @@ def benchmark(
156178 topk : int ,
157179 dtype : torch .dtype ,
158180 use_fp8 : bool ,
181+ use_int8 : bool ,
159182 ) -> Tuple [Dict [str , int ], float ]:
160183 torch .cuda .manual_seed_all (self .seed )
161184
162- dtype_str = "float8" if use_fp8 else None
185+ dtype_str = "float8" if use_fp8 else ( "int8" if use_int8 else None )
163186 # NOTE(woosuk): The current naming convention uses w2.shape[2], which
164187 # is the intermediate size after silu_and_mul.
165188 op_config = get_moe_configs (num_experts , shard_intermediate_size // 2 ,
@@ -173,7 +196,7 @@ def benchmark(
173196 key = lambda x : abs (x - num_tokens ))]
174197 kernel_time = benchmark_config (config , num_tokens , num_experts ,
175198 shard_intermediate_size , hidden_size ,
176- topk , dtype , use_fp8 )
199+ topk , dtype , use_fp8 , use_int8 )
177200 return config , kernel_time
178201
179202 def tune (
@@ -185,8 +208,9 @@ def tune(
185208 topk : int ,
186209 dtype : torch .dtype ,
187210 use_fp8 : bool ,
188- search_space : List [BenchmarkConfig ],
189- ) -> BenchmarkConfig :
211+ use_int8 : bool ,
212+ search_space : List [Dict [str , int ]],
213+ ) -> Dict [str , int ]:
190214 best_config = None
191215 best_time = float ("inf" )
192216 for config in tqdm (search_space ):
@@ -199,6 +223,7 @@ def tune(
199223 topk ,
200224 dtype ,
201225 use_fp8 ,
226+ use_int8 ,
202227 num_iters = 10 )
203228 except triton .runtime .autotuner .OutOfResources :
204229 # Some configurations may be invalid and fail to compile.
@@ -224,20 +249,15 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
224249 }
225250
226251
227- def save_configs (
228- configs : Dict [int , BenchmarkConfig ],
229- num_experts : int ,
230- shard_intermediate_size : int ,
231- hidden_size : int ,
232- topk : int ,
233- dtype : torch .dtype ,
234- use_fp8 : bool ,
235- ) -> None :
236- dtype_str = "float8" if use_fp8 else None
252+ def save_configs (configs : Dict [int , BenchmarkConfig ], num_experts : int ,
253+ shard_intermediate_size : int , hidden_size : int , topk : int ,
254+ dtype : torch .dtype , use_fp8 : bool , use_int8 : bool ) -> None :
255+ dtype_str = "float8" if use_fp8 else "int8" if use_int8 else None
237256 # NOTE(woosuk): The current naming convention uses w2.shape[2], which
238257 # is the intermediate size after silu_and_mul.
239258 filename = get_config_file_name (num_experts , shard_intermediate_size // 2 ,
240259 dtype_str )
260+
241261 print (f"Writing best config to { filename } ..." )
242262 with open (filename , "w" ) as f :
243263 json .dump (configs , f , indent = 4 )
@@ -253,6 +273,11 @@ def main(args: argparse.Namespace):
253273 topk = config .ffn_config .moe_top_k
254274 intermediate_size = config .ffn_config .ffn_hidden_size
255275 shard_intermediate_size = 2 * intermediate_size // args .tp_size
276+ elif config .architectures [0 ] == "JambaForCausalLM" :
277+ E = config .num_experts
278+ topk = config .num_experts_per_tok
279+ intermediate_size = config .intermediate_size
280+ shard_intermediate_size = 2 * intermediate_size // args .tp_size
256281 else :
257282 # Default: Mixtral.
258283 E = config .num_local_experts
@@ -263,6 +288,7 @@ def main(args: argparse.Namespace):
263288 hidden_size = config .hidden_size
264289 dtype = config .torch_dtype
265290 use_fp8 = args .dtype == "fp8"
291+ use_int8 = args .dtype == "int8"
266292
267293 if args .batch_size is None :
268294 batch_sizes = [
@@ -294,20 +320,20 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
294320 start = time .time ()
295321 configs = _distribute (
296322 "tune" , [(batch_size , E , shard_intermediate_size , hidden_size ,
297- topk , dtype , use_fp8 , search_space )
323+ topk , dtype , use_fp8 , use_int8 , search_space )
298324 for batch_size in batch_sizes ])
299325 best_configs = {
300326 M : sort_config (config )
301327 for M , config in zip (batch_sizes , configs )
302328 }
303329 save_configs (best_configs , E , shard_intermediate_size , hidden_size ,
304- topk , dtype , use_fp8 )
330+ topk , dtype , use_fp8 , use_int8 )
305331 end = time .time ()
306332 print (f"Tuning took { end - start :.2f} seconds" )
307333 else :
308334 outputs = _distribute ("benchmark" ,
309335 [(batch_size , E , shard_intermediate_size ,
310- hidden_size , topk , dtype , use_fp8 )
336+ hidden_size , topk , dtype , use_fp8 , use_int8 )
311337 for batch_size in batch_sizes ])
312338
313339 for batch_size , (config , kernel_time ) in zip (batch_sizes , outputs ):
@@ -323,7 +349,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
323349 parser .add_argument ("--tp-size" , "-tp" , type = int , default = 2 )
324350 parser .add_argument ("--dtype" ,
325351 type = str ,
326- choices = ["auto" , "fp8" ],
352+ choices = ["auto" , "fp8" , "int8" ],
327353 default = "auto" )
328354 parser .add_argument ("--seed" , type = int , default = 0 )
329355 parser .add_argument ("--batch-size" , type = int , required = False )
0 commit comments