diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index a9642d4d2b..59aefb0fe0 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -93,14 +93,19 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): """ import bitsandbytes as bnb - # BNB requires CUDA weights + if state.SCB is None: + state.SCB = weight.SCB + + # BNB requires accelerator weights device = weight.device is_cpu = device.type == torch.device("cpu").type if is_cpu: if torch.cuda.is_available(): weight = weight.to(torch.device("cuda")) + state.SCB = state.SCB.to(torch.device("cuda")) elif is_xpu_available(): weight = weight.to(torch.device("xpu")) + state.SCB = state.SCB.to(torch.device("xpu")) cls_name = weight.__class__.__name__ if cls_name == "Params4bit": @@ -109,9 +114,6 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): dequantized = dequantized.to(device) return dequantized - if state.SCB is None: - state.SCB = weight.SCB - if hasattr(bnb.functional, "int8_vectorwise_dequant"): # Use bitsandbytes API if available (requires v0.45.0+) dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB) @@ -121,6 +123,7 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): if is_cpu: dequantized = dequantized.to(device) + state.SCB = state.SCB.to(device) return dequantized diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 51752e5a66..e37d78d49e 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -2817,28 +2817,29 @@ def test_olora_with_quantized_model(self, bits): @pytest.mark.skipif( not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a hardware accelerator" ) +@pytest.mark.single_gpu_tests @require_bitsandbytes class TestLoftQ: r""" Tests for LoftQ to ensure that it reduces the quantization error compared to normal LoRA quantization. """ - # The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to - # quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very - # conservative value to prevent flakiness, in practice most gains are > 1.5 - device = infer_device() - error_factor = 1.005 if device in ("xpu", "cpu") else 1.03 + def get_error_factor(self, device): + # The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to + # quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very + # conservative value to prevent flakiness, in practice most gains are > 1.5 + error_factor = 1.005 if device in ("xpu", "cpu") else 1.03 + return error_factor def get_input(self, model_id, device): tokenizer = AutoTokenizer.from_pretrained(model_id) inputs = tokenizer("All I want is", padding=True, return_tensors="pt") - inputs = inputs.to(self.device) + inputs = inputs.to(device) return inputs def get_base_model(self, model_id, device, **kwargs): cls = AutoModelForSeq2SeqLM if "t5" in str(model_id) else AutoModelForCausalLM - model = cls.from_pretrained(model_id, **kwargs).eval() - model = model.to(self.device) + model = cls.from_pretrained(model_id, device_map=device, **kwargs).eval() return model def get_logits(self, model, inputs): @@ -2882,7 +2883,7 @@ def get_errors( raise ValueError("bits must be 4 or 8") quantized_model = get_peft_model( - self.get_base_model(model_id, device=None, **kwargs), + self.get_base_model(model_id, device, **kwargs), lora_config, ) torch.manual_seed(0) @@ -2901,10 +2902,10 @@ def get_errors( ) model = self.get_base_model(model_id, device) if device != "cpu": - model = model.to(torch_device) + model = model.to(device) loftq_model = get_peft_model(model, lora_config) if device != "cpu": - loftq_model = loftq_model.to(torch_device) + loftq_model = loftq_model.to(device) # save LoRA weights, they should be initialized such that they minimize the quantization error loftq_model.base_model.peft_config["default"].init_lora_weights = True @@ -2917,7 +2918,7 @@ def get_errors( clear_device_cache(garbage_collection=True) # now load quantized model and apply LoftQ-initialized weights on top - base_model = self.get_base_model(tmp_path / "base_model", device=None, **kwargs, torch_dtype=torch.float32) + base_model = self.get_base_model(tmp_path / "base_model", device=device, **kwargs, torch_dtype=torch.float32) loftq_model = PeftModel.from_pretrained(base_model, tmp_path / "loftq_model", is_trainable=True) # TODO sanity check: model is quantized @@ -2966,8 +2967,9 @@ def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path): assert mse_loftq > 0.0 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - assert mse_loftq < (mse_quantized / self.error_factor) - assert mae_loftq < (mae_quantized / self.error_factor) + error_factor = self.get_error_factor(device) + assert mse_loftq < (mse_quantized / error_factor) + assert mae_loftq < (mae_quantized / error_factor) @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_loftq_8bit(self, device, tmp_path): @@ -2981,8 +2983,9 @@ def test_bloomz_loftq_8bit(self, device, tmp_path): assert mse_loftq > 0.0 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - assert mse_loftq < (mse_quantized / self.error_factor) - assert mae_loftq < (mae_quantized / self.error_factor) + error_factor = self.get_error_factor(device) + assert mse_loftq < (mse_quantized / error_factor) + assert mae_loftq < (mae_quantized / error_factor) @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path): @@ -2998,8 +3001,9 @@ def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path): assert mse_loftq > 0.0 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - assert mse_loftq < (mse_quantized / self.error_factor) - assert mae_loftq < (mae_quantized / self.error_factor) + error_factor = self.get_error_factor(device) + assert mse_loftq < (mse_quantized / error_factor) + assert mae_loftq < (mae_quantized / error_factor) @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_t5_loftq_4bit(self, device, tmp_path): @@ -3013,8 +3017,9 @@ def test_t5_loftq_4bit(self, device, tmp_path): assert mse_loftq > 0.0 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - assert mse_loftq < (mse_quantized / self.error_factor) - assert mae_loftq < (mae_quantized / self.error_factor) + error_factor = self.get_error_factor(device) + assert mse_loftq < (mse_quantized / error_factor) + assert mae_loftq < (mae_quantized / error_factor) @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_t5_loftq_8bit(self, device, tmp_path): @@ -3028,8 +3033,9 @@ def test_t5_loftq_8bit(self, device, tmp_path): assert mse_loftq > 0.0 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - assert mse_loftq < (mse_quantized / self.error_factor) - assert mae_loftq < (mae_quantized / self.error_factor) + error_factor = self.get_error_factor(device) + assert mse_loftq < (mse_quantized / error_factor) + assert mae_loftq < (mae_quantized / error_factor) @pytest.mark.xfail # failing for now, but having DoRA pass is only a nice-to-have, not a must, so we're good @pytest.mark.parametrize("device", [torch_device, "cpu"]) @@ -3063,8 +3069,9 @@ def test_bloomz_loftq_8bit_dora(self, device, tmp_path): assert mse_loftq > 0.0 # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin - assert mae_loftq < (mae_quantized / self.error_factor) - assert mse_loftq < (mse_quantized / self.error_factor) + error_factor = self.get_error_factor(device) + assert mae_loftq < (mae_quantized / error_factor) + assert mse_loftq < (mse_quantized / error_factor) def test_replace_lora_weights_with_loftq_using_callable(self): """