From f60f45839be02b70bebc2f26001da344a385fdc8 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 18:29:08 +0000
Subject: [PATCH 01/18] Add mistral target model
---
specforge/modeling/target/mistral.py | 566 +++++++++++++++++++++++++++
1 file changed, 566 insertions(+)
create mode 100644 specforge/modeling/target/mistral.py
diff --git a/specforge/modeling/target/mistral.py b/specforge/modeling/target/mistral.py
new file mode 100644
index 00000000..f5dd2437
--- /dev/null
+++ b/specforge/modeling/target/mistral.py
@@ -0,0 +1,566 @@
+# coding=utf-8
+# Copyright 2025 Mistral AI and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Dict, Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch import nn
+from transformers import MistralConfig
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.generation import GenerationMixin
+from transformers.masking_utils import (
+ create_causal_mask,
+ create_sliding_window_causal_mask,
+)
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_layers import GradientCheckpointingLayer
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+)
+from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from transformers.models.mistral.modeling_mistral import (
+ MistralRMSNorm,
+ MistralRotaryEmbedding,
+ apply_rotary_pos_emb,
+ eager_attention_forward,
+)
+from transformers.processing_utils import Unpack
+from transformers.utils import (
+ TransformersKwargs,
+ auto_docstring,
+ can_return_tuple,
+ logging,
+)
+
+from specforge.distributed import get_tp_group
+from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear
+from specforge.modeling.target.base import DistributedTargetModel
+
+logger = logging.get_logger(__name__)
+
+
+class MistralMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+
+ # distributed linear layers
+ self.tp_group = get_tp_group()
+ self.gate_proj = ColumnParallelLinear(
+ self.hidden_size, self.intermediate_size, bias=False
+ )
+ self.up_proj = ColumnParallelLinear(
+ self.hidden_size, self.intermediate_size, bias=False
+ )
+ self.down_proj = RowParallelLinear(
+ self.intermediate_size, self.hidden_size, bias=False
+ )
+
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group)
+ return down_proj
+
+
+class MistralAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: MistralConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+ self.num_key_value_groups = (
+ config.num_attention_heads // config.num_key_value_heads
+ )
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ # distributed linear layers
+ self.tp_group = get_tp_group()
+ self.q_proj = ColumnParallelLinear(
+ config.hidden_size,
+ config.num_attention_heads * self.head_dim,
+ bias=False,
+ )
+ self.k_proj = ColumnParallelLinear(
+ config.hidden_size,
+ config.num_key_value_heads * self.head_dim,
+ bias=False,
+ )
+ self.v_proj = ColumnParallelLinear(
+ config.hidden_size,
+ config.num_key_value_heads * self.head_dim,
+ bias=False,
+ )
+ self.o_proj = RowParallelLinear(
+ config.num_attention_heads * self.head_dim,
+ config.hidden_size,
+ bias=False,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin
+ )
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
+ self.config._attn_implementation
+ ]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group)
+ return attn_output, attn_weights
+
+
+class MistralDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: MistralConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = MistralMLP(config)
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = MistralRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[
+ tuple[torch.Tensor, torch.Tensor]
+ ] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[
+ torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class MistralPreTrainedModel(PreTrainedModel):
+ config_class: MistralConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MistralDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, MistralRMSNorm):
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class MistralModel(MistralPreTrainedModel):
+ def __init__(self, config: MistralConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size, config.hidden_size, self.padding_idx
+ )
+ self.layers = nn.ModuleList(
+ [
+ MistralDecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = MistralRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You must specify exactly one of input_ids or inputs_embeds"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
+ if not isinstance(past_key_values, (type(None), Cache)):
+ raise ValueError(
+ "The `past_key_values` should be either a `Cache` object or `None`."
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = (
+ past_key_values.get_seq_length() if past_key_values is not None else 0
+ )
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ mask_function = (
+ create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
+ )
+ causal_mask = mask_function(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@auto_docstring
+class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin, DistributedTargetModel):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = MistralModel(config)
+ self.vocab_size = config.vocab_size
+
+ # distributed the lm head
+ self.lm_head = ColumnParallelLinear(
+ config.hidden_size, config.vocab_size, bias=False
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
+
+ >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = (
+ slice(-logits_to_keep, None)
+ if isinstance(logits_to_keep, int)
+ else logits_to_keep
+ )
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ logits = self._gather_tensor(logits, get_tp_group())
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits,
+ labels=labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs
+ )
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def load_weights(self, state_dict: Dict[str, torch.Tensor]):
+ tp_group = get_tp_group()
+
+ updated_state_dict = {}
+ for key, value in state_dict.items():
+ # Ensure that the state dict is a flat dict of keys and tensors. Breaking this assumption
+ # will break recipe code
+ if not isinstance(value, torch.Tensor):
+ raise ValueError(
+ f"Expected all values in the state dict to be torch.Tensor. "
+ f"Found {type(value)} instead."
+ )
+
+ module_key = ".".join(key.split(".")[:-1])
+ module = self.get_submodule(module_key)
+
+ # get the module type based on key and shard accordingly
+ if isinstance(module, RowParallelLinear) and key.endswith(".weight"):
+ value = self._shard_tensor(value, tp_group, -1)
+ elif isinstance(module, ColumnParallelLinear) and key.endswith(".weight"):
+ value = self._shard_tensor(value, tp_group, 0)
+ elif isinstance(module, ColumnParallelLinear) and key.endswith(".bias"):
+ value = self._shard_tensor(value, tp_group, 0)
+
+ updated_state_dict[key] = value
+
+ # load state dict
+ self.load_state_dict(updated_state_dict, strict=False)
+
+
+__all__ = [
+ "MistralForCausalLM",
+ "MistralPreTrainedModel",
+ "MistralModel",
+]
From 5b83e92094b04193ccb117c5bb6807eec2889c80 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 18:36:38 +0000
Subject: [PATCH 02/18] Add mistral to AutoDistributedTargetModel
_model_mapping
---
specforge/modeling/auto.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/specforge/modeling/auto.py b/specforge/modeling/auto.py
index 0e54777b..3da327b1 100644
--- a/specforge/modeling/auto.py
+++ b/specforge/modeling/auto.py
@@ -10,6 +10,7 @@
Llama4Config,
Llama4TextConfig,
LlamaConfig,
+ MistralConfig,
Phi3Config,
PretrainedConfig,
Qwen2_5_VLConfig,
@@ -24,6 +25,7 @@
from .draft.llama3_eagle import LlamaForCausalLMEagle3
from .target.llama import LlamaForCausalLM
from .target.llama4 import Llama4ForCausalLM
+from .target.mistral import MistralForCausalLM
from .target.phi3 import Phi3ForCausalLM
from .target.qwen2 import Qwen2ForCausalLM
from .target.qwen3 import Qwen3ForCausalLM
@@ -87,6 +89,7 @@ class AutoDistributedTargetModel(AutoModelForCausalLMBase):
LlamaConfig: [LlamaForCausalLM],
Qwen3Config: [Qwen3ForCausalLM],
Phi3Config: [Phi3ForCausalLM],
+ MistralConfig: [MistralForCausalLM],
}
@classmethod
From 14582abf081fcf6dbc8207a638133aceeb97f2d8 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 22:29:11 +0000
Subject: [PATCH 03/18] Register mistral templates
---
specforge/data/preprocessing.py | 22 +++++++++++++++-------
specforge/data/template.py | 33 ++++++++++++++++++++++++++++-----
2 files changed, 43 insertions(+), 12 deletions(-)
diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py
index fefa1630..3efdad7c 100644
--- a/specforge/data/preprocessing.py
+++ b/specforge/data/preprocessing.py
@@ -70,12 +70,20 @@ def _apply_loss_mask_from_chat_template(
"""
loss_mask = torch.zeros(len(offsets), dtype=torch.long)
- user_message_separator = (
- f"{chat_template.end_of_turn_token}{chat_template.user_header}"
- )
- assistant_message_separator = (
- f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
- )
+ if chat_template.end_of_turn_token is None:
+ user_message_separator = (
+ f"{chat_template.end_of_turn_token}{chat_template.user_header}"
+ )
+ assistant_message_separator = (
+ f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
+ )
+ else:
+ user_message_separator = (
+ f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header}"
+ )
+ assistant_message_separator = (
+ f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header}"
+ )
# Find spans of assistant responses using regex
assistant_pattern = (
@@ -149,7 +157,7 @@ def preprocess_conversations(
system_prompt = chat_template.system_prompt
# source is a list of conversation messages, need to format
- messages = [{"role": "system", "content": system_prompt}]
+ messages = [{"role": "system", "content": system_prompt}] if system_prompt is None else []
if source[0]["role"] != "user":
# if the first message is not from user, skip it
diff --git a/specforge/data/template.py b/specforge/data/template.py
index 12241113..b17fe354 100644
--- a/specforge/data/template.py
+++ b/specforge/data/template.py
@@ -1,5 +1,5 @@
# Adapted from: https://github.com/sgl-project/sglang/blob/main/python/sglang/lang/chat_template.py#L13
-from typing import List
+from typing import List, Optional
from pydantic import BaseModel
@@ -11,14 +11,19 @@ class ChatTemplate(BaseModel):
Args:
assistant_header(str): The header for the assistant.
user_header(str): The header for the user.
- system_prompt(str): The system prompt.
- end_of_turn_token(str): The end token of a turn of conversation.
+ system_prompt (Optional[str]): The system prompt.
+ end_of_turn_token (Optional[str]): The end token of a turn of conversation.
+ If present, end_of_assistant_token and end_of_user_token are ignored.
+ end_of_assistant_token (Optional[str]): The end token of an assistant turn of conversation.
+ end_of_user_token (Optional[str]): The end token of a user turn of conversation.
"""
assistant_header: str
user_header: str
- system_prompt: str
- end_of_turn_token: str
+ system_prompt: Optional[str] = None
+ end_of_turn_token: Optional[str] = None
+ end_of_assistant_token: Optional[str] = None
+ end_of_user_token: Optional[str] = None
class TemplateRegistry:
@@ -104,6 +109,24 @@ def get_all_template_names(self) -> List[str]:
),
)
+TEMPLATE_REGISTRY.register(
+ name="mistral-v0.1",
+ template=ChatTemplate(
+ assistant_header=" [/INST] ",
+ user_header=" [INST] ",
+ end_of_assistant_token="",
+ ),
+)
+
+TEMPLATE_REGISTRY.register(
+ name="mistral-v0.3",
+ template=ChatTemplate(
+ assistant_header="[/INST] ",
+ user_header="[INST] ",
+ end_of_assistant_token="",
+ ),
+)
+
TEMPLATE_REGISTRY.register(
name="qwen",
template=ChatTemplate(
From 37bf89ee786482e94628dd0b4400800b81c83daf Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 22:43:07 +0000
Subject: [PATCH 04/18] Add target model unit test
---
tests/test_target_modeling/test_mistral_tp.py | 85 +++++++++++++++++++
1 file changed, 85 insertions(+)
create mode 100644 tests/test_target_modeling/test_mistral_tp.py
diff --git a/tests/test_target_modeling/test_mistral_tp.py b/tests/test_target_modeling/test_mistral_tp.py
new file mode 100644
index 00000000..a8e1bf58
--- /dev/null
+++ b/tests/test_target_modeling/test_mistral_tp.py
@@ -0,0 +1,85 @@
+import os
+import tempfile
+import unittest
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from accelerate.utils import set_seed
+from transformers import MistralConfig, MistralForCausalLM
+
+from specforge.distributed import init_distributed
+
+
+def test_mistral_tp(rank, world_size, temp_dir):
+ os.environ["RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "29501"
+
+ init_distributed(tp_size=2)
+ set_seed(42)
+ config = MistralConfig(
+ vocab_size=1000,
+ hidden_size=384,
+ intermediate_size=512,
+ num_hidden_layers=2,
+ max_position_embeddings=1024,
+ num_attention_heads=10,
+ num_key_value_heads=2,
+ tie_word_embeddings=False,
+ initializer_range=0.02,
+ hidden_act="silu",
+ rms_norm_eps=1e-05,
+ )
+
+ # create the single-gpu
+ model = MistralForCausalLM(config).cuda()
+
+ from specforge.modeling.target.mistral import MistralForCausalLM as DistMistralForCausalLM
+
+ dist_model = DistMistralForCausalLM(config).cuda()
+
+ # save the model weights to a temp directory
+ if dist.get_rank() == 0:
+ model.save_pretrained(temp_dir)
+ print(f"Saved model to {temp_dir}")
+ dist.barrier()
+
+ # load the model weights to the distributed model
+ print(f"Loading model from {temp_dir}")
+ dist_model.load_checkpoint(temp_dir)
+ dist.barrier()
+
+ # create data
+ input_ids = torch.randint(0, 1000, (1, 256)).cuda()
+ attention_mask = torch.ones_like(input_ids).cuda()
+
+ expected_logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
+ dist_logits = dist_model(input_ids=input_ids, attention_mask=attention_mask).logits
+
+ assert torch.allclose(
+ expected_logits,
+ dist_logits,
+ rtol=1e-5,
+ atol=1e-5,
+ ), f"Logits are not close, {expected_logits} vs {dist_logits}"
+
+
+class TestMistralTP(unittest.TestCase):
+
+ def setUp(self):
+ self.temp_dir = tempfile.TemporaryDirectory()
+
+ def tearDown(self):
+ self.temp_dir.cleanup()
+
+ def test_mistral_tp(self):
+ mp.spawn(test_mistral_tp, nprocs=2, args=(2, self.temp_dir.name))
+
+
+if __name__ == "__main__":
+ suite = unittest.TestSuite()
+ suite.addTest(unittest.makeSuite(TestMistralTP))
+ runner = unittest.TextTestRunner(verbosity=2)
+ runner.run(suite)
From 1678ae531e9591d5823b9e7fecbe452c134b33bf Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 22:51:26 +0000
Subject: [PATCH 05/18] Add head_dim to test MistralConfig
---
tests/test_target_modeling/test_mistral_tp.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/tests/test_target_modeling/test_mistral_tp.py b/tests/test_target_modeling/test_mistral_tp.py
index a8e1bf58..ae02faa6 100644
--- a/tests/test_target_modeling/test_mistral_tp.py
+++ b/tests/test_target_modeling/test_mistral_tp.py
@@ -27,6 +27,7 @@ def test_mistral_tp(rank, world_size, temp_dir):
max_position_embeddings=1024,
num_attention_heads=10,
num_key_value_heads=2,
+ head_dim=64,
tie_word_embeddings=False,
initializer_range=0.02,
hidden_act="silu",
From 7d6de07e4f9c488dce639c1be37217af6cf71d15 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 23:22:15 +0000
Subject: [PATCH 06/18] Remove mistral v0.1 and v0.3 templates, add mistral
small 24B template
---
specforge/data/preprocessing.py | 2 +-
specforge/data/template.py | 33 ++++++++++++++++-----------------
2 files changed, 17 insertions(+), 18 deletions(-)
diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py
index 3efdad7c..2b29c9e3 100644
--- a/specforge/data/preprocessing.py
+++ b/specforge/data/preprocessing.py
@@ -157,7 +157,7 @@ def preprocess_conversations(
system_prompt = chat_template.system_prompt
# source is a list of conversation messages, need to format
- messages = [{"role": "system", "content": system_prompt}] if system_prompt is None else []
+ messages = [{"role": "system", "content": system_prompt}]
if source[0]["role"] != "user":
# if the first message is not from user, skip it
diff --git a/specforge/data/template.py b/specforge/data/template.py
index b17fe354..55913934 100644
--- a/specforge/data/template.py
+++ b/specforge/data/template.py
@@ -11,16 +11,16 @@ class ChatTemplate(BaseModel):
Args:
assistant_header(str): The header for the assistant.
user_header(str): The header for the user.
- system_prompt (Optional[str]): The system prompt.
- end_of_turn_token (Optional[str]): The end token of a turn of conversation.
+ system_prompt(str): The system prompt.
+ end_of_turn_token(Optional[str]): The end token of a turn of conversation.
If present, end_of_assistant_token and end_of_user_token are ignored.
- end_of_assistant_token (Optional[str]): The end token of an assistant turn of conversation.
- end_of_user_token (Optional[str]): The end token of a user turn of conversation.
+ end_of_assistant_token(Optional[str]): The end token of an assistant turn of conversation.
+ end_of_user_token(Optional[str]): The end token of a user turn of conversation.
"""
assistant_header: str
user_header: str
- system_prompt: Optional[str] = None
+ system_prompt: str
end_of_turn_token: Optional[str] = None
end_of_assistant_token: Optional[str] = None
end_of_user_token: Optional[str] = None
@@ -110,19 +110,18 @@ def get_all_template_names(self) -> List[str]:
)
TEMPLATE_REGISTRY.register(
- name="mistral-v0.1",
+ name="mistral-small-24B",
template=ChatTemplate(
- assistant_header=" [/INST] ",
- user_header=" [INST] ",
- end_of_assistant_token="",
- ),
-)
-
-TEMPLATE_REGISTRY.register(
- name="mistral-v0.3",
- template=ChatTemplate(
- assistant_header="[/INST] ",
- user_header="[INST] ",
+ assistant_header="[/INST]",
+ user_header="[INST]",
+ system_prompt="You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup "
+ "headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date"
+ "is 2025-08-31. When you're not sure about some information, you say that you don't have the "
+ "information and don't make up anything. If the user's question is not clear, ambiguous, or "
+ "does not provide enough context for you to accurately answer the question, you do not try to "
+ "answer it right away and you rather ask the user to clarify their request (e.g. \"What are "
+ "some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to "
+ "Tokyo\" => \"Where do you travel from?\")",
end_of_assistant_token="",
),
)
From 898e471da0c2fa7d5624125a7cc48cbabf912965 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 23:26:45 +0000
Subject: [PATCH 07/18] Add mistral-small-24B eagle3 config
---
configs/mistral-small-24B-eagle3.json | 27 +++++++++++++++++++++++++++
1 file changed, 27 insertions(+)
create mode 100644 configs/mistral-small-24B-eagle3.json
diff --git a/configs/mistral-small-24B-eagle3.json b/configs/mistral-small-24B-eagle3.json
new file mode 100644
index 00000000..7db7f501
--- /dev/null
+++ b/configs/mistral-small-24B-eagle3.json
@@ -0,0 +1,27 @@
+{
+ "architectures": [
+ "LlamaForCausalLMEagle3"
+ ],
+ "attention_dropout": 0.0,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 5120,
+ "initializer_range": 0.02,
+ "intermediate_size": 32768,
+ "max_position_embeddings": 32768,
+ "model_type": "llama",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 1,
+ "num_key_value_heads": 8,
+ "rms_norm_eps": 1e-05,
+ "rope_theta": 100000000.0,
+ "sliding_window": null,
+ "tie_word_embeddings": false,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.47.0",
+ "use_cache": true,
+ "vocab_size": 131072,
+ "draft_vocab_size": 32000
+}
From 6b4ac96a16ab2f1024b864ef9c23bcf0f17ed18b Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 23:46:14 +0000
Subject: [PATCH 08/18] Fix wrong chat_template.end_of_turn_token None check
---
specforge/data/preprocessing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py
index 2b29c9e3..98b01629 100644
--- a/specforge/data/preprocessing.py
+++ b/specforge/data/preprocessing.py
@@ -70,7 +70,7 @@ def _apply_loss_mask_from_chat_template(
"""
loss_mask = torch.zeros(len(offsets), dtype=torch.long)
- if chat_template.end_of_turn_token is None:
+ if chat_template.end_of_turn_token is not None:
user_message_separator = (
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
)
From 1eb91d724472d71ea7cf7958c388d8296d7e7a85 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 23:48:58 +0000
Subject: [PATCH 09/18] Test mistral-small-24B preprocessing
---
tests/test_preprocessing.py | 29 ++++++++++++++++++-----------
1 file changed, 18 insertions(+), 11 deletions(-)
diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py
index 80908dbe..9fe32a80 100644
--- a/tests/test_preprocessing.py
+++ b/tests/test_preprocessing.py
@@ -65,10 +65,10 @@ class TestPreprocessing(unittest.TestCase):
"""Test suite for conversation preprocessing and loss mask generation."""
def setUp(self):
- """Set up test fixtures with Qwen3-8B tokenizer and template."""
- self.model_path = "Qwen/Qwen3-8B"
+ """Set up test fixtures with mistralai/Mistral-Small-24B-Instruct-2501 tokenizer and template."""
+ self.model_path = "mistralai/Mistral-Small-24B-Instruct-2501"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
- self.chat_template = TEMPLATE_REGISTRY.get("qwen")
+ self.chat_template = TEMPLATE_REGISTRY.get("mistral-small-24B")
self.max_length = 512
def test_conversation_preprocessing_basic(self):
@@ -122,7 +122,7 @@ def test_conversation_preprocessing_basic(self):
assistant_tokens, skip_special_tokens=False
)
expected_assistant_text = (
- "\n\n\n\nThe answer is 4.<|im_end|>\n"
+ "The answer is 4."
)
self.assertEqual(
assistant_text,
@@ -165,7 +165,7 @@ def test_multiple_turns_conversation(self):
)
# Exact match for the complete assistant text from both turns
- expected_assistant_text = "The answer is 4.<|im_end|>\n\n\n\nYes, I'm certain.<|im_end|>\n"
+ expected_assistant_text = "The answer is 4.Yes, I'm certain."
self.assertEqual(
assistant_text,
expected_assistant_text,
@@ -175,7 +175,14 @@ def test_multiple_turns_conversation(self):
def test_preformatted_conversation(self):
"""Test preprocessing of pre-formatted conversation strings."""
preformatted_conversations = [
- "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is Python?<|im_end|>\n<|im_start|>assistant\nPython is a programming language.<|im_end|>\n"
+ "[SYSTEM_PROMPT]You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French "
+ "startup headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date is "
+ "2025-08-31. When you're not sure about some information, you say that you don't have the information and "
+ "don't make up anything. If the user's question is not clear, ambiguous, or does not provide enough context "
+ "for you to accurately answer the question, you do not try to answer it right away and you rather ask the "
+ "user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" "
+ "or \"When is the next flight to Tokyo\" => \"Where do you travel from?\")[/SYSTEM_PROMPT]"
+ "[INST]What is Python?[/INST]Python is a programming language."
]
results = preprocess_conversations(
@@ -206,7 +213,7 @@ def test_preformatted_conversation(self):
)
# Check for exact match of the expected assistant response
- expected_assistant_text = "Python is a programming language.<|im_end|>\n"
+ expected_assistant_text = "Python is a programming language."
self.assertEqual(
assistant_text,
expected_assistant_text,
@@ -222,7 +229,7 @@ def test_assistant_span_boundaries(self):
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
],
- "expected_assistant_text": "\n\n\n\nHello!<|im_end|>\n",
+ "expected_assistant_text": "Hello!",
},
{
"name": "Response with punctuation",
@@ -230,7 +237,7 @@ def test_assistant_span_boundaries(self):
{"role": "user", "content": "What's your name?"},
{"role": "assistant", "content": "I'm Claude, an AI assistant."},
],
- "expected_assistant_text": "\n\n\n\nI'm Claude, an AI assistant.<|im_end|>\n",
+ "expected_assistant_text": "I'm Claude, an AI assistant.",
},
{
"name": "Multi-sentence response",
@@ -241,7 +248,7 @@ def test_assistant_span_boundaries(self):
"content": "Python is a programming language. It's very popular for AI.",
},
],
- "expected_assistant_text": "\n\n\n\nPython is a programming language. It's very popular for AI.<|im_end|>\n",
+ "expected_assistant_text": "Python is a programming language. It's very popular for AI.",
},
{
"name": "Response with special characters",
@@ -252,7 +259,7 @@ def test_assistant_span_boundaries(self):
"content": "Sure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.",
},
],
- "expected_assistant_text": "\n\n\n\nSure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.<|im_end|>\n",
+ "expected_assistant_text": "Sure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.",
},
]
From 892278814f982aa8dde7c5919be7bd43978923a6 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Sun, 31 Aug 2025 23:52:40 +0000
Subject: [PATCH 10/18] Restore Qwen3-8B preprocessing test
---
tests/test_preprocessing.py | 29 +++++++++++------------------
1 file changed, 11 insertions(+), 18 deletions(-)
diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py
index 9fe32a80..80908dbe 100644
--- a/tests/test_preprocessing.py
+++ b/tests/test_preprocessing.py
@@ -65,10 +65,10 @@ class TestPreprocessing(unittest.TestCase):
"""Test suite for conversation preprocessing and loss mask generation."""
def setUp(self):
- """Set up test fixtures with mistralai/Mistral-Small-24B-Instruct-2501 tokenizer and template."""
- self.model_path = "mistralai/Mistral-Small-24B-Instruct-2501"
+ """Set up test fixtures with Qwen3-8B tokenizer and template."""
+ self.model_path = "Qwen/Qwen3-8B"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
- self.chat_template = TEMPLATE_REGISTRY.get("mistral-small-24B")
+ self.chat_template = TEMPLATE_REGISTRY.get("qwen")
self.max_length = 512
def test_conversation_preprocessing_basic(self):
@@ -122,7 +122,7 @@ def test_conversation_preprocessing_basic(self):
assistant_tokens, skip_special_tokens=False
)
expected_assistant_text = (
- "The answer is 4."
+ "\n\n\n\nThe answer is 4.<|im_end|>\n"
)
self.assertEqual(
assistant_text,
@@ -165,7 +165,7 @@ def test_multiple_turns_conversation(self):
)
# Exact match for the complete assistant text from both turns
- expected_assistant_text = "The answer is 4.Yes, I'm certain."
+ expected_assistant_text = "The answer is 4.<|im_end|>\n\n\n\nYes, I'm certain.<|im_end|>\n"
self.assertEqual(
assistant_text,
expected_assistant_text,
@@ -175,14 +175,7 @@ def test_multiple_turns_conversation(self):
def test_preformatted_conversation(self):
"""Test preprocessing of pre-formatted conversation strings."""
preformatted_conversations = [
- "[SYSTEM_PROMPT]You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French "
- "startup headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date is "
- "2025-08-31. When you're not sure about some information, you say that you don't have the information and "
- "don't make up anything. If the user's question is not clear, ambiguous, or does not provide enough context "
- "for you to accurately answer the question, you do not try to answer it right away and you rather ask the "
- "user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" "
- "or \"When is the next flight to Tokyo\" => \"Where do you travel from?\")[/SYSTEM_PROMPT]"
- "[INST]What is Python?[/INST]Python is a programming language."
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is Python?<|im_end|>\n<|im_start|>assistant\nPython is a programming language.<|im_end|>\n"
]
results = preprocess_conversations(
@@ -213,7 +206,7 @@ def test_preformatted_conversation(self):
)
# Check for exact match of the expected assistant response
- expected_assistant_text = "Python is a programming language."
+ expected_assistant_text = "Python is a programming language.<|im_end|>\n"
self.assertEqual(
assistant_text,
expected_assistant_text,
@@ -229,7 +222,7 @@ def test_assistant_span_boundaries(self):
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
],
- "expected_assistant_text": "Hello!",
+ "expected_assistant_text": "\n\n\n\nHello!<|im_end|>\n",
},
{
"name": "Response with punctuation",
@@ -237,7 +230,7 @@ def test_assistant_span_boundaries(self):
{"role": "user", "content": "What's your name?"},
{"role": "assistant", "content": "I'm Claude, an AI assistant."},
],
- "expected_assistant_text": "I'm Claude, an AI assistant.",
+ "expected_assistant_text": "\n\n\n\nI'm Claude, an AI assistant.<|im_end|>\n",
},
{
"name": "Multi-sentence response",
@@ -248,7 +241,7 @@ def test_assistant_span_boundaries(self):
"content": "Python is a programming language. It's very popular for AI.",
},
],
- "expected_assistant_text": "Python is a programming language. It's very popular for AI.",
+ "expected_assistant_text": "\n\n\n\nPython is a programming language. It's very popular for AI.<|im_end|>\n",
},
{
"name": "Response with special characters",
@@ -259,7 +252,7 @@ def test_assistant_span_boundaries(self):
"content": "Sure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.",
},
],
- "expected_assistant_text": "Sure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.",
+ "expected_assistant_text": "\n\n\n\nSure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.<|im_end|>\n",
},
]
From 037980f82678511dfc66a3ac866ac29386a7bce0 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Mon, 1 Sep 2025 16:24:00 +0000
Subject: [PATCH 11/18] Add train script for mistral-Small-24B
---
.../run_mistral_small_24B_eagle3_online.sh | 22 +++++++++++++++++++
1 file changed, 22 insertions(+)
create mode 100644 examples/run_mistral_small_24B_eagle3_online.sh
diff --git a/examples/run_mistral_small_24B_eagle3_online.sh b/examples/run_mistral_small_24B_eagle3_online.sh
new file mode 100644
index 00000000..0ee773f3
--- /dev/null
+++ b/examples/run_mistral_small_24B_eagle3_online.sh
@@ -0,0 +1,22 @@
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+ROOT_DIR=$(dirname $SCRIPT_DIR)
+export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
+
+# train eagle3 for mistral-Small-24B
+NUM_GPUS=${1:-8}
+
+torchrun \
+ --standalone \
+ --nproc_per_node $NUM_GPUS \
+ $ROOT_DIR/scripts/train_eagle3_online.py \
+ --target-model-path mistralai/Mistral-Small-24B-Instruct-2501 \
+ --draft-model-config $ROOT_DIR/configs/mistral-small-24B-eagle3.json \
+ --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
+ --output-dir $ROOT_DIR/outputs/mistral-Small-24B-eagle3 \
+ --num-epochs 2 \
+ --batch-size 1 \
+ --learning-rate 1e-4 \
+ --max-length 2048 \
+ --chat-template mistral-small-24B \
+ --cache-dir $ROOT_DIR/cache \
+ --attention-backend flex_attention
From 877247b15dc518ab0233dd22aa7b4151b009ac68 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Mon, 1 Sep 2025 18:24:23 +0000
Subject: [PATCH 12/18] Lint fix
---
specforge/data/template.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/specforge/data/template.py b/specforge/data/template.py
index d50ba26b..6358fd82 100644
--- a/specforge/data/template.py
+++ b/specforge/data/template.py
@@ -13,7 +13,7 @@ class ChatTemplate(BaseModel):
user_header(str): The header for the user.
system_prompt(str): The system prompt.
end_of_turn_token(str): The end token of a turn of conversation.
- If present, end_of_assistant_token and end_of_user_token are ignored.
+ If present, end_of_assistant_token and end_of_user_token are ignored.
end_of_assistant_token(str): The end token of an assistant turn of conversation.
end_of_user_token(str): The end token of a user turn of conversation.
"""
From a059679dd7f64997ad44131f1dd464619361dc5d Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Mon, 1 Sep 2025 18:31:13 +0000
Subject: [PATCH 13/18] Fix misleading return type
---
specforge/modeling/target/mistral.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/specforge/modeling/target/mistral.py b/specforge/modeling/target/mistral.py
index f5dd2437..be8a1cb0 100644
--- a/specforge/modeling/target/mistral.py
+++ b/specforge/modeling/target/mistral.py
@@ -129,7 +129,7 @@ def forward(
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ ) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
From 06cdfeb6af6c6cd661479762a004767bd4b521de Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Thu, 4 Sep 2025 09:19:50 +0000
Subject: [PATCH 14/18] Fix code format
---
specforge/data/template.py | 14 ++---
specforge/modeling/target/mistral.py | 60 +++++++++++--------
tests/test_target_modeling/test_mistral_tp.py | 4 +-
3 files changed, 44 insertions(+), 34 deletions(-)
diff --git a/specforge/data/template.py b/specforge/data/template.py
index 6358fd82..45c3b57a 100644
--- a/specforge/data/template.py
+++ b/specforge/data/template.py
@@ -116,13 +116,13 @@ def get_all_template_names(self) -> List[str]:
assistant_header="[/INST]",
user_header="[INST]",
system_prompt="You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup "
- "headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date"
- "is 2025-08-31. When you're not sure about some information, you say that you don't have the "
- "information and don't make up anything. If the user's question is not clear, ambiguous, or "
- "does not provide enough context for you to accurately answer the question, you do not try to "
- "answer it right away and you rather ask the user to clarify their request (e.g. \"What are "
- "some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to "
- "Tokyo\" => \"Where do you travel from?\")",
+ "headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date"
+ "is 2025-08-31. When you're not sure about some information, you say that you don't have the "
+ "information and don't make up anything. If the user's question is not clear, ambiguous, or "
+ "does not provide enough context for you to accurately answer the question, you do not try to "
+ 'answer it right away and you rather ask the user to clarify their request (e.g. "What are '
+ 'some good restaurants around me?" => "Where are you?" or "When is the next flight to '
+ 'Tokyo" => "Where do you travel from?")',
end_of_assistant_token="",
),
)
diff --git a/specforge/modeling/target/mistral.py b/specforge/modeling/target/mistral.py
index be8a1cb0..d2f92b22 100644
--- a/specforge/modeling/target/mistral.py
+++ b/specforge/modeling/target/mistral.py
@@ -92,7 +92,7 @@ def __init__(self, config: MistralConfig, layer_idx: int):
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
- config.num_attention_heads // config.num_key_value_heads
+ config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
@@ -122,13 +122,13 @@ def __init__(self, config: MistralConfig, layer_idx: int):
)
def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -163,7 +163,9 @@ def forward(
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
- sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
+ sliding_window=getattr(
+ self.config, "sliding_window", None
+ ), # main diff with Llama
**kwargs,
)
@@ -181,24 +183,26 @@ def __init__(self, config: MistralConfig, layer_idx: int):
self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
self.mlp = MistralMLP(config)
- self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.input_layernorm = MistralRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
self.post_attention_layernorm = MistralRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[
- tuple[torch.Tensor, torch.Tensor]
- ] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[
+ tuple[torch.Tensor, torch.Tensor]
+ ] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
]:
@@ -347,14 +351,16 @@ def forward(
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
- device=inputs_embeds.device
+ device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
mask_function = (
- create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
+ create_causal_mask
+ if self.config.sliding_window is None
+ else create_sliding_window_causal_mask
)
causal_mask = mask_function(
config=self.config,
@@ -409,7 +415,9 @@ def forward(
@auto_docstring
-class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin, DistributedTargetModel):
+class MistralForCausalLM(
+ MistralPreTrainedModel, GenerationMixin, DistributedTargetModel
+):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
@@ -518,7 +526,7 @@ def forward(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
- **kwargs
+ **kwargs,
)
return CausalLMOutputWithPast(
diff --git a/tests/test_target_modeling/test_mistral_tp.py b/tests/test_target_modeling/test_mistral_tp.py
index ae02faa6..e3526184 100644
--- a/tests/test_target_modeling/test_mistral_tp.py
+++ b/tests/test_target_modeling/test_mistral_tp.py
@@ -37,7 +37,9 @@ def test_mistral_tp(rank, world_size, temp_dir):
# create the single-gpu
model = MistralForCausalLM(config).cuda()
- from specforge.modeling.target.mistral import MistralForCausalLM as DistMistralForCausalLM
+ from specforge.modeling.target.mistral import (
+ MistralForCausalLM as DistMistralForCausalLM,
+ )
dist_model = DistMistralForCausalLM(config).cuda()
From ab36686db3a8aeb2aebc55d8f8a04d6b05e58122 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Mon, 22 Sep 2025 18:36:27 +0000
Subject: [PATCH 15/18] Fix preprocessing
---
specforge/data/parse.py | 12 ++++++------
specforge/data/preprocessing.py | 16 +++++-----------
2 files changed, 11 insertions(+), 17 deletions(-)
diff --git a/specforge/data/parse.py b/specforge/data/parse.py
index a4c095df..07734c3b 100644
--- a/specforge/data/parse.py
+++ b/specforge/data/parse.py
@@ -41,12 +41,12 @@ class GeneralParser(Parser):
def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
super().__init__(tokenizer, chat_template)
self.system_prompt = chat_template.system_prompt
- self.user_message_separator = (
- f"{chat_template.end_of_turn_token}{chat_template.user_header}"
- )
- self.assistant_message_separator = (
- f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
- )
+ if chat_template.end_of_turn_token:
+ self.user_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}"
+ self.assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}"
+ else:
+ self.user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}"
+ self.assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}"
def parse(
self, conversation: "Conversation", max_length: int, preformatted: bool = False
diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py
index 37092864..b86f3392 100644
--- a/specforge/data/preprocessing.py
+++ b/specforge/data/preprocessing.py
@@ -71,20 +71,14 @@ def _apply_loss_mask_from_chat_template(
"""
loss_mask = torch.zeros(len(offsets), dtype=torch.long)
- if chat_template.end_of_turn_token is not None:
+ if chat_template.end_of_turn_token:
user_message_separator = (
- f"{chat_template.end_of_turn_token}{chat_template.user_header}"
- )
- assistant_message_separator = (
- f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
+ f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}"
)
+ assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}"
else:
- user_message_separator = (
- f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header}"
- )
- assistant_message_separator = (
- f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header}"
- )
+ user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}"
+ assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}"
# Find spans of assistant responses using regex
assistant_pattern = (
From e8a43a29fd286ab884c531b63511a923840c249f Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Mon, 22 Sep 2025 18:57:18 +0000
Subject: [PATCH 16/18] Fix linting
---
scripts/train_eagle3_online.py | 25 ++++++++++---------------
1 file changed, 10 insertions(+), 15 deletions(-)
diff --git a/scripts/train_eagle3_online.py b/scripts/train_eagle3_online.py
index c4df940d..1d2b7eb0 100644
--- a/scripts/train_eagle3_online.py
+++ b/scripts/train_eagle3_online.py
@@ -291,22 +291,17 @@ def main():
# load model with resume
if draft_model_last_checkpoint:
- draft_model = (
- AutoEagle3DraftModel.from_pretrained(
- draft_model_last_checkpoint, attention_backend=args.attention_backend,
- torch_dtype=torch.bfloat16
- )
- .cuda()
-
- )
+ draft_model = AutoEagle3DraftModel.from_pretrained(
+ draft_model_last_checkpoint,
+ attention_backend=args.attention_backend,
+ torch_dtype=torch.bfloat16,
+ ).cuda()
else:
- draft_model = (
- AutoEagle3DraftModel.from_config(
- draft_model_config, attention_backend=args.attention_backend,
- torch_dtype=torch.bfloat16
- )
- .cuda()
- )
+ draft_model = AutoEagle3DraftModel.from_config(
+ draft_model_config,
+ attention_backend=args.attention_backend,
+ torch_dtype=torch.bfloat16,
+ ).cuda()
draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key)
draft_model.freeze_embedding()
print_with_rank("Initialized draft model")
From 26022f14b723ed65c69f1c54f85ad3aad1bca9eb Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Mon, 22 Sep 2025 19:09:43 +0000
Subject: [PATCH 17/18] Increase TP to 2 to fit on H100 with 96GB
---
examples/run_mistral_small_24B_eagle3_online.sh | 1 +
1 file changed, 1 insertion(+)
diff --git a/examples/run_mistral_small_24B_eagle3_online.sh b/examples/run_mistral_small_24B_eagle3_online.sh
index 0ee773f3..7a8a445b 100644
--- a/examples/run_mistral_small_24B_eagle3_online.sh
+++ b/examples/run_mistral_small_24B_eagle3_online.sh
@@ -15,6 +15,7 @@ torchrun \
--output-dir $ROOT_DIR/outputs/mistral-Small-24B-eagle3 \
--num-epochs 2 \
--batch-size 1 \
+ --tp 2 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template mistral-small-24B \
From 33c4bd7062904a1fa718c613f0929be1427aca29 Mon Sep 17 00:00:00 2001
From: ValeGian <46gianninivalerio@gmail.com>
Date: Wed, 1 Oct 2025 15:23:54 +0000
Subject: [PATCH 18/18] Set default NUM_GPUS to 2
---
examples/run_mistral_small_24B_eagle3_online.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/run_mistral_small_24B_eagle3_online.sh b/examples/run_mistral_small_24B_eagle3_online.sh
index 7a8a445b..89bb1314 100644
--- a/examples/run_mistral_small_24B_eagle3_online.sh
+++ b/examples/run_mistral_small_24B_eagle3_online.sh
@@ -3,7 +3,7 @@ ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
# train eagle3 for mistral-Small-24B
-NUM_GPUS=${1:-8}
+NUM_GPUS=${1:-2}
torchrun \
--standalone \