Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.2.0] - 2024-04-02
## [0.2.1] - 2025-04-03

### Added

- The ability to select embedding functions when creating collections (default, cohere, openai, jina, voyageai, roboflow)

### Changed
- Upgraded to v1.0.0 of Chroma
- Fix dotenv path support during argparse

## [0.2.0] - 2025-04-02

### Added
- New `delete_document` tool for removing documents from collections
Expand Down
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ This server provides data retrieval capabilities powered by Chroma, enabling AI
- List all collections with pagination support
- Get collection information and statistics
- Configure HNSW parameters for optimized vector search
- Select embedding functions when creating collections

- **Document Operations**
- Add documents with optional metadata and custom IDs
Expand All @@ -66,6 +67,13 @@ This server provides data retrieval capabilities powered by Chroma, enabling AI
- `chroma_update_documents` - Update existing documents' content, metadata, or embeddings
- `chroma_delete_documents` - Delete specific documents from a collection

### Embedding Functions
Chroma MCP supports several embedding functions: `default`, `cohere`, `openai`, `jina`, `voyageai`, and `roboflow`.

The embedding functions utilize Chroma's collection configuration, which persists the selected embedding function of a collection for retrieval. Once a collection is created using the collection configuration, on retrieval for future queries and inserts, the same embedding function will be used, without needing to specify the embedding function again. Embedding function persistance was added in v1.0.0 of Chroma, so if you created a collection using version <=0.6.3, this feature is not supported.

When accessing embedding functions that utilize external APIs, please be sure to add the environment variable for the API key with the correct format, found in [Embedding Function Environment Variables](#embedding-function-environment-variables)

## Usage with Claude Desktop

1. To add an ephemeral client, add the following to your `claude_desktop_config.json` file:
Expand Down Expand Up @@ -168,4 +176,7 @@ export CHROMA_CUSTOM_AUTH_CREDENTIALS="your-custom-auth-credentials"
export CHROMA_SSL="true"
```


#### Embedding Function Environment Variables
When using external embedding functions that access an API key, follow the naming convention
`CHROMA_<>_API_KEY="<key>"`.
So to set a Cohere API key, set the environment variable `CHROMA_COHERE_API_KEY=""` in the file matching the dotenv-path argument, or as a system environment variable.
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "chroma-mcp"
version = "0.2.0"
version = "0.2.1"
description = "Chroma MCP Server - Vector Database Integration for LLM Applications"
readme = "README.md"
requires-python = ">=3.10"
Expand All @@ -16,11 +16,17 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules"
]
dependencies = [
"chromadb>=0.6.3",
"chromadb>=1.0.0",
"cohere>=5.14.2",
"httpx>=0.28.1",
"mcp[cli]>=1.2.1",
"openai>=1.70.0",
"pillow>=11.1.0",
"pytest>=8.3.5",
"pytest-asyncio>=0.26.0",
"python-dotenv>=0.19.0",
"typing-extensions>=4.13.1",
"voyageai>=0.3.2",
]

[project.urls]
Expand Down
157 changes: 116 additions & 41 deletions src/chroma_mcp/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TypedDict
from enum import Enum
import chromadb
from mcp.server.fastmcp import FastMCP
import os
Expand All @@ -9,6 +10,21 @@
import uuid
import time
import json
from typing_extensions import TypedDict


from chromadb.api.collection_configuration import (
CreateCollectionConfiguration, CreateHNSWConfiguration, UpdateHNSWConfiguration, UpdateCollectionConfiguration
)
from chromadb.api import EmbeddingFunction
from chromadb.utils.embedding_functions import (
DefaultEmbeddingFunction,
CohereEmbeddingFunction,
OpenAIEmbeddingFunction,
JinaEmbeddingFunction,
VoyageAIEmbeddingFunction,
RoboflowEmbeddingFunction,
)

# Initialize FastMCP server
mcp = FastMCP("chroma")
Expand Down Expand Up @@ -49,7 +65,7 @@ def create_parser():
type=lambda x: x.lower() in ['true', 'yes', '1', 't', 'y'],
default=os.getenv('CHROMA_SSL', 'true').lower() in ['true', 'yes', '1', 't', 'y'])
parser.add_argument('--dotenv-path',
help='Path to .chroma_env file',
help='Path to .env file',
default=os.getenv('CHROMA_DOTENV_PATH', '.chroma_env'))
return parser

Expand All @@ -64,7 +80,7 @@ def get_chroma_client(args=None):

# Load environment variables from .env file if it exists
load_dotenv(dotenv_path=args.dotenv_path)

print(args.dotenv_path)
if args.client_type == 'http':
if not args.host:
raise ValueError("Host must be provided via --host flag or CHROMA_HOST environment variable when using HTTP client")
Expand Down Expand Up @@ -143,58 +159,87 @@ async def chroma_list_collections(
"""
client = get_chroma_client()
try:
return client.list_collections(limit=limit, offset=offset)
colls = client.list_collections(limit=limit, offset=offset)
# iterate over colls and output the names
return [coll.name for coll in colls]

except Exception as e:
raise Exception(f"Failed to list collections: {str(e)}") from e


mcp_known_embedding_functions: Dict[str, EmbeddingFunction] = {
"default": DefaultEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"openai": OpenAIEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"voyageai": VoyageAIEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
}
@mcp.tool()
async def chroma_create_collection(
collection_name: str,
hnsw_space: Optional[str] = None,
hnsw_construction_ef: Optional[int] = None,
hnsw_search_ef: Optional[int] = None,
hnsw_M: Optional[int] = None,
hnsw_num_threads: Optional[int] = None,
hnsw_resize_factor: Optional[float] = None,
hnsw_batch_size: Optional[int] = None,
hnsw_sync_threshold: Optional[int] = None
embedding_function_name: Optional[str] = "default",
metadata: Optional[Dict] = None,
space: Optional[str] = None,
ef_construction: Optional[int] = None,
ef_search: Optional[int] = None,
max_neighbors: Optional[int] = None,
num_threads: Optional[int] = None,
batch_size: Optional[int] = None,
sync_threshold: Optional[int] = None,
resize_factor: Optional[float] = None,
) -> str:
"""Create a new Chroma collection with configurable HNSW parameters.

Args:
collection_name: Name of the collection to create
hnsw_space: Distance function used in HNSW index. Options: 'l2', 'ip', 'cosine'
hnsw_construction_ef: Size of the dynamic candidate list for constructing the HNSW graph
hnsw_search_ef: Size of the dynamic candidate list for searching the HNSW graph
hnsw_M: Number of bi-directional links created for every new element
hnsw_num_threads: Number of threads to use during HNSW construction
hnsw_resize_factor: Factor to resize the index by when it's full
hnsw_batch_size: Number of elements to batch together during index construction
hnsw_sync_threshold: Number of elements to process before syncing index to disk
space: Distance function used in HNSW index. Options: 'l2', 'ip', 'cosine'
ef_construction: Size of the dynamic candidate list for constructing the HNSW graph
ef_search: Size of the dynamic candidate list for searching the HNSW graph
max_neighbors: Maximum number of neighbors to consider during HNSW graph construction
num_threads: Number of threads to use during HNSW construction
batch_size: Number of elements to batch together during index construction
sync_threshold: Number of elements to process before syncing index to disk
resize_factor: Factor to resize the index by when it's full
embedding_function_name: Name of the embedding function to use. Options: 'default', 'cohere', 'openai', 'jina', 'voyageai', 'ollama', 'roboflow'
metadata: Optional metadata dict to add to the collection
"""
client = get_chroma_client()


# Build HNSW configuration directly in metadata, only including non-None values
metadata = {
k: v for k, v in {
"hnsw:space": hnsw_space,
"hnsw:construction_ef": hnsw_construction_ef,
"hnsw:M": hnsw_M,
"hnsw:search_ef": hnsw_search_ef,
"hnsw:num_threads": hnsw_num_threads,
"hnsw:resize_factor": hnsw_resize_factor,
"hnsw:batch_size": hnsw_batch_size,
"hnsw:sync_threshold": hnsw_sync_threshold
}.items() if v is not None
}
embedding_function = mcp_known_embedding_functions[embedding_function_name]

hnsw_config = CreateHNSWConfiguration()
if space:
hnsw_config["space"] = space
if ef_construction:
hnsw_config["ef_construction"] = ef_construction
if ef_search:
hnsw_config["ef_search"] = ef_search
if max_neighbors:
hnsw_config["max_neighbors"] = max_neighbors
if num_threads:
hnsw_config["num_threads"] = num_threads
if batch_size:
hnsw_config["batch_size"] = batch_size
if sync_threshold:
hnsw_config["sync_threshold"] = sync_threshold
if resize_factor:
hnsw_config["resize_factor"] = resize_factor



configuration=CreateCollectionConfiguration(
hnsw=hnsw_config,
embedding_function=embedding_function()
)

try:
client.create_collection(
name=collection_name,
metadata=metadata if metadata else None
configuration=configuration,
metadata=metadata
)
config_msg = f" with HNSW configuration: {metadata}" if metadata else ""
config_msg = f" with configuration: {configuration}"
return f"Successfully created collection {collection_name}{config_msg}"
except Exception as e:
raise Exception(f"Failed to create collection '{collection_name}': {str(e)}") from e
Expand Down Expand Up @@ -261,29 +306,53 @@ async def chroma_get_collection_count(collection_name: str) -> int:
async def chroma_modify_collection(
collection_name: str,
new_name: Optional[str] = None,
new_metadata: Optional[Dict] = None
new_metadata: Optional[Dict] = None,
ef_search: Optional[int] = None,
num_threads: Optional[int] = None,
batch_size: Optional[int] = None,
sync_threshold: Optional[int] = None,
resize_factor: Optional[float] = None,
) -> str:
"""Modify a Chroma collection's name or metadata.

Args:
collection_name: Name of the collection to modify
new_name: Optional new name for the collection
new_metadata: Optional new metadata for the collection
ef_search: Size of the dynamic candidate list for searching the HNSW graph
num_threads: Number of threads to use during HNSW construction
batch_size: Number of elements to batch together during index construction
sync_threshold: Number of elements to process before syncing index to disk
resize_factor: Factor to resize the index by when it's full
"""
client = get_chroma_client()
try:
collection = client.get_collection(collection_name)

if new_name:
collection.modify(name=new_name)
if new_metadata:
collection.modify(metadata=new_metadata)
hnsw_config = UpdateHNSWConfiguration()
if ef_search:
hnsw_config["ef_search"] = ef_search
if num_threads:
hnsw_config["num_threads"] = num_threads
if batch_size:
hnsw_config["batch_size"] = batch_size
if sync_threshold:
hnsw_config["sync_threshold"] = sync_threshold
if resize_factor:
hnsw_config["resize_factor"] = resize_factor

configuration = UpdateCollectionConfiguration(
hnsw=hnsw_config
)
collection.modify(name=new_name, configuration=configuration, metadata=new_metadata)

modified_aspects = []
if new_name:
modified_aspects.append("name")
if new_metadata:
modified_aspects.append("metadata")
if ef_search or num_threads or batch_size or sync_threshold or resize_factor:
modified_aspects.append("hnsw")

return f"Successfully modified collection {collection_name}: updated {' and '.join(modified_aspects)}"
except Exception as e:
Expand Down Expand Up @@ -595,6 +664,12 @@ def main():
parser = create_parser()
args = parser.parse_args()

if args.dotenv_path:
load_dotenv(dotenv_path=args.dotenv_path)
# re-parse args to read the updated environment variables
parser = create_parser()
args = parser.parse_args()

# Validate required arguments based on client type
if args.client_type == 'http':
if not args.host:
Expand Down
35 changes: 23 additions & 12 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,23 +559,34 @@ async def test_create_collection_success():
# Test basic creation
result = await mcp.call_tool("chroma_create_collection", {"collection_name": collection_name})
assert "Successfully created collection" in result[0].text

# Test creation with HNSW configuration
hnsw_collection = "test_hnsw_collection"
hnsw_result = await mcp.call_tool("chroma_create_collection", {
hnsw_params = {
"collection_name": hnsw_collection,
"hnsw_space": "cosine",
"hnsw_construction_ef": 100,
"hnsw_search_ef": 50,
"hnsw_M": 16
})
"space": "cosine",
"ef_construction": 100,
"ef_search": 50,
"max_neighbors": 16 # Assuming M corresponds to max_neighbors
}
hnsw_result = await mcp.call_tool("chroma_create_collection", hnsw_params)
assert "Successfully created collection" in hnsw_result[0].text
assert "HNSW configuration" in hnsw_result[0].text

# Check if the specific config values are in the output string
assert "'space': 'cosine'" in hnsw_result[0].text
assert "'ef_construction': 100" in hnsw_result[0].text
assert "'ef_search': 50" in hnsw_result[0].text
assert "'max_neighbors': 16" in hnsw_result[0].text

finally:
# Clean up
await mcp.call_tool("chroma_delete_collection", {"collection_name": collection_name})
await mcp.call_tool("chroma_delete_collection", {"collection_name": hnsw_collection})
# Cleanup: delete the collections if they exist
try:
await mcp.call_tool("chroma_delete_collection", {"collection_name": collection_name})
except Exception:
pass
try:
await mcp.call_tool("chroma_delete_collection", {"collection_name": hnsw_collection})
except Exception:
pass

@pytest.mark.asyncio
async def test_create_collection_duplicate():
Expand Down
Loading