Skip to content

Commit 06cdfeb

Browse files
committed
Fix code format
1 parent a059679 commit 06cdfeb

File tree

3 files changed

+44
-34
lines changed

3 files changed

+44
-34
lines changed

specforge/data/template.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,13 @@ def get_all_template_names(self) -> List[str]:
116116
assistant_header="[/INST]",
117117
user_header="[INST]",
118118
system_prompt="You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup "
119-
"headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date"
120-
"is 2025-08-31. When you're not sure about some information, you say that you don't have the "
121-
"information and don't make up anything. If the user's question is not clear, ambiguous, or "
122-
"does not provide enough context for you to accurately answer the question, you do not try to "
123-
"answer it right away and you rather ask the user to clarify their request (e.g. \"What are "
124-
"some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to "
125-
"Tokyo\" => \"Where do you travel from?\")",
119+
"headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date"
120+
"is 2025-08-31. When you're not sure about some information, you say that you don't have the "
121+
"information and don't make up anything. If the user's question is not clear, ambiguous, or "
122+
"does not provide enough context for you to accurately answer the question, you do not try to "
123+
'answer it right away and you rather ask the user to clarify their request (e.g. "What are '
124+
'some good restaurants around me?" => "Where are you?" or "When is the next flight to '
125+
'Tokyo" => "Where do you travel from?")',
126126
end_of_assistant_token="</s>",
127127
),
128128
)

specforge/modeling/target/mistral.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(self, config: MistralConfig, layer_idx: int):
9292
config, "head_dim", config.hidden_size // config.num_attention_heads
9393
)
9494
self.num_key_value_groups = (
95-
config.num_attention_heads // config.num_key_value_heads
95+
config.num_attention_heads // config.num_key_value_heads
9696
)
9797
self.scaling = self.head_dim**-0.5
9898
self.attention_dropout = config.attention_dropout
@@ -122,13 +122,13 @@ def __init__(self, config: MistralConfig, layer_idx: int):
122122
)
123123

124124
def forward(
125-
self,
126-
hidden_states: torch.Tensor,
127-
position_embeddings: tuple[torch.Tensor, torch.Tensor],
128-
attention_mask: Optional[torch.Tensor],
129-
past_key_value: Optional[Cache] = None,
130-
cache_position: Optional[torch.LongTensor] = None,
131-
**kwargs: Unpack[FlashAttentionKwargs],
125+
self,
126+
hidden_states: torch.Tensor,
127+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
128+
attention_mask: Optional[torch.Tensor],
129+
past_key_value: Optional[Cache] = None,
130+
cache_position: Optional[torch.LongTensor] = None,
131+
**kwargs: Unpack[FlashAttentionKwargs],
132132
) -> tuple[torch.Tensor, torch.Tensor]:
133133
input_shape = hidden_states.shape[:-1]
134134
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -163,7 +163,9 @@ def forward(
163163
attention_mask,
164164
dropout=0.0 if not self.training else self.attention_dropout,
165165
scaling=self.scaling,
166-
sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
166+
sliding_window=getattr(
167+
self.config, "sliding_window", None
168+
), # main diff with Llama
167169
**kwargs,
168170
)
169171

@@ -181,24 +183,26 @@ def __init__(self, config: MistralConfig, layer_idx: int):
181183
self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
182184

183185
self.mlp = MistralMLP(config)
184-
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
186+
self.input_layernorm = MistralRMSNorm(
187+
config.hidden_size, eps=config.rms_norm_eps
188+
)
185189
self.post_attention_layernorm = MistralRMSNorm(
186190
config.hidden_size, eps=config.rms_norm_eps
187191
)
188192

189193
def forward(
190-
self,
191-
hidden_states: torch.Tensor,
192-
attention_mask: Optional[torch.Tensor] = None,
193-
position_ids: Optional[torch.LongTensor] = None,
194-
past_key_value: Optional[Cache] = None,
195-
output_attentions: Optional[bool] = False,
196-
use_cache: Optional[bool] = False,
197-
cache_position: Optional[torch.LongTensor] = None,
198-
position_embeddings: Optional[
199-
tuple[torch.Tensor, torch.Tensor]
200-
] = None, # necessary, but kept here for BC
201-
**kwargs: Unpack[FlashAttentionKwargs],
194+
self,
195+
hidden_states: torch.Tensor,
196+
attention_mask: Optional[torch.Tensor] = None,
197+
position_ids: Optional[torch.LongTensor] = None,
198+
past_key_value: Optional[Cache] = None,
199+
output_attentions: Optional[bool] = False,
200+
use_cache: Optional[bool] = False,
201+
cache_position: Optional[torch.LongTensor] = None,
202+
position_embeddings: Optional[
203+
tuple[torch.Tensor, torch.Tensor]
204+
] = None, # necessary, but kept here for BC
205+
**kwargs: Unpack[FlashAttentionKwargs],
202206
) -> tuple[
203207
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
204208
]:
@@ -347,14 +351,16 @@ def forward(
347351
cache_position = torch.arange(
348352
past_seen_tokens,
349353
past_seen_tokens + inputs_embeds.shape[1],
350-
device=inputs_embeds.device
354+
device=inputs_embeds.device,
351355
)
352356

353357
if position_ids is None:
354358
position_ids = cache_position.unsqueeze(0)
355359

356360
mask_function = (
357-
create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
361+
create_causal_mask
362+
if self.config.sliding_window is None
363+
else create_sliding_window_causal_mask
358364
)
359365
causal_mask = mask_function(
360366
config=self.config,
@@ -409,7 +415,9 @@ def forward(
409415

410416

411417
@auto_docstring
412-
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin, DistributedTargetModel):
418+
class MistralForCausalLM(
419+
MistralPreTrainedModel, GenerationMixin, DistributedTargetModel
420+
):
413421
_tied_weights_keys = ["lm_head.weight"]
414422
_tp_plan = {"lm_head": "colwise_rep"}
415423
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
@@ -518,7 +526,7 @@ def forward(
518526
logits=logits,
519527
labels=labels,
520528
vocab_size=self.config.vocab_size,
521-
**kwargs
529+
**kwargs,
522530
)
523531

524532
return CausalLMOutputWithPast(

tests/test_target_modeling/test_mistral_tp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def test_mistral_tp(rank, world_size, temp_dir):
3737
# create the single-gpu
3838
model = MistralForCausalLM(config).cuda()
3939

40-
from specforge.modeling.target.mistral import MistralForCausalLM as DistMistralForCausalLM
40+
from specforge.modeling.target.mistral import (
41+
MistralForCausalLM as DistMistralForCausalLM,
42+
)
4143

4244
dist_model = DistMistralForCausalLM(config).cuda()
4345

0 commit comments

Comments
 (0)