Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,641 changes: 1,821 additions & 820 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ protobuf = { version = "^5.29.1", optional = true }
cohere = { version = ">=4.44", optional = true }
mistralai = { version = ">=1.0.0", optional = true }
voyageai = { version = ">=0.2.2", optional = true }
boto3 = { version = "^1.36.0", optional = true, extras = ["bedrock"] }
ranx = { version = "^0.3.0", python=">=3.10", optional = true }
boto3 = {version = "1.36.0", optional = true, extras = ["bedrock"]}

[tool.poetry.extras]
openai = ["openai"]
Expand All @@ -53,6 +54,7 @@ vertexai = ["google_cloud_aiplatform", "protobuf"]
cohere = ["cohere"]
mistralai = ["mistralai"]
voyageai = ["voyageai"]
ranx = ["ranx"]
bedrock = ["boto3"]

[tool.poetry.group.dev.dependencies]
Expand Down
2 changes: 2 additions & 0 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import weakref
from typing import Any, Dict, List, Optional

import numpy as np
from redis import Redis

from redisvl.extensions.constants import (
Expand All @@ -23,6 +24,7 @@
from redisvl.index import AsyncSearchIndex, SearchIndex
from redisvl.query import RangeQuery
from redisvl.query.filter import FilterExpression
from redisvl.query.query import BaseQuery
from redisvl.redis.connection import RedisConnectionFactory
from redisvl.utils.log import get_logger
from redisvl.utils.utils import (
Expand Down
27 changes: 19 additions & 8 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ def update_routing_config(self, routing_config: RoutingConfig):
"""
self.routing_config = routing_config

def update_route_thresholds(self, route_thresholds: Dict[str, Optional[float]]):
"""Update the distance thresholds for each route.

Args:
route_thresholds (Dict[str, float]): Dictionary of route names and their distance thresholds.
"""
for route in self.routes:
if route.name in route_thresholds:
route.distance_threshold = route_thresholds[route.name] # type: ignore

def _route_ref_key(self, route_name: str, reference: str) -> str:
"""Generate the route reference key."""
reference_hash = hashify(reference)
Expand Down Expand Up @@ -263,18 +273,16 @@ def _get_route_matches(
aggregation_method: DistanceAggregationMethod,
max_k: int = 1,
) -> List[RouteMatch]:
"""Get the route matches for a given vector and aggregation method."""
"""Get route response from vector db"""

thresholds = [route.distance_threshold for route in self.routes]
if thresholds:
distance_threshold = max(thresholds)
else:
raise ValueError("No distance thresholds provided for the semantic router")
# what's interesting about this is that we only provide one distance_threshold for a range query not multiple
# therefore you might take the max_threshold and further refine from there.
distance_threshold = max(route.distance_threshold for route in self.routes)

vector_range_query = RangeQuery(
vector=vector,
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
distance_threshold=distance_threshold,
distance_threshold=float(distance_threshold),
return_fields=["route_name"],
)

Expand All @@ -293,6 +301,9 @@ def _get_route_matches(
)
raise e

for match in aggregation_result.rows:
print(f"\n {match=}")

# process aggregation results into route matches
return [
self._process_route(route_match) for route_match in aggregation_result.rows
Expand Down Expand Up @@ -329,7 +340,7 @@ def _classify_multi_route(
) -> List[RouteMatch]:
"""Classify to multiple routes, up to max_k (int), using a vector."""

route_matches = self._get_route_matches(vector, aggregation_method, max_k)
route_matches = self._get_route_matches(vector, aggregation_method, max_k=max_k)

# process route matches
top_route_matches: List[RouteMatch] = []
Expand Down
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add imports here similar to vectorizers so we can do from redisvl.utils.threshold_optimizer import CacheThresholdOptimizer

Might also consider just naming it redisvl.utils.optimize to follow the same grammar as redisvl.utils.vectorize

Empty file.
58 changes: 58 additions & 0 deletions redisvl/utils/threshold_optimizer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, List, TypeVar

from redisvl.utils.threshold_optimizer.utils import _validate_test_dict


class EvalMetric(Enum):
"""Evaluation metrics for threshold optimization."""

F1 = "f1"
PRECISION = "precision"
RECALL = "recall"

def __str__(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think EvalMetric(str, Enum) would do the trick? See how we handle other enums in the lib

return self.value

@classmethod
def from_string(cls, metric: str) -> "EvalMetric":
"""Convert string to EvalMetric enum."""
try:
return cls(metric.lower())
except ValueError:
raise ValueError(
f"Invalid metric: {metric}. Valid options are: {', '.join(m.value for m in cls)}"
)


T = TypeVar("T") # Type variable for the optimizable object (Cache or Router)


class BaseThresholdOptimizer(ABC):
"""Abstract base class for threshold optimizers."""

def __init__(
self,
optimizable: T,
test_dict: List[Dict],
opt_fn: Callable,
eval_metric: str = "f1",
):
"""Initialize the optimizer.
Args:
optimizable: The object to optimize (Cache or Router)
test_dict: List of test cases
eval_fn: Function to evaluate performance
opt_fn: Function to perform optimization
"""
self.test_data = _validate_test_dict(test_dict)
self.optimizable = optimizable
self.eval_metric = EvalMetric(eval_metric)
self.opt_fn = opt_fn

@abstractmethod
def optimize(self, **kwargs: Any):
"""Optimize thresholds using the provided optimization function."""
pass
85 changes: 85 additions & 0 deletions redisvl/utils/threshold_optimizer/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Any, Callable, Dict, List

import numpy as np
from ranx import Qrels, Run, evaluate

from redisvl.extensions.llmcache.semantic import SemanticCache
from redisvl.query import RangeQuery
from redisvl.utils.threshold_optimizer.base import BaseThresholdOptimizer, EvalMetric
from redisvl.utils.threshold_optimizer.schema import TestData
from redisvl.utils.threshold_optimizer.utils import NULL_RESPONSE_KEY, _format_qrels


def _generate_run_cache(test_data: List[TestData], threshold: float) -> Run:
"""Format observed data for evaluation with ranx"""
run_dict: Dict[str, Dict[str, int]] = {}

for td in test_data:
run_dict[td.q_id] = {}
for res in td.response:
if float(res["vector_distance"]) < threshold:
# value of 1 is irrelevant checks only on match for f1
run_dict[td.q_id][res["id"]] = 1

if not run_dict[td.q_id]:
# ranx is a little odd in that if there are no matches it errors
# if however there are no keys that match you get the correct score
run_dict[td.q_id][NULL_RESPONSE_KEY] = 1

return Run(run_dict)


def _eval_cache(
test_data: List[TestData], threshold: float, qrels: Qrels, metric: str
) -> float:
"""Formats run data and evaluates supported metric"""
run = _generate_run_cache(test_data, threshold)
return evaluate(qrels, run, metric, make_comparable=True)


def _get_best_threshold(metrics: dict) -> float:
"""
Returns the threshold with the highest F1 score.
If multiple thresholds have the same F1 score, returns the lowest threshold.
"""
return max(metrics.items(), key=lambda x: (x[1]["score"], -x[0]))[0]


def _grid_search_opt_cache(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any particular reason these aren't members of the class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought process is this shows that you can pass in any function that meets this definition not dependant on any class state if you wanted to go a more custom route. It's a little atypical though since it is still an internal method that doesn't really make sense to expose outside of this context.

cache: SemanticCache, test_data: List[TestData], eval_metric: EvalMetric
):
"""Evaluates all thresholds in linspace for cache to determine optimal"""
thresholds = np.linspace(0.01, 0.8, 60)
metrics = {}

for td in test_data:
vec = cache._vectorizer.embed(td.query)
query = RangeQuery(
vec, vector_field_name="prompt_vector", distance_threshold=1.0
)
res = cache.index.query(query)
td.response = res

qrels = _format_qrels(test_data)

for threshold in thresholds:
score = _eval_cache(test_data, threshold, qrels, eval_metric.value)
metrics[threshold] = {"score": score}

best_threshold = _get_best_threshold(metrics)
cache.set_threshold(best_threshold)


class CacheThresholdOptimizer(BaseThresholdOptimizer):
def __init__(
self,
cache: SemanticCache,
test_dict: List[Dict],
opt_fn: Callable = _grid_search_opt_cache,
eval_metric: str = "f1",
):
super().__init__(cache, test_dict, opt_fn, eval_metric)

def optimize(self, **kwargs: Any):
"""Optimize thresholds using the provided optimization function for cache case."""
self.opt_fn(self.optimizable, self.test_data, self.eval_metric, **kwargs)
99 changes: 99 additions & 0 deletions redisvl/utils/threshold_optimizer/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import random
from typing import Any, Callable, Dict, List

import numpy as np
from ranx import Qrels, Run, evaluate

from redisvl.extensions.router.semantic import SemanticRouter
from redisvl.utils.threshold_optimizer.base import BaseThresholdOptimizer, EvalMetric
from redisvl.utils.threshold_optimizer.schema import TestData
from redisvl.utils.threshold_optimizer.utils import NULL_RESPONSE_KEY, _format_qrels


def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> Run:
"""Format router results into format for ranx Run"""
run_dict: Dict[Any, Any] = {}

for td in test_data:
run_dict[td.q_id] = {}
route_match = router(td.query)
if route_match and route_match.name == td.query_match:
run_dict[td.q_id][td.query_match] = 1
else:
run_dict[td.q_id][NULL_RESPONSE_KEY] = 1

return Run(run_dict)


def _eval_router(
router: SemanticRouter, test_data: List[TestData], qrels: Qrels, eval_metric: str
) -> float:
"""Evaluate acceptable metric given run and qrels data"""
run = _generate_run_router(test_data, router)
return evaluate(qrels, run, eval_metric, make_comparable=True)


def _router_random_search(
route_names: List[str], route_thresholds: dict, search_step=0.10
):
"""Performances random search for many threshold to many route context"""
score_threshold_values = []
for route in route_names:
score_threshold_values.append(
np.linspace(
start=max(route_thresholds[route] - search_step, 0),
stop=route_thresholds[route] + search_step,
num=100,
)
)

return {
route: float(random.choice(score_threshold_values[i]))
for i, route in enumerate(route_names)
}


def _random_search_opt_router(
router: SemanticRouter,
test_data: List[TestData],
qrels: Qrels,
eval_metric: EvalMetric,
**kwargs: Any,
):
"""Performs complete optimization for router cases provide acceptable metric"""
best_score = _eval_router(router, test_data, qrels, eval_metric.value)
best_thresholds = router.route_thresholds

max_iterations = kwargs.get("max_iterations", 20)

print(f"Starting score {best_score}, starting thresholds {router.route_thresholds}")
for _ in range(max_iterations):
route_names = router.route_names
route_thresholds = router.route_thresholds
thresholds = _router_random_search(
route_names=route_names, route_thresholds=route_thresholds
)
router.update_route_thresholds(thresholds)
score = _eval_router(router, test_data, qrels, eval_metric.value)
if score > best_score:
best_score = score
best_thresholds = thresholds

print(f"Ending score {best_score}, ending thresholds {router.route_thresholds}")
router.update_route_thresholds(best_thresholds)


class RouterThresholdOptimizer(BaseThresholdOptimizer):
def __init__(
self,
router: SemanticRouter,
test_dict: List[Dict],
opt_fn: Callable = _random_search_opt_router,
eval_metric: str = "f1",
):
super().__init__(router, test_dict, opt_fn, eval_metric)

def optimize(self, **kwargs: Any):
"""Optimize thresholds using the provided optimization function for router case."""
qrels = _format_qrels(self.test_data)
self.opt_fn(self.optimizable, self.test_data, qrels, self.eval_metric, **kwargs)
11 changes: 11 additions & 0 deletions redisvl/utils/threshold_optimizer/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import List, Optional

from pydantic import BaseModel, Field
from ulid import ULID


class TestData(BaseModel):
q_id: str = Field(default_factory=lambda: str(ULID()))
query: str
query_match: Optional[str]
response: List[dict] = []
26 changes: 26 additions & 0 deletions redisvl/utils/threshold_optimizer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import List

from ranx import Qrels

from redisvl.utils.threshold_optimizer.schema import TestData

NULL_RESPONSE_KEY = "no_match"


def _format_qrels(test_data: List[TestData]) -> Qrels:
"""Utility function for creating qrels for evaluation with ranx"""
qrels_dict = {}

for td in test_data:
if td.query_match:
qrels_dict[td.q_id] = {td.query_match: 1}
else:
# This is for capturing true negatives from test set
qrels_dict[td.q_id] = {NULL_RESPONSE_KEY: 1}

return Qrels(qrels_dict)


def _validate_test_dict(test_dict: List[dict]) -> List[TestData]:
"""Convert/validate test_dict for use in optimizer"""
return [TestData(**d) for d in test_dict]
2 changes: 1 addition & 1 deletion schemas/semantic_router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ routes:
- goodbye
metadata:
type: farewell
distance_threshold: 0.3
distance_threshold: 0.2
vectorizer:
type: hf
model: sentence-transformers/all-mpnet-base-v2
Expand Down
Loading