diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index 13afbc1e058e..261efa333995 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -232,3 +232,28 @@ python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \ --model_export hf_format \ --tasks gsm8k ``` + +## Using MXFP4 models + +vLLM supports loading MXFP4 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). + +The scheme currently only supports dynamic quantization for activations. + +Example usage, after installing the latest AMD Quark release: + +```bash +vllm serve fxmarty/qwen_1.5-moe-a2.7b-mxfp4 --tensor-parallel-size 1 +``` + +A simulation of the matrix multiplication execution in MXFP4 can be run on devices that do not support MXFP4 operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from MXFP4 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate MXFP4 models using vLLM, or alternatively to benefit from the ~4x memory savings (compared to float16 and bfloat16). + +To generate offline models quantized using MXFP4 data type, the easiest approach is to use AMD Quark's [quantization script](https://quark.docs.amd.com/latest/pytorch/example_quark_torch_llm_ptq.html), as an example: + +```bash +python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \ + --quant_scheme w_mxfp4_a_mxfp4_sym \ + --output_dir qwen_1.5-moe-a2.7b-mxfp4 \ + --skip_evaluation \ + --model_export hf_format \ + --group_size 32 +``` diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 96e3f29b3d79..0f1c78704642 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -174,6 +174,7 @@ def test_fused_moe( use_int8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False, + use_mxfp4_w4a4=False, per_act_token_quant=False, block_shape=None) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py new file mode 100644 index 000000000000..824b072a9f93 --- /dev/null +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib +import importlib.metadata +from dataclasses import dataclass + +import pytest +import torch +from packaging import version + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( + "quark") is not None and version.parse( + importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@pytest.mark.parametrize('model_case', [ + ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), + ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), + ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1) +]) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, + reason="amd-quark>=0.9 is not available") +def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): + if torch.cuda.device_count() < model_case.tp: + pytest.skip(f"This test requires >={model_case.tp} gpus, got only " + f"{torch.cuda.device_count()}") + + with vllm_runner(model_case.model_id, + tensor_parallel_size=model_case.tp, + load_format="dummy") as llm: + + # TODO: llm.apply_model(check_model) currently relies on V0 internals. + # Re-enable this later. + # def check_model(model): + # layer = model.model.layers[0] + + # qkv_proj = layer.self_attn.qkv_proj + + # assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) + # assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4) + + # assert isinstance(layer.mlp.experts.quant_method, + # QuarkW4A4MXFp4MoEMethod) + + # if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": + # llm.apply_model(check_model) + + output = llm.generate_greedy("Today I am in the French Alps and", + max_tokens=20) + assert output \ No newline at end of file diff --git a/tests/quantization/reference_mxfp4.py b/tests/quantization/reference_mxfp4.py new file mode 100644 index 000000000000..2ef251933f68 --- /dev/null +++ b/tests/quantization/reference_mxfp4.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +BFLOAT16_EXP_BIAS = 127 +BFLOAT16_MANTISSA_BITS = 7 +BFLOAT16_EXP_BITS = 8 + +FLOAT16_EXP_BIAS = 15 +FLOAT16_MANTISSA_BITS = 10 +FLOAT16_EXP_BITS = 5 + +FLOAT8_E8M0_MAX_EXP = 127 +FLOAT4_EXP_BIAS = 1 +FLOAT4_MANTISSA_BITS = 1 + +FLOAT16_VAL_TO_ADD = (1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) +FLOAT16_SIGN_EXPONENT_MASK = (( + (1 << (FLOAT16_EXP_BITS + 1)) - 1) << FLOAT16_MANTISSA_BITS) + +BFLOAT16_VAL_TO_ADD = (1 << + (BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) +BFLOAT16_SIGN_EXPONENT_MASK = (( + (1 << (BFLOAT16_EXP_BITS + 1)) - 1) << BFLOAT16_MANTISSA_BITS) + + +def e8m0_to_half(scale, half_dtype: torch.dtype): + assert scale.dtype == torch.uint8 + + scale_exp = scale.to(torch.int16) - 127 + + # This can be implemented with bitwise operations in a proper kernel. + scale_half = 2.0**(scale_exp.to(torch.float)) + + return scale_half.to(half_dtype) + + +def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype, + half_exp_bias: int, half_mantissa_bits: int): + assert val.dtype == torch.uint8 + + unpacked = torch.zeros(*val.shape[:-1], + val.shape[-1] * 2, + dtype=torch.uint8, + device=val.device) + unpacked[..., 1::2] = (val >> 4) & 0x0F # Extract high 4 bits. + unpacked[..., ::2] = val & 0x0F # Extract low 4 bits. + + # Takes one float4 values represented as b0000xxxx, + # and converts it to the corresponding float16 value. + + sign = unpacked >> 3 + + exp = (unpacked >> 1) & 3 + new_mantissa = unpacked & 1 + + # if exp == 0 and new_mantissa == 0: + # new_exp = 0 + # else: + # new_exp = exp - FLOAT4_EXP_BIAS + FLOAT16_EXP_BIAS + + # int8_t works with float16, but may overflow with bfloat16. + new_exp = exp - FLOAT4_EXP_BIAS + half_exp_bias + + # Cast b0000 to 0. in fp16/bf16. + new_exp = new_exp * torch.logical_or(exp > 0, new_mantissa > 0) + + # Cast b0001 to 0.5 in fp16/bf16. + new_mantissa = torch.logical_and(new_mantissa, exp > 0) + + new_mantissa = new_mantissa.to(torch.int32) + new_exp = new_exp.to(torch.int32) + sign = sign.to(torch.int32) + + qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( + new_mantissa << (half_mantissa_bits - 1)) + + assert qdq_val.max() <= 65535 + assert qdq_val.min() >= 0 + qdq_val = qdq_val.to(torch.uint16) + + result = qdq_val.view(float_dtype) + + return result + + +def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor, + float_dtype: torch.dtype) -> torch.Tensor: + assert x.dtype == torch.uint8 + assert scale.dtype == torch.uint8 + + if float_dtype == torch.float16: + half_exp_bias = FLOAT16_EXP_BIAS + half_mantissa_bits = FLOAT16_MANTISSA_BITS + elif float_dtype == torch.bfloat16: + half_exp_bias = BFLOAT16_EXP_BIAS + half_mantissa_bits = BFLOAT16_MANTISSA_BITS + + scale_half = e8m0_to_half(scale, half_dtype=float_dtype) + + x_half = upcast_fp4_to_fp16_or_bf16(x, + float_dtype=float_dtype, + half_exp_bias=half_exp_bias, + half_mantissa_bits=half_mantissa_bits) + + x_half = x_half.reshape(*x_half.shape[:-1], -1, 32) + x_half = x_half * scale_half[..., None] + x_half = x_half.reshape(*x_half.shape[:-2], -1) + + return x_half + + +def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, + half_exp_bias: int): + # Casts an fp16/bf16 input to the restricted values of float4_e2m1, + # that is to say [0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, + # -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]. + + float_type = val.dtype + + # "rshift_cuda" not implemented for 'UInt16' + val_view = val.view(torch.int16) #.to(torch.int32) + + exp = val_view >> half_mantissa_bits + exp = exp & ((1 << half_exp_bits) - 1) + + exp = exp.view(torch.uint16).to(torch.int32) + + sign = (val_view >> (half_mantissa_bits + half_exp_bits)) & 1 + + mantissa_last = (val_view >> (half_mantissa_bits - 1)) & 1 + + exp_unbias = exp - half_exp_bias + new_exp = exp_unbias + FLOAT4_EXP_BIAS + + exp_shift = (new_exp <= 0) * (1 - new_exp) + + # Typically 9. + # Take the min to prevent overflow on `uint16_t half`. This is the case for + # very small values, correctly mapped to `round_close`. + tail_bits = half_mantissa_bits - FLOAT4_MANTISSA_BITS + exp_shift + tail_bits[tail_bits >= 16] = 16 + + mantissa_plus_one = val_view & ((1 << (half_mantissa_bits + 1)) - 1) + + half = 1 << (tail_bits - 1) + + tail = mantissa_plus_one & ((1 << tail_bits) - 1) + + round_close = (tail < half) # round towards 0 + round_away = (tail > half) # round away from 0 + tie = tail == half + + new_mantissa_close = torch.zeros(val.shape, + device=val.device, + dtype=torch.bool) + new_exp_close = torch.zeros(val.shape, + device=val.device, + dtype=torch.uint16) + + new_mantissa_away = torch.zeros(val.shape, + device=val.device, + dtype=torch.bool) + new_exp_away = torch.zeros(val.shape, + device=val.device, + dtype=torch.uint16) + + new_exp_tie = torch.zeros(val.shape, device=val.device, dtype=torch.uint16) + + # 1. round down + # if new_exp == 0: # case [0.5, 0.749999] + # new_mantissa = 0 + # elif new_exp < 0: # case [0, 0.24999] + # new_mantissa = 0 + # else: + # new_mantissa = mantissa_last + + new_mantissa_close = (new_exp > 0) * mantissa_last + new_exp_close = exp + + # # 2. round up + # if new_exp <= 0: # case [0.250001, 0.499999] and [0.75001, 0.99999] + # new_mantissa = 0 + # new_exp += 1 + # elif mantissa_last == 0: + # new_mantissa = 1 + # else: + # new_mantissa = 0 + # new_exp += 1 + + new_mantissa_away = torch.logical_and(new_exp > 0, mantissa_last == 0) + new_exp_away = exp + torch.logical_or(new_exp <= 0, mantissa_last == 1) + + # # 3. tie + # 0.25 -> 0. (handled by `exp > (half_exp_bias - 2)`) + # 0.75 -> 1. + # 1.25 -> 1. + # 1.75 -> 2. + # 2.5 -> 2. + # 3.5 -> 4. + # 5. -> 4. + new_exp_tie = (exp > (half_exp_bias - 2)) * (exp + (mantissa_last == 1)) + + # Gather round up, round down and tie. + new_exp = round_away * new_exp_away \ + + round_close * new_exp_close \ + + tie * new_exp_tie + + new_mantissa = round_away * new_mantissa_away \ + + round_close * new_mantissa_close + + # if new_exp > 3: + # new_mantissa = 1 + new_mantissa = new_mantissa + (new_exp > + (2 + half_exp_bias)) * (new_mantissa == 0) + + # Clamp the exponent to acceptable values. + new_exp = (new_exp >= (half_exp_bias - 2)) * torch.clamp( + new_exp, half_exp_bias - 2, half_exp_bias + 2) + + sign = sign.to(torch.int32) + new_mantissa = new_mantissa.to(torch.int32) + + qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( + new_mantissa << (half_mantissa_bits - 1)) + + assert qdq_val.max() <= 65535 + assert qdq_val.min() >= 0 + assert qdq_val.dtype == torch.int32 + qdq_val = qdq_val.to(torch.uint16) + + result = qdq_val.view(float_type) + return result + + +def qdq_mxfp4_torch(x: torch.Tensor, + scale_calculation_mode: str = "even") -> torch.Tensor: + half_dtype = x.dtype + + if half_dtype == torch.float16: + half_mantissa_bits = FLOAT16_MANTISSA_BITS + half_exp_bits = FLOAT16_EXP_BITS + half_exp_bias = FLOAT16_EXP_BIAS + val_to_add = FLOAT16_VAL_TO_ADD + sign_exponent_mask = FLOAT16_SIGN_EXPONENT_MASK + elif half_dtype == torch.bfloat16: + half_mantissa_bits = BFLOAT16_MANTISSA_BITS + half_exp_bits = BFLOAT16_EXP_BITS + half_exp_bias = BFLOAT16_EXP_BIAS + val_to_add = BFLOAT16_VAL_TO_ADD + sign_exponent_mask = BFLOAT16_SIGN_EXPONENT_MASK + else: + raise ValueError("not implemented") + + x = x.reshape(*x.shape[:-1], -1, 32) + + block_max = torch.max(torch.abs(x), dim=-1).values + + block_max = block_max.view(torch.uint16).to(torch.int32) + + block_max_uint = torch.bitwise_and(block_max + val_to_add, + sign_exponent_mask) + + assert block_max_uint.max() <= 65535 + assert block_max_uint.min() >= 0 + assert block_max_uint.dtype == torch.int32 + block_max_uint = block_max_uint.to(torch.uint16) + + block_max = block_max_uint.view(half_dtype) + + scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to( + torch.int32) - 2 + + scale_exp = torch.clamp(scale_exp, 0, 2 * FLOAT8_E8M0_MAX_EXP) + + scale = 2.0**(scale_exp - FLOAT8_E8M0_MAX_EXP) + scale = scale.to(half_dtype) + + x = x / scale[..., None] + + x_fp4 = fp16_to_fp4_simulate(x, + half_exp_bits=half_exp_bits, + half_mantissa_bits=half_mantissa_bits, + half_exp_bias=half_exp_bias) + + x_fp4 = x_fp4 * scale[..., None] + return x_fp4.reshape(*x_fp4.shape[:-2], -1) diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 3571f773fb02..2db11cb997d1 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -3,15 +3,44 @@ """Test model set-up and weight loading for quark-quantized models. Run `pytest tests/quantization/test_quark.py`. + +See also `tests/kernels/moe/test_mxfp4_moe.py`. """ +import importlib +import importlib.metadata +import os +from dataclasses import dataclass + +import huggingface_hub +import lm_eval import pytest import torch +from packaging import version from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8) from vllm.platforms import current_platform +from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( + "quark") is not None and version.parse( + importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') + +if QUARK_MXFP4_AVAILABLE: + from quark.torch.export.nn.modules.realquantizer import ( + StaticScaledRealQuantizer) + from quark.torch.kernel import mx as mx_kernel + from quark.torch.quantization.config.config import FP4PerGroupSpec + +try: + huggingface_hub.list_repo_refs( + "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ") + HF_HUB_AMD_ORG_ACCESS = True +except huggingface_hub.errors.RepositoryNotFoundError: + HF_HUB_AMD_ORG_ACCESS = False + @pytest.fixture(scope="function", autouse=True) def use_v0_only(monkeypatch): @@ -90,3 +119,145 @@ def test_quark_fp8_parity(vllm_runner): for key in fp8_state_dict: assert torch.equal(fp8_state_dict[key], quark_state_dict[key]) + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@dataclass +class GSM8KAccuracyTestConfig: + model_name: str + excepted_value: float + + def get_model_args(self) -> str: + return ( + f"pretrained={self.model_name}," + "dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768" + ) + + +ACCURACY_CONFIGS = [ + # Private model. + GSM8KAccuracyTestConfig( + model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant", + excepted_value=0.96), +] + + +@pytest.mark.parametrize("config", ACCURACY_CONFIGS) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, + reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif( + not HF_HUB_AMD_ORG_ACCESS, + reason="Read access to huggingface.co/amd is required for this test.") +def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig): + if torch.cuda.device_count() < 8: + pytest.skip( + f"This test requires >=8 gpus, got only {torch.cuda.device_count()}" + ) + + task = "gsm8k" + rtol = 0.03 + + os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" + + results = lm_eval.simple_evaluate( + model="vllm", + model_args=config.get_model_args(), + tasks=task, + batch_size=64, + num_fewshot=8, + ) + + EXPECTED_VALUE = config.excepted_value + measured_value = results["results"][task]["exact_match,strict-match"] + assert (measured_value - rtol < EXPECTED_VALUE + and measured_value + rtol > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + del os.environ["VLLM_USE_TRITON_FLASH_ATTN"] + + +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, + reason="amd-quark>=0.9 is not available") +@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("scalings", + [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) +def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, + scalings: list[int]): + torch.manual_seed(0) + + hidden_size = 64 * 32 + inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - + 0.5) * 2 + for i in range(hidden_size // 32): + inp[:, i * 32:(i + 1) * + 32] = inp[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)] + + inp_kernel = inp.clone() + inp_kernel_clone = inp_kernel.clone() + + res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even") + res_torch = qdq_mxfp4_torch(inp_kernel, "even") + + for i in range(hidden_size // 32): + assert torch.all(torch.isfinite(res_hip[:, i * 32:(i + 1) * 32])) + assert torch.all(torch.isfinite(res_torch[:, i * 32:(i + 1) * 32])) + + torch.testing.assert_close(res_hip[:, i * 32:(i + 1) * 32], + res_torch[:, i * 32:(i + 1) * 32]) + + +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, + reason="amd-quark>=0.9 is not available") +@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("scalings", + [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) +def test_mxfp4_dequant_kernel_match_quark(float_dtype: torch.dtype, + scalings: list[int]): + qspec = FP4PerGroupSpec( + ch_axis=-1, + group_size=32, + scale_format="e8m0", + scale_calculation_mode="even", + is_dynamic=False, + ).to_quantization_spec() + + weight_quantizer = StaticScaledRealQuantizer( + qspec=qspec, + quantizer=None, + reorder=False, + real_quantized=True, + float_dtype=float_dtype, + device="cuda", + ) + + observer = qspec.observer_cls(qspec, device="cuda") + + hidden_size = 512 + shape = (11008, hidden_size) + + w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2 + + # Make it so that different groups have different scales. + for i in range(hidden_size // 32): + w[:, i * 32:(i + 1) * + 32] = w[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)] + + observer(w) + scale, _ = observer._calculate_qparams() + weight_quantizer.scale = scale + + w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda") + weight_quantizer.maybe_convert_and_transpose_scale() + + scale = weight_quantizer.scale + + out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype) + + out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype) + + assert torch.equal(out_hip, out_torch) diff --git a/vllm/envs.py b/vllm/envs.py index ec6a4896774f..d7ba43c8251f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -94,7 +94,6 @@ VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True - VLLM_QUARK_EMU_MEM_OPT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -723,14 +722,6 @@ def get_vllm_port() -> Optional[int]: lambda: maybe_convert_int( os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)), - # If set, when running in Quark emulation mode, do not dequantize the - # weights at load time. Instead, dequantize weights on-the-fly during - # kernel execution. - # This allows running larger models at the cost of slower inference. - # This flag has no effect when not running in Quark emulation mode. - "VLLM_QUARK_EMU_MEM_OPT": - lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))), - # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 6c03732030d1..432617ba046e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -50,11 +50,14 @@ def get_config_quant_dtype( use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, -) -> Optional[torch.dtype]: + use_mxfp4_w4a4: bool, +) -> Union[None, torch.dtype, str]: if use_fp8_w8a8: return torch.float8_e4m3fn elif use_int8_w8a8: return torch.int8 + elif use_mxfp4_w4a4: + return "mxfp4" return None @@ -126,6 +129,7 @@ def make( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: Optional[list[int]] = None, @@ -144,6 +148,7 @@ def make( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, ) return FusedMoEQuantConfig( quant_dtype, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 0355abbf1d2b..cf8d77063046 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -632,6 +632,7 @@ def __init__( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, ): @@ -641,12 +642,14 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, )) assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" + assert not use_mxfp4_w4a4, "NYI" self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -838,6 +841,7 @@ def __init__( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, ): @@ -847,18 +851,21 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, )) assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" + assert not use_mxfp4_w4a4, "NYI" assert max_num_tokens > 0 assert num_dispatchers > 0 self.use_fp8_w8a8 = use_fp8_w8a8 self.use_int8_w8a8 = use_int8_w8a8 self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a16 = use_int8_w8a16 + self.use_mxfp4_w4a4 = use_mxfp4_w4a4 self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -941,6 +948,7 @@ def apply( config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, + use_mxfp4_w4a4=self.use_mxfp4_w4a4, dtype=hidden_states.dtype) config = try_get_optimal_moe_config( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fbbccbb34d90..4e578196fc00 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -27,6 +27,8 @@ MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + dequant_mxfp4) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -973,13 +975,16 @@ def get_config_dtype_str( dtype: torch.dtype, use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False) -> Optional[str]: + use_fp8_w8a8: Optional[bool] = False, + use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]: if use_fp8_w8a8: return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" elif use_int4_w4a16: return "int4_w4a16" + elif use_mxfp4_w4a4: + return "mxfp4_w4a4" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -998,6 +1003,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1011,9 +1017,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, - per_channel_quant, global_num_experts, expert_map, - w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + use_mxfp4_w4a4, per_channel_quant, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, + a2_scale, block_shape) def inplace_fused_experts_fake( @@ -1028,6 +1034,7 @@ def inplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1062,6 +1069,7 @@ def outplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1075,10 +1083,10 @@ def outplace_fused_experts( return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, - use_int4_w4a16, per_channel_quant, - global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + use_int4_w4a16, use_mxfp4_w4a4, + per_channel_quant, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, + a1_scale, a2_scale, block_shape) def outplace_fused_experts_fake( @@ -1092,6 +1100,7 @@ def outplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1145,6 +1154,7 @@ def fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1203,6 +1213,7 @@ def fused_experts( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1228,6 +1239,7 @@ def fused_experts_impl( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1243,6 +1255,9 @@ def fused_experts_impl( if use_int4_w4a16: assert hidden_states.size(1) // 2 == w1.size(2), ( "Hidden size mismatch") + elif use_mxfp4_w4a4: + # 16bit activation and fp4x2 packed weight + assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch" else: assert hidden_states.size(1) == w1.size(2), ( f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") @@ -1268,12 +1283,14 @@ def fused_experts_impl( config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, dtype=hidden_states.dtype) qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16) + use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -1313,6 +1330,13 @@ def fused_experts_impl( else: out_hidden_states = torch.empty_like(hidden_states) + if use_mxfp4_w4a4: + # Weight has to be dequantized for mxfp4 emulation. + w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) + w1_scale = None + w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) + w2_scale = None + for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1429,6 +1453,7 @@ def fused_moe( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1470,6 +1495,9 @@ def fused_moe( - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 activation to compute the inner products for w1 and w2. Defaults to False. + - use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and + OCP MXFP4 activation to compute the inner products for w1 and w2. + Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices @@ -1513,6 +1541,7 @@ def fused_moe( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1533,6 +1562,7 @@ def __init__( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, ): @@ -1542,6 +1572,7 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, )) @@ -1550,6 +1581,7 @@ def __init__( self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a8 = use_int8_w8a8 self.use_int8_w8a16 = use_int8_w8a16 + self.use_mxfp4_w4a4 = use_mxfp4_w4a4 @property def activation_formats( @@ -1627,6 +1659,7 @@ def apply( config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, + use_mxfp4_w4a4=self.use_mxfp4_w4a4, dtype=hidden_states.dtype) config = try_get_optimal_moe_config( @@ -1718,6 +1751,7 @@ def modular_triton_fused_moe( use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + use_mxfp4_w4a4: bool, per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> mk.FusedMoEModularKernel: @@ -1728,6 +1762,7 @@ def modular_triton_fused_moe( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, ), diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index e660376ebe6b..db3b485888a4 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -19,6 +19,7 @@ def __init__( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, allow_deep_gemm: bool = False, @@ -29,6 +30,7 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, )) @@ -37,6 +39,7 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int4_w4a16=use_int4_w4a16, use_int8_w8a16=use_int8_w8a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index a90cce719b48..1eb949790060 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from math import prod -from typing import Optional +from typing import Optional, Union import torch @@ -10,6 +10,9 @@ per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + quant_dequant_mxfp4) +from vllm.platforms import current_platform from vllm.utils import cdiv @@ -74,10 +77,25 @@ def _int8_quantize( return A, A_scale +def _mxfp4_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, None]: + assert block_shape is None + if not current_platform.supports_mx(): + A = quant_dequant_mxfp4(A) + else: + raise NotImplementedError() + + return A, None + + def moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[None, torch.dtype, str], per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -85,6 +103,8 @@ def moe_kernel_quantize_input( return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == "mxfp4": + return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) else: return A, A_scale diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 05dff4bae395..b67ee5cf453d 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -237,12 +237,6 @@ def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]], "Quark model is not in MX-FP4 format: not group_size=32") return False - # Weights need to use static quantization. - if weight_quant.get("is_dynamic") is True: - logger.debug( - "Quark model is not in MX-FP4 format: not weight static") - return False - # Activations need to use dynamic quantization. if input_quant.get("is_dynamic") is False: logger.debug( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index a040c430cbca..6f69210d0861 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -5,11 +5,12 @@ import torch -import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + OCP_MX_BLOCK_SIZE) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -17,7 +18,9 @@ logger = init_logger(__name__) -__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"] +__all__ = [ + "QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod" +] class QuarkMoEMethod(FusedMoEMethodBase): @@ -40,6 +43,8 @@ def get_moe_method( if quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config) + elif quant_config._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) else: raise RuntimeError("Unsupported FusedMoe scheme") @@ -242,4 +247,163 @@ def apply( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + a2_scale=layer.w2_input_scale, + activation=activation) + + +class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): + + def __init__(self, weight_config: dict[str, Any], input_config: dict[str, + Any]): + self.weight_quant = weight_config + self.input_quant = input_config + + weight_qscheme = self.weight_quant.get("qscheme") + input_qscheme = self.input_quant.get("qscheme") + if not (weight_qscheme == "per_group" + and input_qscheme == "per_group"): + raise ValueError( + "For MX(FP4) Fused MoE layers, only per-group scales " + "for weights and activations are supported. Found " + f"{weight_qscheme}, {input_qscheme}") # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + + if self.static_input_scales: + raise NotImplementedError( + "QuarkW4A4MXFp4MoEMethod with static input scales is currently " + "not implemented. Please open an issue.") + + if not current_platform.supports_mx(): + self.emulate = True + logger.warning_once( + "The current platform does not support native MXFP4 " + "computation. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") + else: + self.emulate = True + logger.warning_once( + "The current platform supports native MXFP4 " + "computation, but kernels are not yet integrated in vLLM. " + "Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + + params_dtype = torch.uint8 + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.") + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + out = fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_mxfp4_w4a4=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + block_shape=None, + activation=activation, + ) + return out diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index 3c56251b7a00..880438a22a69 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -6,14 +6,16 @@ import torch import torch.nn.functional as F -import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4) + OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) from vllm.platforms import current_platform +logger = init_logger(__name__) + __all__ = ["QuarkW4A4MXFP4"] @@ -25,7 +27,29 @@ def __init__(self, weight_quant_spec: dict[str, Any], self.qscheme = "per_group" self.weight_quant_spec = weight_quant_spec self.input_quant_spec = input_quant_spec - self.emulate = not current_platform.supports_mx() + + self.static_input_scales = not input_quant_spec.get("is_dynamic") + + if self.static_input_scales: + raise NotImplementedError( + "QuarkW4A4MXFP4 with static input scales is currently not " + "implemented. Please open an issue.") + + if not current_platform.supports_mx(): + self.emulate = True + logger.warning_once( + "The current platform does not support native MXFP4 " + "computation. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") + else: + self.emulate = True + logger.warning_once( + "The current platform supports native MXFP4 " + "computation, but kernels are not yet integrated in vLLM. " + "Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") @classmethod def get_min_capability(cls) -> int: @@ -37,43 +61,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, requires_grad=False) - if self.emulate: - try: - from quark.torch.export.nn.modules import realquantizer - from quark.torch.quantization.config.config import ( - QuantizationSpec) - except ImportError as err: - raise ImportError( - "The package `amd-quark` is required to use AMD Quark " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err - - weight_quant_spec = QuantizationSpec.from_dict( - self.weight_quant_spec) - - weight_quantizer = realquantizer.get_real_quantizer( - qspec=weight_quant_spec, - quantizer=None, - real_quantized=True, - reorder=False, - float_dtype=self.out_dtype, - scale_shape=layer.weight_scale.shape, - zero_point_shape=None, - ) - weight_quantizer.scale.data = layer.weight_scale.data - - if not envs.VLLM_QUARK_EMU_MEM_OPT: - layer.weight = torch.nn.Parameter( - weight_quantizer(layer.weight.data).to(self.out_dtype), - requires_grad=False, - ) - else: - self.weight_quantizer = weight_quantizer - layer.weight_scale = None - - # This call is necessary to release the scales memory. - torch.cuda.empty_cache() - def create_weights(self, layer: torch.nn.Module, output_partition_sizes: list[int], input_size_per_partition: int, @@ -116,11 +103,10 @@ def apply_weights(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: if self.emulate: - if envs.VLLM_QUARK_EMU_MEM_OPT: - dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) - else: - dq_w = layer.weight - qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE) - return F.linear(qdq_x, dq_w, bias) + dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype) + + x = quant_dequant_mxfp4(x) + + return F.linear(x, dq_w, bias) else: raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 9d4a188f52df..1119045db072 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,45 +1,67 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch +from vllm.utils import direct_register_custom_op + OCP_MX_BLOCK_SIZE = 32 -def per_token_group_quant_mxfp4(x: torch.Tensor, - block_k: int, - scale_calculation_mode: str = "even" - ) -> tuple[torch.Tensor, torch.Tensor]: +def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, + float_dtype: torch.dtype) -> torch.Tensor: try: - from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( - fake_quantize_fp4_fp6_per_group_with_scale) - from quark.torch.quantization.utils import (even_round, - reshape_to_blocks) + from quark.torch.kernel import mx except ImportError as err: raise ImportError("The package `amd-quark` is required to use " "MX-FP4 models. Please install it with `pip install " "amd-quark`.") from err - axis = -1 - block_x = reshape_to_blocks(x, block_k, axis) - amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) - amax = amax.squeeze(-1) - - # TODO: there are other rounding strategies supported in quark and in the - # config.json that we do not check for here! - if scale_calculation_mode != "even": - raise NotImplementedError( - f"Scale calculation mode {scale_calculation_mode} is not yet " - "supported in MX-FP4 quantization") - scale = even_round(amax, "fp4") - - # Apply dequantize(quantize(x)). - x = fake_quantize_fp4_fp6_per_group_with_scale( - x, - scale.to(x.device), - axis=axis, - group_size=block_k, - quant_dtype="fp4", + return mx.dq_mxfp4(x, scale, float_dtype) + + +def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor, + float_dtype: torch.dtype) -> torch.Tensor: + return torch.empty((*x.shape[:-1], x.shape[-1] * 2), + dtype=float_dtype, + device=x.device) + + +def _quant_dequant_mxfp4(x: torch.Tensor, + scale_calculation_mode: str = "even") -> torch.Tensor: + try: + from quark.torch.kernel import mx + except ImportError as err: + raise ImportError("The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + return mx.qdq_mxfp4(x, scale_calculation_mode) + + +def _quant_dequant_mxfp4_fake(x: torch.Tensor, + scale_calculation_mode: str = "even" + ) -> torch.Tensor: + return torch.empty_like(x) + + +try: + direct_register_custom_op( + op_name="dequant_mxfp4", + op_func=_dequant_mxfp4, + mutates_args=[], + fake_impl=_dequant_mxfp4_fake, ) + dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4 +except AttributeError as error: + raise error - return x, scale +try: + direct_register_custom_op( + op_name="quant_dequant_mxfp4", + op_func=_quant_dequant_mxfp4, + mutates_args=[], + fake_impl=_quant_dequant_mxfp4_fake, + ) + quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4 +except AttributeError as error: + raise error