diff --git a/gptqmodel/nn_modules/triton_utils/dequant.py b/gptqmodel/nn_modules/triton_utils/dequant.py index 0dccc532e..eefcf79f9 100644 --- a/gptqmodel/nn_modules/triton_utils/dequant.py +++ b/gptqmodel/nn_modules/triton_utils/dequant.py @@ -93,7 +93,7 @@ def dequant_kernel( tl.store(out_ptr + (x_index), weights, mask=xmask) -def dequant(qweight, scales, qzeros, g_idx, bits, pack_bits, maxq): +def dequant(dtype, qweight, scales, qzeros, g_idx, bits, pack_bits, maxq): """ Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8 """ @@ -102,7 +102,7 @@ def dequant(qweight, scales, qzeros, g_idx, bits, pack_bits, maxq): out_features = scales.shape[1] in_features = g_idx.shape[0] - out = torch.empty((in_features, out_features), device=qweight.device, dtype=torch.float16) + out = torch.empty((in_features, out_features), device=qweight.device, dtype=dtype) numels = out.numel() grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731 @@ -121,14 +121,12 @@ def dequant(qweight, scales, qzeros, g_idx, bits, pack_bits, maxq): ) return out - def quant_matmul(input, qweight, scales, qzeros, g_idx, bits, pack_bits, maxq, transpose=False): - W = dequant(qweight, scales, qzeros, g_idx, bits, pack_bits, maxq) + W = dequant(input.dtype, qweight, scales, qzeros, g_idx, bits, pack_bits, maxq) if transpose: return input @ W.t() return input @ W - class QuantLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd(device_type="cuda") diff --git a/tests/test_olora_finetuning_xpu.py b/tests/test_olora_finetuning_xpu.py index a71ff2ad1..175ddfe68 100644 --- a/tests/test_olora_finetuning_xpu.py +++ b/tests/test_olora_finetuning_xpu.py @@ -50,17 +50,17 @@ def train( cutoff_len: int = 256, val_set_size: int = 16, quantize: bool = False, - eval_step: int = 100, - save_step: int = 100, + eval_step: int = 10, + save_step: int = 10000, device_map: str = "auto", lora_r: int = 32, lora_alpha: int = 16, lora_dropout: float = 0.05, lora_target_modules: List[str] = None, - torch_dtype: str = "bloat16", + torch_dtype: torch.dtype = torch.bfloat16, init_lora_weights="olora", ): - model_kwargs = {"torch_dtype": getattr(torch, torch_dtype), "device_map": DEVICE} + model_kwargs = {"torch_dtype": torch_dtype, "device_map": DEVICE} if quantize: model_kwargs["quantization_config"] = GPTQConfig(bits=4, true_sequential=False, dataset=['/monster/data/model/dataset/c4-train.00000-of-01024.json.gz'], backend="triton") @@ -165,13 +165,13 @@ def test_peft(self): cutoff_len=256, val_set_size=16, quantize=True, - eval_step=100, - save_step=100, + eval_step=10, + save_step=10000, device_map="cuda", lora_r=32, lora_alpha=16, lora_dropout=0.05, lora_target_modules=None, - torch_dtype="bfloat16", + torch_dtype=torch.bfloat16, init_lora_weights="olora", )