diff --git a/utils/check_modular_conversion.py b/utils/check_modular_conversion.py index dad6b1f0a22c..5b9258cbd3f1 100644 --- a/utils/check_modular_conversion.py +++ b/utils/check_modular_conversion.py @@ -36,7 +36,7 @@ def process_file( # Read the actual modeling file with open(file_path, "r", encoding="utf-8") as modeling_file: content = modeling_file.read() - output_buffer = StringIO(generated_modeling_content[file_type][0]) + output_buffer = StringIO(generated_modeling_content[file_type]) output_buffer.seek(0) output_content = output_buffer.read() diff = difflib.unified_diff( @@ -54,7 +54,7 @@ def process_file( shutil.copy(file_path, file_path + BACKUP_EXT) # we always save the generated content, to be able to update dependant files with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file: - modeling_file.write(generated_modeling_content[file_type][0]) + modeling_file.write(generated_modeling_content[file_type]) console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]") if show_diff: console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n") diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 86eecc83a5ab..3c2206cccc28 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -15,6 +15,7 @@ import argparse import glob import importlib +import multiprocessing as mp import os import re import subprocess @@ -1226,7 +1227,8 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: if m.matches(node.module, m.Attribute()): for imported_ in node.names: _import = re.search( - rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement + rf"(?:transformers\.models\.)|(?:\.\.\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", + import_statement, ) if _import: source = _import.group(1) @@ -1688,7 +1690,8 @@ def run_ruff(code, check=False): return stdout.decode() -def convert_modular_file(modular_file): +def convert_modular_file(modular_file: str) -> dict[str, str]: + """Convert a `modular_file` into all the different model-specific files it depicts.""" pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file) output = {} if pattern is not None: @@ -1712,34 +1715,30 @@ def convert_modular_file(modular_file): ) ruffed_code = run_ruff(header + module.code, True) formatted_code = run_ruff(ruffed_code, False) - output[file] = [formatted_code, ruffed_code] + output[file] = formatted_code return output else: print(f"modular pattern not found in {modular_file}, exiting") return {} -def save_modeling_file(modular_file, converted_file): - for file_type in converted_file: +def save_modeling_files(modular_file: str, converted_files: dict[str, str]): + """Save all the `converted_files` from the `modular_file`.""" + for file_type in converted_files: file_name_prefix = file_type.split("*")[0] file_name_suffix = file_type.split("*")[-1] if "*" in file_type else "" new_file_name = modular_file.replace("modular_", f"{file_name_prefix}_").replace( ".py", f"{file_name_suffix}.py" ) - non_comment_lines = len( - [line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")] - ) - if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0: - with open(new_file_name, "w", encoding="utf-8") as f: - f.write(converted_file[file_type][0]) - else: - non_comment_lines = len( - [line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")] - ) - if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0: - logger.warning("The modeling code contains errors, it's written without formatting") - with open(new_file_name, "w", encoding="utf-8") as f: - f.write(converted_file[file_type][1]) + with open(new_file_name, "w", encoding="utf-8") as f: + f.write(converted_files[file_type]) + + +def run_converter(modular_file: str): + """Convert a modular file, and save resulting files.""" + print(f"Converting {modular_file} to a single model single file format") + converted_files = convert_modular_file(modular_file) + save_modeling_files(modular_file, converted_files) if __name__ == "__main__": @@ -1759,9 +1758,17 @@ def save_modeling_file(modular_file, converted_file): nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) + parser.add_argument( + "--num_workers", + "-w", + default=-1, + type=int, + help="The number of workers to use. Default is -1, which means the number of CPU cores.", + ) args = parser.parse_args() # Both arg represent the same data, but as positional and optional files_to_parse = args.files if len(args.files) > 0 else args.files_to_parse + num_workers = mp.cpu_count() if args.num_workers == -1 else args.num_workers if files_to_parse == ["all"]: files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) @@ -1779,12 +1786,17 @@ def save_modeling_file(modular_file, converted_file): raise ValueError(f"Cannot find a modular file for {model_name}. Please provide the full path.") files_to_parse[i] = full_path - priority_list, _ = find_priority_list(files_to_parse) - priority_list = [item for sublist in priority_list for item in sublist] # flatten the list of lists - assert len(priority_list) == len(files_to_parse), "Some files will not be converted" + # This finds the correct order in which we should convert the modular files, so that a model relying on another one + # is necessarily converted after its dependencies + ordered_files, _ = find_priority_list(files_to_parse) + if sum(len(level_files) for level_files in ordered_files) != len(files_to_parse): + raise ValueError( + "Some files will not be converted because they do not appear in the dependency graph." + "This usually means that at least one modular file does not import any model-specific class" + ) - for file_name in priority_list: - print(f"Converting {file_name} to a single model single file format") - module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "") - converted_files = convert_modular_file(file_name) - converter = save_modeling_file(file_name, converted_files) + for dependency_level_files in ordered_files: + # Process files with diff + workers = min(num_workers, len(dependency_level_files)) + with mp.Pool(workers) as pool: + pool.map(run_converter, dependency_level_files)