Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(self,
"scales",
t.zeros(
(math.ceil(in_features / self.group_size), out_features),
dtype=t.float16, # Scales are always float16
dtype=t.float16,
),
)
self.register_buffer(
Expand Down
24 changes: 14 additions & 10 deletions gptqmodel/nn_modules/qlinear/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
from ...models._const import DEVICE, PLATFORM
from ...nn_modules.qlinear import BaseQuantLinear
from ...utils.backend import BACKEND
from ...utils.logger import setup_logger

exllama_import_exception = None
try:
from gptqmodel_exllama_kernels import make_q4, q4_matmul
except ImportError as e:
exllama_import_exception = e

logger = getLogger(__name__)

log = setup_logger()

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
NON_TENSOR = torch.empty((1, 1), device="meta")
Expand All @@ -43,9 +43,6 @@ def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else NON_TENSOR, device)




class ExllamaQuantLinear(BaseQuantLinear):
SUPPORTS_BITS = [4]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
Expand Down Expand Up @@ -130,6 +127,8 @@ def post_init(self):
# if self.bias is not None:
# self.bias.resize_(self.out_features)

# ext_make_q4 only accept float16 scales
self.scales = self.scales.to(dtype=torch.float16)

self.width = self.qweight.shape[1]

Expand Down Expand Up @@ -166,7 +165,6 @@ def ext_q4_matmul(self, x, q4, q4_width):

return output.view(outshape)


def forward(self, x: torch.Tensor):
# TODO FIXME: parent should never call us if there is no data to process
# check: https://github.com/ModelCloud/GPTQModel/issues/1361
Expand All @@ -175,11 +173,11 @@ def forward(self, x: torch.Tensor):

x_dtype = x.dtype
if x_dtype != torch.float16:
logger.warn.once(
f"Exllama kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
)
#log.warn.once(
# f"Exllama kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
#)

x = x.half()
x = x.to(dtype=torch.float16)

# TODO: need to run checks to make sure there is no performance regression padding with F.pad
# if in_features is padded, we need to pad the input as well
Expand All @@ -188,4 +186,10 @@ def forward(self, x: torch.Tensor):

out = self.ext_q4_matmul(x, self.q4, self.width)

if self.bias is not None:
out.add_(self.bias)

if self.adapter:
out = self.adapter.apply(x=x, out=out)

return out.to(x_dtype)
32 changes: 14 additions & 18 deletions gptqmodel/nn_modules/qlinear/exllama_eora.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ def post_init(self):

def forward(self, x):
x_dtype = x.dtype
if x_dtype != torch.float16:
log.warn.once(
f"Exllama EoRA kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
)

x = x.to(dtype=torch.float16)
# if x_dtype != torch.float16:
# # log.warn.once(
# # f"Exllama EoRA kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
# # )
#
# # TODO FIXME...Exllam EoRA kernel must run in fp16 or else output (bfloat16) is junk
# x = x.to(dtype=torch.float16)

# sync with vllm
# log.info(f"x shape: {x.shape}")
Expand All @@ -181,23 +182,18 @@ def forward(self, x):
# if x.size(-1) != self.in_features:
# x = F.pad(x, self.in_features_padding_shape)

if self.adapter:
# only 4 bits fused eora kernel has been validated
if self.bits == 4:
output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused
else:
output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal
else:
output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits)

out = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits)

if self.bias is not None:
output.add_(self.bias)
out.add_(self.bias)

if self.adapter:
out = self.adapter.apply(x=x, out=out)

# log.info(f"output: {output.shape}")

# sync with vllm
output = output.reshape(out_shape)
out = out.reshape(out_shape)
# log.info(f"output reshaped: {output.shape}")

return output.to(dtype=x_dtype)
return out.to(dtype=x_dtype)
11 changes: 7 additions & 4 deletions gptqmodel/nn_modules/qlinear/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ def post_init(self, temp_dq):
# if self.bias is not None:
# self.bias.resize_(self.out_features)

# ext_make_q_matrix only accepts float16
self.scales = self.scales.to(dtype=torch.float16)

self.q_tensors = {
"qweight": self.qweight,
"qzeros": self.qzeros,
Expand All @@ -231,11 +234,11 @@ def forward(self, x: torch.Tensor, force_cuda=False):

x_dtype = x.dtype
if x_dtype != torch.float16:
log.warn.once(
f"Exllama v2 kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
)
# log.warn.once(
# f"Exllama v2 kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
# )

x = x.half()
x = x.to(dtype=torch.float16)

# TODO: need to run checks to make sure there is no performance regression padding with F.pad
# if in_features is padded, we need to pad the input as well
Expand Down
3 changes: 0 additions & 3 deletions gptqmodel/nn_modules/qlinear/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,6 @@ def forward(self, x: torch.Tensor):
if x.shape[0] == 0:
return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)

if x.dtype != torch.float16:
x = x.to(torch.float16)

out = apply_gptq_marlin_linear(
input=x.contiguous() if self.is_lm_head else x,
weight=self.qweight,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_q4_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TestQ4Marlin(ModelTest):
)
def test_generation(self, model_id):
try:
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN)
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN, torch_dtype=torch.bfloat16)
except ValueError as e:
raise e

Expand All @@ -71,13 +71,13 @@ def test_generation(self, model_id):

tokenizer = AutoTokenizer.from_pretrained(model_id)

self.assertInference(model=model_q, tokenizer=tokenizer)
self.assertInference(model=model_q, tokenizer=tokenizer, keywords=["french king", "paris", "named after the river", "roman emperor"])

def test_bias(self):
# TheBloke/Llama-2-7B-Chat-GPTQ has bias, but they are all zeros, use a checkpoint which really uses bias.
model_id = "/monster/data/model/starcoderbase-1b-GPTQ"
try:
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN)
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN, torch_dtype=torch.bfloat16)
except ValueError as e:
raise e

Expand All @@ -93,4 +93,4 @@ def test_bias(self):
model_id = "/monster/data/model/starcoderbase-1b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

self.assertInference(model=model_q, tokenizer=tokenizer)
self.assertInference(model=model_q, tokenizer=tokenizer, keywords=["french king", "paris"])