From 82e3a6dd23624b14e578d2d945213764e51f5660 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Aug 2025 19:47:45 -0400 Subject: [PATCH] Enable gpt-oss mxfp4 on older hardware (sm75+) --- src/transformers/integrations/mxfp4.py | 5 +++- .../quantizers/quantizer_mxfp4.py | 28 +++++++++++++------ tests/quantization/mxfp4/test_mxfp4.py | 17 +++++++++-- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 86517671b5f3..6e10535951d6 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -280,7 +280,10 @@ def mlp_forward(self, hidden_states): batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) + + with torch.cuda.device(router_logits.device): + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) + routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) return routed_out, router_logits diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 061ca072f029..4e54ea2aed7e 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -66,23 +66,33 @@ def validate_environment(self, *args, **kwargs): return compute_capability = torch.cuda.get_device_capability() - major, minor = compute_capability + gpu_is_supported = compute_capability >= (7, 5) + kernels_available = is_triton_available("3.4.0") and is_triton_kernels_availalble() - if not is_triton_available("3.4.0") or not is_triton_kernels_availalble(): - if self.pre_quantized and not self.quantization_config.dequantize: + if self.pre_quantized: + # On unsupported GPUs or without kernels, we will dequantize the model to bf16 + if not gpu_is_supported: logger.warning_once( - "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16" + "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200). " + "We will default to dequantizing the model to bf16." ) self.quantization_config.dequantize = True return - else: - # we can't quantize the model in this case so we raise an error - raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed") - if major < 9: + if not kernels_available: + logger.warning_once( + "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16" + ) + self.quantization_config.dequantize = True + return + elif not gpu_is_supported: + # we can't quantize the model in this case so we raise an error raise ValueError( - "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100, or B100)" + "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200)" ) + elif not kernels_available: + # we can't quantize the model in this case so we raise an error + raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed") device_map = kwargs.get("device_map", None) if device_map is None: diff --git a/tests/quantization/mxfp4/test_mxfp4.py b/tests/quantization/mxfp4/test_mxfp4.py index 2194c2d3219e..1d2d32ab72c4 100644 --- a/tests/quantization/mxfp4/test_mxfp4.py +++ b/tests/quantization/mxfp4/test_mxfp4.py @@ -107,18 +107,31 @@ def test_quantizer_validation_no_cuda(self): def test_quantizer_validation_low_compute_capability(self): """Test quantizer validation with low compute capability""" - with patch("torch.cuda.get_device_capability", return_value=(8, 0)): + with patch("torch.cuda.get_device_capability", return_value=(7, 0)): from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer config = Mxfp4Config() quantizer = Mxfp4HfQuantizer(config) + quantizer.pre_quantized = False with self.assertRaises(ValueError): quantizer.validate_environment() + def test_quantizer_validation_low_compute_capability_with_prequantized(self): + """Test quantizer validation with low compute capability""" + with patch("torch.cuda.get_device_capability", return_value=(7, 0)): + from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer + + config = Mxfp4Config() + quantizer = Mxfp4HfQuantizer(config) + + # Should automatically set dequantize=True and warn + quantizer.validate_environment() + self.assertTrue(quantizer.quantization_config.dequantize) + def test_quantizer_validation_low_compute_capability_with_dequantize(self): """Test quantizer validation with low compute capability but dequantize enabled""" - with patch("torch.cuda.get_device_capability", return_value=(8, 0)): + with patch("torch.cuda.get_device_capability", return_value=(7, 0)): from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer config = Mxfp4Config(dequantize=True)