Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
Fp8LinearMethod)
from vllm.platforms import current_platform

MODELS = [
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
Expand All @@ -20,7 +21,12 @@
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", MODELS)
def test_model_load_and_run(vllm_runner, model_id: str):
@pytest.mark.parametrize("force_marlin", [False, True])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
monkeypatch) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

with vllm_runner(model_id) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
Expand Down Expand Up @@ -61,7 +67,12 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None:
@pytest.mark.parametrize("force_marlin", [False, True])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
monkeypatch) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

with vllm_runner("facebook/opt-125m",
quantization="fp8",
kv_cache_dtype=kv_cache_dtype) as llm:
Expand All @@ -75,9 +86,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None:
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0

capability = torch.cuda.get_device_capability()
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability >= 89:
if capability >= 89 and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -341,6 +342,13 @@ def get_default_config_root():
lambda:
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
("1", "true")),

# If set, forces FP8 Marlin to be used for FP8 quantization regardless
# of the hardware support for FP8 compute.
"VLLM_TEST_FORCE_FP8_MARLIN":
lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),
}

# end-env-vars-definition
Expand Down
11 changes: 10 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.nn import Module
from torch.nn.parameter import Parameter

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
Expand Down Expand Up @@ -118,7 +119,7 @@ def __init__(self, quant_config: Fp8Config):
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN

def create_weights(
self,
Expand Down Expand Up @@ -174,6 +175,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)

# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)),
layer.logical_widths)

# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
Expand Down