Skip to content
110 changes: 70 additions & 40 deletions src/chroma_mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, TypedDict
from typing import Dict, List, TypedDict, Union
from enum import Enum
import chromadb
from mcp.server.fastmcp import FastMCP
Expand Down Expand Up @@ -145,8 +145,8 @@ def get_chroma_client(args=None):

@mcp.tool()
async def chroma_list_collections(
limit: Optional[int] = None,
offset: Optional[int] = None
limit: int | None = None,
offset: int | None = None
) -> List[str]:
"""List all collection names in the Chroma database with pagination support.

Expand All @@ -155,12 +155,15 @@ async def chroma_list_collections(
offset: Optional number of collections to skip before returning results

Returns:
List of collection names
List of collection names or ["__NO_COLLECTIONS_FOUND__"] if database is empty
"""
client = get_chroma_client()
try:
colls = client.list_collections(limit=limit, offset=offset)
# iterate over colls and output the names
# Safe handling: If colls is None or empty, return a special marker
if not colls:
return ["__NO_COLLECTIONS_FOUND__"]
# Otherwise iterate to get collection names
return [coll.name for coll in colls]

except Exception as e:
Expand All @@ -177,16 +180,16 @@ async def chroma_list_collections(
@mcp.tool()
async def chroma_create_collection(
collection_name: str,
embedding_function_name: Optional[str] = "default",
metadata: Optional[Dict] = None,
space: Optional[str] = None,
ef_construction: Optional[int] = None,
ef_search: Optional[int] = None,
max_neighbors: Optional[int] = None,
num_threads: Optional[int] = None,
batch_size: Optional[int] = None,
sync_threshold: Optional[int] = None,
resize_factor: Optional[float] = None,
embedding_function_name: str = "default",
metadata: Dict | None = None,
space: str | None = None,
ef_construction: int | None = None,
ef_search: int | None = None,
max_neighbors: int | None = None,
num_threads: int | None = None,
batch_size: int | None = None,
sync_threshold: int | None = None,
resize_factor: float | None = None,
) -> str:
"""Create a new Chroma collection with configurable HNSW parameters.

Expand Down Expand Up @@ -305,13 +308,13 @@ async def chroma_get_collection_count(collection_name: str) -> int:
@mcp.tool()
async def chroma_modify_collection(
collection_name: str,
new_name: Optional[str] = None,
new_metadata: Optional[Dict] = None,
ef_search: Optional[int] = None,
num_threads: Optional[int] = None,
batch_size: Optional[int] = None,
sync_threshold: Optional[int] = None,
resize_factor: Optional[float] = None,
new_name: str | None = None,
new_metadata: Dict | None = None,
ef_search: int | None = None,
num_threads: int | None = None,
batch_size: int | None = None,
sync_threshold: int | None = None,
resize_factor: float | None = None,
) -> str:
"""Modify a Chroma collection's name or metadata.

Expand Down Expand Up @@ -377,35 +380,62 @@ async def chroma_delete_collection(collection_name: str) -> str:
async def chroma_add_documents(
collection_name: str,
documents: List[str],
metadatas: Optional[List[Dict]] = None,
ids: Optional[List[str]] = None
ids: List[str],
metadatas: List[Dict] | None = None
) -> str:
"""Add documents to a Chroma collection.

Args:
collection_name: Name of the collection to add documents to
documents: List of text documents to add
ids: List of IDs for the documents (required)
metadatas: Optional list of metadata dictionaries for each document
ids: Optional list of IDs for the documents
"""
if not documents:
raise ValueError("The 'documents' list cannot be empty.")

if not ids:
raise ValueError("The 'ids' list is required and cannot be empty.")

# Check if there are empty strings in the ids list
if any(not id.strip() for id in ids):
raise ValueError("IDs cannot be empty strings.")

if len(ids) != len(documents):
raise ValueError(f"Number of ids ({len(ids)}) must match number of documents ({len(documents)}).")

client = get_chroma_client()
try:
collection = client.get_or_create_collection(collection_name)

# Generate sequential IDs if none provided
if ids is None:
ids = [str(i) for i in range(len(documents))]
# Check for duplicate IDs
existing_ids = collection.get(include=[])["ids"]
duplicate_ids = [id for id in ids if id in existing_ids]

if duplicate_ids:
raise ValueError(
f"The following IDs already exist in collection '{collection_name}': {duplicate_ids}. "
f"Use 'chroma_update_documents' to update existing documents."
)

collection.add(
result = collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)

return f"Successfully added {len(documents)} documents to collection {collection_name}"
# Check the return value
if result and isinstance(result, dict):
# If the return value is a dictionary, it may contain success information
if 'success' in result and not result['success']:
raise Exception(f"Failed to add documents: {result.get('error', 'Unknown error')}")

# If the return value contains the actual number added
if 'count' in result:
return f"Successfully added {result['count']} documents to collection {collection_name}"

# Default return
return f"Successfully added {len(documents)} documents to collection {collection_name}, result is {result}"
except Exception as e:
raise Exception(f"Failed to add documents to collection '{collection_name}': {str(e)}") from e

Expand All @@ -414,8 +444,8 @@ async def chroma_query_documents(
collection_name: str,
query_texts: List[str],
n_results: int = 5,
where: Optional[Dict] = None,
where_document: Optional[Dict] = None,
where: Dict | None = None,
where_document: Dict | None = None,
include: List[str] = ["documents", "metadatas", "distances"]
) -> Dict:
"""Query documents from a Chroma collection with advanced filtering.
Expand Down Expand Up @@ -452,12 +482,12 @@ async def chroma_query_documents(
@mcp.tool()
async def chroma_get_documents(
collection_name: str,
ids: Optional[List[str]] = None,
where: Optional[Dict] = None,
where_document: Optional[Dict] = None,
ids: List[str] | None = None,
where: Dict | None = None,
where_document: Dict | None = None,
include: List[str] = ["documents", "metadatas"],
limit: Optional[int] = None,
offset: Optional[int] = None
limit: int | None = None,
offset: int | None = None
) -> Dict:
"""Get documents from a Chroma collection with optional filtering.

Expand Down Expand Up @@ -496,9 +526,9 @@ async def chroma_get_documents(
async def chroma_update_documents(
collection_name: str,
ids: List[str],
embeddings: Optional[List[List[float]]] = None,
metadatas: Optional[List[Dict]] = None,
documents: Optional[List[str]] = None
embeddings: List[List[float]] | None = None,
metadatas: List[Dict] | None = None,
documents: List[str] | None = None
) -> str:
"""Update documents in a Chroma collection.

Expand Down
4 changes: 2 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_get_chroma_client_ephemeral():
@pytest.mark.asyncio
async def test_list_collections():
# Test list_collections tool
result = await mcp.call_tool("chroma_list_collections", {"limit": None, "offset": None})
result = await mcp.call_tool("chroma_list_collections", {"limit": 50, "offset": 0})
assert isinstance(result, list)

@pytest.mark.asyncio
Expand Down Expand Up @@ -538,7 +538,7 @@ async def test_list_collections_success():
await mcp.call_tool("chroma_create_collection", {"collection_name": collection_name})

# List collections
result = await mcp.call_tool("chroma_list_collections", {"limit": None, "offset": None})
result = await mcp.call_tool("chroma_list_collections", {"limit": 50, "offset": 0})
assert isinstance(result, list)
assert any(collection_name in item.text for item in result)

Expand Down