Skip to content

Commit ab61e79

Browse files
siiddhanttManishMadan2882
authored andcommitted
test: implement full API test suite with mongomock and centralized fixtures (#2068)
1 parent 314c104 commit ab61e79

14 files changed

Lines changed: 1601 additions & 167 deletions

File tree

application/utils.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def get_encoding():
2121

2222

2323
def get_gpt_model() -> str:
24-
"""Get the appropriate GPT model based on provider"""
24+
"""Get GPT model based on provider"""
2525
model_map = {
2626
"openai": "gpt-4o-mini",
2727
"anthropic": "claude-2",
@@ -32,16 +32,7 @@ def get_gpt_model() -> str:
3232

3333

3434
def safe_filename(filename):
35-
"""
36-
Creates a safe filename that preserves the original extension.
37-
Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters.
38-
39-
Args:
40-
filename (str): The original filename
41-
42-
Returns:
43-
str: A safe filename that can be used for storage
44-
"""
35+
"""Create safe filename, preserving extension. Handles non-Latin characters."""
4536
if not filename:
4637
return str(uuid.uuid4())
4738
_, extension = os.path.splitext(filename)
@@ -83,8 +74,14 @@ def count_tokens_docs(docs):
8374
return tokens
8475

8576

77+
def get_missing_fields(data, required_fields):
78+
"""Check for missing required fields. Returns list of missing field names."""
79+
return [field for field in required_fields if field not in data]
80+
81+
8682
def check_required_fields(data, required_fields):
87-
missing_fields = [field for field in required_fields if field not in data]
83+
"""Validate required fields. Returns Flask 400 response if validation fails, None otherwise."""
84+
missing_fields = get_missing_fields(data, required_fields)
8885
if missing_fields:
8986
return make_response(
9087
jsonify(
@@ -98,7 +95,8 @@ def check_required_fields(data, required_fields):
9895
return None
9996

10097

101-
def validate_required_fields(data, required_fields):
98+
def get_field_validation_errors(data, required_fields):
99+
"""Check for missing and empty fields. Returns dict with 'missing_fields' and 'empty_fields', or None."""
102100
missing_fields = []
103101
empty_fields = []
104102

@@ -107,12 +105,24 @@ def validate_required_fields(data, required_fields):
107105
missing_fields.append(field)
108106
elif not data[field]:
109107
empty_fields.append(field)
110-
errors = []
111-
if missing_fields:
112-
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
113-
if empty_fields:
114-
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
115-
if errors:
108+
if missing_fields or empty_fields:
109+
return {"missing_fields": missing_fields, "empty_fields": empty_fields}
110+
return None
111+
112+
113+
def validate_required_fields(data, required_fields):
114+
"""Validate required fields (must exist and be non-empty). Returns Flask 400 response if validation fails, None otherwise."""
115+
errors_dict = get_field_validation_errors(data, required_fields)
116+
if errors_dict:
117+
errors = []
118+
if errors_dict["missing_fields"]:
119+
errors.append(
120+
f"Missing required fields: {', '.join(errors_dict['missing_fields'])}"
121+
)
122+
if errors_dict["empty_fields"]:
123+
errors.append(
124+
f"Empty values in required fields: {', '.join(errors_dict['empty_fields'])}"
125+
)
116126
return make_response(
117127
jsonify({"success": False, "message": " | ".join(errors)}), 400
118128
)
@@ -124,10 +134,7 @@ def get_hash(data):
124134

125135

126136
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
127-
"""
128-
Limits chat history based on token count.
129-
Returns a list of messages that fit within the token limit.
130-
"""
137+
"""Limit chat history to fit within token limit."""
131138
from application.core.settings import settings
132139

133140
max_token_limit = (
@@ -161,7 +168,7 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
161168

162169

163170
def validate_function_name(function_name):
164-
"""Validates if a function name matches the allowed pattern."""
171+
"""Validate function name matches allowed pattern (alphanumeric, underscore, hyphen)."""
165172
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
166173
return False
167174
return True

tests/agents/test_base_agent.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytest
44
from application.agents.classic_agent import ClassicAgent
55
from application.core.settings import settings
6-
from tests.conftest import FakeMongoCollection
76

87

98
@pytest.mark.unit
@@ -168,10 +167,13 @@ def test_get_user_tools(
168167
mock_llm_creator,
169168
mock_llm_handler_creator,
170169
):
171-
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
172-
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
173-
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": True},
174-
}
170+
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
171+
user_tools.insert_one(
172+
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
173+
)
174+
user_tools.insert_one(
175+
{"_id": "2", "user": "test_user", "name": "tool2", "status": True}
176+
)
175177

176178
agent = ClassicAgent(**agent_base_params)
177179
tools = agent._get_user_tools("test_user")
@@ -187,10 +189,13 @@ def test_get_user_tools_filters_by_status(
187189
mock_llm_creator,
188190
mock_llm_handler_creator,
189191
):
190-
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"].docs = {
191-
"1": {"_id": "1", "user": "test_user", "name": "tool1", "status": True},
192-
"2": {"_id": "2", "user": "test_user", "name": "tool2", "status": False},
193-
}
192+
user_tools = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
193+
user_tools.insert_one(
194+
{"_id": "1", "user": "test_user", "name": "tool1", "status": True}
195+
)
196+
user_tools.insert_one(
197+
{"_id": "2", "user": "test_user", "name": "tool2", "status": False}
198+
)
194199

195200
agent = ClassicAgent(**agent_base_params)
196201
tools = agent._get_user_tools("test_user")
@@ -209,17 +214,16 @@ def test_get_tools_by_api_key(
209214
tool_id = str(ObjectId())
210215
tool_obj_id = ObjectId(tool_id)
211216

212-
fake_agent_collection = FakeMongoCollection()
213-
fake_agent_collection.docs["api_key_123"] = {
214-
"key": "api_key_123",
215-
"tools": [tool_id],
216-
}
217-
218-
fake_tools_collection = FakeMongoCollection()
219-
fake_tools_collection.docs[tool_id] = {"_id": tool_obj_id, "name": "api_tool"}
217+
agents_collection = mock_mongo_db[settings.MONGO_DB_NAME]["agents"]
218+
agents_collection.insert_one(
219+
{
220+
"key": "api_key_123",
221+
"tools": [tool_id],
222+
}
223+
)
220224

221-
mock_mongo_db[settings.MONGO_DB_NAME]["agents"] = fake_agent_collection
222-
mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"] = fake_tools_collection
225+
tools_collection = mock_mongo_db[settings.MONGO_DB_NAME]["user_tools"]
226+
tools_collection.insert_one({"_id": tool_obj_id, "name": "api_tool"})
223227

224228
agent = ClassicAgent(**agent_base_params)
225229
tools = agent._get_tools("api_key_123")

tests/api/__init__.py

Whitespace-only changes.

tests/api/answer/__init__.py

Whitespace-only changes.

tests/api/answer/routes/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)