-
Notifications
You must be signed in to change notification settings - Fork 686
Expand file tree
/
Copy pathcall_graph_test.py
More file actions
73 lines (60 loc) · 2.44 KB
/
call_graph_test.py
File metadata and controls
73 lines (60 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import tempfile
from pathlib import Path
import pytest
from cog.command.call_graph import analyze_python_file
def test_call_graph():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "predict.py")
with open(filepath, "w", encoding="utf8") as handle:
handle.write("""from cog import Path, Input
from cog.ext.pipelines import include
flux_schnell = include("black-forest-labs/flux-schnell")
def run(
prompt: str = Input(description="Describe the image to generate"),
seed: int = Input(description="A seed", default=0)
) -> Path:
output_url = flux_schnell(prompt=prompt, seed=seed)[0]
return output_url
""")
includes = analyze_python_file(Path(filepath))
assert includes == ["black-forest-labs/flux-schnell"]
def test_call_graph_with_dynamic_string():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "predict.py")
with open(filepath, "w", encoding="utf8") as handle:
handle.write("""from cog import Path, Input
from cog.ext.pipelines import include
i = 2
flux_schnell = include(f"black-forest-labs/flux-schnell-{i}")
def run(
prompt: str = Input(description="Describe the image to generate"),
seed: int = Input(description="A seed", default=0)
) -> Path:
output_url = flux_schnell(prompt=prompt, seed=seed)[0]
return output_url
""")
with pytest.raises(ValueError) as excinfo:
analyze_python_file(Path(filepath))
assert str(excinfo.value).endswith(
"Unresolvable argument at line 5: Not a string literal"
)
def test_call_graph_include_constructed_in_local_scope():
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "predict.py")
with open(filepath, "w", encoding="utf8") as handle:
handle.write("""from cog import Path, Input
from cog.ext.pipelines import include
def run(
prompt: str = Input(description="Describe the image to generate"),
seed: int = Input(description="A seed", default=0)
) -> Path:
flux_schnell = include("black-forest-labs/flux-schnell")
output_url = flux_schnell(prompt=prompt, seed=seed)[0]
return output_url
""")
with pytest.raises(ValueError) as excinfo:
analyze_python_file(Path(filepath))
assert str(excinfo.value).endswith(
"Invalid scope at line 8: `cog.ext.pipelines.include(...)` must be in global scope"
)