diff --git a/python/cog/command/call_graph.py b/python/cog/command/call_graph.py new file mode 100644 index 0000000000..c8f24a7055 --- /dev/null +++ b/python/cog/command/call_graph.py @@ -0,0 +1,94 @@ +"""A CLI for determining the call graph of a pipeline.""" + +import ast +import sys +from pathlib import Path +from typing import List + + +class IncludeAnalyzer(ast.NodeVisitor): + def __init__(self, file_path: Path) -> None: + self.file_path = file_path + self.includes: List[str] = [] + self.imports: dict[str, str] = {} + self.scope_stack: List[str] = [] + + def visit_Import(self, node: ast.Import) -> None: + for alias in node.names: + self.imports[alias.asname or alias.name] = alias.name + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + module = node.module or "" + for alias in node.names: + full_name = f"{module}.{alias.name}" if module else alias.name + self.imports[alias.asname or alias.name] = full_name + self.generic_visit(node) + + # Scope tracking + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self.scope_stack.append("function") + self.generic_visit(node) + self.scope_stack.pop() + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self.scope_stack.append("function") + self.generic_visit(node) + self.scope_stack.pop() + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + self.scope_stack.append("class") + self.generic_visit(node) + self.scope_stack.pop() + + def visit_Lambda(self, node: ast.Lambda) -> None: + self.scope_stack.append("lambda") + self.generic_visit(node) + self.scope_stack.pop() + + def visit_Call(self, node: ast.Call) -> None: + target = None + + if isinstance(node.func, ast.Attribute): + # Handles replicate.include + if isinstance(node.func.value, ast.Name): + target = f"{self.imports.get(node.func.value.id, node.func.value.id)}.{node.func.attr}" + elif isinstance(node.func, ast.Name): + # Handles `from replicate import include` then `include(...)` + target = self.imports.get(node.func.id, node.func.id) + + if target == "cog.ext.pipelines.include": + # Check scope + if self.scope_stack: + raise ValueError( + f"[{self.file_path}] Invalid scope at line {node.lineno}: `cog.ext.pipelines.include(...)` must be in global scope" + ) + elif node.args: + arg = node.args[0] + if isinstance(arg, ast.Str): + self.includes.append(arg.s) + else: + raise ValueError( + f"[{self.file_path}] Unresolvable argument at line {node.lineno}: Not a string literal" + ) + self.generic_visit(node) + + +def analyze_python_file( + file_path: Path, +) -> List[str]: + source = file_path.read_text() + tree = ast.parse(source, filename=str(file_path)) + analyzer = IncludeAnalyzer(file_path) + analyzer.visit(tree) + return analyzer.includes + + +def main(filepath: str) -> None: + """Run the main code for determining the call graph of a pipeline.""" + includes = analyze_python_file(Path(filepath)) + print(",".join(includes)) + + +if __name__ == "__main__": + main(sys.argv[1]) diff --git a/python/tests/command/__init__.py b/python/tests/command/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/tests/command/call_graph_test.py b/python/tests/command/call_graph_test.py new file mode 100644 index 0000000000..8c79716a84 --- /dev/null +++ b/python/tests/command/call_graph_test.py @@ -0,0 +1,73 @@ +import os +import tempfile +from pathlib import Path + +import pytest + +from cog.command.call_graph import analyze_python_file + + +def test_call_graph(): + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "predict.py") + with open(filepath, "w", encoding="utf8") as handle: + handle.write("""from cog import Path, Input +from cog.ext.pipelines import include + +flux_schnell = include("black-forest-labs/flux-schnell") + +def run( + prompt: str = Input(description="Describe the image to generate"), + seed: int = Input(description="A seed", default=0) +) -> Path: + output_url = flux_schnell(prompt=prompt, seed=seed)[0] + return output_url +""") + includes = analyze_python_file(Path(filepath)) + assert includes == ["black-forest-labs/flux-schnell"] + + +def test_call_graph_with_dynamic_string(): + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "predict.py") + with open(filepath, "w", encoding="utf8") as handle: + handle.write("""from cog import Path, Input +from cog.ext.pipelines import include + +i = 2 +flux_schnell = include(f"black-forest-labs/flux-schnell-{i}") + +def run( + prompt: str = Input(description="Describe the image to generate"), + seed: int = Input(description="A seed", default=0) +) -> Path: + output_url = flux_schnell(prompt=prompt, seed=seed)[0] + return output_url +""") + with pytest.raises(ValueError) as excinfo: + analyze_python_file(Path(filepath)) + assert str(excinfo.value).endswith( + "Unresolvable argument at line 5: Not a string literal" + ) + + +def test_call_graph_include_constructed_in_local_scope(): + with tempfile.TemporaryDirectory() as tmpdir: + filepath = os.path.join(tmpdir, "predict.py") + with open(filepath, "w", encoding="utf8") as handle: + handle.write("""from cog import Path, Input +from cog.ext.pipelines import include + +def run( + prompt: str = Input(description="Describe the image to generate"), + seed: int = Input(description="A seed", default=0) +) -> Path: + flux_schnell = include("black-forest-labs/flux-schnell") + output_url = flux_schnell(prompt=prompt, seed=seed)[0] + return output_url +""") + with pytest.raises(ValueError) as excinfo: + analyze_python_file(Path(filepath)) + assert str(excinfo.value).endswith( + "Invalid scope at line 8: `cog.ext.pipelines.include(...)` must be in global scope" + )