@@ -92,7 +92,7 @@ def __init__(self, config: MistralConfig, layer_idx: int):
9292 config , "head_dim" , config .hidden_size // config .num_attention_heads
9393 )
9494 self .num_key_value_groups = (
95- config .num_attention_heads // config .num_key_value_heads
95+ config .num_attention_heads // config .num_key_value_heads
9696 )
9797 self .scaling = self .head_dim ** - 0.5
9898 self .attention_dropout = config .attention_dropout
@@ -122,13 +122,13 @@ def __init__(self, config: MistralConfig, layer_idx: int):
122122 )
123123
124124 def forward (
125- self ,
126- hidden_states : torch .Tensor ,
127- position_embeddings : tuple [torch .Tensor , torch .Tensor ],
128- attention_mask : Optional [torch .Tensor ],
129- past_key_value : Optional [Cache ] = None ,
130- cache_position : Optional [torch .LongTensor ] = None ,
131- ** kwargs : Unpack [FlashAttentionKwargs ],
125+ self ,
126+ hidden_states : torch .Tensor ,
127+ position_embeddings : tuple [torch .Tensor , torch .Tensor ],
128+ attention_mask : Optional [torch .Tensor ],
129+ past_key_value : Optional [Cache ] = None ,
130+ cache_position : Optional [torch .LongTensor ] = None ,
131+ ** kwargs : Unpack [FlashAttentionKwargs ],
132132 ) -> tuple [torch .Tensor , torch .Tensor ]:
133133 input_shape = hidden_states .shape [:- 1 ]
134134 hidden_shape = (* input_shape , - 1 , self .head_dim )
@@ -163,7 +163,9 @@ def forward(
163163 attention_mask ,
164164 dropout = 0.0 if not self .training else self .attention_dropout ,
165165 scaling = self .scaling ,
166- sliding_window = getattr (self .config , "sliding_window" , None ), # main diff with Llama
166+ sliding_window = getattr (
167+ self .config , "sliding_window" , None
168+ ), # main diff with Llama
167169 ** kwargs ,
168170 )
169171
@@ -181,24 +183,26 @@ def __init__(self, config: MistralConfig, layer_idx: int):
181183 self .self_attn = MistralAttention (config = config , layer_idx = layer_idx )
182184
183185 self .mlp = MistralMLP (config )
184- self .input_layernorm = MistralRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
186+ self .input_layernorm = MistralRMSNorm (
187+ config .hidden_size , eps = config .rms_norm_eps
188+ )
185189 self .post_attention_layernorm = MistralRMSNorm (
186190 config .hidden_size , eps = config .rms_norm_eps
187191 )
188192
189193 def forward (
190- self ,
191- hidden_states : torch .Tensor ,
192- attention_mask : Optional [torch .Tensor ] = None ,
193- position_ids : Optional [torch .LongTensor ] = None ,
194- past_key_value : Optional [Cache ] = None ,
195- output_attentions : Optional [bool ] = False ,
196- use_cache : Optional [bool ] = False ,
197- cache_position : Optional [torch .LongTensor ] = None ,
198- position_embeddings : Optional [
199- tuple [torch .Tensor , torch .Tensor ]
200- ] = None , # necessary, but kept here for BC
201- ** kwargs : Unpack [FlashAttentionKwargs ],
194+ self ,
195+ hidden_states : torch .Tensor ,
196+ attention_mask : Optional [torch .Tensor ] = None ,
197+ position_ids : Optional [torch .LongTensor ] = None ,
198+ past_key_value : Optional [Cache ] = None ,
199+ output_attentions : Optional [bool ] = False ,
200+ use_cache : Optional [bool ] = False ,
201+ cache_position : Optional [torch .LongTensor ] = None ,
202+ position_embeddings : Optional [
203+ tuple [torch .Tensor , torch .Tensor ]
204+ ] = None , # necessary, but kept here for BC
205+ ** kwargs : Unpack [FlashAttentionKwargs ],
202206 ) -> tuple [
203207 torch .FloatTensor , Optional [tuple [torch .FloatTensor , torch .FloatTensor ]]
204208 ]:
@@ -347,14 +351,16 @@ def forward(
347351 cache_position = torch .arange (
348352 past_seen_tokens ,
349353 past_seen_tokens + inputs_embeds .shape [1 ],
350- device = inputs_embeds .device
354+ device = inputs_embeds .device ,
351355 )
352356
353357 if position_ids is None :
354358 position_ids = cache_position .unsqueeze (0 )
355359
356360 mask_function = (
357- create_causal_mask if self .config .sliding_window is None else create_sliding_window_causal_mask
361+ create_causal_mask
362+ if self .config .sliding_window is None
363+ else create_sliding_window_causal_mask
358364 )
359365 causal_mask = mask_function (
360366 config = self .config ,
@@ -409,7 +415,9 @@ def forward(
409415
410416
411417@auto_docstring
412- class MistralForCausalLM (MistralPreTrainedModel , GenerationMixin , DistributedTargetModel ):
418+ class MistralForCausalLM (
419+ MistralPreTrainedModel , GenerationMixin , DistributedTargetModel
420+ ):
413421 _tied_weights_keys = ["lm_head.weight" ]
414422 _tp_plan = {"lm_head" : "colwise_rep" }
415423 _pp_plan = {"lm_head" : (["hidden_states" ], ["logits" ])}
@@ -518,7 +526,7 @@ def forward(
518526 logits = logits ,
519527 labels = labels ,
520528 vocab_size = self .config .vocab_size ,
521- ** kwargs
529+ ** kwargs ,
522530 )
523531
524532 return CausalLMOutputWithPast (
0 commit comments