Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4955,10 +4955,12 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"}
self.add_past_key_values(common_inputs, direction="inputs")
else:
common_inputs["attention_mask"] = {0: "batch_size", 1: "sequence_length"}
return common_inputs


Expand Down
2 changes: 0 additions & 2 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7339,8 +7339,6 @@ def granite_moe_hybrid_update_causal_mask(
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

Expand Down
98 changes: 98 additions & 0 deletions tests/openvino/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,3 +872,101 @@ def test_load_and_infer_with_eagle3_model(self, model_arch, model_pair):

del ov_model
gc.collect()

HYBRID_ARCHITECTURES = []
if is_transformers_version(">=", "4.53.0"):
HYBRID_ARCHITECTURES.append("granitemoehybrid")
if is_transformers_version(">=", "4.54.0"):
HYBRID_ARCHITECTURES.append("lfm2")
if is_transformers_version(">=", "4.57.0"):
HYBRID_ARCHITECTURES.append("qwen3_next")
# not including zamba2 - the Mamba mixer's torch_forward crashes on the second chunk

@parameterized.expand(HYBRID_ARCHITECTURES)
@pytest.mark.run_slow
@slow
def test_hybrid_model_multi_step_generation(self, model_arch):
"""
Validates that hybrid models with mixed recurrent/attention layers produce correct results
over multiple sequential generation calls with cache.
"""
model_id = MODEL_NAMES[model_arch]
tokenizer = self.get_tokenizer(model_arch)

ov_model = OVModelForCausalLM.from_pretrained(
model_id, export=True, ov_config=F32_CONFIG, device=OPENVINO_DEVICE
)
self.assertTrue(ov_model.stateful, "Hybrid model should be exported as stateful")

set_seed(SEED)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)

gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=1,
do_sample=False,
eos_token_id=None,
)
ov_model.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
ov_model.config.eos_token_id = None
transformers_model.config.eos_token_id = None

full_text = "Today is a nice day and I am happy"
full_tokens = tokenizer(full_text, return_tensors="pt")
full_input_ids = full_tokens["input_ids"]
num_tokens = full_input_ids.shape[1]

chunk_size = max(num_tokens // 3, 1)
chunks = [full_input_ids[:, i : i + chunk_size] for i in range(0, num_tokens, chunk_size)]

# OV chunked prefill
ov_cache = None
ov_past_len = 0
for chunk_ids in chunks:
cur_len = chunk_ids.shape[1]
attn_mask = torch.ones((1, ov_past_len + cur_len), dtype=torch.int64)
ov_out = ov_model(input_ids=chunk_ids, attention_mask=attn_mask, cache_params=ov_cache)
ov_cache = ov_out.cache_params
ov_past_len += cur_len
# Transformers chunked prefill with the model-specific hybrid cache
if model_arch == "granitemoehybrid":
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
HybridMambaAttentionDynamicCache,
)

cache = HybridMambaAttentionDynamicCache(config=transformers_model.config, batch_size=1)
elif model_arch == "lfm2":
from transformers.models.lfm2.modeling_lfm2 import Lfm2HybridConvCache

cache = Lfm2HybridConvCache(config=transformers_model.config, max_batch_size=1)
elif model_arch == "qwen3_next":
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextDynamicCache

cache = Qwen3NextDynamicCache(config=transformers_model.config)

past_len = 0
for chunk_ids in chunks:
cur_len = chunk_ids.shape[1]
attn_mask = torch.ones((1, past_len + cur_len), dtype=torch.int64)
with torch.no_grad():
tf_out = transformers_model(
input_ids=chunk_ids, attention_mask=attn_mask, past_key_values=cache, use_cache=True
)
cache = tf_out.past_key_values
past_len += cur_len

self.assertTrue(
torch.allclose(
ov_out.logits,
tf_out.logits,
atol=5e-2, # qwen3-next max diff is 0.04301672801375389
),
f"Chunked prefill OV vs transformers mismatch:\n"
f" max diff: {(ov_out.logits - tf_out.logits).abs().max().item()}",
)

del transformers_model
del ov_model
gc.collect()
Loading