Conversation
| vocab_size=32000, | ||
| head_dim=64, | ||
| hidden_act="gelu", | ||
| hidden_act="silu", |
There was a problem hiding this comment.
might be a bit breaking for configs that were saved without hidden_act and therefore defaulted to GeLU prev
There was a problem hiding this comment.
I think it was broken before actually, it is a fix here.
| image_processor=image_processor, | ||
| image_token="[IMG]", | ||
| patch_size=patch_size, | ||
| chat_template=chat_template, |
There was a problem hiding this comment.
model has no chat template anymore? 🥲
There was a problem hiding this comment.
It makes no sense anymore to put one here unfortunately, depending on the models the chat template looks very different (tokenizer version, Thinking or not, ...). So having a default one is arguably worse than none at all imo.
| cos, sin = position_embeddings | ||
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | ||
|
|
||
| if self.config.rope_parameters.llama_4_scaling_beta is not None: |
There was a problem hiding this comment.
rope_parameters is a simple dict so we better safe-get rope_parameters.get(llama_4_scaling_beta')
|
Look good, let's just address one comment and fix CI so it's ✅ |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mistral, mistral3 |
There was a problem hiding this comment.
LGTM! We will need to add a new model for this change to get through, as we never change old models to add new features 🤗
The best way is to write a simple modular_new_model.py with something like this:
...
class MyModelAttention(MistralAttention):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states = query_states * self._get_llama4_attn_scale(cache_position).to(query_states.dtype)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weightsthe best would be to modify the RopeEmbedding to do the scaling in there, this way you don't have to change the forward pass!
What does this PR do?
Add llama 4 scaling for long context to Mistral models.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@patrickvonplaten @ArthurZucker @zucchini-nlp