Skip to content

Commit c12635c

Browse files
committed
Add use_int8 to the moe benchmarks
1 parent a097b6e commit c12635c

File tree

1 file changed

+52
-26
lines changed

1 file changed

+52
-26
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,35 @@ def benchmark_config(
3131
topk: int,
3232
dtype: torch.dtype,
3333
use_fp8: bool,
34+
use_int8: bool,
3435
num_iters: int = 100,
3536
) -> float:
3637
init_dtype = torch.float16 if use_fp8 else dtype
3738
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
38-
w1 = torch.randn(num_experts,
39-
shard_intermediate_size,
40-
hidden_size,
41-
dtype=init_dtype)
42-
w2 = torch.randn(num_experts,
43-
hidden_size,
44-
shard_intermediate_size // 2,
45-
dtype=init_dtype)
39+
if use_int8:
40+
w1 = torch.randint(-127,
41+
127, (
42+
num_experts,
43+
shard_intermediate_size,
44+
hidden_size,
45+
),
46+
dtype=torch.int8)
47+
w2 = torch.randint(-127,
48+
127, (
49+
num_experts,
50+
hidden_size,
51+
shard_intermediate_size // 2,
52+
),
53+
dtype=torch.int8)
54+
else:
55+
w1 = torch.randn(num_experts,
56+
shard_intermediate_size,
57+
hidden_size,
58+
dtype=init_dtype)
59+
w2 = torch.randn(num_experts,
60+
hidden_size,
61+
shard_intermediate_size // 2,
62+
dtype=init_dtype)
4663
gating_output = torch.randn(num_iters,
4764
num_tokens,
4865
num_experts,
@@ -52,6 +69,10 @@ def benchmark_config(
5269
w2_scale = None
5370
a1_scale = None
5471
a2_scale = None
72+
if use_int8:
73+
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
74+
dtype=torch.float32)
75+
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
5576
if use_fp8:
5677
w1_scale = torch.randn(num_experts, dtype=torch.float32)
5778
w2_scale = torch.randn(num_experts, dtype=torch.float32)
@@ -77,6 +98,7 @@ def run():
7798
inplace=True,
7899
override_config=config,
79100
use_fp8=use_fp8,
101+
use_int8=use_int8,
80102
w1_scale=w1_scale,
81103
w2_scale=w2_scale,
82104
a1_scale=a1_scale,
@@ -156,10 +178,11 @@ def benchmark(
156178
topk: int,
157179
dtype: torch.dtype,
158180
use_fp8: bool,
181+
use_int8: bool,
159182
) -> Tuple[Dict[str, int], float]:
160183
torch.cuda.manual_seed_all(self.seed)
161184

162-
dtype_str = "float8" if use_fp8 else None
185+
dtype_str = "float8" if use_fp8 else ("int8" if use_int8 else None)
163186
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
164187
# is the intermediate size after silu_and_mul.
165188
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
@@ -173,7 +196,7 @@ def benchmark(
173196
key=lambda x: abs(x - num_tokens))]
174197
kernel_time = benchmark_config(config, num_tokens, num_experts,
175198
shard_intermediate_size, hidden_size,
176-
topk, dtype, use_fp8)
199+
topk, dtype, use_fp8, use_int8)
177200
return config, kernel_time
178201

179202
def tune(
@@ -185,8 +208,9 @@ def tune(
185208
topk: int,
186209
dtype: torch.dtype,
187210
use_fp8: bool,
188-
search_space: List[BenchmarkConfig],
189-
) -> BenchmarkConfig:
211+
use_int8: bool,
212+
search_space: List[Dict[str, int]],
213+
) -> Dict[str, int]:
190214
best_config = None
191215
best_time = float("inf")
192216
for config in tqdm(search_space):
@@ -199,6 +223,7 @@ def tune(
199223
topk,
200224
dtype,
201225
use_fp8,
226+
use_int8,
202227
num_iters=10)
203228
except triton.runtime.autotuner.OutOfResources:
204229
# Some configurations may be invalid and fail to compile.
@@ -224,20 +249,15 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
224249
}
225250

226251

227-
def save_configs(
228-
configs: Dict[int, BenchmarkConfig],
229-
num_experts: int,
230-
shard_intermediate_size: int,
231-
hidden_size: int,
232-
topk: int,
233-
dtype: torch.dtype,
234-
use_fp8: bool,
235-
) -> None:
236-
dtype_str = "float8" if use_fp8 else None
252+
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
253+
shard_intermediate_size: int, hidden_size: int, topk: int,
254+
dtype: torch.dtype, use_fp8: bool, use_int8: bool) -> None:
255+
dtype_str = "float8" if use_fp8 else "int8" if use_int8 else None
237256
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
238257
# is the intermediate size after silu_and_mul.
239258
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
240259
dtype_str)
260+
241261
print(f"Writing best config to {filename}...")
242262
with open(filename, "w") as f:
243263
json.dump(configs, f, indent=4)
@@ -253,6 +273,11 @@ def main(args: argparse.Namespace):
253273
topk = config.ffn_config.moe_top_k
254274
intermediate_size = config.ffn_config.ffn_hidden_size
255275
shard_intermediate_size = 2 * intermediate_size // args.tp_size
276+
elif config.architectures[0] == "JambaForCausalLM":
277+
E = config.num_experts
278+
topk = config.num_experts_per_tok
279+
intermediate_size = config.intermediate_size
280+
shard_intermediate_size = 2 * intermediate_size // args.tp_size
256281
else:
257282
# Default: Mixtral.
258283
E = config.num_local_experts
@@ -263,6 +288,7 @@ def main(args: argparse.Namespace):
263288
hidden_size = config.hidden_size
264289
dtype = config.torch_dtype
265290
use_fp8 = args.dtype == "fp8"
291+
use_int8 = args.dtype == "int8"
266292

267293
if args.batch_size is None:
268294
batch_sizes = [
@@ -294,20 +320,20 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
294320
start = time.time()
295321
configs = _distribute(
296322
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
297-
topk, dtype, use_fp8, search_space)
323+
topk, dtype, use_fp8, use_int8, search_space)
298324
for batch_size in batch_sizes])
299325
best_configs = {
300326
M: sort_config(config)
301327
for M, config in zip(batch_sizes, configs)
302328
}
303329
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
304-
topk, dtype, use_fp8)
330+
topk, dtype, use_fp8, use_int8)
305331
end = time.time()
306332
print(f"Tuning took {end - start:.2f} seconds")
307333
else:
308334
outputs = _distribute("benchmark",
309335
[(batch_size, E, shard_intermediate_size,
310-
hidden_size, topk, dtype, use_fp8)
336+
hidden_size, topk, dtype, use_fp8, use_int8)
311337
for batch_size in batch_sizes])
312338

313339
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
@@ -323,7 +349,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
323349
parser.add_argument("--tp-size", "-tp", type=int, default=2)
324350
parser.add_argument("--dtype",
325351
type=str,
326-
choices=["auto", "fp8"],
352+
choices=["auto", "fp8", "int8"],
327353
default="auto")
328354
parser.add_argument("--seed", type=int, default=0)
329355
parser.add_argument("--batch-size", type=int, required=False)

0 commit comments

Comments
 (0)