Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6b834a3
Add experts int8 config
mzusman Aug 11, 2024
afddd3b
Add support in fusedmoe
mzusman Aug 11, 2024
289367a
Add experts int8 to quantization list
mzusman Aug 11, 2024
084405e
Remove logger
mzusman Aug 12, 2024
0c690fe
Add to optimized quantization
mzusman Aug 12, 2024
3100490
Format
mzusman Aug 12, 2024
413400c
Add startup test for experts_int8
mzusman Aug 12, 2024
9e7bc79
Typo
mzusman Aug 12, 2024
1ebb5d7
Add test
mzusman Aug 12, 2024
44a72d6
Change compute capabiltiy to 80
mzusman Aug 12, 2024
39660ca
Format
mzusman Aug 12, 2024
a097b6e
Disable for CPU
mzusman Aug 12, 2024
c12635c
Add use_int8 to the moe benchmarks
mzusman Aug 15, 2024
9436034
Use JambaMoE to implement MLP
mzusman Aug 15, 2024
4b712e4
Use MoE to implement MLP
mzusman Aug 15, 2024
3b6967e
Format
mzusman Aug 15, 2024
5f5b11e
Fix
mzusman Aug 15, 2024
e199b17
Move experts_int8 to quantizatiob subdir and add is quant method
mzusman Aug 15, 2024
9c47ad0
Split if else in benchmark moe
mzusman Aug 15, 2024
97f0585
Rename use_int8 to use_int8_w8a16, use_fp8 to use_fp_w8a8
mzusman Aug 15, 2024
0025459
Reverse order
mzusman Aug 15, 2024
a1d75cb
Change dtype in configs filename
mzusman Aug 15, 2024
505e3d3
Single function to get dtype config name
mzusman Aug 15, 2024
80d977c
Align experts int8 apply with fp8
mzusman Aug 15, 2024
1c403be
Align with upstream
mzusman Aug 15, 2024
744ecd4
Format
mzusman Aug 15, 2024
a5bf0b3
Change fp8 to fp8_w8a8
mzusman Aug 15, 2024
1c7e689
Correct the args
mzusman Aug 15, 2024
e438b84
Remove experts int8 from ignore cpu
mzusman Aug 15, 2024
c23a2f4
Fix typo
mzusman Aug 15, 2024
7e619c7
Fix Jamba tests since MLP layer is not aligned with HF
mzusman Aug 15, 2024
70a6598
Merge remote-tracking branch 'github/main' into expert_int8_upstream
mzusman Aug 15, 2024
4d6c546
Merge remote-tracking branch 'github/main' into expert_int8_upstream
mzusman Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 70 additions & 38 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,36 @@ def benchmark_config(
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8 else dtype
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
if use_int8_w8a16:
w1 = torch.randint(-127,
127, (
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8)
w2 = torch.randint(-127,
127, (
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8)
else:
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
gating_output = torch.randn(num_iters,
num_tokens,
num_experts,
Expand All @@ -52,7 +69,11 @@ def benchmark_config(
w2_scale = None
a1_scale = None
a2_scale = None
if use_fp8:
if use_int8_w8a16:
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
dtype=torch.float32)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
Expand All @@ -76,7 +97,8 @@ def run():
renormalize=True,
inplace=True,
override_config=config,
use_fp8=use_fp8,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
Expand Down Expand Up @@ -155,11 +177,13 @@ def benchmark(
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(self.seed)

dtype_str = "float8" if use_fp8 else None
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
Expand All @@ -173,7 +197,8 @@ def benchmark(
key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(config, num_tokens, num_experts,
shard_intermediate_size, hidden_size,
topk, dtype, use_fp8)
topk, dtype, use_fp8_w8a8,
use_int8_w8a16)
return config, kernel_time

def tune(
Expand All @@ -184,9 +209,10 @@ def tune(
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
search_space: List[BenchmarkConfig],
) -> BenchmarkConfig:
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
Expand All @@ -198,7 +224,8 @@ def tune(
hidden_size,
topk,
dtype,
use_fp8,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=10)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
Expand All @@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
}


def save_configs(
configs: Dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
) -> None:
dtype_str = "float8" if use_fp8 else None
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None:
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)

# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str)

print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
Expand All @@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral.
E = config.num_local_experts
Expand All @@ -262,7 +293,8 @@ def main(args: argparse.Namespace):

hidden_size = config.hidden_size
dtype = config.torch_dtype
use_fp8 = args.dtype == "fp8"
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"

if args.batch_size is None:
batch_sizes = [
Expand Down Expand Up @@ -294,21 +326,21 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8, search_space)
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
}
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8)
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute("benchmark",
[(batch_size, E, shard_intermediate_size,
hidden_size, topk, dtype, use_fp8)
for batch_size in batch_sizes])
outputs = _distribute(
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
for batch_size in batch_sizes])

for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
Expand All @@ -323,7 +355,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
parser.add_argument("--tp-size", "-tp", type=int, default=2)
parser.add_argument("--dtype",
type=str,
choices=["auto", "fp8"],
choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
Expand Down
13 changes: 7 additions & 6 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
MODELS = ["ai21labs/Jamba-tiny-random"]


# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
# TODO: Fix this with trained model
@pytest.mark.skip()
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -17,8 +20,6 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
Expand All @@ -36,8 +37,8 @@ def test_models(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
def test_batching(
vllm_runner,
example_prompts,
Expand Down
28 changes: 28 additions & 0 deletions tests/quantization/test_experts_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# flake8: noqa
"""Tests experts_int8 quantization startup and generation,
doesn't test correctness
"""
import pytest

from tests.quantization.utils import is_quant_method_supported

MODELS = ["ai21labs/Jamba-tiny-random"]


@pytest.mark.skipif(not is_quant_method_supported("experts_int8"),
reason="ExpertsInt8 is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_model_experts_int8_startup(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:

with vllm_runner(model, dtype=dtype,
quantization="experts_int8") as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ def _verify_quantization(self) -> None:
rocm_supported_quantization = ["gptq", "squeezellm"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
"experts_int8"
]
tpu_supported_quantization = ["tpu_int8"]
if self.quantization is not None:
Expand Down
Loading