diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index e37d78d49e..b7778940d4 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -2950,8 +2950,9 @@ def test_bloomz_loftq_4bit(self, device, tmp_path): assert mse_loftq > 0.0 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - assert mse_loftq < (mse_quantized / self.error_factor) - assert mae_loftq < (mae_quantized / self.error_factor) + error_factor = self.get_error_factor(device) + assert mse_loftq < (mse_quantized / error_factor) + assert mae_loftq < (mae_quantized / error_factor) @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path):