Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f88c6a8
add support gigachat3
Mohamed-Ashraf273 Mar 3, 2026
c26ffe8
support gigacgat3
Mohamed-Ashraf273 Mar 3, 2026
751cd02
add tests & create tiny model
Mohamed-Ashraf273 Mar 3, 2026
0d16b4f
add tests and fix issues
Mohamed-Ashraf273 Mar 4, 2026
f2a1e53
fix version skip test
Mohamed-Ashraf273 Mar 4, 2026
aac19fb
add docs & modify patcher
Mohamed-Ashraf273 Mar 4, 2026
5c134eb
modify patcher
Mohamed-Ashraf273 Mar 4, 2026
f231bca
modify patcher
Mohamed-Ashraf273 Mar 4, 2026
8f18ff5
fix issues
Mohamed-Ashraf273 Mar 5, 2026
5049ce3
update test
Mohamed-Ashraf273 Mar 5, 2026
722dc53
update tests
Mohamed-Ashraf273 Mar 6, 2026
28a6330
fix test issue
Mohamed-Ashraf273 Mar 6, 2026
07efafd
fix test issue
Mohamed-Ashraf273 Mar 6, 2026
c52c62a
fix tests
Mohamed-Ashraf273 Mar 9, 2026
a63a52d
fix conflict
Mohamed-Ashraf273 Mar 9, 2026
f8bdfe5
fix conflict
Mohamed-Ashraf273 Mar 9, 2026
8228058
Merge branch 'main' into support_gigachat3
Mohamed-Ashraf273 Mar 9, 2026
5b32d32
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
04b4d9f
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
c0ba5d0
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
63a956d
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
acd8148
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
dbc1432
update deepseek's patcher
Mohamed-Ashraf273 Mar 15, 2026
47d2910
modify patcher
Mohamed-Ashraf273 Mar 17, 2026
eb601d9
update patcher
Mohamed-Ashraf273 Mar 20, 2026
fe1b84a
removed unnecessary check
Mohamed-Ashraf273 Mar 23, 2026
e86e7d9
fix pacther
Mohamed-Ashraf273 Mar 23, 2026
cbc2005
fix version
Mohamed-Ashraf273 Mar 23, 2026
199da92
fix version
Mohamed-Ashraf273 Mar 23, 2026
1dee64c
revert refactoring
Mohamed-Ashraf273 Mar 23, 2026
225aed3
update doc
Mohamed-Ashraf273 Mar 23, 2026
4174478
modify based on review
Mohamed-Ashraf273 Mar 24, 2026
9ae1162
fix issues
Mohamed-Ashraf273 Mar 26, 2026
5dbc1c8
fix issues
Mohamed-Ashraf273 Mar 26, 2026
f7043c7
fix issues
Mohamed-Ashraf273 Mar 26, 2026
2877a8e
fix issues
Mohamed-Ashraf273 Mar 26, 2026
2c2d31b
fix issues
Mohamed-Ashraf273 Mar 26, 2026
5cb2d8b
fix issues
Mohamed-Ashraf273 Mar 26, 2026
f1b92ed
Merge branch 'main' into support_gigachat3
Mohamed-Ashraf273 Mar 30, 2026
39e770c
Remove Flaubert and add GigaChat3 to models list
Mohamed-Ashraf273 Mar 30, 2026
aec4a90
update docs
Mohamed-Ashraf273 Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Here is the list of the supported architectures :
- Falcon
- Falcon-Mamba
- FlauBERT
- GigaChat3
- GLM-4
- GLM-Edge
- GPT-2
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4081,8 +4081,8 @@ class M2M100OpenVINOConfig(BartOpenVINOConfig):
)
@register_in_tasks_manager("deepseek", *["text-generation", "text-generation-with-past"], library_name="transformers")
class DeepseekOpenVINOConfig(MiniCPM3OpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.46.0"
MAX_TRANSFORMERS_VERSION = "4.53.3"
MIN_TRANSFORMERS_VERSION = "4.53.0"
MAX_TRANSFORMERS_VERSION = None
_MODEL_PATCHER = DeepseekPatcher


Expand Down
225 changes: 185 additions & 40 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@
override_arguments,
sdpa_mask_without_vmap,
)
from optimum.intel.utils.import_utils import is_diffusers_version, is_torch_version, is_transformers_version
from optimum.intel.utils.import_utils import (
is_diffusers_version,
is_openvino_version,
is_torch_version,
is_transformers_version,
)

from ._ov_ops import convert_recurrent_attention_cell

Expand Down Expand Up @@ -3771,25 +3776,96 @@ def __enter__(self):
block.self_attn._orig_forward = block.self_attn.forward
block.self_attn.forward = types.MethodType(self_attn_fwd, block.self_attn)
if hasattr(block.mlp, "moe_infer"):
block.mlp._org_moe_infer = block.mlp.moe_infer
# old interface (transformers < 4.57): moe_infer(self, x, topk_ids, topk_weight)
block.mlp._orig_moe_infer = block.mlp.moe_infer
block.mlp._orig_moe = None
block.mlp.ep_rank = getattr(block.mlp, "ep_rank", 0)
block.mlp.experts_per_rank = getattr(block.mlp, "experts_per_rank", len(block.mlp.experts))
block.mlp.moe_infer = types.MethodType(deepseek_moe_infer, block.mlp)
elif hasattr(block.mlp, "moe") and hasattr(block.mlp, "experts"):
# new interface (transformers >= 4.57): moe(self, hidden_states, topk_indices, topk_weights)
block.mlp._orig_moe = block.mlp.moe
block.mlp._orig_moe_infer = None
num_experts = len(block.mlp.experts)

# Concatenate expert weights
gate_projs = torch.concat(
tuple(block.mlp.experts[i].gate_proj.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)
up_projs = torch.concat(
tuple(block.mlp.experts[i].up_proj.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)
down_projs = torch.concat(
tuple(block.mlp.experts[i].down_proj.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)

if is_openvino_version("<", "2026.1.0"):
logger.warning(
"This model works best with OpenVINO 2026.1 or later. "
"Earlier versions require float() conversion for MoE weights, "
"which may affect performance. "
"OpenVINO 2026.1 includes a fix for torch.bmm dtype handling."
)
block.mlp.gate_projs = gate_projs.float()
block.mlp.up_projs = up_projs.float()
block.mlp.down_projs = down_projs.float()
else:
block.mlp.gate_projs = gate_projs
block.mlp.up_projs = up_projs
block.mlp.down_projs = down_projs

block.mlp.moe = types.MethodType(deepseek_moe, block.mlp)
elif hasattr(block.mlp, "experts"):
# fallback: patch by injecting moe_infer with required attributes
block.mlp._orig_moe_infer = None
block.mlp._orig_moe = None
block.mlp.ep_rank = 0
block.mlp.experts_per_rank = len(block.mlp.experts)
block.mlp.moe_infer = types.MethodType(deepseek_moe_infer, block.mlp)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.model.layers:
block.self_attn.forward = block.self_attn._orig_forward
if hasattr(block.self_attn, "_orig_forward"):
block.self_attn.forward = block.self_attn._orig_forward
if hasattr(block.mlp, "_orig_moe"):
if block.mlp._orig_moe is not None:
block.mlp.moe = block.mlp._orig_moe
if hasattr(block.mlp, "gate_projs"):
del block.mlp.gate_projs
if hasattr(block.mlp, "up_projs"):
del block.mlp.up_projs
if hasattr(block.mlp, "down_projs"):
del block.mlp.down_projs
delattr(block.mlp, "_orig_moe")
if hasattr(block.mlp, "_orig_moe_infer"):
block.mlp.moe_infer = block.mlp._orig_moe_infer
if block.mlp._orig_moe_infer is not None:
block.mlp.moe_infer = block.mlp._orig_moe_infer
else:
if hasattr(block.mlp, "moe_infer"):
delattr(block.mlp, "moe_infer")
if hasattr(block.mlp, "ep_rank"):
delattr(block.mlp, "ep_rank")
if hasattr(block.mlp, "experts_per_rank"):
delattr(block.mlp, "experts_per_rank")
delattr(block.mlp, "_orig_moe_infer")


def deepseek_v3_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings=None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
past_key_values=None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# modified from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L751
def rotate_half(x):
Expand All @@ -3813,9 +3889,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
past_key_value=past_key_value,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
kwargs=kwargs,
)

bsz, q_len, _ = hidden_states.size()
Expand All @@ -3824,50 +3904,89 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q = q.view(bsz, q_len, self.num_heads, self.qk_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)

k_pass, k_rot = torch.split(
compressed_kv,
[self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1,
)

k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass))
k_pass = k_pass.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)

q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
k_pass, value_states = torch.split(
k_pass,
[self.qk_nope_head_dim, self.v_head_dim],
dim=-1,
)

# Difference with original code, k_pe.new_empty create constant tensor in torchscript
query_states = torch.concat([q_nope, q_pe], dim=-1)
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
k_rot = k_rot.view(bsz, 1, q_len, self.qk_rope_head_dim)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
new_interface = position_embeddings is not None and not hasattr(self, "rotary_emb")

if new_interface:
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
apply_rotary_pos_emb as deepseek_v3_apply_rotary_pos_emb,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
apply_rotary_pos_emb_interleave as deepseek_v3_apply_rotary_pos_emb_interleave,
)

cos, sin = position_embeddings

if getattr(self.config, "rope_interleave", False):
try:
q_pe, k_rot = deepseek_v3_apply_rotary_pos_emb_interleave(q_pe, k_rot, cos, sin)
except Exception as e:
raise RuntimeError(
"Failed to apply interleaved rotary position embeddings, "
f"may due to incompatible transformers version, try to `pip install transformers>=4.57.1`: {e}"
)
else:
q_pe, k_rot = deepseek_v3_apply_rotary_pos_emb(q_pe, k_rot, cos, sin)

kv_cache = past_key_values if past_key_values is not None else past_key_value

else:
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_rot = apply_rotary_pos_emb(q_pe, k_rot, cos, sin, position_ids)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

kv_cache = past_key_value

k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
query_states = torch.cat((q_nope, q_pe), dim=-1)
key_states = torch.cat((k_pass, k_rot), dim=-1)

if kv_cache is not None:
if new_interface:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = kv_cache.update(key_states, value_states, self.layer_idx, cache_kwargs)

if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]

else:
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = kv_cache.update(key_states, value_states, self.layer_idx, cache_kwargs)

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
Expand All @@ -3884,6 +4003,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
scale=None if not new_interface else self.scaling,
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand All @@ -3892,6 +4012,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):

attn_output = self.o_proj(attn_output)

if new_interface:
return attn_output, None

return attn_output, None, past_key_value


Expand Down Expand Up @@ -4051,6 +4174,28 @@ def deepseek_moe_infer(self, x, topk_ids, topk_weight):
return final_out


def deepseek_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
"""
Vectorized MoE that matches original behavior.
"""
orig_dtype = hidden_states.dtype
num_experts = len(self.experts)
batch_tokens, _ = hidden_states.shape
compute_dtype = torch.promote_types(hidden_states.dtype, self.gate_projs.dtype)
routing = torch.zeros(batch_tokens, num_experts, dtype=compute_dtype, device=hidden_states.device)
routing.scatter_(1, topk_indices, topk_weights.to(dtype=compute_dtype))
expanded = hidden_states.to(dtype=compute_dtype).unsqueeze(0).expand(num_experts, -1, -1)
act_fn = self.experts[0].act_fn
gate = torch.bmm(expanded, self.gate_projs.to(dtype=compute_dtype).transpose(1, 2))
up = torch.bmm(expanded, self.up_projs.to(dtype=compute_dtype).transpose(1, 2))
gate_up = act_fn(gate) * up
next_states = torch.bmm(gate_up, self.down_projs.to(dtype=compute_dtype).transpose(1, 2))
routing = routing.transpose(0, 1).unsqueeze(-1)
next_states = next_states * routing
next_states = next_states.sum(dim=0)
return next_states.to(orig_dtype)


class Qwen2VLLanguageModelPatcher(OVDecoderModelPatcher):
def __init__(
self,
Expand Down
32 changes: 16 additions & 16 deletions tests/openvino/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
if is_transformers_version(">=", "4.46.0"):
SUPPORTED_ARCHITECTURES += ("glm", "mistral-nemo", "phimoe")

if is_transformers_version("<", "4.54.0"):
SUPPORTED_ARCHITECTURES += ("deepseek",)

# gptq and awq install disabled for windows test environment
if platform.system() != "Windows" and is_transformers_version("<", "4.56.0"):
SUPPORTED_ARCHITECTURES += ("opt_gptq", "mixtral_awq")

if is_transformers_version(">=", "4.53.0"):
SUPPORTED_ARCHITECTURES += ("deepseek", "gigachat3")

if is_transformers_version(">", "4.47"):
SUPPORTED_ARCHITECTURES += ("olmo2",)

Expand Down Expand Up @@ -230,6 +230,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"minicpm3": 6,
"phimoe": 2,
"deepseek": 2,
"gigachat3": 2,
"opt_gptq": 12,
"mixtral_awq": 2,
"gemma3_text": 2,
Expand Down Expand Up @@ -290,11 +291,8 @@ def test_find_untested_architectures(self):

if "llama4_text" in supported_architectures:
supported_architectures.remove("llama4_text")
if is_transformers_version(">=", str(DeepseekOpenVINOConfig.MAX_TRANSFORMERS_VERSION)):
if "deepseek_v2" in supported_architectures:
supported_architectures.remove("deepseek_v2")
if "deepseek_v3" in supported_architectures:
supported_architectures.remove("deepseek_v3")
if is_transformers_version(">=", str(DeepseekOpenVINOConfig.MIN_TRANSFORMERS_VERSION)):
supported_architectures -= {"deepseek_v2", "deepseek_v3"}
if is_transformers_version("<", str(BitnetOpenVINOConfig.MIN_TRANSFORMERS_VERSION)):
supported_architectures -= {"bitnet"}
if is_transformers_version("<", str(LFM2OpenVINOConfig.MIN_TRANSFORMERS_VERSION)):
Expand Down Expand Up @@ -397,6 +395,16 @@ def test_compare_to_transformers(self, model_arch):
return

tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)

# Gigachat3 tokenizer add token_type_ids which DeepSeekV3
# and similar models do not accept in generate(); strip it so both OV and PT calls succeed.
if model_arch in ["gigachat3"]:
tokens.pop("token_type_ids", None)

if model_arch == "deepseek":
ov_model.generation_config.do_sample = False
transformers_model.generation_config.do_sample = False

ov_model.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
ov_model.config.eos_token_id = None
Expand All @@ -413,10 +421,6 @@ def test_compare_to_transformers(self, model_arch):

ov_outputs = ov_model.generate(**tokens, generation_config=gen_config)

# TODO: add back once https://huggingface.co/katuni4ka/tiny-random-minicpm3/discussions/1 merged (for all models) as current modeling incompatible with transformers >= v4.49
if model_arch in {"deepseek"} and is_transformers_version(">=", "4.49"):
self.skipTest("Incompatible modeling code")

additional_inputs = {}
# gemma2 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache,
# align cache representation in torch model
Expand Down Expand Up @@ -666,10 +670,6 @@ def test_beam_search(self, model_arch):
if model_arch in ["lfm2", "granitemoehybrid"]:
return

# TODO: add back once https://huggingface.co/katuni4ka/tiny-random-minicpm3/discussions/1 merged (for all models) as current modeling incompatible with transformers >= v4.49
if model_arch in {"deepseek"} and is_transformers_version(">=", "4.49"):
self.skipTest("Incompatible modeling code")

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in REMOTE_CODE_MODELS)
if model_arch == "persimmon":
tokenizer.pad_token_id = tokenizer.bos_token_id
Expand Down
Loading