Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clients/ts-sdk/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -9219,6 +9219,11 @@
"description": "The base URL for the reranker API",
"nullable": true
},
"RERANKER_MODEL_NAME": {
"type": "string",
"description": "The model name for the Reranker API",
"nullable": true
},
"SEMANTIC_ENABLED": {
"type": "boolean",
"description": "Whether to use semantic search",
Expand Down
4 changes: 4 additions & 0 deletions clients/ts-sdk/src/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,10 @@ export type DatasetConfigurationDTO = {
* The base URL for the reranker API
*/
RERANKER_BASE_URL?: (string) | null;
/**
* The model name for the Reranker API
*/
RERANKER_MODEL_NAME?: (string) | null;
/**
* Whether to use semantic search
*/
Expand Down
2 changes: 1 addition & 1 deletion frontends/dashboard/src/components/ApiKeys.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ export const ApiKeys = () => {
];
return createSolidTable({
columns: columns,
data: userApiKeysQuery.data as ApiKeyRespBody[],
data: userApiKeysQuery.data,
getCoreRowModel: getCoreRowModel(),
});
});
Expand Down
49 changes: 49 additions & 0 deletions frontends/dashboard/src/components/NewDatasetModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { useNavigate } from "@solidjs/router";
import {
availableDistanceMetrics,
availableEmbeddingModels,
availableRerankerModels,
} from "shared/types";
import { createToast } from "./ShowToasts";
import { createNewDataset } from "../api/createDataset";
Expand Down Expand Up @@ -337,6 +338,54 @@ export const NewDatasetModal = (props: NewDatasetModalProps) => {
</select>
</div>

<div class="content-center py-4 sm:grid sm:grid-cols-3 sm:items-start sm:gap-4">
<label
for="embeddingSize"
class="flex h-full items-center gap-2 pt-1.5 text-sm font-medium leading-6"
>
Reranker Model{" "}
<Tooltip
body={
<FaRegularCircleQuestion class="h-3 w-3 text-black" />
}
tooltipText="Reranker Model for re-ranking search results."
/>
</label>
<select
id="embeddingSize"
name="embeddingSize"
class="col-span-2 block w-full rounded-md border-[0.5px] border-neutral-300 bg-white px-3 py-1.5 shadow-sm placeholder:text-neutral-400 focus:outline-magenta-500 sm:text-sm sm:leading-6"
value={
availableRerankerModels.find(
(model) =>
model.id ===
serverConfig.RERANKER_MODEL_NAME,
)?.name ?? availableRerankerModels[0].name
}
onChange={(e) => {
const selectedModel =
availableRerankerModels.find(
(model) =>
model.name === e.currentTarget.value,
);

setServerConfig((prev) => {
return {
...prev,
RERANKER_MODEL_NAME: selectedModel?.id,
};
});
}}
>
<For each={availableRerankerModels}>
{(model) => (
<option value={model.name}>
{model.name}
</option>
)}
</For>
</select>
</div>
<div class="content-center py-4 sm:grid sm:grid-cols-3 sm:items-start sm:gap-4">
<label
for="distanceMetric"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { DatasetConfig } from "./LegacySettingsWrapper";
import {
availableDistanceMetrics,
availableEmbeddingModels,
availableRerankerModels,
} from "shared/types";

export const GeneralServerSettings = (props: {
Expand Down Expand Up @@ -94,6 +95,84 @@ export const GeneralServerSettings = (props: {
/>
</div>

<div class="col-span-4 space-y-1 sm:col-span-2">
<div class="flex items-center">
<label
for="embeddingSize"
class="mr-2 block text-sm font-medium leading-6"
>
Reranker Model
</label>
<Tooltip
body={<AiOutlineInfoCircle />}
tooltipText="Reranker Model for re-ranking search results."
/>
</div>
<select
id="embeddingSize"
aria-readonly
title="Embedding Model is only editable on creation"
name="embeddingSize"
class="col-span-2 block w-full rounded-md border-[0.5px] border-neutral-300 bg-white px-3 py-1.5 shadow-sm placeholder:text-neutral-400 focus:outline-magenta-500 sm:text-sm sm:leading-6"
value={
availableRerankerModels.find(
(metric) =>
metric.id === props.serverConfig().RERANKER_MODEL_NAME,
)?.name ?? availableRerankerModels[0].name
}
onChange={(e) => {
const selectedModel = availableRerankerModels.find(
(model) => model.name === e.currentTarget.value,
);

const url = selectedModel?.url ?? "";

props.setServerConfig((prev) => {
return {
...prev,
RERANKER_MODEL_NAME: selectedModel?.id,
RERANKER_BASE_URL: url,
};
});
}}
>
<For each={availableRerankerModels}>
{(metric) => (
<option value={metric.name}>{metric.name}</option>
)}
</For>
</select>
</div>
<div class="col-span-4 sm:col-span-2">
<div class="flex flex-row items-center gap-2">
<label
for="rerankerApiKey"
class="block text-sm font-medium leading-6"
>
Cohere API Key (for reranker)
</label>
<Tooltip
body={<AiOutlineInfoCircle />}
tooltipText="Sets the API key for the Cohere reranker if you choose to use it."
/>
</div>
<input
type="text"
name="rerankerApiKey"
id="linesBeforeShowMore"
class="block w-full rounded-md border-[0.5px] border-neutral-300 px-3 py-1.5 shadow-sm placeholder:text-neutral-400 focus:outline-magenta-500 sm:text-sm sm:leading-6"
value={props.serverConfig().RERANKER_API_KEY ?? ""}
onChange={(e) =>
props.setServerConfig((prev) => {
return {
...prev,
RERANKER_API_KEY: e.currentTarget.value,
};
})
}
/>
</div>

<div class="col-span-4 space-y-1 sm:col-span-2">
<div class="flex items-center">
<label
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export type DatasetConfig = Exclude<
"PUBLIC_DATASET"
> & {
LLM_API_KEY?: string | null;
RERANKER_API_KEY?: string | null;
PUBLIC_DATASET?: {
enabled: boolean;
api_key: string;
Expand Down
1 change: 1 addition & 0 deletions frontends/dashboard/src/utils/serverEnvs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export const defaultServerEnvsConfiguration: DatasetConfig = {
LLM_API_KEY: "",
EMBEDDING_BASE_URL: "https://embedding.trieve.ai",
EMBEDDING_MODEL_NAME: "jina-base-en",
RERANKER_MODEL_NAME: "bge-reranker-large",
MESSAGE_TO_QUERY_PROMPT: "",
RAG_PROMPT: "",
EMBEDDING_SIZE: 768,
Expand Down
13 changes: 13 additions & 0 deletions frontends/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,19 @@ export const availableEmbeddingModels = [
},
];

export const availableRerankerModels = [
{
id: "bge-reranker-large",
name: "bge-reranker-large (hosted by Trieve)",
url: null,
},
{
id: "rerank-v3.5",
name: "Cohere rerank-v3.5 (with Cohere API key)",
url: "https://api.cohere.com/v2",
},
];

export const availableDistanceMetrics = [
{
id: "cosine",
Expand Down
51 changes: 50 additions & 1 deletion server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2293,6 +2293,9 @@ pub struct DatasetConfiguration {
pub LLM_BASE_URL: String,
#[serde(skip_serializing)]
pub LLM_API_KEY: String,
#[serde(skip_serializing)]
pub RERANKER_API_KEY: String,
pub RERANKER_MODEL_NAME: String,
pub EMBEDDING_BASE_URL: String,
pub EMBEDDING_MODEL_NAME: String,
pub RERANKER_BASE_URL: String,
Expand Down Expand Up @@ -2369,6 +2372,11 @@ pub struct DatasetConfigurationDTO {
#[serde(skip_serializing)]
/// The API key for the LLM API
pub LLM_API_KEY: Option<String>,
#[serde(skip_serializing)]
/// The API key for the Reranker API
pub RERANKER_API_KEY: Option<String>,
/// The model name for the Reranker API
pub RERANKER_MODEL_NAME: Option<String>,
/// The base URL for the embedding API
pub EMBEDDING_BASE_URL: Option<String>,
/// The name of the embedding model to use
Expand Down Expand Up @@ -2432,6 +2440,8 @@ impl From<DatasetConfigurationDTO> for DatasetConfiguration {
DatasetConfiguration {
LLM_BASE_URL: dto.LLM_BASE_URL.unwrap_or("https://api.openai.com/v1".to_string()),
LLM_API_KEY: dto.LLM_API_KEY.unwrap_or("".to_string()),
RERANKER_API_KEY: dto.RERANKER_API_KEY.unwrap_or("".to_string()),
RERANKER_MODEL_NAME: dto.RERANKER_MODEL_NAME.unwrap_or("bge-reranker-large".to_string()),
EMBEDDING_BASE_URL: dto.EMBEDDING_BASE_URL.unwrap_or("https://api.openai.com/v1".to_string()),
EMBEDDING_MODEL_NAME: dto.EMBEDDING_MODEL_NAME.unwrap_or("text-embedding-3-small".to_string()),
RERANKER_BASE_URL: dto.RERANKER_BASE_URL.unwrap_or("".to_string()),
Expand Down Expand Up @@ -2474,6 +2484,8 @@ impl From<DatasetConfiguration> for DatasetConfigurationDTO {
DatasetConfigurationDTO {
LLM_BASE_URL: Some(config.LLM_BASE_URL),
LLM_API_KEY: Some(config.LLM_API_KEY),
RERANKER_API_KEY: Some(config.RERANKER_API_KEY),
RERANKER_MODEL_NAME: Some(config.RERANKER_MODEL_NAME),
EMBEDDING_BASE_URL: Some(config.EMBEDDING_BASE_URL),
EMBEDDING_MODEL_NAME: Some(config.EMBEDDING_MODEL_NAME),
RERANKER_BASE_URL: Some(config.RERANKER_BASE_URL),
Expand Down Expand Up @@ -2520,6 +2532,8 @@ impl Default for DatasetConfiguration {
DatasetConfiguration {
LLM_BASE_URL: "https://api.openai.com/v1".to_string(),
LLM_API_KEY: "".to_string(),
RERANKER_API_KEY: "".to_string(),
RERANKER_MODEL_NAME: "bge-reranker-large".to_string(),
EMBEDDING_BASE_URL: "https://api.openai.com/v1".to_string(),
EMBEDDING_MODEL_NAME: "text-embedding-3-small".to_string(),
RERANKER_BASE_URL: "".to_string(),
Expand Down Expand Up @@ -2594,6 +2608,30 @@ impl DatasetConfiguration {
}
})
.unwrap_or("".to_string()),
RERANKER_API_KEY: configuration
.get("RERANKER_API_KEY")
.unwrap_or(&json!("".to_string()))
.as_str()
.map(|s| {
if s.is_empty() {
"".to_string()
} else {
s.to_string()
}
})
.unwrap_or("".to_string()),
RERANKER_MODEL_NAME: configuration
.get("RERANKER_MODEL_NAME")
.unwrap_or(&json!("bge-reranker-large".to_string()))
.as_str()
.map(|s| {
if s.is_empty() {
"bge-reranker-large".to_string()
} else {
s.to_string()
}
})
.unwrap_or("".to_string()),
EMBEDDING_BASE_URL: configuration
.get("EMBEDDING_BASE_URL")
.unwrap_or(&json!(get_env!("OPENAI_BASE_URL", "OPENAI_BASE_URL must be set").to_string()))
Expand Down Expand Up @@ -2676,7 +2714,7 @@ impl DatasetConfiguration {
})
.unwrap_or("text-embedding-3-small".to_string()),
RERANKER_BASE_URL: configuration
.get("RERANKER_SERVER_ORIGIN")
.get("RERANKER_BASE_URL")
.unwrap_or(&json!(get_env!("RERANKER_SERVER_ORIGIN", "RERANKER_SERVER_ORIGIN must be set").to_string()))
.as_str()
.map(|s| {
Expand Down Expand Up @@ -2818,6 +2856,9 @@ impl DatasetConfiguration {
json!({
"LLM_BASE_URL": self.LLM_BASE_URL,
"LLM_API_KEY": self.LLM_API_KEY,
"RERANKER_API_KEY": self.RERANKER_API_KEY,
"RERANKER_BASE_URL": self.RERANKER_BASE_URL,
"RERANKER_MODEL_NAME": self.RERANKER_MODEL_NAME,
"EMBEDDING_BASE_URL": self.EMBEDDING_BASE_URL,
"EMBEDDING_MODEL_NAME": self.EMBEDDING_MODEL_NAME,
"MESSAGE_TO_QUERY_PROMPT": self.MESSAGE_TO_QUERY_PROMPT,
Expand Down Expand Up @@ -2888,6 +2929,14 @@ impl DatasetConfigurationDTO {
.LLM_API_KEY
.clone()
.unwrap_or(curr_dataset_config.LLM_API_KEY),
RERANKER_API_KEY: self
.RERANKER_API_KEY
.clone()
.unwrap_or(curr_dataset_config.RERANKER_API_KEY),
RERANKER_MODEL_NAME: self
.RERANKER_MODEL_NAME
.clone()
.unwrap_or(curr_dataset_config.RERANKER_MODEL_NAME),
EMBEDDING_BASE_URL: self
.EMBEDDING_BASE_URL
.clone()
Expand Down
Loading