Skip to content

Commit 0d3bd2a

Browse files
authored
[ENH]: Error if source_key set but no ef (#5751)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Errors if source_key is set but ef is not supplied. Both client and server side - New functionality - ... ## Test plan _How are these changes tested?_ Added unit-test - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan None ## Observability plan None ## Documentation Changes None
1 parent a1ea81a commit 0d3bd2a

File tree

4 files changed

+52
-10
lines changed

4 files changed

+52
-10
lines changed

chromadb/api/types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,20 @@ def _validate_single_sparse_vector_index(self, key: str) -> None:
20442044
f"Only one sparse vector index is allowed per collection."
20452045
)
20462046

2047+
def _validate_sparse_vector_config(self, config: SparseVectorIndexConfig) -> None:
2048+
"""
2049+
Validate that if source_key is provided then either embedding_function or bm25
2050+
must be provided since there is no default embedding function.
2051+
Raises ValueError otherwise.
2052+
"""
2053+
if (config.source_key is not None
2054+
and config.embedding_function is None
2055+
and config.bm25 is not True):
2056+
raise ValueError(
2057+
f"If source_key is provided then either embedding_function or bm25 must be provided "
2058+
f"since there is no default embedding function. Config: {config}"
2059+
)
2060+
20472061
def _set_index_for_key(self, key: str, config: IndexConfig, enabled: bool) -> None:
20482062
"""Set an index configuration for a specific key."""
20492063
config_name = self._get_config_class_name(config)
@@ -2052,6 +2066,7 @@ def _set_index_for_key(self, key: str, config: IndexConfig, enabled: bool) -> No
20522066
# Do this BEFORE creating the key entry
20532067
if config_name == "SparseVectorIndexConfig" and enabled:
20542068
self._validate_single_sparse_vector_index(key)
2069+
self._validate_sparse_vector_config(cast(SparseVectorIndexConfig, config))
20552070

20562071
if key not in self.keys:
20572072
self.keys[key] = ValueTypes()
@@ -2096,6 +2111,8 @@ def _enable_all_indexes_for_key(self, key: str) -> None:
20962111
if key not in self.keys:
20972112
self.keys[key] = ValueTypes()
20982113

2114+
self._validate_single_sparse_vector_index(key)
2115+
20992116
# Enable all index types with default configs
21002117
self.keys[key].string = StringValueType(
21012118
fts_index=FtsIndexType(enabled=True, config=FtsIndexConfig()),

chromadb/test/api/test_schema.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def test_chained_create_and_delete_operations(self) -> None:
246246
# 1. Create sparse vector index on "embeddings_key"
247247
# 2. Disable string inverted index on "text_key_1"
248248
# 3. Disable string inverted index on "text_key_2"
249-
sparse_config = SparseVectorIndexConfig(source_key="raw_text")
249+
sparse_config = SparseVectorIndexConfig(source_key="raw_text", bm25=True)
250250
string_config = StringInvertedIndexConfig()
251251

252252
result = (
@@ -1268,7 +1268,7 @@ def test_multiple_index_types_on_same_key(self) -> None:
12681268
schema = Schema()
12691269

12701270
# Enable sparse vector on "multi_field"
1271-
sparse_config = SparseVectorIndexConfig(source_key="source")
1271+
sparse_config = SparseVectorIndexConfig(source_key="source", bm25=True)
12721272
schema.create_index(config=sparse_config, key="multi_field")
12731273

12741274
# Also enable string_inverted_index on the same key
@@ -1454,7 +1454,7 @@ def test_multiple_serialize_deserialize_roundtrips(self) -> None:
14541454
hnsw=hnsw_config
14551455
)
14561456
original.create_index(config=vector_config)
1457-
original.create_index(config=SparseVectorIndexConfig(source_key="text"), key="embeddings")
1457+
original.create_index(config=SparseVectorIndexConfig(source_key="text", bm25=True), key="embeddings")
14581458
original.delete_index(config=StringInvertedIndexConfig(), key="tags")
14591459

14601460
# First roundtrip
@@ -1514,7 +1514,7 @@ def test_many_keys_stress(self) -> None:
15141514
key_name = f"field_{i}"
15151515
if i == 0:
15161516
# Enable sparse vector on ONE key only
1517-
schema.create_index(config=SparseVectorIndexConfig(source_key=f"source_{i}"), key=key_name)
1517+
schema.create_index(config=SparseVectorIndexConfig(source_key=f"source_{i}", bm25=True), key=key_name)
15181518
elif i % 2 == 1:
15191519
# Disable string inverted index
15201520
schema.delete_index(config=StringInvertedIndexConfig(), key=key_name)
@@ -1578,7 +1578,7 @@ def test_chained_operations(self) -> None:
15781578

15791579
# Chain multiple operations
15801580
result = (schema
1581-
.create_index(config=SparseVectorIndexConfig(source_key="text"), key="field1")
1581+
.create_index(config=SparseVectorIndexConfig(source_key="text", bm25=True), key="field1")
15821582
.delete_index(config=StringInvertedIndexConfig(), key="field2")
15831583
.delete_index(config=StringInvertedIndexConfig(), key="field3")
15841584
.delete_index(config=IntInvertedIndexConfig(), key="field4"))
@@ -1820,7 +1820,7 @@ def test_keys_have_independent_configs(self) -> None:
18201820
schema = Schema()
18211821

18221822
# Enable sparse vector on a key - it gets exactly what we specify
1823-
sparse_config = SparseVectorIndexConfig(source_key="default_source")
1823+
sparse_config = SparseVectorIndexConfig(source_key="default_source", bm25=True)
18241824
schema.create_index(config=sparse_config, key="field1")
18251825

18261826
# Verify field1 has the sparse vector with the specified source_key
@@ -1907,7 +1907,7 @@ def test_key_specific_overrides_are_independent(self) -> None:
19071907
schema = Schema()
19081908

19091909
# Create sparse vector on one key and string indexes on others
1910-
schema.create_index(config=SparseVectorIndexConfig(source_key="source_a"), key="key_a")
1910+
schema.create_index(config=SparseVectorIndexConfig(source_key="source_a", bm25=True), key="key_a")
19111911
schema.create_index(config=StringInvertedIndexConfig(), key="key_b")
19121912
schema.create_index(config=StringInvertedIndexConfig(), key="key_c")
19131913

@@ -1992,7 +1992,7 @@ def test_partial_override_fills_from_defaults(self) -> None:
19921992
schema = Schema()
19931993

19941994
# Enable sparse vector on a key
1995-
schema.create_index(config=SparseVectorIndexConfig(source_key="my_source"), key="multi_index_field")
1995+
schema.create_index(config=SparseVectorIndexConfig(source_key="my_source", bm25=True), key="multi_index_field")
19961996

19971997
# This key now has sparse_vector overridden, but string, int, etc. should still follow global defaults
19981998
field = schema.keys["multi_index_field"]
@@ -2337,3 +2337,23 @@ def test_config_source_key_validates_special_keys() -> None:
23372337
# Regular keys (no #) are allowed
23382338
config6 = SparseVectorIndexConfig(source_key="my_field")
23392339
assert config6.source_key == "my_field"
2340+
2341+
2342+
def test_sparse_vector_config_requires_ef_with_source_key() -> None:
2343+
"""Test that SparseVectorIndexConfig raises ValueError when source_key is provided without embedding_function."""
2344+
schema = Schema()
2345+
2346+
# Attempt to create sparse vector index with source_key but no embedding_function
2347+
with pytest.raises(ValueError) as exc_info:
2348+
schema.create_index(
2349+
key="invalid_sparse",
2350+
config=SparseVectorIndexConfig(
2351+
source_key="text_field",
2352+
# No embedding_function provided - should raise ValueError
2353+
),
2354+
)
2355+
2356+
# Verify the error message mentions both source_key and embedding_function
2357+
error_msg = str(exc_info.value)
2358+
assert "source_key" in error_msg.lower()
2359+
assert "embedding_function" in error_msg.lower()

chromadb/test/api/test_schema_e2e.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2359,10 +2359,10 @@ def test_modify_collection_preserves_other_schema_fields(client: ClientAPI) -> N
23592359
collection_refreshed = client.get_collection(collection_name)
23602360
refreshed_schema = collection_refreshed.schema
23612361
assert refreshed_schema is not None
2362-
2362+
23632363
# Verify vector index was updated on server
23642364
assert refreshed_schema.defaults.float_list.vector_index.config.spann.search_nprobe == 128 # type: ignore
2365-
2365+
23662366
# Verify other value types are still intact on server
23672367
assert refreshed_schema.defaults.string is not None
23682368
assert refreshed_schema.defaults.string.string_inverted_index is not None

rust/types/src/validators.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ pub fn validate_schema(schema: &Schema) -> Result<(), ValidationError> {
281281
.into(),
282282
));
283283
}
284+
if svit.config.source_key.is_some() && svit.config.embedding_function.is_none() {
285+
return Err(ValidationError::new("schema").with_message(
286+
"If source_key is provided then embedding_function must also be provided since there is no default embedding function.".into(),
287+
));
288+
}
284289
}
285290
// Validate source_key for sparse vector index
286291
if let Some(source_key) = &svit.config.source_key {

0 commit comments

Comments
 (0)