Skip to content

Commit b998066

Browse files
committed
Enable SIM rules
Signed-off-by: cyy <[email protected]>
1 parent 4f93cc9 commit b998066

27 files changed

+49
-78
lines changed

examples/modular-transformers/modeling_dummy_bert.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,7 @@ def forward(
302302
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
303303
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
304304
# a causal mask in case tgt_len == 1.
305-
is_causal = (
306-
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
307-
)
305+
is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1
308306

309307
attn_output = torch.nn.functional.scaled_dot_product_attention(
310308
query_layer,

examples/modular-transformers/modeling_roberta.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,7 @@ def forward(
305305
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
306306
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
307307
# a causal mask in case tgt_len == 1.
308-
is_causal = (
309-
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
310-
)
308+
is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1
311309

312310
attn_output = torch.nn.functional.scaled_dot_product_attention(
313311
query_layer,

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@ line-length = 119
1919

2020
[tool.ruff.lint]
2121
# Never enforce `E501` (line length violations).
22-
ignore = ["C901", "E501", "E741", "F402", "F823"]
22+
# SIM300: Yoda condition detected
23+
# SIM212: Checks for if expressions that check against a negated condition.
24+
# SIM905: Consider using a list literal instead of `str.split`
25+
ignore = ["C901", "E501", "E741", "F402", "F823", "SIM1", "SIM300", "SIM212", "SIM905"]
2326
# RUF013: Checks for the use of implicit Optional
2427
# in type annotations when the default parameter value is None.
25-
select = ["C", "E", "F", "I", "W", "RUF013", "UP006", "PERF102", "PLC1802", "PLC0208"]
28+
select = ["C", "E", "F", "I", "W", "RUF013", "UP006", "PERF102", "PLC1802", "PLC0208","SIM"]
2629
extend-safe-fixes = ["UP006"]
2730

2831
# Ignore import violations in all `__init__.py` files.

src/transformers/commands/serving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
891891
inputs = processor.apply_chat_template(
892892
processor_inputs,
893893
add_generation_prompt=True,
894-
tools=req.get("tools", None),
894+
tools=req.get("tools"),
895895
return_tensors="pt",
896896
return_dict=True,
897897
tokenize=True,

src/transformers/data/data_collator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def tf_default_data_collator(features: list[InputDataClass]) -> dict[str, Any]:
183183
if label_col_name is not None:
184184
if isinstance(first[label_col_name], tf.Tensor):
185185
dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
186-
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
186+
elif isinstance(first[label_col_name], (np.ndarray, np.generic)):
187187
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
188188
elif isinstance(first[label_col_name], (tuple, list)):
189189
dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32

src/transformers/generation/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -645,8 +645,8 @@ def prepare_inputs_for_generation(
645645

646646
# If it's not defined, it means the model uses the new general mask API
647647
if causal_mask_creation_function is None: # can't be found
648-
token_type_ids = model_inputs.get("token_type_ids", None)
649-
position_ids = model_inputs.get(position_ids_key, None)
648+
token_type_ids = model_inputs.get("token_type_ids")
649+
position_ids = model_inputs.get(position_ids_key)
650650
# Some models may overwrite the general one
651651
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
652652
attention_mask = causal_mask_creation_function(

src/transformers/integrations/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
277277
num_local_query_heads = query.shape[1]
278278

279279
# When running TP this helps:
280-
if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
280+
if (num_local_query_heads & (num_local_query_heads - 1)) != 0:
281281
key = repeat_kv(key, query.shape[1] // key.shape[1])
282282
value = repeat_kv(value, query.shape[1] // value.shape[1])
283283
enable_gqa = False

src/transformers/integrations/vptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def replace_with_vptq_linear(
4545
should not be passed by the user.
4646
"""
4747

48-
modules_to_not_convert = ["lm_head"] if not modules_to_not_convert else modules_to_not_convert
48+
modules_to_not_convert = modules_to_not_convert if modules_to_not_convert else ["lm_head"]
4949

5050
for name, module in model.named_children():
5151
if current_key_name is None:

src/transformers/keras_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _postprocess_predictions_or_labels(self, inputs):
167167
# If it's a dict with only one key, just return the array
168168
if len(outputs) == 1:
169169
outputs = list(outputs.values())[0]
170-
elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
170+
elif isinstance(inputs[0], (tuple, list)):
171171
outputs = []
172172
for input_list in zip(*inputs):
173173
outputs.append(self._concatenate_batches(input_list))

src/transformers/models/biogpt/convert_biogpt_original_pytorch_checkpoint_to_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def rewrite_dict_keys(d):
148148
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
149149
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
150150
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
151-
keep_keys = "<s> <pad> </s> <unk>".split()
151+
keep_keys = ["<s>", "<pad>", "</s>", "<unk>"]
152152
# restore the special tokens
153153
for k in keep_keys:
154154
del d2[f"{k}</w>"]

0 commit comments

Comments
 (0)