diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 4af4f9f8e..4849e76dc 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -137,7 +137,7 @@ def __init__(self, "scales", t.zeros( (math.ceil(in_features / self.group_size), out_features), - dtype=t.float16, # Scales are always float16 + dtype=t.float16, ), ) self.register_buffer( diff --git a/gptqmodel/nn_modules/qlinear/exllama.py b/gptqmodel/nn_modules/qlinear/exllama.py index 559c5ac40..fdc9a1123 100644 --- a/gptqmodel/nn_modules/qlinear/exllama.py +++ b/gptqmodel/nn_modules/qlinear/exllama.py @@ -25,6 +25,7 @@ from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import BaseQuantLinear from ...utils.backend import BACKEND +from ...utils.logger import setup_logger exllama_import_exception = None try: @@ -32,8 +33,7 @@ except ImportError as e: exllama_import_exception = e -logger = getLogger(__name__) - +log = setup_logger() # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension NON_TENSOR = torch.empty((1, 1), device="meta") @@ -43,9 +43,6 @@ def ext_make_q4(qweight, qzeros, scales, g_idx, device): """Construct Q4Matrix, return handle""" return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else NON_TENSOR, device) - - - class ExllamaQuantLinear(BaseQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -130,6 +127,8 @@ def post_init(self): # if self.bias is not None: # self.bias.resize_(self.out_features) + # ext_make_q4 only accept float16 scales + self.scales = self.scales.to(dtype=torch.float16) self.width = self.qweight.shape[1] @@ -166,7 +165,6 @@ def ext_q4_matmul(self, x, q4, q4_width): return output.view(outshape) - def forward(self, x: torch.Tensor): # TODO FIXME: parent should never call us if there is no data to process # check: https://github.com/ModelCloud/GPTQModel/issues/1361 @@ -175,11 +173,11 @@ def forward(self, x: torch.Tensor): x_dtype = x.dtype if x_dtype != torch.float16: - logger.warn.once( - f"Exllama kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model." - ) + #log.warn.once( + # f"Exllama kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model." + #) - x = x.half() + x = x.to(dtype=torch.float16) # TODO: need to run checks to make sure there is no performance regression padding with F.pad # if in_features is padded, we need to pad the input as well @@ -188,4 +186,10 @@ def forward(self, x: torch.Tensor): out = self.ext_q4_matmul(x, self.q4, self.width) + if self.bias is not None: + out.add_(self.bias) + + if self.adapter: + out = self.adapter.apply(x=x, out=out) + return out.to(x_dtype) diff --git a/gptqmodel/nn_modules/qlinear/exllama_eora.py b/gptqmodel/nn_modules/qlinear/exllama_eora.py index 98026faa4..1b24efa40 100644 --- a/gptqmodel/nn_modules/qlinear/exllama_eora.py +++ b/gptqmodel/nn_modules/qlinear/exllama_eora.py @@ -155,12 +155,13 @@ def post_init(self): def forward(self, x): x_dtype = x.dtype - if x_dtype != torch.float16: - log.warn.once( - f"Exllama EoRA kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model." - ) - - x = x.to(dtype=torch.float16) + # if x_dtype != torch.float16: + # # log.warn.once( + # # f"Exllama EoRA kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model." + # # ) + # + # # TODO FIXME...Exllam EoRA kernel must run in fp16 or else output (bfloat16) is junk + # x = x.to(dtype=torch.float16) # sync with vllm # log.info(f"x shape: {x.shape}") @@ -181,23 +182,18 @@ def forward(self, x): # if x.size(-1) != self.in_features: # x = F.pad(x, self.in_features_padding_shape) - if self.adapter: - # only 4 bits fused eora kernel has been validated - if self.bits == 4: - output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused - else: - output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal - else: - output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits) - + out = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits) if self.bias is not None: - output.add_(self.bias) + out.add_(self.bias) + + if self.adapter: + out = self.adapter.apply(x=x, out=out) # log.info(f"output: {output.shape}") # sync with vllm - output = output.reshape(out_shape) + out = out.reshape(out_shape) # log.info(f"output reshaped: {output.shape}") - return output.to(dtype=x_dtype) + return out.to(dtype=x_dtype) diff --git a/gptqmodel/nn_modules/qlinear/exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2.py index 8297452d3..a621efbce 100644 --- a/gptqmodel/nn_modules/qlinear/exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2.py @@ -206,6 +206,9 @@ def post_init(self, temp_dq): # if self.bias is not None: # self.bias.resize_(self.out_features) + # ext_make_q_matrix only accepts float16 + self.scales = self.scales.to(dtype=torch.float16) + self.q_tensors = { "qweight": self.qweight, "qzeros": self.qzeros, @@ -231,11 +234,11 @@ def forward(self, x: torch.Tensor, force_cuda=False): x_dtype = x.dtype if x_dtype != torch.float16: - log.warn.once( - f"Exllama v2 kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model." - ) + # log.warn.once( + # f"Exllama v2 kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model." + # ) - x = x.half() + x = x.to(dtype=torch.float16) # TODO: need to run checks to make sure there is no performance regression padding with F.pad # if in_features is padded, we need to pad the input as well diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index ed17e675c..8c4f74979 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -416,9 +416,6 @@ def forward(self, x: torch.Tensor): if x.shape[0] == 0: return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device) - if x.dtype != torch.float16: - x = x.to(torch.float16) - out = apply_gptq_marlin_linear( input=x.contiguous() if self.is_lm_head else x, weight=self.qweight, diff --git a/tests/test_q4_marlin.py b/tests/test_q4_marlin.py index 044f1dfa4..29e55053e 100644 --- a/tests/test_q4_marlin.py +++ b/tests/test_q4_marlin.py @@ -57,7 +57,7 @@ class TestQ4Marlin(ModelTest): ) def test_generation(self, model_id): try: - model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN) + model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN, torch_dtype=torch.bfloat16) except ValueError as e: raise e @@ -71,13 +71,13 @@ def test_generation(self, model_id): tokenizer = AutoTokenizer.from_pretrained(model_id) - self.assertInference(model=model_q, tokenizer=tokenizer) + self.assertInference(model=model_q, tokenizer=tokenizer, keywords=["french king", "paris", "named after the river", "roman emperor"]) def test_bias(self): # TheBloke/Llama-2-7B-Chat-GPTQ has bias, but they are all zeros, use a checkpoint which really uses bias. model_id = "/monster/data/model/starcoderbase-1b-GPTQ" try: - model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN) + model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN, torch_dtype=torch.bfloat16) except ValueError as e: raise e @@ -93,4 +93,4 @@ def test_bias(self): model_id = "/monster/data/model/starcoderbase-1b" tokenizer = AutoTokenizer.from_pretrained(model_id) - self.assertInference(model=model_q, tokenizer=tokenizer) + self.assertInference(model=model_q, tokenizer=tokenizer, keywords=["french king", "paris"])