Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion src/chroma_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,82 @@ async def chroma_get_documents(
limit=limit,
offset=offset
)


@mcp.tool()
async def chroma_update_documents(
collection_name: str,
ids: List[str],
embeddings: Optional[List[List[float]]] = None,
metadatas: Optional[List[Dict]] = None,
documents: Optional[List[str]] = None
) -> str:
"""Update documents in a Chroma collection.

Args:
collection_name: Name of the collection to update documents in
ids: List of document IDs to update (required)
embeddings: Optional list of new embeddings for the documents.
Must match length of ids if provided.
metadatas: Optional list of new metadata dictionaries for the documents.
Must match length of ids if provided.
documents: Optional list of new text documents.
Must match length of ids if provided.

Returns:
A confirmation message indicating the number of documents updated.

Raises:
ValueError: If 'ids' is empty or if none of 'embeddings', 'metadatas',
or 'documents' are provided, or if the length of provided
update lists does not match the length of 'ids'.
Exception: If the collection does not exist or if the update operation fails.
"""
if not ids:
raise ValueError("The 'ids' list cannot be empty.")

if embeddings is None and metadatas is None and documents is None:
raise ValueError(
"At least one of 'embeddings', 'metadatas', or 'documents' "
"must be provided for update."
)

# Ensure provided lists match the length of ids if they are not None
if embeddings is not None and len(embeddings) != len(ids):
raise ValueError("Length of 'embeddings' list must match length of 'ids' list.")
if metadatas is not None and len(metadatas) != len(ids):
raise ValueError("Length of 'metadatas' list must match length of 'ids' list.")
if documents is not None and len(documents) != len(ids):
raise ValueError("Length of 'documents' list must match length of 'ids' list.")


client = get_chroma_client()
try:
collection = client.get_collection(collection_name)
except Exception as e:
raise Exception(
f"Failed to get collection '{collection_name}': {str(e)}"
) from e

# Prepare arguments for update, excluding None values at the top level
update_args = {
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
}
kwargs = {k: v for k, v in update_args.items() if v is not None}

try:
collection.update(**kwargs)
return (
f"Successfully processed update request for {len(ids)} documents in "
f"collection '{collection_name}'. Note: Non-existent IDs are ignored by ChromaDB."
)
except Exception as e:
raise Exception(
f"Failed to update documents in collection '{collection_name}': {str(e)}"
) from e

def validate_thought_data(input_data: Dict) -> Dict:
"""Validate thought data structure."""
if not input_data.get("sessionId"):
Expand Down
151 changes: 150 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import os
from unittest.mock import patch, MagicMock
import argparse
from mcp.server.fastmcp.exceptions import ToolError # Import ToolError
import json # Import json for parsing results


# Add pytest-asyncio marker
pytest_plugins = ['pytest_asyncio']
Expand Down Expand Up @@ -261,4 +264,150 @@ def test_required_args_for_cloud_client():
# Check that error was called for missing api-key (the first check in the code)
mock_error.assert_called_with(
"API key must be provided via --api-key flag or CHROMA_API_KEY environment variable when using cloud client"
)
)

# --- Tests for chroma_update_documents ---

@pytest.mark.asyncio
async def test_update_documents_success():
"""Test successful document update."""
collection_name = "test_update_collection_success"
doc_ids = ["doc1", "doc2"]
initial_docs = ["Initial doc 1", "Initial doc 2"]
initial_metadatas = [{"source": "initial"}, {"source": "initial"}]
updated_docs = ["Updated doc 1", initial_docs[1]] # Update only first doc content
updated_metadatas = [initial_metadatas[0], {"source": "updated"}] # Update only second doc metadata

try:
# 1. Create collection
await mcp.call_tool("chroma_create_collection", {"collection_name": collection_name})

# 2. Add initial documents
await mcp.call_tool("chroma_add_documents", {
"collection_name": collection_name,
"documents": initial_docs,
"metadatas": initial_metadatas,
"ids": doc_ids
})

# 3. Update documents (pass both documents and metadatas)
update_result = await mcp.call_tool("chroma_update_documents", {
"collection_name": collection_name,
"ids": doc_ids,
"documents": updated_docs,
"metadatas": updated_metadatas
})
assert len(update_result) == 1
# Updated success message check
assert (
f"Successfully processed update request for {len(doc_ids)} documents"
in update_result[0].text
)

# 4. Verify updates
get_result_raw = await mcp.call_tool("chroma_get_documents", {
"collection_name": collection_name,
"ids": doc_ids,
"include": ["documents", "metadatas"]
})
# Corrected: Parse the JSON string from TextContent
assert len(get_result_raw) == 1
get_result = json.loads(get_result_raw[0].text)
assert isinstance(get_result, dict)

assert get_result.get("ids") == doc_ids
# Check updated document content
assert get_result.get("documents") == updated_docs
# Check updated metadata
assert get_result.get("metadatas") == updated_metadatas

finally:
# Clean up
await mcp.call_tool("chroma_delete_collection", {"collection_name": collection_name})

@pytest.mark.asyncio
async def test_update_documents_invalid_args():
"""Test update documents with invalid arguments."""
collection_name = "test_update_collection_invalid"

try:
await mcp.call_tool("chroma_create_collection", {"collection_name": collection_name})
await mcp.call_tool("chroma_add_documents", {
"collection_name": collection_name,
"documents": ["Test doc"],
"ids": ["doc1"]
})

# Test with empty IDs list - Expect ToolError wrapping ValueError
with pytest.raises(ToolError, match="The 'ids' list cannot be empty."):
await mcp.call_tool("chroma_update_documents", {
"collection_name": collection_name,
"ids": [],
"documents": ["New content"]
})

# Test with no update fields provided - Expect ToolError wrapping ValueError
with pytest.raises(
ToolError,
match="At least one of 'embeddings', 'metadatas', or 'documents' must be provided"
):
await mcp.call_tool("chroma_update_documents", {
"collection_name": collection_name,
"ids": ["doc1"]
# No embeddings, metadatas, or documents
})

finally:
# Clean up
await mcp.call_tool("chroma_delete_collection", {"collection_name": collection_name})

@pytest.mark.asyncio
async def test_update_documents_collection_not_found():
"""Test updating documents in a non-existent collection."""
# Expect ToolError wrapping the Exception from the function
with pytest.raises(ToolError, match="Failed to get collection"):
await mcp.call_tool("chroma_update_documents", {
"collection_name": "non_existent_collection",
"ids": ["doc1"],
"documents": ["New content"]
})

@pytest.mark.asyncio
async def test_update_documents_id_not_found():
"""Test updating a document with an ID that does not exist. Expect no exception."""
collection_name = "test_update_id_not_found"
try:
await mcp.call_tool("chroma_create_collection", {"collection_name": collection_name})
await mcp.call_tool("chroma_add_documents", {
"collection_name": collection_name,
"documents": ["Test doc"],
"ids": ["existing_id"]
})

# Attempt to update a non-existent ID - should not raise Exception
update_result = await mcp.call_tool("chroma_update_documents", {
"collection_name": collection_name,
"ids": ["non_existent_id"],
"documents": ["New content"]
})
# Check the success message (even though the ID didn't exist)
assert len(update_result) == 1
assert "Successfully processed update request" in update_result[0].text

# Optionally, verify that the existing document was not changed
get_result_raw = await mcp.call_tool("chroma_get_documents", {
"collection_name": collection_name,
"ids": ["existing_id"],
"include": ["documents"]
})
# Corrected assertion: Parse JSON and check structure/content
assert len(get_result_raw) == 1
get_result = json.loads(get_result_raw[0].text)
assert isinstance(get_result, dict)
assert "documents" in get_result
assert isinstance(get_result["documents"], list)
assert get_result["documents"] == ["Test doc"]

finally:
# Clean up
await mcp.call_tool("chroma_delete_collection", {"collection_name": collection_name})