@@ -30,19 +30,36 @@ def benchmark_config(
3030 hidden_size : int ,
3131 topk : int ,
3232 dtype : torch .dtype ,
33- use_fp8 : bool ,
33+ use_fp8_w8a8 : bool ,
34+ use_int8_w8a16 : bool ,
3435 num_iters : int = 100 ,
3536) -> float :
36- init_dtype = torch .float16 if use_fp8 else dtype
37+ init_dtype = torch .float16 if use_fp8_w8a8 else dtype
3738 x = torch .randn (num_tokens , hidden_size , dtype = dtype )
38- w1 = torch .randn (num_experts ,
39- shard_intermediate_size ,
40- hidden_size ,
41- dtype = init_dtype )
42- w2 = torch .randn (num_experts ,
43- hidden_size ,
44- shard_intermediate_size // 2 ,
45- dtype = init_dtype )
39+ if use_int8_w8a16 :
40+ w1 = torch .randint (- 127 ,
41+ 127 , (
42+ num_experts ,
43+ shard_intermediate_size ,
44+ hidden_size ,
45+ ),
46+ dtype = torch .int8 )
47+ w2 = torch .randint (- 127 ,
48+ 127 , (
49+ num_experts ,
50+ hidden_size ,
51+ shard_intermediate_size // 2 ,
52+ ),
53+ dtype = torch .int8 )
54+ else :
55+ w1 = torch .randn (num_experts ,
56+ shard_intermediate_size ,
57+ hidden_size ,
58+ dtype = init_dtype )
59+ w2 = torch .randn (num_experts ,
60+ hidden_size ,
61+ shard_intermediate_size // 2 ,
62+ dtype = init_dtype )
4663 gating_output = torch .randn (num_iters ,
4764 num_tokens ,
4865 num_experts ,
@@ -52,7 +69,11 @@ def benchmark_config(
5269 w2_scale = None
5370 a1_scale = None
5471 a2_scale = None
55- if use_fp8 :
72+ if use_int8_w8a16 :
73+ w1_scale = torch .randn ((num_experts , 2 * shard_intermediate_size ),
74+ dtype = torch .float32 )
75+ w2_scale = torch .randn ((hidden_size , num_experts ), dtype = torch .float32 )
76+ if use_fp8_w8a8 :
5677 w1_scale = torch .randn (num_experts , dtype = torch .float32 )
5778 w2_scale = torch .randn (num_experts , dtype = torch .float32 )
5879 a1_scale = torch .randn (1 , dtype = torch .float32 )
@@ -76,7 +97,8 @@ def run():
7697 renormalize = True ,
7798 inplace = True ,
7899 override_config = config ,
79- use_fp8 = use_fp8 ,
100+ use_fp8_w8a8 = use_fp8_w8a8 ,
101+ use_int8_w8a16 = use_int8_w8a16 ,
80102 w1_scale = w1_scale ,
81103 w2_scale = w2_scale ,
82104 a1_scale = a1_scale ,
@@ -155,11 +177,13 @@ def benchmark(
155177 hidden_size : int ,
156178 topk : int ,
157179 dtype : torch .dtype ,
158- use_fp8 : bool ,
180+ use_fp8_w8a8 : bool ,
181+ use_int8_w8a16 : bool ,
159182 ) -> Tuple [Dict [str , int ], float ]:
160183 torch .cuda .manual_seed_all (self .seed )
161-
162- dtype_str = "float8" if use_fp8 else None
184+ dtype_str = get_config_dtype_str (dtype ,
185+ use_int8_w8a16 = use_int8_w8a16 ,
186+ use_fp8_w8a8 = use_fp8_w8a8 )
163187 # NOTE(woosuk): The current naming convention uses w2.shape[2], which
164188 # is the intermediate size after silu_and_mul.
165189 op_config = get_moe_configs (num_experts , shard_intermediate_size // 2 ,
@@ -173,7 +197,8 @@ def benchmark(
173197 key = lambda x : abs (x - num_tokens ))]
174198 kernel_time = benchmark_config (config , num_tokens , num_experts ,
175199 shard_intermediate_size , hidden_size ,
176- topk , dtype , use_fp8 )
200+ topk , dtype , use_fp8_w8a8 ,
201+ use_int8_w8a16 )
177202 return config , kernel_time
178203
179204 def tune (
@@ -184,9 +209,10 @@ def tune(
184209 hidden_size : int ,
185210 topk : int ,
186211 dtype : torch .dtype ,
187- use_fp8 : bool ,
188- search_space : List [BenchmarkConfig ],
189- ) -> BenchmarkConfig :
212+ use_fp8_w8a8 : bool ,
213+ use_int8_w8a16 : bool ,
214+ search_space : List [Dict [str , int ]],
215+ ) -> Dict [str , int ]:
190216 best_config = None
191217 best_time = float ("inf" )
192218 for config in tqdm (search_space ):
@@ -198,7 +224,8 @@ def tune(
198224 hidden_size ,
199225 topk ,
200226 dtype ,
201- use_fp8 ,
227+ use_fp8_w8a8 ,
228+ use_int8_w8a16 ,
202229 num_iters = 10 )
203230 except triton .runtime .autotuner .OutOfResources :
204231 # Some configurations may be invalid and fail to compile.
@@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
224251 }
225252
226253
227- def save_configs (
228- configs : Dict [int , BenchmarkConfig ],
229- num_experts : int ,
230- shard_intermediate_size : int ,
231- hidden_size : int ,
232- topk : int ,
233- dtype : torch .dtype ,
234- use_fp8 : bool ,
235- ) -> None :
236- dtype_str = "float8" if use_fp8 else None
254+ def save_configs (configs : Dict [int , BenchmarkConfig ], num_experts : int ,
255+ shard_intermediate_size : int , hidden_size : int , topk : int ,
256+ dtype : torch .dtype , use_fp8_w8a8 : bool ,
257+ use_int8_w8a16 : bool ) -> None :
258+ dtype_str = get_config_dtype_str (dtype ,
259+ use_int8_w8a16 = use_int8_w8a16 ,
260+ use_fp8_w8a8 = use_fp8_w8a8 )
261+
237262 # NOTE(woosuk): The current naming convention uses w2.shape[2], which
238263 # is the intermediate size after silu_and_mul.
239264 filename = get_config_file_name (num_experts , shard_intermediate_size // 2 ,
240265 dtype_str )
266+
241267 print (f"Writing best config to { filename } ..." )
242268 with open (filename , "w" ) as f :
243269 json .dump (configs , f , indent = 4 )
@@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
253279 topk = config .ffn_config .moe_top_k
254280 intermediate_size = config .ffn_config .ffn_hidden_size
255281 shard_intermediate_size = 2 * intermediate_size // args .tp_size
282+ elif config .architectures [0 ] == "JambaForCausalLM" :
283+ E = config .num_experts
284+ topk = config .num_experts_per_tok
285+ intermediate_size = config .intermediate_size
286+ shard_intermediate_size = 2 * intermediate_size // args .tp_size
256287 else :
257288 # Default: Mixtral.
258289 E = config .num_local_experts
@@ -262,7 +293,8 @@ def main(args: argparse.Namespace):
262293
263294 hidden_size = config .hidden_size
264295 dtype = config .torch_dtype
265- use_fp8 = args .dtype == "fp8"
296+ use_fp8_w8a8 = args .dtype == "fp8_w8a8"
297+ use_int8_w8a16 = args .dtype == "int8_w8a16"
266298
267299 if args .batch_size is None :
268300 batch_sizes = [
@@ -294,21 +326,21 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
294326 start = time .time ()
295327 configs = _distribute (
296328 "tune" , [(batch_size , E , shard_intermediate_size , hidden_size ,
297- topk , dtype , use_fp8 , search_space )
329+ topk , dtype , use_fp8_w8a8 , use_int8_w8a16 , search_space )
298330 for batch_size in batch_sizes ])
299331 best_configs = {
300332 M : sort_config (config )
301333 for M , config in zip (batch_sizes , configs )
302334 }
303335 save_configs (best_configs , E , shard_intermediate_size , hidden_size ,
304- topk , dtype , use_fp8 )
336+ topk , dtype , use_fp8_w8a8 , use_int8_w8a16 )
305337 end = time .time ()
306338 print (f"Tuning took { end - start :.2f} seconds" )
307339 else :
308- outputs = _distribute ("benchmark" ,
309- [(batch_size , E , shard_intermediate_size ,
310- hidden_size , topk , dtype , use_fp8 )
311- for batch_size in batch_sizes ])
340+ outputs = _distribute (
341+ "benchmark" , [(batch_size , E , shard_intermediate_size , hidden_size ,
342+ topk , dtype , use_fp8_w8a8 , use_int8_w8a16 )
343+ for batch_size in batch_sizes ])
312344
313345 for batch_size , (config , kernel_time ) in zip (batch_sizes , outputs ):
314346 print (f"Batch size: { batch_size } , config: { config } " )
@@ -323,7 +355,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
323355 parser .add_argument ("--tp-size" , "-tp" , type = int , default = 2 )
324356 parser .add_argument ("--dtype" ,
325357 type = str ,
326- choices = ["auto" , "fp8 " ],
358+ choices = ["auto" , "fp8_w8a8" , "int8_w8a16 " ],
327359 default = "auto" )
328360 parser .add_argument ("--seed" , type = int , default = 0 )
329361 parser .add_argument ("--batch-size" , type = int , required = False )
0 commit comments