Skip to content

Commit df5c044

Browse files
authored
[ENH]: (Rust client): export more types, config -> schema on create_collection() (#5699)
1 parent 7ac13ec commit df5c044

File tree

5 files changed

+54
-15
lines changed

5 files changed

+54
-15
lines changed

clients/new-js/packages/chromadb/src/api/types.gen.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ export type RemoveTaskResponse = {
284284

285285
/**
286286
* Schema representation for collection index configurations
287+
*
287288
* This represents the server-side schema structure used for index management
288289
*/
289290
export type Schema = {

rust/chroma/src/client/chroma_http_client.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use backon::Retryable;
33
use chroma_api_types::ErrorResponse;
44
use chroma_error::ChromaValidationError;
55
use chroma_types::Collection;
6-
use chroma_types::CollectionConfiguration;
76
use chroma_types::Metadata;
7+
use chroma_types::Schema;
88
use chroma_types::WhereError;
99
use parking_lot::Mutex;
1010
use reqwest::Method;
@@ -437,10 +437,10 @@ impl ChromaHttpClient {
437437
pub async fn get_or_create_collection(
438438
&self,
439439
name: impl AsRef<str>,
440-
configuration: Option<CollectionConfiguration>,
440+
schema: Option<Schema>,
441441
metadata: Option<Metadata>,
442442
) -> Result<ChromaCollection, ChromaHttpClientError> {
443-
self.common_create_collection(name, configuration, metadata, true)
443+
self.common_create_collection(name, schema, metadata, true)
444444
.await
445445
}
446446

@@ -473,10 +473,10 @@ impl ChromaHttpClient {
473473
pub async fn create_collection(
474474
&self,
475475
name: impl AsRef<str>,
476-
configuration: Option<CollectionConfiguration>,
476+
schema: Option<Schema>,
477477
metadata: Option<Metadata>,
478478
) -> Result<ChromaCollection, ChromaHttpClientError> {
479-
self.common_create_collection(name, configuration, metadata, false)
479+
self.common_create_collection(name, schema, metadata, false)
480480
.await
481481
}
482482

@@ -620,7 +620,7 @@ impl ChromaHttpClient {
620620
async fn common_create_collection(
621621
&self,
622622
name: impl AsRef<str>,
623-
configuration: Option<CollectionConfiguration>,
623+
schema: Option<Schema>,
624624
metadata: Option<Metadata>,
625625
get_or_create: bool,
626626
) -> Result<ChromaCollection, ChromaHttpClientError> {
@@ -637,7 +637,7 @@ impl ChromaHttpClient {
637637
),
638638
Some(serde_json::json!({
639639
"name": name.as_ref(),
640-
"configuration": configuration,
640+
"schema": schema,
641641
"metadata": metadata,
642642
"get_or_create": get_or_create,
643643
})),
@@ -877,6 +877,7 @@ mod tests {
877877
use super::*;
878878
use crate::client::ChromaRetryOptions;
879879
use crate::tests::with_client;
880+
use chroma_types::{EmbeddingFunctionConfiguration, EmbeddingFunctionNewConfiguration};
880881
use httpmock::{HttpMockResponse, MockServer};
881882
use std::sync::atomic::AtomicBool;
882883
use std::time::Duration;
@@ -1059,13 +1060,23 @@ mod tests {
10591060
#[test_log::test]
10601061
async fn test_live_cloud_create_collection() {
10611062
with_client(|client| async move {
1062-
let collection = client.create_collection("foo", None, None).await.unwrap();
1063+
let schema = Schema::default_with_embedding_function(
1064+
EmbeddingFunctionConfiguration::Known(EmbeddingFunctionNewConfiguration {
1065+
name: "bar".to_string(),
1066+
config: serde_json::json!({}),
1067+
}),
1068+
);
1069+
let collection = client
1070+
.create_collection("foo", Some(schema.clone()), None)
1071+
.await
1072+
.unwrap();
10631073
assert_eq!(collection.collection.name, "foo");
10641074

1065-
client
1075+
let collection = client
10661076
.get_or_create_collection("foo", None, None)
10671077
.await
10681078
.unwrap();
1079+
assert_eq!(collection.schema().clone().unwrap(), schema);
10691080
})
10701081
.await;
10711082
}

rust/chroma/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,13 @@
101101
#![deny(missing_docs)]
102102

103103
pub mod client;
104-
pub mod collection;
104+
mod collection;
105105
pub mod embed;
106106
pub mod types;
107107

108108
pub use client::ChromaHttpClient;
109109
pub use client::ChromaHttpClientOptions;
110+
pub use collection::ChromaCollection;
110111

111112
#[cfg(test)]
112113
mod tests {

rust/chroma/src/types.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@
33
pub use chroma_api_types::{GetUserIdentityResponse, HeartbeatResponse};
44

55
pub use chroma_types::{
6-
plan::SearchPayload, AddCollectionRecordsRequest, AddCollectionRecordsResponse, Collection,
7-
DeleteCollectionRecordsRequest, DeleteCollectionRecordsResponse, ForkCollectionRequest,
8-
GetRequest, GetResponse, IncludeList, Metadata, QueryRequest, QueryResponse, Schema,
9-
SearchRequest, SearchResponse, UpdateCollectionRecordsRequest, UpdateCollectionRecordsResponse,
10-
UpdateMetadata, UpsertCollectionRecordsRequest, UpsertCollectionRecordsResponse, Where,
6+
plan::SearchPayload, AddCollectionRecordsRequest, AddCollectionRecordsResponse,
7+
BooleanOperator, Collection, CompositeExpression, DeleteCollectionRecordsRequest,
8+
DeleteCollectionRecordsResponse, DocumentExpression, DocumentOperator,
9+
EmbeddingFunctionConfiguration, EmbeddingFunctionNewConfiguration, ForkCollectionRequest,
10+
GetRequest, GetResponse, Include, IncludeList, Metadata, MetadataComparison,
11+
MetadataExpression, MetadataSetValue, MetadataValue, PrimitiveOperator, QueryRequest,
12+
QueryResponse, Schema, SearchRequest, SearchResponse, SetOperator,
13+
UpdateCollectionRecordsRequest, UpdateCollectionRecordsResponse, UpdateMetadata,
14+
UpdateMetadataValue, UpsertCollectionRecordsRequest, UpsertCollectionRecordsResponse, Where,
1115
};
16+
17+
pub use chroma_types::operator::Key;

rust/types/src/collection_schema.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ pub const EMBEDDING_KEY: &str = "#embedding";
8888
// ============================================================================
8989

9090
/// Schema representation for collection index configurations
91+
///
9192
/// This represents the server-side schema structure used for index management
9293
9394
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
@@ -1114,6 +1115,25 @@ impl Schema {
11141115
}
11151116
}
11161117

1118+
pub fn default_with_embedding_function(
1119+
embedding_function: EmbeddingFunctionConfiguration,
1120+
) -> Schema {
1121+
let mut schema = Schema::new_default(KnnIndex::Spann);
1122+
if let Some(float_list) = &mut schema.defaults.float_list {
1123+
if let Some(vector_index) = &mut float_list.vector_index {
1124+
vector_index.config.embedding_function = Some(embedding_function.clone());
1125+
}
1126+
}
1127+
if let Some(embedding_types) = schema.keys.get_mut(EMBEDDING_KEY) {
1128+
if let Some(float_list) = &mut embedding_types.float_list {
1129+
if let Some(vector_index) = &mut float_list.vector_index {
1130+
vector_index.config.embedding_function = Some(embedding_function);
1131+
}
1132+
}
1133+
}
1134+
schema
1135+
}
1136+
11171137
/// Check if schema is default by comparing it word-by-word with new_default
11181138
fn is_schema_default(schema: &Schema) -> bool {
11191139
// Compare with both possible default schemas (HNSW and SPANN)

0 commit comments

Comments
 (0)