Skip to content
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f88c6a8
add support gigachat3
Mohamed-Ashraf273 Mar 3, 2026
c26ffe8
support gigacgat3
Mohamed-Ashraf273 Mar 3, 2026
751cd02
add tests & create tiny model
Mohamed-Ashraf273 Mar 3, 2026
0d16b4f
add tests and fix issues
Mohamed-Ashraf273 Mar 4, 2026
f2a1e53
fix version skip test
Mohamed-Ashraf273 Mar 4, 2026
aac19fb
add docs & modify patcher
Mohamed-Ashraf273 Mar 4, 2026
5c134eb
modify patcher
Mohamed-Ashraf273 Mar 4, 2026
f231bca
modify patcher
Mohamed-Ashraf273 Mar 4, 2026
8f18ff5
fix issues
Mohamed-Ashraf273 Mar 5, 2026
5049ce3
update test
Mohamed-Ashraf273 Mar 5, 2026
722dc53
update tests
Mohamed-Ashraf273 Mar 6, 2026
28a6330
fix test issue
Mohamed-Ashraf273 Mar 6, 2026
07efafd
fix test issue
Mohamed-Ashraf273 Mar 6, 2026
c52c62a
fix tests
Mohamed-Ashraf273 Mar 9, 2026
a63a52d
fix conflict
Mohamed-Ashraf273 Mar 9, 2026
f8bdfe5
fix conflict
Mohamed-Ashraf273 Mar 9, 2026
8228058
Merge branch 'main' into support_gigachat3
Mohamed-Ashraf273 Mar 9, 2026
5b32d32
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
04b4d9f
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
c0ba5d0
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
63a956d
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
acd8148
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
dbc1432
update deepseek's patcher
Mohamed-Ashraf273 Mar 15, 2026
47d2910
modify patcher
Mohamed-Ashraf273 Mar 17, 2026
eb601d9
update patcher
Mohamed-Ashraf273 Mar 20, 2026
fe1b84a
removed unnecessary check
Mohamed-Ashraf273 Mar 23, 2026
e86e7d9
fix pacther
Mohamed-Ashraf273 Mar 23, 2026
cbc2005
fix version
Mohamed-Ashraf273 Mar 23, 2026
199da92
fix version
Mohamed-Ashraf273 Mar 23, 2026
1dee64c
revert refactoring
Mohamed-Ashraf273 Mar 23, 2026
225aed3
update doc
Mohamed-Ashraf273 Mar 23, 2026
4174478
modify based on review
Mohamed-Ashraf273 Mar 24, 2026
9ae1162
fix issues
Mohamed-Ashraf273 Mar 26, 2026
5dbc1c8
fix issues
Mohamed-Ashraf273 Mar 26, 2026
f7043c7
fix issues
Mohamed-Ashraf273 Mar 26, 2026
2877a8e
fix issues
Mohamed-Ashraf273 Mar 26, 2026
2c2d31b
fix issues
Mohamed-Ashraf273 Mar 26, 2026
5cb2d8b
fix issues
Mohamed-Ashraf273 Mar 26, 2026
f1b92ed
Merge branch 'main' into support_gigachat3
Mohamed-Ashraf273 Mar 30, 2026
39e770c
Remove Flaubert and add GigaChat3 to models list
Mohamed-Ashraf273 Mar 30, 2026
aec4a90
update docs
Mohamed-Ashraf273 Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Here is the list of the supported architectures :
- Falcon
- Falcon-Mamba
- Flaubert
- GigaChat3
- GLM-4
- GLM-Edge
- GPT-2
Expand Down
124 changes: 119 additions & 5 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@
override_arguments,
sdpa_mask_without_vmap,
)
from optimum.intel.utils.import_utils import is_diffusers_version, is_torch_version, is_transformers_version
from optimum.intel.utils.import_utils import (
is_diffusers_version,
is_openvino_version,
is_torch_version,
is_transformers_version,
)

from ._ov_ops import convert_recurrent_attention_cell

Expand Down Expand Up @@ -3767,29 +3772,92 @@ def __enter__(self):

self_attn_fwd = self_attn.get(self._model.config.model_type)
for block in self._model.model.layers:
# Patch attention
if self_attn_fwd is not None:
block.self_attn._orig_forward = block.self_attn.forward
block.self_attn.forward = types.MethodType(self_attn_fwd, block.self_attn)

# Patch MoE
if hasattr(block.mlp, "moe_infer"):
block.mlp._org_moe_infer = block.mlp.moe_infer
block.mlp._orig_moe_infer = block.mlp.moe_infer
block.mlp._orig_moe = None
block.mlp.ep_rank = getattr(block.mlp, "ep_rank", 0)
block.mlp.experts_per_rank = getattr(block.mlp, "experts_per_rank", len(block.mlp.experts))
block.mlp.moe_infer = types.MethodType(deepseek_moe_infer, block.mlp)

elif hasattr(block.mlp, "moe") and hasattr(block.mlp, "experts"):
block.mlp._orig_moe = block.mlp.moe
block.mlp._orig_moe_infer = None

# Pre-concatenate expert weights for vectorized computation
num_experts = len(block.mlp.experts)
gate_projs = torch.concat(
tuple(block.mlp.experts[i].gate_proj.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)
up_projs = torch.concat(
tuple(block.mlp.experts[i].up_proj.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)
down_projs = torch.concat(
tuple(block.mlp.experts[i].down_proj.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)

if is_openvino_version("<", "2026.1.0"):
block.mlp.gate_projs = gate_projs.float()
block.mlp.up_projs = up_projs.float()
block.mlp.down_projs = down_projs.float()
else:
block.mlp.gate_projs = gate_projs
block.mlp.up_projs = up_projs
block.mlp.down_projs = down_projs

block.mlp.moe = types.MethodType(deepseek_moe, block.mlp)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.model.layers:
block.self_attn.forward = block.self_attn._orig_forward
# Restore attention
if hasattr(block.self_attn, "_orig_forward"):
block.self_attn.forward = block.self_attn._orig_forward

# Restore MoE - handle both interfaces
if hasattr(block.mlp, "_orig_moe"):
if block.mlp._orig_moe is not None:
block.mlp.moe = block.mlp._orig_moe
if hasattr(block.mlp, "gate_projs"):
del block.mlp.gate_projs
if hasattr(block.mlp, "up_projs"):
del block.mlp.up_projs
if hasattr(block.mlp, "down_projs"):
del block.mlp.down_projs
delattr(block.mlp, "_orig_moe")

if hasattr(block.mlp, "_orig_moe_infer"):
block.mlp.moe_infer = block.mlp._orig_moe_infer
if block.mlp._orig_moe_infer is not None:
block.mlp.moe_infer = block.mlp._orig_moe_infer
else:
if hasattr(block.mlp, "moe_infer"):
delattr(block.mlp, "moe_infer")
if hasattr(block.mlp, "ep_rank"):
delattr(block.mlp, "ep_rank")
if hasattr(block.mlp, "experts_per_rank"):
delattr(block.mlp, "experts_per_rank")
delattr(block.mlp, "_orig_moe_infer")


def deepseek_v3_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, # ← ADD THIS
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# modified from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L751
def rotate_half(x):
Expand All @@ -3808,6 +3876,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)

if not hasattr(self, 'q_head_dim'):
self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim

if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
Expand All @@ -3816,6 +3887,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)

bsz, q_len, _ = hidden_states.size()
Expand Down Expand Up @@ -3846,7 +3919,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

new_interface = False # Set to True if using new rotary embedding interface
if hasattr(self, 'rotary_emb'):
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
else:
from transformers.models.deepseek_v3.modeling_deepseek_v3 import apply_rotary_pos_emb

cos, sin = position_embeddings
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
new_interface = True


q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

Expand Down Expand Up @@ -3892,6 +3976,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):

attn_output = self.o_proj(attn_output)

if new_interface:
return attn_output, None

return attn_output, None, past_key_value


Expand Down Expand Up @@ -4051,6 +4138,33 @@ def deepseek_moe_infer(self, x, topk_ids, topk_weight):
return final_out


def deepseek_moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
"""
Vectorized MoE forward for DeepSeek-V3.
"""
num_experts = len(self.experts)
batch_tokens, hidden_dim = hidden_states.shape

routing = torch.zeros(
batch_tokens, num_experts,
dtype=topk_weights.dtype,
device=hidden_states.device
)
routing.scatter_(1, topk_indices, topk_weights)

hidden_states = hidden_states.repeat(num_experts, 1)
hidden_states = hidden_states.view(num_experts, batch_tokens, hidden_dim)
act_fn = self.experts[0].act_fn
gate = torch.bmm(hidden_states, self.gate_projs.transpose(1, 2))
up = torch.bmm(hidden_states, self.up_projs.transpose(1, 2))
gate_up = act_fn(gate) * up
next_states = torch.bmm(gate_up, self.down_projs.transpose(1, 2))
routing = routing.transpose(0, 1).unsqueeze(-1)
next_states = next_states * routing
next_states = next_states.sum(dim=0)
return next_states.type(hidden_states.dtype)


class Qwen2VLLanguageModelPatcher(OVDecoderModelPatcher):
def __init__(
self,
Expand Down
9 changes: 8 additions & 1 deletion tests/openvino/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES += ("glm", "mistral-nemo", "phimoe")

if is_transformers_version("<", "4.54.0"):
SUPPORTED_ARCHITECTURES += ("deepseek",)
SUPPORTED_ARCHITECTURES += ("deepseek", "gigachat3",)

# gptq and awq install disabled for windows test environment
if platform.system() != "Windows" and is_transformers_version("<", "4.56.0"):
Expand Down Expand Up @@ -230,6 +230,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"minicpm3": 6,
"phimoe": 2,
"deepseek": 2,
"gigachat3": 2,
"opt_gptq": 12,
"mixtral_awq": 2,
"gemma3_text": 2,
Expand Down Expand Up @@ -397,6 +398,12 @@ def test_compare_to_transformers(self, model_arch):
return

tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)

# Gigachat3 tokenizer add token_type_ids which DeepSeekV3
# and similar models do not accept in generate(); strip it so both OV and PT calls succeed.
if model_arch in ["gigachat3"]:
tokens.pop("token_type_ids", None)

ov_model.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
ov_model.config.eos_token_id = None
Expand Down
3 changes: 3 additions & 0 deletions tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class ExportModelTest(unittest.TestCase):
if is_transformers_version(">=", "4.48.0"):
SUPPORTED_ARCHITECTURES.update({"cohere2": OVModelForCausalLM})

if is_transformers_version(">=", "4.46.0") and is_transformers_version("<=", "4.53.3"):
SUPPORTED_ARCHITECTURES.update({"deepseek": OVModelForCausalLM, "gigachat3": OVModelForCausalLM})

if is_transformers_version(">=", "4.49"):
SUPPORTED_ARCHITECTURES.update({"zamba2": OVModelForCausalLM})

Expand Down
8 changes: 8 additions & 0 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ class OVCLIExportTestCase(unittest.TestCase):
]
)

if is_transformers_version(">=", "4.46.0") and is_transformers_version("<=", "4.53.3"):
SUPPORTED_ARCHITECTURES.extend(
[
("text-generation-with-past", "gigachat3"),
]
)

if is_transformers_version(">=", "4.57.0"):
SUPPORTED_ARCHITECTURES.extend(
[
Expand Down Expand Up @@ -198,6 +205,7 @@ class OVCLIExportTestCase(unittest.TestCase):
"exaone4": 2,
"bitnet": 2,
"granitemoehybrid": 2,
"gigachat3": 2,
}

TOKENIZER_CHAT_TEMPLATE_TESTS_MODELS = {
Expand Down
2 changes: 2 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"deberta-v2": "optimum-intel-internal-testing/tiny-random-DebertaV2Model",
"decilm": "optimum-intel-internal-testing/tiny-random-decilm",
"deepseek": "optimum-intel-internal-testing/tiny-random-deepseek-v3",
"gigachat3": "optimum-intel-internal-testing/tiny-random-gigachat3",
"deit": "optimum-intel-internal-testing/tiny-random-DeiTModel",
"convnext": "optimum-intel-internal-testing/tiny-random-convnext",
"convnextv2": "optimum-intel-internal-testing/tiny-random-ConvNextV2Model",
Expand Down Expand Up @@ -372,6 +373,7 @@
"hunyuan_v1_dense": {"model": 32},
"qwen3_eagle3": {"model": 20},
"qwen3_next": {"model": 100},
"gigachat3": {"model": 58},
}

TEST_IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
Expand Down
Loading