Skip to content

Commit be54ada

Browse files
authored
Fix bfloat16 kernels (#1420)
* fix bfloat16 forward for exllama v1 Signed-off-by: Qubitium <[email protected]> * fjx bfloat16 compat for marlin Signed-off-by: Qubitium <[email protected]> * fjx bfloat16 compat for exllama v2 Signed-off-by: Qubitium <[email protected]> * fjx bfloat16 compat for exllama eora Signed-off-by: Qubitium <[email protected]> * test bfloat16 Signed-off-by: Qubitium <[email protected]> * fix ci Signed-off-by: Qubitium <[email protected]> --------- Signed-off-by: Qubitium <[email protected]>
1 parent 9dbeae4 commit be54ada

6 files changed

Lines changed: 40 additions & 40 deletions

File tree

gptqmodel/nn_modules/qlinear/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(self,
137137
"scales",
138138
t.zeros(
139139
(math.ceil(in_features / self.group_size), out_features),
140-
dtype=t.float16, # Scales are always float16
140+
dtype=t.float16,
141141
),
142142
)
143143
self.register_buffer(

gptqmodel/nn_modules/qlinear/exllama.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@
2525
from ...models._const import DEVICE, PLATFORM
2626
from ...nn_modules.qlinear import BaseQuantLinear
2727
from ...utils.backend import BACKEND
28+
from ...utils.logger import setup_logger
2829

2930
exllama_import_exception = None
3031
try:
3132
from gptqmodel_exllama_kernels import make_q4, q4_matmul
3233
except ImportError as e:
3334
exllama_import_exception = e
3435

35-
logger = getLogger(__name__)
36-
36+
log = setup_logger()
3737

3838
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
3939
NON_TENSOR = torch.empty((1, 1), device="meta")
@@ -43,9 +43,6 @@ def ext_make_q4(qweight, qzeros, scales, g_idx, device):
4343
"""Construct Q4Matrix, return handle"""
4444
return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else NON_TENSOR, device)
4545

46-
47-
48-
4946
class ExllamaQuantLinear(BaseQuantLinear):
5047
SUPPORTS_BITS = [4]
5148
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
@@ -130,6 +127,8 @@ def post_init(self):
130127
# if self.bias is not None:
131128
# self.bias.resize_(self.out_features)
132129

130+
# ext_make_q4 only accept float16 scales
131+
self.scales = self.scales.to(dtype=torch.float16)
133132

134133
self.width = self.qweight.shape[1]
135134

@@ -166,7 +165,6 @@ def ext_q4_matmul(self, x, q4, q4_width):
166165

167166
return output.view(outshape)
168167

169-
170168
def forward(self, x: torch.Tensor):
171169
# TODO FIXME: parent should never call us if there is no data to process
172170
# check: https://github.com/ModelCloud/GPTQModel/issues/1361
@@ -175,11 +173,11 @@ def forward(self, x: torch.Tensor):
175173

176174
x_dtype = x.dtype
177175
if x_dtype != torch.float16:
178-
logger.warn.once(
179-
f"Exllama kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
180-
)
176+
#log.warn.once(
177+
# f"Exllama kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
178+
#)
181179

182-
x = x.half()
180+
x = x.to(dtype=torch.float16)
183181

184182
# TODO: need to run checks to make sure there is no performance regression padding with F.pad
185183
# if in_features is padded, we need to pad the input as well
@@ -188,4 +186,10 @@ def forward(self, x: torch.Tensor):
188186

189187
out = self.ext_q4_matmul(x, self.q4, self.width)
190188

189+
if self.bias is not None:
190+
out.add_(self.bias)
191+
192+
if self.adapter:
193+
out = self.adapter.apply(x=x, out=out)
194+
191195
return out.to(x_dtype)

gptqmodel/nn_modules/qlinear/exllama_eora.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,13 @@ def post_init(self):
155155

156156
def forward(self, x):
157157
x_dtype = x.dtype
158-
if x_dtype != torch.float16:
159-
log.warn.once(
160-
f"Exllama EoRA kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
161-
)
162-
163-
x = x.to(dtype=torch.float16)
158+
# if x_dtype != torch.float16:
159+
# # log.warn.once(
160+
# # f"Exllama EoRA kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
161+
# # )
162+
#
163+
# # TODO FIXME...Exllam EoRA kernel must run in fp16 or else output (bfloat16) is junk
164+
# x = x.to(dtype=torch.float16)
164165

165166
# sync with vllm
166167
# log.info(f"x shape: {x.shape}")
@@ -181,23 +182,18 @@ def forward(self, x):
181182
# if x.size(-1) != self.in_features:
182183
# x = F.pad(x, self.in_features_padding_shape)
183184

184-
if self.adapter:
185-
# only 4 bits fused eora kernel has been validated
186-
if self.bits == 4:
187-
output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused
188-
else:
189-
output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal
190-
else:
191-
output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits)
192-
185+
out = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits)
193186

194187
if self.bias is not None:
195-
output.add_(self.bias)
188+
out.add_(self.bias)
189+
190+
if self.adapter:
191+
out = self.adapter.apply(x=x, out=out)
196192

197193
# log.info(f"output: {output.shape}")
198194

199195
# sync with vllm
200-
output = output.reshape(out_shape)
196+
out = out.reshape(out_shape)
201197
# log.info(f"output reshaped: {output.shape}")
202198

203-
return output.to(dtype=x_dtype)
199+
return out.to(dtype=x_dtype)

gptqmodel/nn_modules/qlinear/exllamav2.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ def post_init(self, temp_dq):
206206
# if self.bias is not None:
207207
# self.bias.resize_(self.out_features)
208208

209+
# ext_make_q_matrix only accepts float16
210+
self.scales = self.scales.to(dtype=torch.float16)
211+
209212
self.q_tensors = {
210213
"qweight": self.qweight,
211214
"qzeros": self.qzeros,
@@ -231,11 +234,11 @@ def forward(self, x: torch.Tensor, force_cuda=False):
231234

232235
x_dtype = x.dtype
233236
if x_dtype != torch.float16:
234-
log.warn.once(
235-
f"Exllama v2 kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
236-
)
237+
# log.warn.once(
238+
# f"Exllama v2 kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
239+
# )
237240

238-
x = x.half()
241+
x = x.to(dtype=torch.float16)
239242

240243
# TODO: need to run checks to make sure there is no performance regression padding with F.pad
241244
# if in_features is padded, we need to pad the input as well

gptqmodel/nn_modules/qlinear/marlin.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,6 @@ def forward(self, x: torch.Tensor):
416416
if x.shape[0] == 0:
417417
return torch.empty((0, self.out_features), dtype=x.dtype, device=x.device)
418418

419-
if x.dtype != torch.float16:
420-
x = x.to(torch.float16)
421-
422419
out = apply_gptq_marlin_linear(
423420
input=x.contiguous() if self.is_lm_head else x,
424421
weight=self.qweight,

tests/test_q4_marlin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class TestQ4Marlin(ModelTest):
5757
)
5858
def test_generation(self, model_id):
5959
try:
60-
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN)
60+
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN, torch_dtype=torch.bfloat16)
6161
except ValueError as e:
6262
raise e
6363

@@ -71,13 +71,13 @@ def test_generation(self, model_id):
7171

7272
tokenizer = AutoTokenizer.from_pretrained(model_id)
7373

74-
self.assertInference(model=model_q, tokenizer=tokenizer)
74+
self.assertInference(model=model_q, tokenizer=tokenizer, keywords=["french king", "paris", "named after the river", "roman emperor"])
7575

7676
def test_bias(self):
7777
# TheBloke/Llama-2-7B-Chat-GPTQ has bias, but they are all zeros, use a checkpoint which really uses bias.
7878
model_id = "/monster/data/model/starcoderbase-1b-GPTQ"
7979
try:
80-
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN)
80+
model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN, torch_dtype=torch.bfloat16)
8181
except ValueError as e:
8282
raise e
8383

@@ -93,4 +93,4 @@ def test_bias(self):
9393
model_id = "/monster/data/model/starcoderbase-1b"
9494
tokenizer = AutoTokenizer.from_pretrained(model_id)
9595

96-
self.assertInference(model=model_q, tokenizer=tokenizer)
96+
self.assertInference(model=model_q, tokenizer=tokenizer, keywords=["french king", "paris"])

0 commit comments

Comments
 (0)