Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions moorcheh_sdk/resources/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@

from ..exceptions import APIError, InvalidInputError
from ..types import AnswerResponse, ChatHistoryItem
from ..utils.decorators import required_args
from ..utils.logging import setup_logging
from .base import BaseResource

logger = setup_logging(__name__)


class Answer(BaseResource):
@required_args(
["namespace", "query", "top_k", "ai_model", "temperature"],
types={
"namespace": str,
"query": str,
"top_k": int,
"ai_model": str,
"temperature": (int, float),
},
)
def generate(
self,
namespace: str,
Expand Down Expand Up @@ -64,16 +75,9 @@ def generate(
"Attempting to get generative answer for query in namespace"
f" '{namespace}'..."
)

if not namespace or not isinstance(namespace, str):
raise InvalidInputError("'namespace' must be a non-empty string.")
if not query or not isinstance(query, str):
raise InvalidInputError("'query' must be a non-empty string.")
if not isinstance(top_k, int) or top_k <= 0:
if top_k <= 0:
raise InvalidInputError("'top_k' must be a positive integer.")
if not isinstance(ai_model, str) or not ai_model:
raise InvalidInputError("'ai_model' must be a non-empty string.")
if not isinstance(temperature, (int, float)) or not (0 <= temperature <= 1):
if not (0 <= temperature <= 1):
raise InvalidInputError(
"'temperature' must be a number between 0.0 and 1.0."
)
Expand Down
30 changes: 11 additions & 19 deletions moorcheh_sdk/resources/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
DocumentUploadResponse,
)
from ..utils.constants import INVALID_ID_CHARS
from ..utils.decorators import required_args
from ..utils.logging import setup_logging
from .base import BaseResource

logger = setup_logging(__name__)


class Documents(BaseResource):
@required_args(
["namespace_name", "documents"],
types={"namespace_name": str, "documents": list},
)
def upload(
self, namespace_name: str, documents: list[Document]
) -> DocumentUploadResponse:
Expand Down Expand Up @@ -47,12 +52,6 @@ def upload(
APIError: For other API errors.
MoorchehError: For network issues.
"""
if not namespace_name or not isinstance(namespace_name, str):
raise InvalidInputError("'namespace_name' must be a non-empty string.")
if not isinstance(documents, list) or not documents:
raise InvalidInputError(
"'documents' must be a non-empty list of dictionaries."
)

logger.info(
f"Attempting to upload {len(documents)} documents to namespace"
Expand Down Expand Up @@ -105,6 +104,9 @@ def upload(
)
return cast(DocumentUploadResponse, response_data)

@required_args(
["namespace_name", "ids"], types={"namespace_name": str, "ids": list}
)
def get(self, namespace_name: str, ids: list[str | int]) -> DocumentGetResponse:
"""
Retrieves documents by their IDs from a text-based namespace.
Expand Down Expand Up @@ -134,12 +136,6 @@ def get(self, namespace_name: str, ids: list[str | int]) -> DocumentGetResponse:
APIError: For other API errors.
MoorchehError: For network issues.
"""
if not namespace_name or not isinstance(namespace_name, str):
raise InvalidInputError("'namespace_name' must be a non-empty string.")
if not isinstance(ids, list) or not ids:
raise InvalidInputError(
"'ids' must be a non-empty list of strings or integers."
)
if len(ids) > 100:
raise InvalidInputError(
"Maximum of 100 document IDs can be requested per call."
Expand Down Expand Up @@ -174,6 +170,9 @@ def get(self, namespace_name: str, ids: list[str | int]) -> DocumentGetResponse:
)
return cast(DocumentGetResponse, response_data)

@required_args(
["namespace_name", "ids"], types={"namespace_name": str, "ids": list}
)
def delete(
self, namespace_name: str, ids: list[str | int]
) -> DocumentDeleteResponse:
Expand Down Expand Up @@ -201,13 +200,6 @@ def delete(
APIError: For other API errors.
MoorchehError: For network issues.
"""
if not namespace_name or not isinstance(namespace_name, str):
raise InvalidInputError("'namespace_name' must be a non-empty string.")
if not isinstance(ids, list) or not ids:
raise InvalidInputError(
"'ids' must be a non-empty list of strings or integers."
)

logger.info(
f"Attempting to delete {len(ids)} document(s) from namespace"
f" '{namespace_name}' with IDs: {ids}"
Expand Down
10 changes: 5 additions & 5 deletions moorcheh_sdk/resources/namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

from ..exceptions import APIError, InvalidInputError
from ..types import NamespaceCreateResponse, NamespaceListResponse
from ..utils.decorators import required_args
from ..utils.logging import setup_logging
from .base import BaseResource

logger = setup_logging(__name__)


class Namespaces(BaseResource):
@required_args(
["namespace_name", "type"], types={"namespace_name": str, "type": str}
)
def create(
self, namespace_name: str, type: str, vector_dimension: int | None = None
) -> NamespaceCreateResponse:
Expand Down Expand Up @@ -41,9 +45,6 @@ def create(
logger.info(
f"Attempting to create namespace '{namespace_name}' of type '{type}'..."
)
# Client-side validation
if not namespace_name or not isinstance(namespace_name, str):
raise InvalidInputError("'namespace_name' must be a non-empty string.")
if type not in ["text", "vector"]:
raise InvalidInputError("Namespace type must be 'text' or 'vector'.")
if type == "vector":
Expand Down Expand Up @@ -79,6 +80,7 @@ def create(
)
return cast(NamespaceCreateResponse, response_data)

@required_args(["namespace_name"], types={"namespace_name": str})
def delete(self, namespace_name: str) -> None:
"""
Deletes a namespace and all its data.
Expand All @@ -94,8 +96,6 @@ def delete(self, namespace_name: str) -> None:
MoorchehError: For network issues.
"""
logger.info(f"Attempting to delete namespace '{namespace_name}'...")
if not namespace_name or not isinstance(namespace_name, str):
raise InvalidInputError("'namespace_name' must be a non-empty string.")

endpoint = f"/namespaces/{namespace_name}"
# API returns 200 with body now, not 204
Expand Down
18 changes: 11 additions & 7 deletions moorcheh_sdk/resources/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@

from ..exceptions import APIError, InvalidInputError
from ..types import SearchResponse
from ..utils.decorators import required_args
from ..utils.logging import setup_logging
from .base import BaseResource

logger = setup_logging(__name__)


class Search(BaseResource):
@required_args(
["namespaces", "query", "top_k", "kiosk_mode"],
types={
"namespaces": list,
"query": (str, list),
"top_k": int,
"kiosk_mode": bool,
},
)
def query(
self,
namespaces: list[str],
Expand Down Expand Up @@ -50,24 +60,18 @@ def query(
APIError: For other API errors.
MoorchehError: For network issues.
"""
if not isinstance(namespaces, list) or not namespaces:
raise InvalidInputError("'namespaces' must be a non-empty list of strings.")
if not all(isinstance(ns, str) and ns for ns in namespaces):
raise InvalidInputError(
"All items in 'namespaces' list must be non-empty strings."
)
if not query:
raise InvalidInputError("'query' cannot be empty.")
if not isinstance(top_k, int) or top_k <= 0:
if top_k <= 0:
raise InvalidInputError("'top_k' must be a positive integer.")
if threshold is not None and (
not isinstance(threshold, (int, float)) or not (0 <= threshold <= 1)
):
raise InvalidInputError(
"'threshold' must be a number between 0 and 1, or None."
)
if not isinstance(kiosk_mode, bool):
raise InvalidInputError("'kiosk_mode' must be a boolean.")

query_type = "vector" if isinstance(query, list) else "text"
logger.info(
Expand Down
19 changes: 7 additions & 12 deletions moorcheh_sdk/resources/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

from ..exceptions import APIError, InvalidInputError
from ..types import Vector, VectorDeleteResponse, VectorUploadResponse
from ..utils.decorators import required_args
from ..utils.logging import setup_logging
from .base import BaseResource

logger = setup_logging(__name__)


class Vectors(BaseResource):
@required_args(
["namespace_name", "vectors"], types={"namespace_name": str, "vectors": list}
)
def upload(
self, namespace_name: str, vectors: list[Vector]
) -> VectorUploadResponse:
Expand Down Expand Up @@ -42,12 +46,6 @@ def upload(
APIError: For other API errors.
MoorchehError: For network issues.
"""
if not namespace_name or not isinstance(namespace_name, str):
raise InvalidInputError("'namespace_name' must be a non-empty string.")
if not isinstance(vectors, list) or not vectors:
raise InvalidInputError(
"'vectors' must be a non-empty list of dictionaries."
)

logger.info(
f"Attempting to upload {len(vectors)} vectors to namespace"
Expand Down Expand Up @@ -106,6 +104,9 @@ def upload(
)
return cast(VectorUploadResponse, response_data)

@required_args(
["namespace_name", "ids"], types={"namespace_name": str, "ids": list}
)
def delete(self, namespace_name: str, ids: list[str | int]) -> VectorDeleteResponse:
"""
Deletes vectors by their IDs from a vector-based namespace.
Expand All @@ -131,12 +132,6 @@ def delete(self, namespace_name: str, ids: list[str | int]) -> VectorDeleteRespo
APIError: For other API errors.
MoorchehError: For network issues.
"""
if not namespace_name or not isinstance(namespace_name, str):
raise InvalidInputError("'namespace_name' must be a non-empty string.")
if not isinstance(ids, list) or not ids:
raise InvalidInputError(
"'ids' must be a non-empty list of strings or integers."
)

logger.info(
f"Attempting to delete {len(ids)} vector(s) from namespace"
Expand Down
64 changes: 64 additions & 0 deletions moorcheh_sdk/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import inspect
from collections.abc import Callable, Sequence
from functools import wraps
from typing import ParamSpec, TypeVar

from ..exceptions import InvalidInputError

P = ParamSpec("P")
R = TypeVar("R")


def required_args(
args: Sequence[str],
types: dict[str, type | tuple[type, ...]] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator to enforce that specific arguments are provided, not None, and optionally of a specific type.

For strings and collections, it also checks that they are not empty.

Args:
args: A list of argument names that are required.
types: A dictionary mapping argument names to their expected types.
"""

def decorator(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func)
def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R:
sig = inspect.signature(func)
try:
bound = sig.bind(*func_args, **func_kwargs)
except TypeError as e:
raise InvalidInputError(str(e)) from e

bound.apply_defaults()

for arg_name in args:
if arg_name not in bound.arguments:
# This might happen if the argument is not in the signature,
# but if it is in 'args' list it should be expected.
# However, bind() would fail if a required arg is missing unless it has a default.
# If it has a default of None, we might want to catch it here.
continue

val = bound.arguments[arg_name]

if val is None:
raise InvalidInputError(f"Argument '{arg_name}' cannot be None.")

if isinstance(val, (str, list, dict, set, tuple)) and not val:
raise InvalidInputError(f"Argument '{arg_name}' cannot be empty.")

if types and arg_name in types:
expected_type = types[arg_name]
if not isinstance(val, expected_type):
raise InvalidInputError(
f"Argument '{arg_name}' must be of type {expected_type}."
)

return func(*func_args, **func_kwargs)

return wrapper

return decorator
8 changes: 4 additions & 4 deletions tests/resources/test_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def test_generate_answer_with_prompts(client, mocker, mock_response):
@pytest.mark.parametrize(
"ns, q, tk, model, temp, history, msg",
[
("", "q", 5, "m", 0.5, [], "'namespace' must be a non-empty string"),
(None, "q", 5, "m", 0.5, [], "'namespace' must be a non-empty string"),
("ns", "", 5, "m", 0.5, [], "'query' must be a non-empty string"),
("", "q", 5, "m", 0.5, [], "Argument 'namespace' cannot be empty."),
(None, "q", 5, "m", 0.5, [], "Argument 'namespace' cannot be None."),
("ns", "", 5, "m", 0.5, [], "Argument 'query' cannot be empty."),
("ns", "q", 0, "m", 0.5, [], "'top_k' must be a positive integer"),
("ns", "q", -1, "m", 0.5, [], "'top_k' must be a positive integer"),
("ns", "q", 5, "", 0.5, [], "'ai_model' must be a non-empty string"),
("ns", "q", 5, "", 0.5, [], "Argument 'ai_model' cannot be empty."),
(
"ns",
"q",
Expand Down
23 changes: 16 additions & 7 deletions tests/resources/test_namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def test_create_namespace_conflict(client, mocker, mock_response):
@pytest.mark.parametrize(
"name, ns_type, dim, expected_error_msg",
[
("", "text", None, "'namespace_name' must be a non-empty string"),
(None, "text", None, "'namespace_name' must be a non-empty string"),
("", "text", None, "Argument 'namespace_name' cannot be empty."),
(None, "text", None, "Argument 'namespace_name' cannot be None."),
("test", "invalid_type", None, "Namespace type must be 'text' or 'vector'"),
(
"test",
Expand Down Expand Up @@ -250,11 +250,20 @@ def test_delete_namespace_not_found(client, mocker, mock_response):
client._mock_httpx_instance.request.assert_called_once()


@pytest.mark.parametrize("invalid_name", ["", None, 123])
def test_delete_namespace_invalid_name_client_side(client, invalid_name):
@pytest.mark.parametrize(
"invalid_name, expected_error",
[
("", "Argument 'namespace_name' cannot be empty."),
(None, "Argument 'namespace_name' cannot be None."),
(123, "Argument 'namespace_name' must be of type <class 'str'>."),
],
)
def test_delete_namespace_invalid_name_client_side(
client, invalid_name, expected_error
):
"""Test client-side validation for delete_namespace name."""
with pytest.raises(
InvalidInputError, match="'namespace_name' must be a non-empty string"
):
import re

with pytest.raises(InvalidInputError, match=re.escape(expected_error)):
client.namespaces.delete(invalid_name)
client._mock_httpx_instance.request.assert_not_called()
Loading