Skip to content

Commit 9e94801

Browse files
SunMarcmatthewdouglasgante
authored
enable/disable compile for quants methods (huggingface#36519)
* disable compile for most quants methods * fix * Update src/transformers/generation/configuration_utils.py Co-authored-by: Matthew Douglas <[email protected]> * Update tests/quantization/bnb/test_mixed_int8.py Co-authored-by: Matthew Douglas <[email protected]> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <[email protected]> * changes from joao suggestions --------- Co-authored-by: Matthew Douglas <[email protected]> Co-authored-by: Joao Gante <[email protected]>
1 parent c53d53d commit 9e94801

File tree

6 files changed

+80
-4
lines changed

6 files changed

+80
-4
lines changed

src/transformers/generation/configuration_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,7 @@ class GenerationConfig(PushToHubMixin):
379379
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
380380
gains.
381381
382-
disable_compile (`bool`, *optional*): Whether to disable the compilation of the forward pass when using 'statis' cache
383-
implementation.
382+
disable_compile (`bool`, *optional*): Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when specific criteria are met, including using a compileable cache. Please open an issue if you find the need to use this flag.
384383
385384
> Wild card
386385

src/transformers/generation/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,7 +1613,6 @@ def _prepare_generation_config(
16131613
model_kwargs = generation_config.update(**kwargs)
16141614
else:
16151615
model_kwargs = kwargs
1616-
16171616
return generation_config, model_kwargs
16181617

16191618
def _get_initial_cache_position(self, input_ids, model_kwargs):
@@ -3281,7 +3280,9 @@ def _sample(
32813280
model_forward = self.__call__
32823281
if isinstance(model_kwargs.get("past_key_values"), Cache):
32833282
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
3284-
is_compileable = is_compileable and not self.generation_config.disable_compile
3283+
if getattr(self, "hf_quantizer", None) is not None:
3284+
is_compileable &= self.hf_quantizer.is_compileable
3285+
is_compileable = is_compileable and not generation_config.disable_compile
32853286
if is_compileable and (
32863287
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
32873288
):

src/transformers/quantizers/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,11 @@ def is_qat_trainable(self) -> bool:
271271
"""Flag indicating whether the quantized model can carry out quantization aware training"""
272272
return False
273273

274+
@property
275+
def is_compileable(self) -> bool:
276+
"""Flag indicating whether the quantized model can be compiled"""
277+
return False
278+
274279
@abstractmethod
275280
def _process_model_before_weight_loading(self, model, **kwargs): ...
276281

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,7 @@ def is_trainable(self):
243243
"int8_dynamic_activation_int8_weight",
244244
]
245245
return self.quantization_config.quant_type in supported_quant_types_for_training
246+
247+
@property
248+
def is_compileable(self) -> bool:
249+
return True

tests/quantization/bnb/test_4bit.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,3 +771,36 @@ def test_set_load_in_8_bit(self):
771771
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
772772
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
773773
quantization_config.load_in_8bit = True
774+
775+
776+
@require_bitsandbytes
777+
@require_accelerate
778+
@require_torch_gpu_if_bnb_not_multi_backend_enabled
779+
@slow
780+
@apply_skip_if_not_implemented
781+
class Bnb4bitCompile(unittest.TestCase):
782+
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
783+
input_text = "Hello my name is"
784+
785+
def setUp(self):
786+
# Models and tokenizer
787+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
788+
self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)
789+
790+
def test_generate_compile(self):
791+
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
792+
793+
# if nothing is set, compile will be disabled for bnb
794+
self.model_4bit.generate(
795+
input_ids=encoded_input["input_ids"].to(self.model_4bit.device),
796+
max_new_tokens=10,
797+
cache_implementation="static",
798+
)
799+
with self.assertRaises(Exception):
800+
# overwrite property
801+
object.__setattr__(self.model_4bit.hf_quantizer, "is_compileable", True)
802+
self.model_4bit.generate(
803+
input_ids=encoded_input["input_ids"].to(self.model_4bit.device),
804+
max_new_tokens=10,
805+
cache_implementation="static",
806+
)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,3 +966,37 @@ def test_int8_from_pretrained(self):
966966
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
967967

968968
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
969+
970+
971+
@require_bitsandbytes
972+
@require_accelerate
973+
@require_torch
974+
@require_torch_gpu_if_bnb_not_multi_backend_enabled
975+
@slow
976+
@apply_skip_if_not_implemented
977+
class Bnb8bitCompile(unittest.TestCase):
978+
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
979+
input_text = "Hello my name is"
980+
981+
def setUp(self):
982+
# Models and tokenizer
983+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
984+
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)
985+
986+
def test_generate_compile(self):
987+
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
988+
989+
# if nothing is set, compile will be disabled for bnb
990+
self.model_8bit.generate(
991+
input_ids=encoded_input["input_ids"].to(self.model_8bit.device),
992+
max_new_tokens=10,
993+
cache_implementation="static",
994+
)
995+
996+
with self.assertRaises(Exception):
997+
object.__setattr__(self.model_8bit.hf_quantizer, "is_compileable", True)
998+
self.model_8bit.generate(
999+
input_ids=encoded_input["input_ids"].to(self.model_8bit.device),
1000+
max_new_tokens=10,
1001+
cache_implementation="static",
1002+
)

0 commit comments

Comments
 (0)