Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ..nn_modules.qlinear import BaseQuantLinear
from ..nn_modules.qlinear.exllama import ExllamaQuantLinear
from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear
from ..nn_modules.qlinear.ipex import IPEXQuantLinear
from ..nn_modules.qlinear.ipex import IPEXQuantLinear, HAS_IPEX
from ..quantization import FORMAT, QuantizeConfig
from ..quantization.config import FORMAT_FIELD_JSON, QUANT_METHOD, dynamic_get
from .backend import BACKEND
Expand Down Expand Up @@ -892,28 +892,36 @@ def auto_dtype(config: PretrainedConfig,

assert isinstance(device, DEVICE)

# for inference, DynamicCuda, Exllama, Triton, and Marlin are all fp16 kernels
if quant_inference and device != DEVICE.CPU:
return torch.float16

# TODO: both MPS and XPU are locked to float16
# XPU stack is missing bfloat16 (hardware supports it)
# MPS stack has bfloat16 bugs in pytorch
if device in [DEVICE.MPS, DEVICE.XPU]:
log.info("Loader: Auto dtype (MPS or XPU): `torch.float16`")
return torch.float16

# IPEX for CPU is optimized for bfloat16
if device in [DEVICE.CPU] and HAS_IPEX:
log.info("Loader: Auto dtype (CPU + IPEX): `torch.bfloat16`")
return torch.bfloat16

# get dtype from config
dtype = getattr(config, "torch_dtype")
if dtype and not isinstance(dtype, torch.dtype):
raise ValueError(f"torch_dtype in config must be a torch.dtype, but got {dtype}")

if dtype == torch.float32:
if dtype in [torch.float32, torch.float64]:
log.info("Loader: Auto dtype (float32 down-cast): `torch.bfloat16`")
return torch.bfloat16
elif dtype == torch.float16:
log.info("Loader: Auto dtype (native float16): `torch.float16`")
return torch.float16
elif dtype == torch.bfloat16:
log.info("Loader: Auto dtype (native bfloat16): `torch.bfloat16`")
return torch.bfloat16
else:
# TODO: extract weights from model file to check their original type, instead of forcing bfloat16
# up/down-cast everything else to bfloat16 if not already in bfloat16
log.info(f"Loader: Auto dtype (native = `{dtype}`): `torch.bfloat16`")
return torch.bfloat16


Expand Down
2 changes: 1 addition & 1 deletion tests/test_q4_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TestQ4Marlin(ModelTest):
)
def test_generation(self, model_id):
try:
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN, torch_dtype=torch.bfloat16)
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN)
except ValueError as e:
raise e

Expand Down