Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4955,10 +4955,12 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"}
self.add_past_key_values(common_inputs, direction="inputs")
else:
common_inputs["attention_mask"] = {0: "batch_size", 1: "sequence_length"}
return common_inputs


Expand Down
2 changes: 0 additions & 2 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7339,8 +7339,6 @@ def granite_moe_hybrid_update_causal_mask(
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

Expand Down
101 changes: 97 additions & 4 deletions tests/openvino/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,12 @@ def test_pipeline(self, model_arch):
if is_transformers_version("<=", "4.46") and model_arch == "qwen"
# in older transformers versions, remote code tokenizers (and granite/granitemoe)
# were not loaded in pipelines because they were not registered in TOKENIZER_MAPPING
else model_id
if is_transformers_version("<=", "4.46")
and model_arch in REMOTE_CODE_MODELS + ("granite", "granitemoe")
else None
else (
model_id
if is_transformers_version("<=", "4.46")
and model_arch in REMOTE_CODE_MODELS + ("granite", "granitemoe")
else None
)
),
)
set_seed(SEED)
Expand Down Expand Up @@ -872,3 +874,94 @@ def test_load_and_infer_with_eagle3_model(self, model_arch, model_pair):

del ov_model
gc.collect()

HYBRID_ARCHITECTURES = []
if is_transformers_version(">=", "4.53"):
HYBRID_ARCHITECTURES.append("granitemoehybrid")
if is_transformers_version(">=", "4.54"):
HYBRID_ARCHITECTURES.append("lfm2")
if is_transformers_version(">=", "4.57"):
HYBRID_ARCHITECTURES.append("qwen3_next")
# not including zamba2 - the Mamba mixer's torch_forward crashes on the second chunk

@parameterized.expand(HYBRID_ARCHITECTURES, skip_on_empty=True)
@pytest.mark.run_slow
@slow
def test_hybrid_model_multi_step_generation(self, model_arch):
"""
Validates that hybrid models with mixed recurrent/attention layers produce correct results
over multiple sequential generation calls with cache.
"""
model_id = MODEL_NAMES[model_arch]
tokenizer = self.get_tokenizer(model_arch)

ov_model = OVModelForCausalLM.from_pretrained(
model_id, export=True, ov_config=F32_CONFIG, device=OPENVINO_DEVICE
)
self.assertTrue(ov_model.stateful, "Hybrid model should be exported as stateful")

set_seed(SEED)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)

ov_model.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
ov_model.config.eos_token_id = None
transformers_model.config.eos_token_id = None

full_text = "Today is a nice day and I am happy"
full_tokens = tokenizer(full_text, return_tensors="pt")
full_input_ids = full_tokens["input_ids"]
num_tokens = full_input_ids.shape[1]

chunk_size = max(num_tokens // 3, 1)
chunks = [full_input_ids[:, i : i + chunk_size] for i in range(0, num_tokens, chunk_size)]

# OV chunked prefill
ov_cache = None
ov_past_len = 0
for chunk_ids in chunks:
cur_len = chunk_ids.shape[1]
attn_mask = torch.ones((1, ov_past_len + cur_len), dtype=torch.int64)
ov_out = ov_model(input_ids=chunk_ids, attention_mask=attn_mask, cache_params=ov_cache)
ov_cache = ov_out.cache_params
ov_past_len += cur_len
# Transformers chunked prefill with the model-specific hybrid cache
if model_arch == "granitemoehybrid":
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
HybridMambaAttentionDynamicCache,
)

cache = HybridMambaAttentionDynamicCache(config=transformers_model.config, batch_size=1)
elif model_arch == "lfm2":
from transformers.models.lfm2.modeling_lfm2 import Lfm2HybridConvCache

cache = Lfm2HybridConvCache(config=transformers_model.config, max_batch_size=1)
elif model_arch == "qwen3_next":
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextDynamicCache

cache = Qwen3NextDynamicCache(config=transformers_model.config)

past_len = 0
for chunk_ids in chunks:
cur_len = chunk_ids.shape[1]
attn_mask = torch.ones((1, past_len + cur_len), dtype=torch.int64)
with torch.no_grad():
tf_out = transformers_model(
input_ids=chunk_ids, attention_mask=attn_mask, past_key_values=cache, use_cache=True
)
cache = tf_out.past_key_values
past_len += cur_len

self.assertTrue(
torch.allclose(
ov_out.logits,
tf_out.logits,
atol=5e-2, # qwen3-next max diff is 0.04301672801375389
),
f"Chunked prefill OV vs transformers mismatch:\n"
f" max diff: {(ov_out.logits - tf_out.logits).abs().max().item()}",
)

del transformers_model
del ov_model
gc.collect()