Skip to content

Commit 8e8fb95

Browse files
authored
[ENH]: Disallow setting only source_key without an ef (#5758)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Setting only source_key without an ef in the sparse vector index config now returns an error - New functionality - ... ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 0734fdf commit 8e8fb95

File tree

3 files changed

+47
-8
lines changed

3 files changed

+47
-8
lines changed

chromadb/api/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,17 @@ def _validate_single_sparse_vector_index(self, key: str) -> None:
20442044
f"Only one sparse vector index is allowed per collection."
20452045
)
20462046

2047+
def _validate_sparse_vector_config(self, config: SparseVectorIndexConfig) -> None:
2048+
"""
2049+
Validate that if source_key is provided then embedding_function is also provided
2050+
since there is no default embedding function. Raises ValueError otherwise.
2051+
"""
2052+
if config.source_key is not None and config.embedding_function is None:
2053+
raise ValueError(
2054+
f"If source_key is provided then embedding_function must also be provided "
2055+
f"since there is no default embedding function. Config: {config}"
2056+
)
2057+
20472058
def _set_index_for_key(self, key: str, config: IndexConfig, enabled: bool) -> None:
20482059
"""Set an index configuration for a specific key."""
20492060
config_name = self._get_config_class_name(config)
@@ -2052,6 +2063,7 @@ def _set_index_for_key(self, key: str, config: IndexConfig, enabled: bool) -> No
20522063
# Do this BEFORE creating the key entry
20532064
if config_name == "SparseVectorIndexConfig" and enabled:
20542065
self._validate_single_sparse_vector_index(key)
2066+
self._validate_sparse_vector_config(cast(SparseVectorIndexConfig, config))
20552067

20562068
if key not in self.keys:
20572069
self.keys[key] = ValueTypes()
@@ -2096,6 +2108,8 @@ def _enable_all_indexes_for_key(self, key: str) -> None:
20962108
if key not in self.keys:
20972109
self.keys[key] = ValueTypes()
20982110

2111+
self._validate_single_sparse_vector_index(key)
2112+
20992113
# Enable all index types with default configs
21002114
self.keys[key].string = StringValueType(
21012115
fts_index=FtsIndexType(enabled=True, config=FtsIndexConfig()),

chromadb/test/api/test_schema.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def test_chained_create_and_delete_operations(self) -> None:
246246
# 1. Create sparse vector index on "embeddings_key"
247247
# 2. Disable string inverted index on "text_key_1"
248248
# 3. Disable string inverted index on "text_key_2"
249-
sparse_config = SparseVectorIndexConfig(source_key="raw_text")
249+
sparse_config = SparseVectorIndexConfig(source_key="raw_text", embedding_function=MockSparseEmbeddingFunction())
250250
string_config = StringInvertedIndexConfig()
251251

252252
result = (
@@ -1268,7 +1268,7 @@ def test_multiple_index_types_on_same_key(self) -> None:
12681268
schema = Schema()
12691269

12701270
# Enable sparse vector on "multi_field"
1271-
sparse_config = SparseVectorIndexConfig(source_key="source")
1271+
sparse_config = SparseVectorIndexConfig(source_key="source", embedding_function=MockSparseEmbeddingFunction())
12721272
schema.create_index(config=sparse_config, key="multi_field")
12731273

12741274
# Also enable string_inverted_index on the same key
@@ -1454,7 +1454,7 @@ def test_multiple_serialize_deserialize_roundtrips(self) -> None:
14541454
hnsw=hnsw_config
14551455
)
14561456
original.create_index(config=vector_config)
1457-
original.create_index(config=SparseVectorIndexConfig(source_key="text"), key="embeddings")
1457+
original.create_index(config=SparseVectorIndexConfig(source_key="text", embedding_function=MockSparseEmbeddingFunction()), key="embeddings")
14581458
original.delete_index(config=StringInvertedIndexConfig(), key="tags")
14591459

14601460
# First roundtrip
@@ -1514,7 +1514,7 @@ def test_many_keys_stress(self) -> None:
15141514
key_name = f"field_{i}"
15151515
if i == 0:
15161516
# Enable sparse vector on ONE key only
1517-
schema.create_index(config=SparseVectorIndexConfig(source_key=f"source_{i}"), key=key_name)
1517+
schema.create_index(config=SparseVectorIndexConfig(source_key=f"source_{i}", embedding_function=MockSparseEmbeddingFunction()), key=key_name)
15181518
elif i % 2 == 1:
15191519
# Disable string inverted index
15201520
schema.delete_index(config=StringInvertedIndexConfig(), key=key_name)
@@ -1578,7 +1578,7 @@ def test_chained_operations(self) -> None:
15781578

15791579
# Chain multiple operations
15801580
result = (schema
1581-
.create_index(config=SparseVectorIndexConfig(source_key="text"), key="field1")
1581+
.create_index(config=SparseVectorIndexConfig(source_key="text", embedding_function=MockSparseEmbeddingFunction()), key="field1")
15821582
.delete_index(config=StringInvertedIndexConfig(), key="field2")
15831583
.delete_index(config=StringInvertedIndexConfig(), key="field3")
15841584
.delete_index(config=IntInvertedIndexConfig(), key="field4"))
@@ -1820,7 +1820,7 @@ def test_keys_have_independent_configs(self) -> None:
18201820
schema = Schema()
18211821

18221822
# Enable sparse vector on a key - it gets exactly what we specify
1823-
sparse_config = SparseVectorIndexConfig(source_key="default_source")
1823+
sparse_config = SparseVectorIndexConfig(source_key="default_source", embedding_function=MockSparseEmbeddingFunction())
18241824
schema.create_index(config=sparse_config, key="field1")
18251825

18261826
# Verify field1 has the sparse vector with the specified source_key
@@ -1907,7 +1907,7 @@ def test_key_specific_overrides_are_independent(self) -> None:
19071907
schema = Schema()
19081908

19091909
# Create sparse vector on one key and string indexes on others
1910-
schema.create_index(config=SparseVectorIndexConfig(source_key="source_a"), key="key_a")
1910+
schema.create_index(config=SparseVectorIndexConfig(source_key="source_a", embedding_function=MockSparseEmbeddingFunction()), key="key_a")
19111911
schema.create_index(config=StringInvertedIndexConfig(), key="key_b")
19121912
schema.create_index(config=StringInvertedIndexConfig(), key="key_c")
19131913

@@ -1992,7 +1992,7 @@ def test_partial_override_fills_from_defaults(self) -> None:
19921992
schema = Schema()
19931993

19941994
# Enable sparse vector on a key
1995-
schema.create_index(config=SparseVectorIndexConfig(source_key="my_source"), key="multi_index_field")
1995+
schema.create_index(config=SparseVectorIndexConfig(source_key="my_source", embedding_function=MockSparseEmbeddingFunction()), key="multi_index_field")
19961996

19971997
# This key now has sparse_vector overridden, but string, int, etc. should still follow global defaults
19981998
field = schema.keys["multi_index_field"]
@@ -2337,3 +2337,23 @@ def test_config_source_key_validates_special_keys() -> None:
23372337
# Regular keys (no #) are allowed
23382338
config6 = SparseVectorIndexConfig(source_key="my_field")
23392339
assert config6.source_key == "my_field"
2340+
2341+
2342+
def test_sparse_vector_config_requires_ef_with_source_key() -> None:
2343+
"""Test that SparseVectorIndexConfig raises ValueError when source_key is provided without embedding_function."""
2344+
schema = Schema()
2345+
2346+
# Attempt to create sparse vector index with source_key but no embedding_function
2347+
with pytest.raises(ValueError) as exc_info:
2348+
schema.create_index(
2349+
key="invalid_sparse",
2350+
config=SparseVectorIndexConfig(
2351+
source_key="text_field",
2352+
# No embedding_function provided - should raise ValueError
2353+
),
2354+
)
2355+
2356+
# Verify the error message mentions both source_key and embedding_function
2357+
error_msg = str(exc_info.value)
2358+
assert "source_key" in error_msg.lower()
2359+
assert "embedding_function" in error_msg.lower()

rust/types/src/validators.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ pub fn validate_schema(schema: &Schema) -> Result<(), ValidationError> {
281281
.into(),
282282
));
283283
}
284+
if svit.config.source_key.is_some() && svit.config.embedding_function.is_none() {
285+
return Err(ValidationError::new("schema").with_message(
286+
"If source_key is provided then embedding_function must also be provided since there is no default embedding function.".into(),
287+
));
288+
}
284289
}
285290
// Validate source_key for sparse vector index
286291
if let Some(source_key) = &svit.config.source_key {

0 commit comments

Comments
 (0)