Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions python/cog/command/call_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""A CLI for determining the call graph of a pipeline."""

import ast
import sys
from pathlib import Path
from typing import List


class IncludeAnalyzer(ast.NodeVisitor):
def __init__(self, file_path: Path) -> None:
self.file_path = file_path
self.includes: List[str] = []
self.imports: dict[str, str] = {}
self.scope_stack: List[str] = []

def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
self.imports[alias.asname or alias.name] = alias.name
self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
module = node.module or ""
for alias in node.names:
full_name = f"{module}.{alias.name}" if module else alias.name
self.imports[alias.asname or alias.name] = full_name
self.generic_visit(node)

# Scope tracking
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.scope_stack.append("function")
self.generic_visit(node)
self.scope_stack.pop()

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.scope_stack.append("function")
self.generic_visit(node)
self.scope_stack.pop()

def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.scope_stack.append("class")
self.generic_visit(node)
self.scope_stack.pop()

def visit_Lambda(self, node: ast.Lambda) -> None:
self.scope_stack.append("lambda")
self.generic_visit(node)
self.scope_stack.pop()

def visit_Call(self, node: ast.Call) -> None:
target = None

if isinstance(node.func, ast.Attribute):
# Handles replicate.include
if isinstance(node.func.value, ast.Name):
target = f"{self.imports.get(node.func.value.id, node.func.value.id)}.{node.func.attr}"
elif isinstance(node.func, ast.Name):
# Handles `from replicate import include` then `include(...)`
target = self.imports.get(node.func.id, node.func.id)

if target == "cog.ext.pipelines.include":
# Check scope
if self.scope_stack:
raise ValueError(
f"[{self.file_path}] Invalid scope at line {node.lineno}: `cog.ext.pipelines.include(...)` must be in global scope"
)
elif node.args:
arg = node.args[0]
if isinstance(arg, ast.Str):
self.includes.append(arg.s)
else:
raise ValueError(
f"[{self.file_path}] Unresolvable argument at line {node.lineno}: Not a string literal"
)
self.generic_visit(node)


def analyze_python_file(
file_path: Path,
) -> List[str]:
source = file_path.read_text()
tree = ast.parse(source, filename=str(file_path))
analyzer = IncludeAnalyzer(file_path)
analyzer.visit(tree)
return analyzer.includes


def main(filepath: str) -> None:
"""Run the main code for determining the call graph of a pipeline."""
includes = analyze_python_file(Path(filepath))
print(",".join(includes))


if __name__ == "__main__":
main(sys.argv[1])
Empty file.
73 changes: 73 additions & 0 deletions python/tests/command/call_graph_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import tempfile
from pathlib import Path

import pytest

from cog.command.call_graph import analyze_python_file


def test_call_graph():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "predict.py")
with open(filepath, "w", encoding="utf8") as handle:
handle.write("""from cog import Path, Input
from cog.ext.pipelines import include

flux_schnell = include("black-forest-labs/flux-schnell")

def run(
prompt: str = Input(description="Describe the image to generate"),
seed: int = Input(description="A seed", default=0)
) -> Path:
output_url = flux_schnell(prompt=prompt, seed=seed)[0]
return output_url
""")
includes = analyze_python_file(Path(filepath))
assert includes == ["black-forest-labs/flux-schnell"]


def test_call_graph_with_dynamic_string():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "predict.py")
with open(filepath, "w", encoding="utf8") as handle:
handle.write("""from cog import Path, Input
from cog.ext.pipelines import include

i = 2
flux_schnell = include(f"black-forest-labs/flux-schnell-{i}")

def run(
prompt: str = Input(description="Describe the image to generate"),
seed: int = Input(description="A seed", default=0)
) -> Path:
output_url = flux_schnell(prompt=prompt, seed=seed)[0]
return output_url
""")
with pytest.raises(ValueError) as excinfo:
analyze_python_file(Path(filepath))
assert str(excinfo.value).endswith(
"Unresolvable argument at line 5: Not a string literal"
)


def test_call_graph_include_constructed_in_local_scope():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "predict.py")
with open(filepath, "w", encoding="utf8") as handle:
handle.write("""from cog import Path, Input
from cog.ext.pipelines import include

def run(
prompt: str = Input(description="Describe the image to generate"),
seed: int = Input(description="A seed", default=0)
) -> Path:
flux_schnell = include("black-forest-labs/flux-schnell")
output_url = flux_schnell(prompt=prompt, seed=seed)[0]
return output_url
""")
with pytest.raises(ValueError) as excinfo:
analyze_python_file(Path(filepath))
assert str(excinfo.value).endswith(
"Invalid scope at line 8: `cog.ext.pipelines.include(...)` must be in global scope"
)