diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index a03c98430..b9acd973a 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -50,7 +50,7 @@ from ..nn_modules.qlinear import BaseQuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear -from ..nn_modules.qlinear.ipex import IPEXQuantLinear +from ..nn_modules.qlinear.ipex import IPEXQuantLinear, HAS_IPEX from ..quantization import FORMAT, QuantizeConfig from ..quantization.config import FORMAT_FIELD_JSON, QUANT_METHOD, dynamic_get from .backend import BACKEND @@ -892,28 +892,36 @@ def auto_dtype(config: PretrainedConfig, assert isinstance(device, DEVICE) - # for inference, DynamicCuda, Exllama, Triton, and Marlin are all fp16 kernels - if quant_inference and device != DEVICE.CPU: - return torch.float16 - # TODO: both MPS and XPU are locked to float16 # XPU stack is missing bfloat16 (hardware supports it) # MPS stack has bfloat16 bugs in pytorch if device in [DEVICE.MPS, DEVICE.XPU]: + log.info("Loader: Auto dtype (MPS or XPU): `torch.float16`") return torch.float16 + # IPEX for CPU is optimized for bfloat16 + if device in [DEVICE.CPU] and HAS_IPEX: + log.info("Loader: Auto dtype (CPU + IPEX): `torch.bfloat16`") + return torch.bfloat16 + # get dtype from config dtype = getattr(config, "torch_dtype") if dtype and not isinstance(dtype, torch.dtype): raise ValueError(f"torch_dtype in config must be a torch.dtype, but got {dtype}") - if dtype == torch.float32: + if dtype in [torch.float32, torch.float64]: + log.info("Loader: Auto dtype (float32 down-cast): `torch.bfloat16`") return torch.bfloat16 elif dtype == torch.float16: + log.info("Loader: Auto dtype (native float16): `torch.float16`") return torch.float16 + elif dtype == torch.bfloat16: + log.info("Loader: Auto dtype (native bfloat16): `torch.bfloat16`") + return torch.bfloat16 else: # TODO: extract weights from model file to check their original type, instead of forcing bfloat16 # up/down-cast everything else to bfloat16 if not already in bfloat16 + log.info(f"Loader: Auto dtype (native = `{dtype}`): `torch.bfloat16`") return torch.bfloat16 diff --git a/tests/test_q4_marlin.py b/tests/test_q4_marlin.py index 29e55053e..7f28a1fd5 100644 --- a/tests/test_q4_marlin.py +++ b/tests/test_q4_marlin.py @@ -57,7 +57,7 @@ class TestQ4Marlin(ModelTest): ) def test_generation(self, model_id): try: - model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN, torch_dtype=torch.bfloat16) + model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN) except ValueError as e: raise e