Skip to content

Commit 66b9fec

Browse files
xin3hexinhe3
authored andcommitted
Add check_neural_compressor_min_version for 4 bit behavior (#1500)
Signed-off-by: Xin <[email protected]> Signed-off-by: xinhe3 <[email protected]> Co-authored-by: xinhe3 <[email protected]>
1 parent ac8bc72 commit 66b9fec

2 files changed

Lines changed: 12 additions & 3 deletions

File tree

examples/text-generation/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from optimum.habana.utils import (
4040
check_habana_frameworks_version,
41+
check_neural_compressor_min_version,
4142
check_optimum_habana_min_version,
4243
get_habana_frameworks_version,
4344
set_seed,
@@ -267,9 +268,8 @@ def setup_model(args, model_dtype, model_kwargs, logger):
267268
original_model=org_model,
268269
**model_kwargs,
269270
)
270-
# TODO: This will be removed in v1.19 Synapse release
271-
# the loaded model should have the same dtype as original_model
272-
model = model.to(model_kwargs["torch_dtype"])
271+
if not check_neural_compressor_min_version("3.2"):
272+
model = model.to(model_kwargs["torch_dtype"])
273273
else:
274274
if args.assistant_model is not None:
275275
assistant_model = AutoModelForCausalLM.from_pretrained(

optimum/habana/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,15 @@ def check_habana_frameworks_version(req_version):
384384
)
385385

386386

387+
def check_neural_compressor_min_version(req_version):
388+
"""
389+
Checks if the installed version of `neural_compressor` is larger than or equal to `req_version`.
390+
"""
391+
import neural_compressor
392+
393+
return version.Version(neural_compressor.__version__) >= version.Version(req_version)
394+
395+
387396
def get_device_name():
388397
"""
389398
Returns the name of the current device: Gaudi or Gaudi2.

0 commit comments

Comments
 (0)