Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions python/cog/command/call_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""A CLI for determining the call graph of a pipeline."""

import ast
import sys
from pathlib import Path
from typing import List


class IncludeAnalyzer(ast.NodeVisitor):
def __init__(self, file_path: Path) -> None:
self.file_path = file_path
self.includes: List[str] = []
self.errors: List[str] = []
self.imports: dict[str, str] = {}

def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
self.imports[alias.asname or alias.name] = alias.name
self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
module = node.module or ""
for alias in node.names:
full_name = f"{module}.{alias.name}" if module else alias.name
self.imports[alias.asname or alias.name] = full_name
self.generic_visit(node)

def visit_Call(self, node: ast.Call) -> None:
target = None

if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
target = f"{self.imports.get(node.func.value.id, node.func.value.id)}.{node.func.attr}"
elif isinstance(node.func, ast.Name):
target = self.imports.get(node.func.id, node.func.id)

if target == "cog.ext.pipelines.include":
if node.args:
arg = node.args[0]
if isinstance(arg, ast.Str):
self.includes.append(arg.s)
else:
raise ValueError(
f"[{self.file_path}] Unresolvable argument at line {node.lineno}: Not a string literal"
)
self.generic_visit(node)


def analyze_python_file(
file_path: Path,
) -> List[str]:
source = file_path.read_text()
tree = ast.parse(source, filename=str(file_path))
analyzer = IncludeAnalyzer(file_path)
analyzer.visit(tree)
return analyzer.includes


def main(filepath: str) -> None:
"""Run the main code for determining the call graph of a pipeline."""
includes = analyze_python_file(Path(filepath))
print(",".join(includes))


if __name__ == "__main__":
main(sys.argv[1])
Empty file.
51 changes: 51 additions & 0 deletions python/tests/command/call_graph_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import tempfile
from pathlib import Path

import pytest

from cog.command.call_graph import analyze_python_file


def test_call_graph():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "predict.py")
with open(filepath, "w", encoding="utf8") as handle:
handle.write("""from cog import Path, Input
from cog.ext.pipelines import include

flux_schnell = include("black-forest-labs/flux-schnell")

def run(
prompt: str = Input(description="Describe the image to generate"),
seed: int = Input(description="A seed", default=0)
) -> Path:
output_url = flux_schnell(prompt=prompt, seed=seed)[0]
return output_url
""")
includes = analyze_python_file(Path(filepath))
assert includes == ["black-forest-labs/flux-schnell"]


def test_call_graph_with_dynamic_string():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "predict.py")
with open(filepath, "w", encoding="utf8") as handle:
handle.write("""from cog import Path, Input
from cog.ext.pipelines import include

for i in range(50):
flux_schnell = include(f"black-forest-labs/flux-schnell-{i}")

def run(
prompt: str = Input(description="Describe the image to generate"),
seed: int = Input(description="A seed", default=0)
) -> Path:
output_url = flux_schnell(prompt=prompt, seed=seed)[0]
return output_url
""")
with pytest.raises(ValueError) as excinfo:
analyze_python_file(Path(filepath))
assert str(excinfo.value).endswith(
"Unresolvable argument at line 5: Not a string literal"
)