Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f88c6a8
add support gigachat3
Mohamed-Ashraf273 Mar 3, 2026
c26ffe8
support gigacgat3
Mohamed-Ashraf273 Mar 3, 2026
751cd02
add tests & create tiny model
Mohamed-Ashraf273 Mar 3, 2026
0d16b4f
add tests and fix issues
Mohamed-Ashraf273 Mar 4, 2026
f2a1e53
fix version skip test
Mohamed-Ashraf273 Mar 4, 2026
aac19fb
add docs & modify patcher
Mohamed-Ashraf273 Mar 4, 2026
5c134eb
modify patcher
Mohamed-Ashraf273 Mar 4, 2026
f231bca
modify patcher
Mohamed-Ashraf273 Mar 4, 2026
8f18ff5
fix issues
Mohamed-Ashraf273 Mar 5, 2026
5049ce3
update test
Mohamed-Ashraf273 Mar 5, 2026
722dc53
update tests
Mohamed-Ashraf273 Mar 6, 2026
28a6330
fix test issue
Mohamed-Ashraf273 Mar 6, 2026
07efafd
fix test issue
Mohamed-Ashraf273 Mar 6, 2026
c52c62a
fix tests
Mohamed-Ashraf273 Mar 9, 2026
a63a52d
fix conflict
Mohamed-Ashraf273 Mar 9, 2026
f8bdfe5
fix conflict
Mohamed-Ashraf273 Mar 9, 2026
8228058
Merge branch 'main' into support_gigachat3
Mohamed-Ashraf273 Mar 9, 2026
5b32d32
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
04b4d9f
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
c0ba5d0
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
63a956d
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
acd8148
revert conevrt.py changes
Mohamed-Ashraf273 Mar 13, 2026
dbc1432
update deepseek's patcher
Mohamed-Ashraf273 Mar 15, 2026
47d2910
modify patcher
Mohamed-Ashraf273 Mar 17, 2026
eb601d9
update patcher
Mohamed-Ashraf273 Mar 20, 2026
fe1b84a
removed unnecessary check
Mohamed-Ashraf273 Mar 23, 2026
e86e7d9
fix pacther
Mohamed-Ashraf273 Mar 23, 2026
cbc2005
fix version
Mohamed-Ashraf273 Mar 23, 2026
199da92
fix version
Mohamed-Ashraf273 Mar 23, 2026
1dee64c
revert refactoring
Mohamed-Ashraf273 Mar 23, 2026
225aed3
update doc
Mohamed-Ashraf273 Mar 23, 2026
4174478
modify based on review
Mohamed-Ashraf273 Mar 24, 2026
9ae1162
fix issues
Mohamed-Ashraf273 Mar 26, 2026
5dbc1c8
fix issues
Mohamed-Ashraf273 Mar 26, 2026
f7043c7
fix issues
Mohamed-Ashraf273 Mar 26, 2026
2877a8e
fix issues
Mohamed-Ashraf273 Mar 26, 2026
2c2d31b
fix issues
Mohamed-Ashraf273 Mar 26, 2026
5cb2d8b
fix issues
Mohamed-Ashraf273 Mar 26, 2026
f1b92ed
Merge branch 'main' into support_gigachat3
Mohamed-Ashraf273 Mar 30, 2026
39e770c
Remove Flaubert and add GigaChat3 to models list
Mohamed-Ashraf273 Mar 30, 2026
aec4a90
update docs
Mohamed-Ashraf273 Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
351 changes: 127 additions & 224 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3758,8 +3758,8 @@ class DeepseekPatcher(OVDecoderModelPatcher):
def __enter__(self):
super().__enter__()
self_attn = {
"deepseek_v3": deepseek_v3_attn_forward,
"deepseek_v2": deepseek_v2_attn_forward,
"deepseek_v3": make_deepseek_attn_forward(version=3),
"deepseek_v2": make_deepseek_attn_forward(version=2),
"deepseek": minicpm3_attn_forward,
}

Expand All @@ -3770,249 +3770,152 @@ def __enter__(self):
block.self_attn.forward = types.MethodType(self_attn_fwd, block.self_attn)
if hasattr(block.mlp, "moe_infer"):
block.mlp._org_moe_infer = block.mlp.moe_infer
block.mlp.moe_infer = types.MethodType(deepseek_moe_infer, block.mlp)
elif hasattr(block.mlp, "experts"):
block.mlp._org_moe_infer = None
block.mlp.ep_rank = 0
block.mlp.experts_per_rank = len(block.mlp.experts)
block.mlp.moe_infer = types.MethodType(deepseek_moe_infer, block.mlp)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.model.layers:
block.self_attn.forward = block.self_attn._orig_forward
if hasattr(block.mlp, "_orig_moe_infer"):
block.mlp.moe_infer = block.mlp._orig_moe_infer


def deepseek_v3_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# modified from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L751
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
q_fp32 = q.to(dtype=torch.float32, device=q.device)
k_fp32 = k.to(dtype=torch.float32, device=k.device)
q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)

if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)

bsz, q_len, _ = hidden_states.size()

if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)

k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

# Difference with original code, k_pe.new_empty create constant tensor in torchscript
query_states = torch.concat([q_nope, q_pe], dim=-1)
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)

attn_output = self.o_proj(attn_output)

return attn_output, None, past_key_value


def deepseek_v2_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value=None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# modified from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py#L806
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
if hasattr(block.self_attn, "_orig_forward"):
block.self_attn.forward = block.self_attn._orig_forward
if hasattr(block.mlp, "_org_moe_infer"):
if block.mlp._org_moe_infer is not None:
block.mlp.moe_infer = block.mlp._org_moe_infer
else:
delattr(block.mlp, "moe_infer")
if hasattr(block.mlp, "ep_rank"):
delattr(block.mlp, "ep_rank")
if hasattr(block.mlp, "experts_per_rank"):
delattr(block.mlp, "experts_per_rank")

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
def make_deepseek_attn_forward(version: int = 3):
"""Return a MLA attention forward function for the given DeepSeek version.

b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
Args:
version: 2 for deepseek_v2 (uses freqs_cis), 3 for deepseek_v3 (uses cos/sin tuple)
"""
from typing import Callable

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

if output_attentions:
return self._orig_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
if version == 3:
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
apply_rotary_pos_emb,
apply_rotary_pos_emb_interleave,
eager_attention_forward,
)

bsz, q_len, _ = hidden_states.shape

if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
elif version == 2:

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(1).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out

def eager_attention_forward(module, query, key, value, attention_mask, scaling, dropout=0.0, **kwargs):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
return attn_output.transpose(1, 2).contiguous(), attn_weights
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
raise ValueError(f"Unsupported DeepSeek version: {version}")

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
def deepseek_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
past_key_values=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, seq_length = hidden_states.shape[:-1]

k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if self.q_lora_rank is None:
q_states = self.q_proj(hidden_states)
else:
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q_states = q_states.view(batch_size, seq_length, -1, self.qk_head_dim).transpose(1, 2)
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass))
k_pass = k_pass.view(batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)

if version == 3:
cos, sin = position_embeddings
if self.config.rope_interleave:
q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
else:
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
kv_cache = past_key_value
else:
q_rot, k_rot = apply_rotary_emb(q_rot, k_rot, position_embeddings.to(q_rot.device))
cache_kwargs = {"cache_position": cache_position}
kv_cache = past_key_values

q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
query_states = torch.cat((q_pass, q_rot), dim=-1)
key_states = torch.cat((k_pass, k_rot), dim=-1)

# Difference with original code, k_pe.new_empty create constant tensor in torchscript
query_states = torch.concat([q_nope, q_pe], dim=-1)
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if kv_cache is not None:
key_states, value_states = kv_cache.update(key_states, value_states, self.layer_idx, cache_kwargs)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
is_flash_attn = "flash" in self.config._attn_implementation
if is_flash_attn and self.qk_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if version == 2:
attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)

attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
if is_flash_attn and self.qk_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]

attn_output = self.o_proj(attn_output)
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

return attn_output, None, past_key_value
return deepseek_attn_forward


def deepseek_moe_infer(self, x, topk_ids, topk_weight):
Expand Down
Loading
Loading