From f60f45839be02b70bebc2f26001da344a385fdc8 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 18:29:08 +0000 Subject: [PATCH 01/18] Add mistral target model --- specforge/modeling/target/mistral.py | 566 +++++++++++++++++++++++++++ 1 file changed, 566 insertions(+) create mode 100644 specforge/modeling/target/mistral.py diff --git a/specforge/modeling/target/mistral.py b/specforge/modeling/target/mistral.py new file mode 100644 index 00000000..f5dd2437 --- /dev/null +++ b/specforge/modeling/target/mistral.py @@ -0,0 +1,566 @@ +# coding=utf-8 +# Copyright 2025 Mistral AI and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from transformers import MistralConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.mistral.modeling_mistral import ( + MistralRMSNorm, + MistralRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) + +from specforge.distributed import get_tp_group +from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear +from specforge.modeling.target.base import DistributedTargetModel + +logger = logging.get_logger(__name__) + + +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # distributed linear layers + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class MistralAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # distributed linear layers + self.tp_group = get_tp_group() + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=False, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=False, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=False, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class MistralDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class MistralPreTrainedModel(PreTrainedModel): + config_class: MistralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MistralDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class MistralModel(MistralPreTrainedModel): + def __init__(self, config: MistralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + MistralDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MistralRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError( + "The `past_key_values` should be either a `Cache` object or `None`." + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = ( + create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + ) + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@auto_docstring +class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin, DistributedTargetModel): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + + # distributed the lm head + self.lm_head = ColumnParallelLinear( + config.hidden_size, config.vocab_size, bias=False + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = self._gather_tensor(logits, get_tp_group()) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def load_weights(self, state_dict: Dict[str, torch.Tensor]): + tp_group = get_tp_group() + + updated_state_dict = {} + for key, value in state_dict.items(): + # Ensure that the state dict is a flat dict of keys and tensors. Breaking this assumption + # will break recipe code + if not isinstance(value, torch.Tensor): + raise ValueError( + f"Expected all values in the state dict to be torch.Tensor. " + f"Found {type(value)} instead." + ) + + module_key = ".".join(key.split(".")[:-1]) + module = self.get_submodule(module_key) + + # get the module type based on key and shard accordingly + if isinstance(module, RowParallelLinear) and key.endswith(".weight"): + value = self._shard_tensor(value, tp_group, -1) + elif isinstance(module, ColumnParallelLinear) and key.endswith(".weight"): + value = self._shard_tensor(value, tp_group, 0) + elif isinstance(module, ColumnParallelLinear) and key.endswith(".bias"): + value = self._shard_tensor(value, tp_group, 0) + + updated_state_dict[key] = value + + # load state dict + self.load_state_dict(updated_state_dict, strict=False) + + +__all__ = [ + "MistralForCausalLM", + "MistralPreTrainedModel", + "MistralModel", +] From 5b83e92094b04193ccb117c5bb6807eec2889c80 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 18:36:38 +0000 Subject: [PATCH 02/18] Add mistral to AutoDistributedTargetModel _model_mapping --- specforge/modeling/auto.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/specforge/modeling/auto.py b/specforge/modeling/auto.py index 0e54777b..3da327b1 100644 --- a/specforge/modeling/auto.py +++ b/specforge/modeling/auto.py @@ -10,6 +10,7 @@ Llama4Config, Llama4TextConfig, LlamaConfig, + MistralConfig, Phi3Config, PretrainedConfig, Qwen2_5_VLConfig, @@ -24,6 +25,7 @@ from .draft.llama3_eagle import LlamaForCausalLMEagle3 from .target.llama import LlamaForCausalLM from .target.llama4 import Llama4ForCausalLM +from .target.mistral import MistralForCausalLM from .target.phi3 import Phi3ForCausalLM from .target.qwen2 import Qwen2ForCausalLM from .target.qwen3 import Qwen3ForCausalLM @@ -87,6 +89,7 @@ class AutoDistributedTargetModel(AutoModelForCausalLMBase): LlamaConfig: [LlamaForCausalLM], Qwen3Config: [Qwen3ForCausalLM], Phi3Config: [Phi3ForCausalLM], + MistralConfig: [MistralForCausalLM], } @classmethod From 14582abf081fcf6dbc8207a638133aceeb97f2d8 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 22:29:11 +0000 Subject: [PATCH 03/18] Register mistral templates --- specforge/data/preprocessing.py | 22 +++++++++++++++------- specforge/data/template.py | 33 ++++++++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index fefa1630..3efdad7c 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -70,12 +70,20 @@ def _apply_loss_mask_from_chat_template( """ loss_mask = torch.zeros(len(offsets), dtype=torch.long) - user_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.user_header}" - ) - assistant_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" - ) + if chat_template.end_of_turn_token is None: + user_message_separator = ( + f"{chat_template.end_of_turn_token}{chat_template.user_header}" + ) + assistant_message_separator = ( + f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" + ) + else: + user_message_separator = ( + f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header}" + ) + assistant_message_separator = ( + f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header}" + ) # Find spans of assistant responses using regex assistant_pattern = ( @@ -149,7 +157,7 @@ def preprocess_conversations( system_prompt = chat_template.system_prompt # source is a list of conversation messages, need to format - messages = [{"role": "system", "content": system_prompt}] + messages = [{"role": "system", "content": system_prompt}] if system_prompt is None else [] if source[0]["role"] != "user": # if the first message is not from user, skip it diff --git a/specforge/data/template.py b/specforge/data/template.py index 12241113..b17fe354 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -1,5 +1,5 @@ # Adapted from: https://github.com/sgl-project/sglang/blob/main/python/sglang/lang/chat_template.py#L13 -from typing import List +from typing import List, Optional from pydantic import BaseModel @@ -11,14 +11,19 @@ class ChatTemplate(BaseModel): Args: assistant_header(str): The header for the assistant. user_header(str): The header for the user. - system_prompt(str): The system prompt. - end_of_turn_token(str): The end token of a turn of conversation. + system_prompt (Optional[str]): The system prompt. + end_of_turn_token (Optional[str]): The end token of a turn of conversation. + If present, end_of_assistant_token and end_of_user_token are ignored. + end_of_assistant_token (Optional[str]): The end token of an assistant turn of conversation. + end_of_user_token (Optional[str]): The end token of a user turn of conversation. """ assistant_header: str user_header: str - system_prompt: str - end_of_turn_token: str + system_prompt: Optional[str] = None + end_of_turn_token: Optional[str] = None + end_of_assistant_token: Optional[str] = None + end_of_user_token: Optional[str] = None class TemplateRegistry: @@ -104,6 +109,24 @@ def get_all_template_names(self) -> List[str]: ), ) +TEMPLATE_REGISTRY.register( + name="mistral-v0.1", + template=ChatTemplate( + assistant_header=" [/INST] ", + user_header=" [INST] ", + end_of_assistant_token="", + ), +) + +TEMPLATE_REGISTRY.register( + name="mistral-v0.3", + template=ChatTemplate( + assistant_header="[/INST] ", + user_header="[INST] ", + end_of_assistant_token="", + ), +) + TEMPLATE_REGISTRY.register( name="qwen", template=ChatTemplate( From 37bf89ee786482e94628dd0b4400800b81c83daf Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 22:43:07 +0000 Subject: [PATCH 04/18] Add target model unit test --- tests/test_target_modeling/test_mistral_tp.py | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/test_target_modeling/test_mistral_tp.py diff --git a/tests/test_target_modeling/test_mistral_tp.py b/tests/test_target_modeling/test_mistral_tp.py new file mode 100644 index 00000000..a8e1bf58 --- /dev/null +++ b/tests/test_target_modeling/test_mistral_tp.py @@ -0,0 +1,85 @@ +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from transformers import MistralConfig, MistralForCausalLM + +from specforge.distributed import init_distributed + + +def test_mistral_tp(rank, world_size, temp_dir): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29501" + + init_distributed(tp_size=2) + set_seed(42) + config = MistralConfig( + vocab_size=1000, + hidden_size=384, + intermediate_size=512, + num_hidden_layers=2, + max_position_embeddings=1024, + num_attention_heads=10, + num_key_value_heads=2, + tie_word_embeddings=False, + initializer_range=0.02, + hidden_act="silu", + rms_norm_eps=1e-05, + ) + + # create the single-gpu + model = MistralForCausalLM(config).cuda() + + from specforge.modeling.target.mistral import MistralForCausalLM as DistMistralForCausalLM + + dist_model = DistMistralForCausalLM(config).cuda() + + # save the model weights to a temp directory + if dist.get_rank() == 0: + model.save_pretrained(temp_dir) + print(f"Saved model to {temp_dir}") + dist.barrier() + + # load the model weights to the distributed model + print(f"Loading model from {temp_dir}") + dist_model.load_checkpoint(temp_dir) + dist.barrier() + + # create data + input_ids = torch.randint(0, 1000, (1, 256)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + + expected_logits = model(input_ids=input_ids, attention_mask=attention_mask).logits + dist_logits = dist_model(input_ids=input_ids, attention_mask=attention_mask).logits + + assert torch.allclose( + expected_logits, + dist_logits, + rtol=1e-5, + atol=1e-5, + ), f"Logits are not close, {expected_logits} vs {dist_logits}" + + +class TestMistralTP(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_mistral_tp(self): + mp.spawn(test_mistral_tp, nprocs=2, args=(2, self.temp_dir.name)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestMistralTP)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) From 1678ae531e9591d5823b9e7fecbe452c134b33bf Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 22:51:26 +0000 Subject: [PATCH 05/18] Add head_dim to test MistralConfig --- tests/test_target_modeling/test_mistral_tp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_target_modeling/test_mistral_tp.py b/tests/test_target_modeling/test_mistral_tp.py index a8e1bf58..ae02faa6 100644 --- a/tests/test_target_modeling/test_mistral_tp.py +++ b/tests/test_target_modeling/test_mistral_tp.py @@ -27,6 +27,7 @@ def test_mistral_tp(rank, world_size, temp_dir): max_position_embeddings=1024, num_attention_heads=10, num_key_value_heads=2, + head_dim=64, tie_word_embeddings=False, initializer_range=0.02, hidden_act="silu", From 7d6de07e4f9c488dce639c1be37217af6cf71d15 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 23:22:15 +0000 Subject: [PATCH 06/18] Remove mistral v0.1 and v0.3 templates, add mistral small 24B template --- specforge/data/preprocessing.py | 2 +- specforge/data/template.py | 33 ++++++++++++++++----------------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index 3efdad7c..2b29c9e3 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -157,7 +157,7 @@ def preprocess_conversations( system_prompt = chat_template.system_prompt # source is a list of conversation messages, need to format - messages = [{"role": "system", "content": system_prompt}] if system_prompt is None else [] + messages = [{"role": "system", "content": system_prompt}] if source[0]["role"] != "user": # if the first message is not from user, skip it diff --git a/specforge/data/template.py b/specforge/data/template.py index b17fe354..55913934 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -11,16 +11,16 @@ class ChatTemplate(BaseModel): Args: assistant_header(str): The header for the assistant. user_header(str): The header for the user. - system_prompt (Optional[str]): The system prompt. - end_of_turn_token (Optional[str]): The end token of a turn of conversation. + system_prompt(str): The system prompt. + end_of_turn_token(Optional[str]): The end token of a turn of conversation. If present, end_of_assistant_token and end_of_user_token are ignored. - end_of_assistant_token (Optional[str]): The end token of an assistant turn of conversation. - end_of_user_token (Optional[str]): The end token of a user turn of conversation. + end_of_assistant_token(Optional[str]): The end token of an assistant turn of conversation. + end_of_user_token(Optional[str]): The end token of a user turn of conversation. """ assistant_header: str user_header: str - system_prompt: Optional[str] = None + system_prompt: str end_of_turn_token: Optional[str] = None end_of_assistant_token: Optional[str] = None end_of_user_token: Optional[str] = None @@ -110,19 +110,18 @@ def get_all_template_names(self) -> List[str]: ) TEMPLATE_REGISTRY.register( - name="mistral-v0.1", + name="mistral-small-24B", template=ChatTemplate( - assistant_header=" [/INST] ", - user_header=" [INST] ", - end_of_assistant_token="", - ), -) - -TEMPLATE_REGISTRY.register( - name="mistral-v0.3", - template=ChatTemplate( - assistant_header="[/INST] ", - user_header="[INST] ", + assistant_header="[/INST]", + user_header="[INST]", + system_prompt="You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup " + "headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date" + "is 2025-08-31. When you're not sure about some information, you say that you don't have the " + "information and don't make up anything. If the user's question is not clear, ambiguous, or " + "does not provide enough context for you to accurately answer the question, you do not try to " + "answer it right away and you rather ask the user to clarify their request (e.g. \"What are " + "some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to " + "Tokyo\" => \"Where do you travel from?\")", end_of_assistant_token="", ), ) From 898e471da0c2fa7d5624125a7cc48cbabf912965 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 23:26:45 +0000 Subject: [PATCH 07/18] Add mistral-small-24B eagle3 config --- configs/mistral-small-24B-eagle3.json | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 configs/mistral-small-24B-eagle3.json diff --git a/configs/mistral-small-24B-eagle3.json b/configs/mistral-small-24B-eagle3.json new file mode 100644 index 00000000..7db7f501 --- /dev/null +++ b/configs/mistral-small-24B-eagle3.json @@ -0,0 +1,27 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 32768, + "max_position_embeddings": 32768, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 100000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.47.0", + "use_cache": true, + "vocab_size": 131072, + "draft_vocab_size": 32000 +} From 6b4ac96a16ab2f1024b864ef9c23bcf0f17ed18b Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 23:46:14 +0000 Subject: [PATCH 08/18] Fix wrong chat_template.end_of_turn_token None check --- specforge/data/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index 2b29c9e3..98b01629 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -70,7 +70,7 @@ def _apply_loss_mask_from_chat_template( """ loss_mask = torch.zeros(len(offsets), dtype=torch.long) - if chat_template.end_of_turn_token is None: + if chat_template.end_of_turn_token is not None: user_message_separator = ( f"{chat_template.end_of_turn_token}{chat_template.user_header}" ) From 1eb91d724472d71ea7cf7958c388d8296d7e7a85 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 23:48:58 +0000 Subject: [PATCH 09/18] Test mistral-small-24B preprocessing --- tests/test_preprocessing.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 80908dbe..9fe32a80 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -65,10 +65,10 @@ class TestPreprocessing(unittest.TestCase): """Test suite for conversation preprocessing and loss mask generation.""" def setUp(self): - """Set up test fixtures with Qwen3-8B tokenizer and template.""" - self.model_path = "Qwen/Qwen3-8B" + """Set up test fixtures with mistralai/Mistral-Small-24B-Instruct-2501 tokenizer and template.""" + self.model_path = "mistralai/Mistral-Small-24B-Instruct-2501" self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) - self.chat_template = TEMPLATE_REGISTRY.get("qwen") + self.chat_template = TEMPLATE_REGISTRY.get("mistral-small-24B") self.max_length = 512 def test_conversation_preprocessing_basic(self): @@ -122,7 +122,7 @@ def test_conversation_preprocessing_basic(self): assistant_tokens, skip_special_tokens=False ) expected_assistant_text = ( - "\n\n\n\nThe answer is 4.<|im_end|>\n" + "The answer is 4." ) self.assertEqual( assistant_text, @@ -165,7 +165,7 @@ def test_multiple_turns_conversation(self): ) # Exact match for the complete assistant text from both turns - expected_assistant_text = "The answer is 4.<|im_end|>\n\n\n\nYes, I'm certain.<|im_end|>\n" + expected_assistant_text = "The answer is 4.Yes, I'm certain." self.assertEqual( assistant_text, expected_assistant_text, @@ -175,7 +175,14 @@ def test_multiple_turns_conversation(self): def test_preformatted_conversation(self): """Test preprocessing of pre-formatted conversation strings.""" preformatted_conversations = [ - "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is Python?<|im_end|>\n<|im_start|>assistant\nPython is a programming language.<|im_end|>\n" + "[SYSTEM_PROMPT]You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French " + "startup headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date is " + "2025-08-31. When you're not sure about some information, you say that you don't have the information and " + "don't make up anything. If the user's question is not clear, ambiguous, or does not provide enough context " + "for you to accurately answer the question, you do not try to answer it right away and you rather ask the " + "user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" " + "or \"When is the next flight to Tokyo\" => \"Where do you travel from?\")[/SYSTEM_PROMPT]" + "[INST]What is Python?[/INST]Python is a programming language." ] results = preprocess_conversations( @@ -206,7 +213,7 @@ def test_preformatted_conversation(self): ) # Check for exact match of the expected assistant response - expected_assistant_text = "Python is a programming language.<|im_end|>\n" + expected_assistant_text = "Python is a programming language." self.assertEqual( assistant_text, expected_assistant_text, @@ -222,7 +229,7 @@ def test_assistant_span_boundaries(self): {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}, ], - "expected_assistant_text": "\n\n\n\nHello!<|im_end|>\n", + "expected_assistant_text": "Hello!", }, { "name": "Response with punctuation", @@ -230,7 +237,7 @@ def test_assistant_span_boundaries(self): {"role": "user", "content": "What's your name?"}, {"role": "assistant", "content": "I'm Claude, an AI assistant."}, ], - "expected_assistant_text": "\n\n\n\nI'm Claude, an AI assistant.<|im_end|>\n", + "expected_assistant_text": "I'm Claude, an AI assistant.", }, { "name": "Multi-sentence response", @@ -241,7 +248,7 @@ def test_assistant_span_boundaries(self): "content": "Python is a programming language. It's very popular for AI.", }, ], - "expected_assistant_text": "\n\n\n\nPython is a programming language. It's very popular for AI.<|im_end|>\n", + "expected_assistant_text": "Python is a programming language. It's very popular for AI.", }, { "name": "Response with special characters", @@ -252,7 +259,7 @@ def test_assistant_span_boundaries(self): "content": "Sure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.", }, ], - "expected_assistant_text": "\n\n\n\nSure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.<|im_end|>\n", + "expected_assistant_text": "Sure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.", }, ] From 892278814f982aa8dde7c5919be7bd43978923a6 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Sun, 31 Aug 2025 23:52:40 +0000 Subject: [PATCH 10/18] Restore Qwen3-8B preprocessing test --- tests/test_preprocessing.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 9fe32a80..80908dbe 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -65,10 +65,10 @@ class TestPreprocessing(unittest.TestCase): """Test suite for conversation preprocessing and loss mask generation.""" def setUp(self): - """Set up test fixtures with mistralai/Mistral-Small-24B-Instruct-2501 tokenizer and template.""" - self.model_path = "mistralai/Mistral-Small-24B-Instruct-2501" + """Set up test fixtures with Qwen3-8B tokenizer and template.""" + self.model_path = "Qwen/Qwen3-8B" self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) - self.chat_template = TEMPLATE_REGISTRY.get("mistral-small-24B") + self.chat_template = TEMPLATE_REGISTRY.get("qwen") self.max_length = 512 def test_conversation_preprocessing_basic(self): @@ -122,7 +122,7 @@ def test_conversation_preprocessing_basic(self): assistant_tokens, skip_special_tokens=False ) expected_assistant_text = ( - "The answer is 4." + "\n\n\n\nThe answer is 4.<|im_end|>\n" ) self.assertEqual( assistant_text, @@ -165,7 +165,7 @@ def test_multiple_turns_conversation(self): ) # Exact match for the complete assistant text from both turns - expected_assistant_text = "The answer is 4.Yes, I'm certain." + expected_assistant_text = "The answer is 4.<|im_end|>\n\n\n\nYes, I'm certain.<|im_end|>\n" self.assertEqual( assistant_text, expected_assistant_text, @@ -175,14 +175,7 @@ def test_multiple_turns_conversation(self): def test_preformatted_conversation(self): """Test preprocessing of pre-formatted conversation strings.""" preformatted_conversations = [ - "[SYSTEM_PROMPT]You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French " - "startup headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date is " - "2025-08-31. When you're not sure about some information, you say that you don't have the information and " - "don't make up anything. If the user's question is not clear, ambiguous, or does not provide enough context " - "for you to accurately answer the question, you do not try to answer it right away and you rather ask the " - "user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" " - "or \"When is the next flight to Tokyo\" => \"Where do you travel from?\")[/SYSTEM_PROMPT]" - "[INST]What is Python?[/INST]Python is a programming language." + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is Python?<|im_end|>\n<|im_start|>assistant\nPython is a programming language.<|im_end|>\n" ] results = preprocess_conversations( @@ -213,7 +206,7 @@ def test_preformatted_conversation(self): ) # Check for exact match of the expected assistant response - expected_assistant_text = "Python is a programming language." + expected_assistant_text = "Python is a programming language.<|im_end|>\n" self.assertEqual( assistant_text, expected_assistant_text, @@ -229,7 +222,7 @@ def test_assistant_span_boundaries(self): {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}, ], - "expected_assistant_text": "Hello!", + "expected_assistant_text": "\n\n\n\nHello!<|im_end|>\n", }, { "name": "Response with punctuation", @@ -237,7 +230,7 @@ def test_assistant_span_boundaries(self): {"role": "user", "content": "What's your name?"}, {"role": "assistant", "content": "I'm Claude, an AI assistant."}, ], - "expected_assistant_text": "I'm Claude, an AI assistant.", + "expected_assistant_text": "\n\n\n\nI'm Claude, an AI assistant.<|im_end|>\n", }, { "name": "Multi-sentence response", @@ -248,7 +241,7 @@ def test_assistant_span_boundaries(self): "content": "Python is a programming language. It's very popular for AI.", }, ], - "expected_assistant_text": "Python is a programming language. It's very popular for AI.", + "expected_assistant_text": "\n\n\n\nPython is a programming language. It's very popular for AI.<|im_end|>\n", }, { "name": "Response with special characters", @@ -259,7 +252,7 @@ def test_assistant_span_boundaries(self): "content": "Sure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.", }, ], - "expected_assistant_text": "Sure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.", + "expected_assistant_text": "\n\n\n\nSure! Here's an example: 2 + 2 = 4, and π ≈ 3.14159.<|im_end|>\n", }, ] From 037980f82678511dfc66a3ac866ac29386a7bce0 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Mon, 1 Sep 2025 16:24:00 +0000 Subject: [PATCH 11/18] Add train script for mistral-Small-24B --- .../run_mistral_small_24B_eagle3_online.sh | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 examples/run_mistral_small_24B_eagle3_online.sh diff --git a/examples/run_mistral_small_24B_eagle3_online.sh b/examples/run_mistral_small_24B_eagle3_online.sh new file mode 100644 index 00000000..0ee773f3 --- /dev/null +++ b/examples/run_mistral_small_24B_eagle3_online.sh @@ -0,0 +1,22 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for mistral-Small-24B +NUM_GPUS=${1:-8} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3_online.py \ + --target-model-path mistralai/Mistral-Small-24B-Instruct-2501 \ + --draft-model-config $ROOT_DIR/configs/mistral-small-24B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \ + --output-dir $ROOT_DIR/outputs/mistral-Small-24B-eagle3 \ + --num-epochs 2 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template mistral-small-24B \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend flex_attention From 877247b15dc518ab0233dd22aa7b4151b009ac68 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Mon, 1 Sep 2025 18:24:23 +0000 Subject: [PATCH 12/18] Lint fix --- specforge/data/template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specforge/data/template.py b/specforge/data/template.py index d50ba26b..6358fd82 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -13,7 +13,7 @@ class ChatTemplate(BaseModel): user_header(str): The header for the user. system_prompt(str): The system prompt. end_of_turn_token(str): The end token of a turn of conversation. - If present, end_of_assistant_token and end_of_user_token are ignored. + If present, end_of_assistant_token and end_of_user_token are ignored. end_of_assistant_token(str): The end token of an assistant turn of conversation. end_of_user_token(str): The end token of a user turn of conversation. """ From a059679dd7f64997ad44131f1dd464619361dc5d Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Mon, 1 Sep 2025 18:31:13 +0000 Subject: [PATCH 13/18] Fix misleading return type --- specforge/modeling/target/mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specforge/modeling/target/mistral.py b/specforge/modeling/target/mistral.py index f5dd2437..be8a1cb0 100644 --- a/specforge/modeling/target/mistral.py +++ b/specforge/modeling/target/mistral.py @@ -129,7 +129,7 @@ def forward( past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) From 06cdfeb6af6c6cd661479762a004767bd4b521de Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Thu, 4 Sep 2025 09:19:50 +0000 Subject: [PATCH 14/18] Fix code format --- specforge/data/template.py | 14 ++--- specforge/modeling/target/mistral.py | 60 +++++++++++-------- tests/test_target_modeling/test_mistral_tp.py | 4 +- 3 files changed, 44 insertions(+), 34 deletions(-) diff --git a/specforge/data/template.py b/specforge/data/template.py index 6358fd82..45c3b57a 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -116,13 +116,13 @@ def get_all_template_names(self) -> List[str]: assistant_header="[/INST]", user_header="[INST]", system_prompt="You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup " - "headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date" - "is 2025-08-31. When you're not sure about some information, you say that you don't have the " - "information and don't make up anything. If the user's question is not clear, ambiguous, or " - "does not provide enough context for you to accurately answer the question, you do not try to " - "answer it right away and you rather ask the user to clarify their request (e.g. \"What are " - "some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to " - "Tokyo\" => \"Where do you travel from?\")", + "headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date" + "is 2025-08-31. When you're not sure about some information, you say that you don't have the " + "information and don't make up anything. If the user's question is not clear, ambiguous, or " + "does not provide enough context for you to accurately answer the question, you do not try to " + 'answer it right away and you rather ask the user to clarify their request (e.g. "What are ' + 'some good restaurants around me?" => "Where are you?" or "When is the next flight to ' + 'Tokyo" => "Where do you travel from?")', end_of_assistant_token="", ), ) diff --git a/specforge/modeling/target/mistral.py b/specforge/modeling/target/mistral.py index be8a1cb0..d2f92b22 100644 --- a/specforge/modeling/target/mistral.py +++ b/specforge/modeling/target/mistral.py @@ -92,7 +92,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): config, "head_dim", config.hidden_size // config.num_attention_heads ) self.num_key_value_groups = ( - config.num_attention_heads // config.num_key_value_heads + config.num_attention_heads // config.num_key_value_heads ) self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout @@ -122,13 +122,13 @@ def __init__(self, config: MistralConfig, layer_idx: int): ) def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -163,7 +163,9 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + sliding_window=getattr( + self.config, "sliding_window", None + ), # main diff with Llama **kwargs, ) @@ -181,24 +183,26 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) self.mlp = MistralMLP(config) - self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.post_attention_layernorm = MistralRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[ - tuple[torch.Tensor, torch.Tensor] - ] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[ torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] ]: @@ -347,14 +351,16 @@ def forward( cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device + device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) mask_function = ( - create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + create_causal_mask + if self.config.sliding_window is None + else create_sliding_window_causal_mask ) causal_mask = mask_function( config=self.config, @@ -409,7 +415,9 @@ def forward( @auto_docstring -class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin, DistributedTargetModel): +class MistralForCausalLM( + MistralPreTrainedModel, GenerationMixin, DistributedTargetModel +): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} @@ -518,7 +526,7 @@ def forward( logits=logits, labels=labels, vocab_size=self.config.vocab_size, - **kwargs + **kwargs, ) return CausalLMOutputWithPast( diff --git a/tests/test_target_modeling/test_mistral_tp.py b/tests/test_target_modeling/test_mistral_tp.py index ae02faa6..e3526184 100644 --- a/tests/test_target_modeling/test_mistral_tp.py +++ b/tests/test_target_modeling/test_mistral_tp.py @@ -37,7 +37,9 @@ def test_mistral_tp(rank, world_size, temp_dir): # create the single-gpu model = MistralForCausalLM(config).cuda() - from specforge.modeling.target.mistral import MistralForCausalLM as DistMistralForCausalLM + from specforge.modeling.target.mistral import ( + MistralForCausalLM as DistMistralForCausalLM, + ) dist_model = DistMistralForCausalLM(config).cuda() From ab36686db3a8aeb2aebc55d8f8a04d6b05e58122 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Mon, 22 Sep 2025 18:36:27 +0000 Subject: [PATCH 15/18] Fix preprocessing --- specforge/data/parse.py | 12 ++++++------ specforge/data/preprocessing.py | 16 +++++----------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/specforge/data/parse.py b/specforge/data/parse.py index a4c095df..07734c3b 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -41,12 +41,12 @@ class GeneralParser(Parser): def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): super().__init__(tokenizer, chat_template) self.system_prompt = chat_template.system_prompt - self.user_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.user_header}" - ) - self.assistant_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" - ) + if chat_template.end_of_turn_token: + self.user_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}" + self.assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}" + else: + self.user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}" + self.assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}" def parse( self, conversation: "Conversation", max_length: int, preformatted: bool = False diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index 37092864..b86f3392 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -71,20 +71,14 @@ def _apply_loss_mask_from_chat_template( """ loss_mask = torch.zeros(len(offsets), dtype=torch.long) - if chat_template.end_of_turn_token is not None: + if chat_template.end_of_turn_token: user_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.user_header}" - ) - assistant_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" + f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}" ) + assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}" else: - user_message_separator = ( - f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header}" - ) - assistant_message_separator = ( - f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header}" - ) + user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}" + assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}" # Find spans of assistant responses using regex assistant_pattern = ( From e8a43a29fd286ab884c531b63511a923840c249f Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Mon, 22 Sep 2025 18:57:18 +0000 Subject: [PATCH 16/18] Fix linting --- scripts/train_eagle3_online.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/scripts/train_eagle3_online.py b/scripts/train_eagle3_online.py index c4df940d..1d2b7eb0 100644 --- a/scripts/train_eagle3_online.py +++ b/scripts/train_eagle3_online.py @@ -291,22 +291,17 @@ def main(): # load model with resume if draft_model_last_checkpoint: - draft_model = ( - AutoEagle3DraftModel.from_pretrained( - draft_model_last_checkpoint, attention_backend=args.attention_backend, - torch_dtype=torch.bfloat16 - ) - .cuda() - - ) + draft_model = AutoEagle3DraftModel.from_pretrained( + draft_model_last_checkpoint, + attention_backend=args.attention_backend, + torch_dtype=torch.bfloat16, + ).cuda() else: - draft_model = ( - AutoEagle3DraftModel.from_config( - draft_model_config, attention_backend=args.attention_backend, - torch_dtype=torch.bfloat16 - ) - .cuda() - ) + draft_model = AutoEagle3DraftModel.from_config( + draft_model_config, + attention_backend=args.attention_backend, + torch_dtype=torch.bfloat16, + ).cuda() draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) draft_model.freeze_embedding() print_with_rank("Initialized draft model") From 26022f14b723ed65c69f1c54f85ad3aad1bca9eb Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Mon, 22 Sep 2025 19:09:43 +0000 Subject: [PATCH 17/18] Increase TP to 2 to fit on H100 with 96GB --- examples/run_mistral_small_24B_eagle3_online.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/run_mistral_small_24B_eagle3_online.sh b/examples/run_mistral_small_24B_eagle3_online.sh index 0ee773f3..7a8a445b 100644 --- a/examples/run_mistral_small_24B_eagle3_online.sh +++ b/examples/run_mistral_small_24B_eagle3_online.sh @@ -15,6 +15,7 @@ torchrun \ --output-dir $ROOT_DIR/outputs/mistral-Small-24B-eagle3 \ --num-epochs 2 \ --batch-size 1 \ + --tp 2 \ --learning-rate 1e-4 \ --max-length 2048 \ --chat-template mistral-small-24B \ From 33c4bd7062904a1fa718c613f0929be1427aca29 Mon Sep 17 00:00:00 2001 From: ValeGian <46gianninivalerio@gmail.com> Date: Wed, 1 Oct 2025 15:23:54 +0000 Subject: [PATCH 18/18] Set default NUM_GPUS to 2 --- examples/run_mistral_small_24B_eagle3_online.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_mistral_small_24B_eagle3_online.sh b/examples/run_mistral_small_24B_eagle3_online.sh index 7a8a445b..89bb1314 100644 --- a/examples/run_mistral_small_24B_eagle3_online.sh +++ b/examples/run_mistral_small_24B_eagle3_online.sh @@ -3,7 +3,7 @@ ROOT_DIR=$(dirname $SCRIPT_DIR) export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels # train eagle3 for mistral-Small-24B -NUM_GPUS=${1:-8} +NUM_GPUS=${1:-2} torchrun \ --standalone \