Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import optimum.exporters.openvino.model_configs

from .__main__ import main_export
from .convert import export, export_from_model, export_models, export_pytorch_via_onnx
from .convert import export, export_from_model, export_models, export_pytorch_via_onnx, _resolve_flux_text_encoder_model_type
from .stateful import ensure_stateful_is_available, patch_stateful


Expand Down
150 changes: 124 additions & 26 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,21 +532,43 @@ def export_models(
output_name = output_names[i] if output_names is not None else Path(model_name + ".xml")
output_path = output_dir / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
outputs.append(
export(
model=submodel,
config=sub_export_config,
output=output_path,
opset=opset,
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
ov_config=ov_config,
stateful=stateful[i] if isinstance(stateful, (list, tuple)) else stateful,
patch_16bit_model=patch_16bit_model,
library_name=library_name,
try:
outputs.append(
export(
model=submodel,
config=sub_export_config,
output=output_path,
opset=opset,
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
ov_config=ov_config,
stateful=stateful[i] if isinstance(stateful, (list, tuple)) else stateful,
patch_16bit_model=patch_16bit_model,
library_name=library_name,
)
)
except Exception as e:
if "prim::TupleConstruct" not in str(e):
raise

resolved_opset = opset or getattr(sub_export_config, "DEFAULT_ONNX_OPSET", 14)
logger.warning(
f"Falling back to ONNX export for submodel `{model_name}` due to PyTorch frontend limitation: {e}"
)
outputs.append(
export_pytorch_via_onnx(
model=submodel,
config=sub_export_config,
opset=resolved_opset,
output=output_path,
device=device,
input_shapes=input_shapes,
model_kwargs=model_kwargs,
ov_config=ov_config,
library_name=library_name,
)
)
)

outputs = list(map(list, zip(*outputs)))
return outputs
Expand Down Expand Up @@ -1286,28 +1308,79 @@ def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
return models_for_export


def _resolve_flux_text_encoder_model_type(text_encoder, default_model_type: str, tokenizer=None) -> str:
config = getattr(text_encoder, "config", None)
model_type = str(getattr(config, "model_type", "") or "").lower()
architectures = [str(x) for x in (getattr(config, "architectures", []) or [])]
encoder_cls_name = text_encoder.__class__.__name__
tokenizer_cls_name = tokenizer.__class__.__name__ if tokenizer is not None else ""

looks_like_gemma = (
model_type in {"gemma", "gemma2", "gemma3", "gemma3_text"}
or any("Gemma" in arch for arch in architectures)
or "Gemma" in encoder_cls_name
or "Gemma" in tokenizer_cls_name
)
if looks_like_gemma:
return "gemma2-text-encoder"

looks_like_qwen = (
model_type in {"qwen", "qwen2", "qwen3"}
or any("Qwen" in arch for arch in architectures)
or "Qwen" in encoder_cls_name
or "Qwen" in tokenizer_cls_name
)
if looks_like_qwen:
return "qwen3-text-encoder"

return default_model_type


def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
models_for_export = {}

# Text encoder
text_encoder = getattr(pipeline, "text_encoder", None)
if text_encoder is not None:
text_encoder_for_export = text_encoder
if "CausalLM" in text_encoder.__class__.__name__ and hasattr(text_encoder, "model"):
text_encoder_for_export = text_encoder.model

text_encoder_model_type = _resolve_flux_text_encoder_model_type(
text_encoder,
"clip-text",
getattr(pipeline, "tokenizer", None),
)

text_encoder_library_name = "diffusers"
if text_encoder_model_type in {"qwen3", "qwen2", "qwen"}:
text_encoder_library_name = "transformers"

if hasattr(text_encoder_for_export, "config"):
text_encoder_for_export.config.output_hidden_states = True
text_encoder_for_export.config.return_dict = True

text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=text_encoder,
model=text_encoder_for_export,
exporter=exporter,
library_name="diffusers",
library_name=text_encoder_library_name,
task="feature-extraction",
model_type="clip-text",
model_type=text_encoder_model_type,
)
text_encoder_export_config = text_encoder_config_constructor(
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
text_encoder_for_export.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder"] = (text_encoder, text_encoder_export_config)
models_for_export["text_encoder"] = (text_encoder_for_export, text_encoder_export_config)

transformer = pipeline.transformer
transformer.config.text_encoder_projection_dim = transformer.config.joint_attention_dim
transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
transformer.config.time_cond_proj_dim = None

transformer_forward_inputs = inspect.signature(transformer.forward).parameters
if "pooled_projections" in transformer_forward_inputs and not hasattr(transformer.config, "pooled_projection_dim"):
transformer.config.pooled_projection_dim = transformer.config.joint_attention_dim

export_config_constructor = TasksManager.get_exporter_config_constructor(
model=transformer,
exporter=exporter,
Expand All @@ -1321,8 +1394,14 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
transformer_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["transformer"] = (transformer, transformer_export_config)

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
vae_scaling_factor = None
if hasattr(pipeline, "vae") and hasattr(pipeline.vae, "config"):
vae_scaling_factor = getattr(pipeline.vae.config, "scaling_factor", None)

# VAE Encoder
vae_encoder = copy.deepcopy(pipeline.vae)
if vae_scaling_factor is not None:
vae_encoder.register_to_config(scaling_factor=float(vae_scaling_factor))
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_encoder,
Expand All @@ -1337,8 +1416,18 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
vae_encoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
# VAE Decoder
vae_decoder = copy.deepcopy(pipeline.vae)
if vae_scaling_factor is not None:
vae_decoder.register_to_config(scaling_factor=float(vae_scaling_factor))
if hasattr(vae_decoder, "bn") and hasattr(vae_decoder.bn, "running_mean") and hasattr(vae_decoder.bn, "running_var"):
vae_decoder.register_to_config(
**{
"bn_running_mean_data": vae_decoder.bn.running_mean.detach().cpu().tolist(),
"bn_running_var_data": vae_decoder.bn.running_var.detach().cpu().tolist(),
"bn_eps": float(getattr(vae_decoder.bn, "eps", 1e-5)),
}
)
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_decoder,
Expand All @@ -1355,24 +1444,33 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):

text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2_for_export = text_encoder_2
if "CausalLM" in text_encoder_2.__class__.__name__ and hasattr(text_encoder_2, "model"):
text_encoder_2_for_export = text_encoder_2.model

text_encoder_2_model_type = _resolve_flux_text_encoder_model_type(
text_encoder_2,
"t5-encoder-model",
getattr(pipeline, "tokenizer_2", None),
)

export_config_constructor = TasksManager.get_exporter_config_constructor(
model=text_encoder_2,
model=text_encoder_2_for_export,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model_type="t5-encoder-model",
model_type=text_encoder_2_model_type,
)
export_config = export_config_constructor(
text_encoder_2.config,
text_encoder_2_for_export.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
)
export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)
models_for_export["text_encoder_2"] = (text_encoder_2_for_export, export_config)

return models_for_export


def _get_encoder_decoder_stateful_models_for_export(
model: "PreTrainedModel",
task: str,
Expand Down
108 changes: 103 additions & 5 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,34 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
return common_inputs

@register_in_tasks_manager("qwen3-text-encoder", *["feature-extraction"], library_name="diffusers")
class Qwen3TextEncoderOpenVINOConfig(Qwen3OpenVINOConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = {"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}

num_layers = getattr(self._normalized_config, "num_layers", None)
if num_layers is None:
num_layers = getattr(self._normalized_config, "num_hidden_layers", 0)

for i in range(int(num_layers) + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return common_outputs

@property
def values_override(self) -> Optional[Dict[str, Any]]:
values = super().values_override or {}
values.update({"output_hidden_states": True, "return_dict": True, "use_cache": False})
return values


class DummyQwen3VLLMInputGenerator(DummyTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
Expand Down Expand Up @@ -2533,6 +2561,27 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


def _get_flux_ids_dim(config) -> int:
for attr_name in ("axes_dims_rope", "axes_dim", "axes_dims"):
value = getattr(config, attr_name, None)
if value is not None:
if isinstance(value, (list, tuple)):
return len(value)
if isinstance(value, int):
return value

if hasattr(config, "get"):
for key in ("axes_dims_rope", "axes_dim", "axes_dims"):
value = config.get(key, None)
if value is not None:
if isinstance(value, (list, tuple)):
return len(value)
if isinstance(value, int):
return value

return 3


class DummyFluxTransformerInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = (
"pixel_values",
Expand All @@ -2551,12 +2600,12 @@ def __init__(
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"] // 4,
height: int = DEFAULT_DUMMY_SHAPES["height"] // 4,
# Reduce img shape by 4 for FLUX to reduce memory usage on conversion
**kwargs,
):
super().__init__(task, normalized_config, batch_size, num_channels, width, height, **kwargs)
if getattr(normalized_config, "in_channels", None):
self.num_channels = normalized_config.in_channels // 4
self.ids_dim = _get_flux_ids_dim(normalized_config.config)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name in ["hidden_states", "sample"]:
Expand All @@ -2567,9 +2616,9 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
img_ids_width = self.width // 2
return self.random_int_tensor(
(
[self.batch_size, img_ids_height * img_ids_width, 3]
[self.batch_size, img_ids_height * img_ids_width, self.ids_dim]
if is_diffusers_version("<", "0.31.0")
else [img_ids_height * img_ids_width, 3]
else [img_ids_height * img_ids_width, self.ids_dim]
),
min_value=0,
max_value=min(img_ids_height, img_ids_width),
Expand All @@ -2589,14 +2638,35 @@ class DummyFluxTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
"txt_ids",
)

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
**kwargs,
)
self.ids_dim = _get_flux_ids_dim(normalized_config.config)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "txt_ids":
import torch

shape = (
[self.batch_size, self.sequence_length, 3]
[self.batch_size, self.sequence_length, self.ids_dim]
if is_diffusers_version("<", "0.31.0")
else [self.sequence_length, 3]
else [self.sequence_length, self.ids_dim]
)
dtype = DTYPE_MAPPER.pt(float_dtype)
return torch.full(shape, 0, dtype=dtype)
Expand All @@ -2614,10 +2684,38 @@ class FluxTransformerOpenVINOConfig(SD3TransformerOpenVINOConfig):
)
_MODEL_PATCHER = FluxTransfromerModelPatcher

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

config = self._normalized_config.config
pooled_projection_dim = getattr(config, "pooled_projection_dim", None)
if pooled_projection_dim is None and hasattr(config, "get"):
pooled_projection_dim = config.get("pooled_projection_dim", None)

self._use_pooled_projections = pooled_projection_dim is not None

if self._use_pooled_projections:
self.DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestpsInputGenerator,
DummyFluxTransformerInputGenerator,
DummyFluxTextInputGenerator,
PooledProjectionsDummyInputGenerator,
)
else:
self.DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestpsInputGenerator,
DummyFluxTransformerInputGenerator,
DummyFluxTextInputGenerator,
)

@property
def inputs(self):
common_inputs = super().inputs
common_inputs.pop("sample", None)

if not getattr(self, "_use_pooled_projections", True):
common_inputs.pop("pooled_projections", None)

common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
common_inputs["txt_ids"] = (
{0: "batch_size", 1: "sequence_length"} if is_diffusers_version("<", "0.31.0") else {0: "sequence_length"}
Expand Down
Loading
Loading