Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions utils/check_modular_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def process_file(
# Read the actual modeling file
with open(file_path, "r", encoding="utf-8") as modeling_file:
content = modeling_file.read()
output_buffer = StringIO(generated_modeling_content[file_type][0])
output_buffer = StringIO(generated_modeling_content[file_type])
output_buffer.seek(0)
output_content = output_buffer.read()
diff = difflib.unified_diff(
Expand All @@ -54,7 +54,7 @@ def process_file(
shutil.copy(file_path, file_path + BACKUP_EXT)
# we always save the generated content, to be able to update dependant files
with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file:
modeling_file.write(generated_modeling_content[file_type][0])
modeling_file.write(generated_modeling_content[file_type])
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
if show_diff:
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
Expand Down
66 changes: 39 additions & 27 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import glob
import importlib
import multiprocessing as mp
import os
import re
import subprocess
Expand Down Expand Up @@ -1226,7 +1227,8 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
if m.matches(node.module, m.Attribute()):
for imported_ in node.names:
_import = re.search(
rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement
rf"(?:transformers\.models\.)|(?:\.\.\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*",
Comment on lines -1229 to +1230
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was missing, which caused from ...models.name.modeling import issues in #40431

import_statement,
)
if _import:
source = _import.group(1)
Expand Down Expand Up @@ -1688,7 +1690,8 @@ def run_ruff(code, check=False):
return stdout.decode()


def convert_modular_file(modular_file):
def convert_modular_file(modular_file: str) -> dict[str, str]:
"""Convert a `modular_file` into all the different model-specific files it depicts."""
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
output = {}
if pattern is not None:
Expand All @@ -1712,34 +1715,30 @@ def convert_modular_file(modular_file):
)
ruffed_code = run_ruff(header + module.code, True)
formatted_code = run_ruff(ruffed_code, False)
output[file] = [formatted_code, ruffed_code]
output[file] = formatted_code
return output
else:
print(f"modular pattern not found in {modular_file}, exiting")
return {}


def save_modeling_file(modular_file, converted_file):
for file_type in converted_file:
def save_modeling_files(modular_file: str, converted_files: dict[str, str]):
"""Save all the `converted_files` from the `modular_file`."""
for file_type in converted_files:
file_name_prefix = file_type.split("*")[0]
file_name_suffix = file_type.split("*")[-1] if "*" in file_type else ""
new_file_name = modular_file.replace("modular_", f"{file_name_prefix}_").replace(
".py", f"{file_name_suffix}.py"
)
non_comment_lines = len(
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
)
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
with open(new_file_name, "w", encoding="utf-8") as f:
f.write(converted_file[file_type][0])
else:
non_comment_lines = len(
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
)
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
logger.warning("The modeling code contains errors, it's written without formatting")
with open(new_file_name, "w", encoding="utf-8") as f:
f.write(converted_file[file_type][1])
with open(new_file_name, "w", encoding="utf-8") as f:
f.write(converted_files[file_type])


def run_converter(modular_file: str):
"""Convert a modular file, and save resulting files."""
print(f"Converting {modular_file} to a single model single file format")
converted_files = convert_modular_file(modular_file)
save_modeling_files(modular_file, converted_files)


if __name__ == "__main__":
Expand All @@ -1759,9 +1758,17 @@ def save_modeling_file(modular_file, converted_file):
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)
parser.add_argument(
"--num_workers",
"-w",
default=-1,
type=int,
help="The number of workers to use. Default is -1, which means the number of CPU cores.",
)
args = parser.parse_args()
# Both arg represent the same data, but as positional and optional
files_to_parse = args.files if len(args.files) > 0 else args.files_to_parse
num_workers = mp.cpu_count() if args.num_workers == -1 else args.num_workers

if files_to_parse == ["all"]:
files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
Expand All @@ -1779,12 +1786,17 @@ def save_modeling_file(modular_file, converted_file):
raise ValueError(f"Cannot find a modular file for {model_name}. Please provide the full path.")
files_to_parse[i] = full_path

priority_list, _ = find_priority_list(files_to_parse)
priority_list = [item for sublist in priority_list for item in sublist] # flatten the list of lists
assert len(priority_list) == len(files_to_parse), "Some files will not be converted"
# This finds the correct order in which we should convert the modular files, so that a model relying on another one
# is necessarily converted after its dependencies
ordered_files, _ = find_priority_list(files_to_parse)
if sum(len(level_files) for level_files in ordered_files) != len(files_to_parse):
raise ValueError(
"Some files will not be converted because they do not appear in the dependency graph."
"This usually means that at least one modular file does not import any model-specific class"
)

for file_name in priority_list:
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converted_files = convert_modular_file(file_name)
converter = save_modeling_file(file_name, converted_files)
for dependency_level_files in ordered_files:
# Process files with diff
workers = min(num_workers, len(dependency_level_files))
with mp.Pool(workers) as pool:
pool.map(run_converter, dependency_level_files)