diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 0624624a77..cf297b44e0 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -195,6 +195,9 @@ Qwen2VLLanguageModelPatcher, Qwen2VLVisionEmbMergerPatcher, Qwen3MoeModelPatcher, + Qwen3_5MoeModelPatcher, + Qwen3_5MoeTextModelPatcher, + Qwen3_5TextModelPatcher, Qwen3NextModelPatcher, Qwen3VLLanguageModelPatcher, Qwen3VLVisionEmbMergerPatcher, @@ -5451,3 +5454,41 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): ) return dummy_inputs + + +# Qwen3.5 configs — inherit Qwen3Next's hybrid cache/input handling since the +# architecture is structurally identical (GatedDeltaNet + full attention layers). +@register_in_tasks_manager( + "qwen3_5", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class Qwen3_5OpenVINOConfig(Qwen3NextOpenVINOConfig): + _MODEL_PATCHER = Qwen3_5TextModelPatcher + + +@register_in_tasks_manager( + "qwen3_5_text", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class Qwen3_5TextOpenVINOConfig(Qwen3NextOpenVINOConfig): + _MODEL_PATCHER = Qwen3_5TextModelPatcher + + +@register_in_tasks_manager( + "qwen3_5_moe", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class Qwen3_5MoEOpenVINOConfig(Qwen3NextOpenVINOConfig): + _MODEL_PATCHER = Qwen3_5MoeTextModelPatcher + + +@register_in_tasks_manager( + "qwen3_5_moe_text", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class Qwen3_5MoETextOpenVINOConfig(Qwen3NextOpenVINOConfig): + _MODEL_PATCHER = Qwen3_5MoeTextModelPatcher diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 32dd2d6c6d..b281f42589 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -8319,3 +8319,306 @@ def __exit__(self, exc_type, exc_value, traceback): sparse_moe_block = decoder_layer.mlp decoder_layer.mlp.forward = decoder_layer.mlp._orig_forward del sparse_moe_block.down_projs, sparse_moe_block.gate_projs, sparse_moe_block.up_projs + +# Qwen3.5 MoE forward patch — replaces dynamic expert loop with static iteration +# The original uses torch.no_grad() + nonzero() + dynamic for loop which breaks OV tracing +def qwen3_5_moe_experts_forward_patched(self, hidden_states, top_k_index, top_k_weights): + import torch.nn as nn + final_hidden_states = torch.zeros_like(hidden_states) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + for expert_idx in range(self.num_experts): + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + if token_idx.shape[0] == 0: + continue + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states + + +# Patched GatedDeltaNet forward for Qwen3.5 — removes dynamic control flow for OV tracing. +# Key differences from Qwen3Next version: +# - Uses separate projections: in_proj_qkv, in_proj_z, in_proj_b, in_proj_a +# - No fix_query_key_value_ordering step +# - Cache uses self.layer_idx directly (not remapped through linear_attn_mapping) +def qwen3_5_gated_delta_net_forward( + self, + hidden_states: torch.Tensor, + cache_params=None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, +): + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + batch_size, seq_len, _ = hidden_states.shape + + layer_idx = self.layer_idx + recurrent_state = None + if cache_params is not None: + conv_state = cache_params.conv_states[layer_idx] + recurrent_state = cache_params.recurrent_states[layer_idx] + + # Qwen3.5 uses separate projections (not combined like Qwen3Next) + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + # Convolution — use ov_causal_conv1d for both prefill and decode (no branching) + if cache_params is not None: + new_mixed_qkv, new_conv_state = ov_causal_conv1d(conv_state, mixed_qkv, self.conv1d.weight, self.conv1d.bias) + mixed_qkv = F.silu(new_mixed_qkv) + cache_params.conv_states[layer_idx] = new_conv_state + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + # Split QKV: the flat dim is [key_dim, key_dim, value_dim] contiguous blocks. + # Reshape to 4D BEFORE split to avoid 3D bf16 ops that the GPU plugin can't handle. + # Insert a dummy head dim of 1: [batch, seq, 1, total] → split on dim 3 → reshape per-tensor. + mixed_qkv = mixed_qkv.unsqueeze(2) # [batch, seq, 1, 6144] — now 4D + query = mixed_qkv[:, :, :, :self.key_dim] # [batch, seq, 1, key_dim] + key = mixed_qkv[:, :, :, self.key_dim:2 * self.key_dim] # [batch, seq, 1, key_dim] + value = mixed_qkv[:, :, :, 2 * self.key_dim:] # [batch, seq, 1, value_dim] + + query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + # Always use recurrent path — it handles both prefill and decode correctly + # and avoids the chunk_gated_delta_rule which has tracing issues + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + self, + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + if cache_params is not None: + cache_params.recurrent_states[layer_idx] = last_recurrent_state + + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + return output + + +class Qwen3_5TextModelPatcher(OVDecoderModelPatcher): + """ + Patcher for Qwen3.5 hybrid (GatedDeltaNet + Attention) text models. + Handles cache decomposition, GatedDeltaNet forward patching, and + RecurrentAttentionCell OV conversion — following the same pattern as + Qwen3NextModelPatcher. + """ + + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Optional[Dict[str, Any]] = None, + ): + from openvino.frontend.pytorch import ConversionExtension, ModuleExtension + + super().__init__(config, model, model_kwargs) + + # Detect cache class from the model's module + model_module = type(model).__module__ + if "qwen3_5_moe" in model_module: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeDynamicCache as CacheClass + else: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache as CacheClass + + # Build the cache wrapper class that decomposes flat tensor list + # into the structured cache the model expects. + # Unlike Qwen3Next which uses compact linear_attn_mapping/full_attn_mapping, + # Qwen3.5 stores conv_states/recurrent_states for ALL layers (None for attention layers). + class Qwen3_5DynamicCacheWrap(CacheClass): + def __init__(self, config, conv_states, recurrent_states, key_cache, value_cache): + super().__init__(config=config) + self.conv_states = conv_states + self.recurrent_states = recurrent_states + self.key_cache = key_cache + self.value_cache = value_cache + + def update(self, key_states, value_states, layer_idx, cache_kwargs=None): + # Cast cached states to match incoming dtype to prevent SDPA dtype mismatch + if self.key_cache[layer_idx] is not None: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.dtype) + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.dtype) + return super().update(key_states, value_states, layer_idx, cache_kwargs) + + def patched_forward( + input_ids, + attention_mask=None, + cache_params=None, + ): + model_config = self.real_config._config + layer_types = model_config.layer_types + num_layers = len(layer_types) + num_linear = sum(1 for t in layer_types if t == "linear_attention") + num_full = sum(1 for t in layer_types if t == "full_attention") + + use_cache = False + wrapped_cache_params = None + if cache_params is not None: + use_cache = True + + # Flat cache layout: [conv_0, recur_0, conv_1, recur_1, ..., key_0, val_0, key_1, val_1, ...] + # First 2*num_linear entries are for linear attention layers + # Next 2*num_full entries are for full attention layers + linear_conv_states = [] + linear_recurrent_states = [] + for idx in range(num_linear): + linear_conv_states.append(cache_params[2 * idx]) + linear_recurrent_states.append(cache_params[2 * idx + 1]) + + full_keys = [] + full_values = [] + offset = 2 * num_linear + for idx in range(num_full): + full_keys.append(cache_params[offset + 2 * idx]) + full_values.append(cache_params[offset + 2 * idx + 1]) + + # Expand to per-layer lists (None for non-matching layers) + conv_states = [None] * num_layers + recurrent_states = [None] * num_layers + key_cache = [None] * num_layers + value_cache = [None] * num_layers + + linear_idx = 0 + full_idx = 0 + for i, lt in enumerate(layer_types): + if lt == "linear_attention": + conv_states[i] = linear_conv_states[linear_idx] + recurrent_states[i] = linear_recurrent_states[linear_idx] + linear_idx += 1 + else: + key_cache[i] = full_keys[full_idx] + value_cache[i] = full_values[full_idx] + full_idx += 1 + + wrapped_cache_params = Qwen3_5DynamicCacheWrap( + model_config, conv_states, recurrent_states, key_cache, value_cache + ) + + causal_lm_output = self.model_orig_forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=wrapped_cache_params, + use_cache=use_cache, + ) + outputs = { + "logits": causal_lm_output.logits, + } + + if use_cache: + past_kv = causal_lm_output.past_key_values + # Re-flatten to same layout: linear states first, then attention KV + present_key_values = [] + for i, lt in enumerate(layer_types): + if lt == "linear_attention": + present_key_values.append(past_kv.conv_states[i]) + present_key_values.append(past_kv.recurrent_states[i]) + + for i, lt in enumerate(layer_types): + if lt != "linear_attention": + present_key_values.append(past_kv.key_cache[i]) + present_key_values.append(past_kv.value_cache[i]) + + outputs["present_key_values"] = present_key_values + + return outputs + + self.patched_forward = patched_forward + self.model_orig_forward = self.orig_forward + self.orig_forward = patched_forward + + # Reuse RecurrentAttentionCell and conversion extension from Qwen3Next + self.module_extensions = { + RecurrentAttentionCell: ModuleExtension(RecurrentAttentionCell, "RecurrentAttentionCellOp"), + } + self.conversion_extensions = [ + ConversionExtension("RecurrentAttentionCellOp", convert_recurrent_attention_cell), + ] + + def __enter__(self): + super().__enter__() + setattr(self._model, self.orig_forward_name, self.patched_forward) + + for idx, decoder_layer in enumerate(self._model.model.layers): + layer_type = self._model.model.config.layer_types[idx] + if layer_type == "linear_attention": + linear_attn_layer = decoder_layer.linear_attn + linear_attn_layer._orig_forward = linear_attn_layer.forward + linear_attn_layer.forward = types.MethodType(qwen3_5_gated_delta_net_forward, linear_attn_layer) + linear_attn_layer.recurrent_gated_delta_rule = patched_recurrent_gated_delta_rule + linear_attn_layer.recurrent_attention_cell = RecurrentAttentionCell() + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + setattr(self._model, self.orig_forward_name, self.model_orig_forward) + for idx, decoder_layer in enumerate(self._model.model.layers): + layer_type = self._model.model.config.layer_types[idx] + if layer_type == "linear_attention": + linear_attn_layer = decoder_layer.linear_attn + linear_attn_layer.forward = linear_attn_layer._orig_forward + + + +class Qwen3_5MoeTextModelPatcher(Qwen3_5TextModelPatcher): + """ + Patcher for Qwen3.5 MoE hybrid models (35B-A3B). + Extends Qwen3_5TextModelPatcher with MoE expert forward patching. + """ + + def __enter__(self): + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts + + super().__enter__() + self._original_experts_forward = Qwen3_5MoeExperts.forward + Qwen3_5MoeExperts.forward = qwen3_5_moe_experts_forward_patched + + def __exit__(self, exc_type, exc_value, traceback): + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts + + super().__exit__(exc_type, exc_value, traceback) + Qwen3_5MoeExperts.forward = self._original_experts_forward + + +# Legacy patcher kept for backward compatibility — only patches MoE experts, no GatedDeltaNet +class Qwen3_5MoeModelPatcher(OVDecoderModelPatcher): + def __enter__(self): + super().__enter__() + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts + self.original_experts_forward = Qwen3_5MoeExperts.forward + Qwen3_5MoeExperts.forward = qwen3_5_moe_experts_forward_patched + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts + Qwen3_5MoeExperts.forward = self.original_experts_forward + diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 3b8642d65a..f730d00913 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -24,6 +24,48 @@ from .utils import MULTI_MODAL_TEXT_GENERATION_MODELS, SSM_MODELS +def _convert_bf16_to_f16(ov_model: ov.Model): + """ + Convert all bf16 tensors in the model to f16 by: + 1. Replacing bf16 Constants with f16 equivalents + 2. Replacing bf16 Convert/ConvertLike ops with f16 targets + Required for GPU devices that support f16 but not bf16. + """ + from openvino import Type as OVType + changed = False + + # Pass 1: Convert bf16 Constants to f16 + for op in list(ov_model.get_ops()): + if op.get_type_name() == "Constant" and op.get_output_element_type(0) == OVType.bf16: + data_f16 = op.get_data().astype(np.float32).astype(np.float16) + new_const = opset13.constant(data_f16, dtype=OVType.f16) + new_const.set_friendly_name(op.get_friendly_name()) + for target in list(op.output(0).get_target_inputs()): + target.replace_source_output(new_const.output(0)) + changed = True + + # Pass 2: Replace Convert ops that target bf16 with f16 + for op in list(ov_model.get_ops()): + if op.get_type_name() == "Convert" and op.get_output_element_type(0) == OVType.bf16: + convert_f16 = opset13.convert(op.input(0).get_source_output(), OVType.f16) + convert_f16.set_friendly_name(op.get_friendly_name()) + for target in list(op.output(0).get_target_inputs()): + target.replace_source_output(convert_f16.output(0)) + changed = True + + # Pass 3: Replace ConvertLike ops that produce bf16 + for op in list(ov_model.get_ops()): + if op.get_type_name() == "ConvertLike" and op.get_output_element_type(0) == OVType.bf16: + convert_f16 = opset13.convert(op.input(0).get_source_output(), OVType.f16) + convert_f16.set_friendly_name(op.get_friendly_name()) + for target in list(op.output(0).get_target_inputs()): + target.replace_source_output(convert_f16.output(0)) + changed = True + + if changed: + ov_model.validate_nodes_and_infer_types() + + def model_has_state(ov_model: ov.Model): if isinstance(ov_model, ov.CompiledModel): return len(ov_model.query_state()) > 0 @@ -301,9 +343,43 @@ def get_kv_ssm_tensor_names(ssm_prefix_names: list, kv_prefix_names: list, ov_te # hybrid models can contain transformer blocks as well # so KV tensors must be handled properly batch_dim = 0 + + # Normalize cache state dtypes to f32 for CPU compatibility. + # bf16 state variables can cause issues with some CPU plugin versions. + # Also reconciles dummy input (f32) vs model output (bf16/f16) mismatches. + from openvino.preprocess import PrePostProcessor + from openvino import Type as OVType + ppp = PrePostProcessor(ov_model) + needs_ppp = False + for inp_name, out_name in zip(cache_inputs, cache_outputs): + inp_type = ov_model.input(inp_name).get_element_type() + out_type = ov_model.output(out_name).get_element_type() + if inp_type != OVType.f32: + ppp.input(inp_name).tensor().set_element_type(OVType.f32) + needs_ppp = True + if out_type != OVType.f32: + ppp.output(out_name).tensor().set_element_type(OVType.f32) + needs_ppp = True + if needs_ppp: + ov_model = ppp.build() + fuse_cache_reorder(ov_model, not_cache_inputs, cache_inputs, batch_dim) make_stateful(ov_model, not_cache_inputs, cache_inputs, cache_outputs, batch_dim) + # Note: bf16 models require GPU devices with bf16 support. + # For GPUs without bf16 (e.g. Intel Arc Xe-LPG), the model runs on CPU only. + + # Ensure logits output is f32 for runtime compatibility (e.g. openvino_genai) + logits_output = None + for out in ov_model.outputs: + if "logits" in out.get_any_name(): + logits_output = out + break + if logits_output is not None and logits_output.get_element_type() != OVType.f32: + ppp = PrePostProcessor(ov_model) + ppp.output(logits_output.get_any_name()).tensor().set_element_type(OVType.f32) + ov_model = ppp.build() + def patch_stateful(config: PretrainedConfig, ov_model: ov.Model): if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"): diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index af2f1edaba..f763e26203 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -305,7 +305,10 @@ def get_submodels(model): "minicpmo", ] -SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid", "qwen3_next"] +SSM_MODELS = [ + "mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid", "qwen3_next", + "qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text", +] # All transformers, diffusers, timm and sentence transformers models that are supported via optimum-onnx OnnxConfigs but that have currently no test # TODO: add tests for all models that are compatible and remove support for all others diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index a97416cea1..648816cd76 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -25,7 +25,10 @@ from transformers import GenerationConfig, PretrainedConfig from transformers.file_utils import add_start_docstrings from transformers.generation import GenerationMixin -from transformers.utils import is_offline_mode +try: + from transformers.utils import is_offline_mode +except ImportError: + from huggingface_hub.constants import is_offline_mode from transformers.utils.hub import cached_file from optimum.exporters.base import ExportConfig @@ -264,7 +267,7 @@ def __init__( # some model configs may have issues with loading without parameters initialization try: misplaced_generation_parameters = self.config._get_non_default_generation_parameters() - except (KeyError, TypeError): + except (KeyError, TypeError, AttributeError): misplaced_generation_parameters = {} if len(misplaced_generation_parameters) > 0: logger.warning( diff --git a/optimum/intel/openvino/modeling_open_clip.py b/optimum/intel/openvino/modeling_open_clip.py index 2e2ee2d63c..d350609a40 100644 --- a/optimum/intel/openvino/modeling_open_clip.py +++ b/optimum/intel/openvino/modeling_open_clip.py @@ -31,7 +31,10 @@ from transformers.file_utils import add_start_docstrings from transformers.modeling_outputs import ModelOutput from transformers.models.clip.modeling_clip import CLIPOutput -from transformers.utils import is_offline_mode +try: + from transformers.utils import is_offline_mode +except ImportError: + from huggingface_hub.constants import is_offline_mode from optimum.exporters.tasks import TasksManager diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index e6e99ffd56..60fdaea2a1 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -27,12 +27,15 @@ AutoConfig, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, - AutoModelForVision2Seq, GenerationConfig, Pix2StructForConditionalGeneration, PretrainedConfig, WhisperForConditionalGeneration, ) +try: + from transformers import AutoModelForVision2Seq +except ImportError: + from transformers import AutoModelForImageTextToText as AutoModelForVision2Seq from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput diff --git a/optimum/intel/openvino/utils.py b/optimum/intel/openvino/utils.py index 818eb41726..eda23fcfa6 100644 --- a/optimum/intel/openvino/utils.py +++ b/optimum/intel/openvino/utils.py @@ -32,7 +32,17 @@ from openvino import Type as OVType from packaging.version import Version from transformers import AutoTokenizer, CLIPTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers.onnx.utils import ParameterFormat, compute_serialized_parameters_size +try: + from transformers.onnx.utils import ParameterFormat, compute_serialized_parameters_size +except ModuleNotFoundError: + # transformers >= 5.x removed transformers.onnx; inline the trivial logic + from enum import Enum + + class ParameterFormat(Enum): + Float = 4 # bytes per float32 + + def compute_serialized_parameters_size(num_parameters: int, fmt: ParameterFormat) -> int: + return num_parameters * fmt.value from optimum.intel.utils.import_utils import is_torch_version diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index cab9e5efa3..bbcd44bd1c 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -23,7 +23,16 @@ from typing import Dict, List, Optional, Type, Union import torch -from huggingface_hub import HfApi, HfFolder, hf_hub_download +from huggingface_hub import HfApi, hf_hub_download +try: + from huggingface_hub import HfFolder +except ImportError: + # HfFolder removed in newer huggingface_hub; provide shim + class HfFolder: + @staticmethod + def get_token(): + from huggingface_hub import get_token + return get_token() from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.hf_api import file_exists from transformers import CLIPConfig, PretrainedConfig, PreTrainedModel