-
Notifications
You must be signed in to change notification settings - Fork 359
Expand file tree
/
Copy pathconftest.py
More file actions
222 lines (178 loc) · 6.99 KB
/
conftest.py
File metadata and controls
222 lines (178 loc) · 6.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
import shutil
import uuid
from pathlib import Path
from typing import Callable
from unittest.mock import patch
import litellm
import pytest
from dotenv import load_dotenv
from kiln_ai.datamodel.basemodel import KilnAttachmentModel
from kiln_ai.pytest_mock_files import MockFileFactoryMimeType
from kiln_ai.pytest_test_output import make_test_output_dir
from kiln_ai.utils.config import Config
@pytest.fixture(autouse=True)
def _clear_httpx_clients() -> None:
litellm.in_memory_llm_clients_cache.flush_cache()
@pytest.fixture(autouse=True)
def skip_remote_model_list():
"""Set environment variable to skip remote model list fetching during tests"""
os.environ["KILN_SKIP_REMOTE_MODEL_LIST"] = "true"
yield
# Clean up after the test
if "KILN_SKIP_REMOTE_MODEL_LIST" in os.environ:
del os.environ["KILN_SKIP_REMOTE_MODEL_LIST"]
@pytest.fixture(scope="session", autouse=True)
def load_env():
load_dotenv()
# Reset Config singleton between tests to prevent state leakage
@pytest.fixture(autouse=True)
def reset_config():
Config._shared_instance = None
yield
Config._shared_instance = None
# mock out the settings path so we don't clobber the user's actual settings during tests
@pytest.fixture(autouse=True)
def use_temp_settings_dir(tmp_path):
with patch.object(
Config, "settings_path", return_value=str(tmp_path / "settings.yaml")
):
yield
@pytest.fixture(scope="session", autouse=True)
def setup_test_logging():
from kiln_ai.utils.logging import setup_litellm_logging
setup_litellm_logging("test_model_calls.log")
yield
def pytest_addoption(parser):
parser.addoption(
"--runpaid",
action="store_true",
default=False,
help="run tests that make paid API calls",
)
parser.addoption(
"--runslow",
action="store_true",
default=False,
help="run slow tests",
)
parser.addoption(
"--runsinglewithoutchecks",
action="store_true",
default=False,
help="if testing a single test, don't check for skips like runpaid",
)
parser.addoption(
"--ollama",
action="store_true",
default=False,
help="run tests that use ollama server",
)
def is_single_manual_test(config, items) -> bool:
# Check if we're running manually (eg, in vscode)
if not config.getoption("--runsinglewithoutchecks"):
return False
if len(items) == 1:
return True
if len(items) == 0:
return False
# Check if all of the items are the same prefix, expluding a.b.c[param]
# This is still a 'single test' for the purposes of this flag
prefix = items[0].name.split("[")[0] + "["
for item in items:
if not item.name.startswith(prefix):
return False
return True
def pytest_collection_modifyitems(config, items):
# Always run test if it's a single test manually invoked
if is_single_manual_test(config, items):
return
# Mark tests that use paid services as skipped unless --runpaid is passed
if not config.getoption("--runpaid"):
skip_paid = pytest.mark.skip(reason="need --runpaid option to run")
for item in items:
if "paid" in item.keywords:
item.add_marker(skip_paid)
# Mark tests that use slow services as skipped unless --runslow is passed
if not config.getoption("--runslow"):
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
# Mark tests that use ollama server as skipped unless --ollama is passed
if not config.getoption("--ollama"):
skip_ollama = pytest.mark.skip(reason="need --ollama option to run")
for item in items:
if "ollama" in item.keywords:
item.add_marker(skip_ollama)
@pytest.fixture
def test_output_dir(request: pytest.FixtureRequest) -> Path:
return make_test_output_dir(request)
@pytest.fixture
def test_data_dir() -> Path:
"""
The directory that contains test files with various mime types.
"""
return Path(__file__).parent / "libs" / "core" / "tests" / "assets"
@pytest.fixture
def mock_file_factory(
tmp_path, test_data_dir
) -> Callable[[MockFileFactoryMimeType], Path]:
"""
Create a mock file factory that creates a mock file for the given mime type.
The file is copied to the tmp path so it is safe to alter it without contaminating the original file.
"""
def create_file(mime_type: MockFileFactoryMimeType) -> Path:
match mime_type:
# document
case MockFileFactoryMimeType.PDF:
filename = test_data_dir / "document_paper.pdf"
case MockFileFactoryMimeType.CSV:
filename = test_data_dir / "document_people.csv"
case MockFileFactoryMimeType.TXT:
filename = test_data_dir / "document_ice_cubes.txt"
case MockFileFactoryMimeType.HTML:
filename = test_data_dir / "document_ice_cubes.html"
case MockFileFactoryMimeType.MD:
filename = test_data_dir / "document_ice_cubes.md"
# images
case MockFileFactoryMimeType.PNG:
filename = test_data_dir / "image_kodim23.png"
case MockFileFactoryMimeType.JPG:
filename = test_data_dir / "image_nasa.jpg"
case MockFileFactoryMimeType.JPEG:
filename = test_data_dir / "image_nasa.jpeg"
# audio
case MockFileFactoryMimeType.OGG:
filename = test_data_dir / "audio_ice_cubes.ogg"
case MockFileFactoryMimeType.MP3:
filename = test_data_dir / "audio_ice_cubes.mp3"
case MockFileFactoryMimeType.WAV:
filename = test_data_dir / "audio_ice_cubes.wav"
# video
case MockFileFactoryMimeType.MP4:
filename = test_data_dir / "video_tv_bars.mp4"
case MockFileFactoryMimeType.MOV:
filename = test_data_dir / "video_tv_bars.mov"
case _:
raise ValueError(f"No test file found for mime type: {mime_type}")
# copy the file to the tmp path to avoid test contamination of the original file
path_copy = tmp_path / f"{uuid.uuid4()!s}.{filename.suffix}"
shutil.copy(filename, path_copy)
return path_copy
return create_file
@pytest.fixture
def mock_attachment_factory(mock_file_factory):
"""
Create a mock attachment factory that creates a mock attachment for the given mime type.
The attachment is created from the mock file factory.
"""
def create_attachment(
mime_type: MockFileFactoryMimeType,
text: str | None = None,
) -> KilnAttachmentModel:
if text is not None:
return KilnAttachmentModel.from_data(text, mime_type)
path = mock_file_factory(mime_type)
return KilnAttachmentModel.from_file(path)
return create_attachment