-
Notifications
You must be signed in to change notification settings - Fork 61
Add distance threshold optimizer classes #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
3dbb961
af2047f
f53be64
c87531b
705cd2e
d3a0717
5c0f6ba
4102d04
229b3fa
50dabf8
e176878
ce6d811
de3e92d
da80020
3419237
1a29c1e
18ff100
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| from abc import ABC, abstractmethod | ||
| from enum import Enum | ||
| from typing import Any, Callable, Dict, List, TypeVar | ||
|
|
||
| from redisvl.utils.threshold_optimizer.utils import _validate_test_dict | ||
|
|
||
|
|
||
| class EvalMetric(Enum): | ||
| """Evaluation metrics for threshold optimization.""" | ||
|
|
||
| F1 = "f1" | ||
| PRECISION = "precision" | ||
| RECALL = "recall" | ||
|
|
||
| def __str__(self) -> str: | ||
|
||
| return self.value | ||
|
|
||
| @classmethod | ||
| def from_string(cls, metric: str) -> "EvalMetric": | ||
| """Convert string to EvalMetric enum.""" | ||
| try: | ||
| return cls(metric.lower()) | ||
| except ValueError: | ||
| raise ValueError( | ||
| f"Invalid metric: {metric}. Valid options are: {', '.join(m.value for m in cls)}" | ||
| ) | ||
|
|
||
|
|
||
| T = TypeVar("T") # Type variable for the optimizable object (Cache or Router) | ||
|
|
||
|
|
||
| class BaseThresholdOptimizer(ABC): | ||
| """Abstract base class for threshold optimizers.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| optimizable: T, | ||
| test_dict: List[Dict], | ||
| opt_fn: Callable, | ||
| eval_metric: str = "f1", | ||
| ): | ||
| """Initialize the optimizer. | ||
| Args: | ||
| optimizable: The object to optimize (Cache or Router) | ||
| test_dict: List of test cases | ||
| eval_fn: Function to evaluate performance | ||
| opt_fn: Function to perform optimization | ||
| """ | ||
| self.test_data = _validate_test_dict(test_dict) | ||
| self.optimizable = optimizable | ||
| self.eval_metric = EvalMetric(eval_metric) | ||
| self.opt_fn = opt_fn | ||
|
|
||
| @abstractmethod | ||
| def optimize(self, **kwargs: Any): | ||
| """Optimize thresholds using the provided optimization function.""" | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| from typing import Any, Callable, Dict, List | ||
|
|
||
| import numpy as np | ||
| from ranx import Qrels, Run, evaluate | ||
|
|
||
| from redisvl.extensions.llmcache.semantic import SemanticCache | ||
| from redisvl.query import RangeQuery | ||
| from redisvl.utils.threshold_optimizer.base import BaseThresholdOptimizer, EvalMetric | ||
| from redisvl.utils.threshold_optimizer.schema import TestData | ||
| from redisvl.utils.threshold_optimizer.utils import NULL_RESPONSE_KEY, _format_qrels | ||
|
|
||
|
|
||
| def _generate_run_cache(test_data: List[TestData], threshold: float) -> Run: | ||
| """Format observed data for evaluation with ranx""" | ||
| run_dict: Dict[str, Dict[str, int]] = {} | ||
|
|
||
| for td in test_data: | ||
| run_dict[td.q_id] = {} | ||
| for res in td.response: | ||
| if float(res["vector_distance"]) < threshold: | ||
| # value of 1 is irrelevant checks only on match for f1 | ||
| run_dict[td.q_id][res["id"]] = 1 | ||
|
|
||
| if not run_dict[td.q_id]: | ||
| # ranx is a little odd in that if there are no matches it errors | ||
| # if however there are no keys that match you get the correct score | ||
| run_dict[td.q_id][NULL_RESPONSE_KEY] = 1 | ||
|
|
||
| return Run(run_dict) | ||
|
|
||
|
|
||
| def _eval_cache( | ||
| test_data: List[TestData], threshold: float, qrels: Qrels, metric: str | ||
| ) -> float: | ||
| """Formats run data and evaluates supported metric""" | ||
| run = _generate_run_cache(test_data, threshold) | ||
| return evaluate(qrels, run, metric, make_comparable=True) | ||
|
|
||
|
|
||
| def _get_best_threshold(metrics: dict) -> float: | ||
| """ | ||
| Returns the threshold with the highest F1 score. | ||
| If multiple thresholds have the same F1 score, returns the lowest threshold. | ||
| """ | ||
| return max(metrics.items(), key=lambda x: (x[1]["score"], -x[0]))[0] | ||
|
|
||
|
|
||
| def _grid_search_opt_cache( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any particular reason these aren't members of the class?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My thought process is this shows that you can pass in any function that meets this definition not dependant on any class state if you wanted to go a more custom route. It's a little atypical though since it is still an internal method that doesn't really make sense to expose outside of this context. |
||
| cache: SemanticCache, test_data: List[TestData], eval_metric: EvalMetric | ||
| ): | ||
| """Evaluates all thresholds in linspace for cache to determine optimal""" | ||
| thresholds = np.linspace(0.01, 0.8, 60) | ||
| metrics = {} | ||
|
|
||
| for td in test_data: | ||
| vec = cache._vectorizer.embed(td.query) | ||
| query = RangeQuery( | ||
| vec, vector_field_name="prompt_vector", distance_threshold=1.0 | ||
| ) | ||
| res = cache.index.query(query) | ||
| td.response = res | ||
|
|
||
| qrels = _format_qrels(test_data) | ||
|
|
||
| for threshold in thresholds: | ||
| score = _eval_cache(test_data, threshold, qrels, eval_metric.value) | ||
| metrics[threshold] = {"score": score} | ||
|
|
||
| best_threshold = _get_best_threshold(metrics) | ||
| cache.set_threshold(best_threshold) | ||
|
|
||
|
|
||
| class CacheThresholdOptimizer(BaseThresholdOptimizer): | ||
| def __init__( | ||
| self, | ||
| cache: SemanticCache, | ||
| test_dict: List[Dict], | ||
| opt_fn: Callable = _grid_search_opt_cache, | ||
| eval_metric: str = "f1", | ||
| ): | ||
| super().__init__(cache, test_dict, opt_fn, eval_metric) | ||
|
|
||
| def optimize(self, **kwargs: Any): | ||
| """Optimize thresholds using the provided optimization function for cache case.""" | ||
| self.opt_fn(self.optimizable, self.test_data, self.eval_metric, **kwargs) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| import random | ||
| from typing import Any, Callable, Dict, List | ||
|
|
||
| import numpy as np | ||
| from ranx import Qrels, Run, evaluate | ||
|
|
||
| from redisvl.extensions.router.semantic import SemanticRouter | ||
| from redisvl.utils.threshold_optimizer.base import BaseThresholdOptimizer, EvalMetric | ||
| from redisvl.utils.threshold_optimizer.schema import TestData | ||
| from redisvl.utils.threshold_optimizer.utils import NULL_RESPONSE_KEY, _format_qrels | ||
|
|
||
|
|
||
| def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> Run: | ||
| """Format router results into format for ranx Run""" | ||
| run_dict: Dict[Any, Any] = {} | ||
|
|
||
| for td in test_data: | ||
| run_dict[td.q_id] = {} | ||
| route_match = router(td.query) | ||
| if route_match and route_match.name == td.query_match: | ||
| run_dict[td.q_id][td.query_match] = 1 | ||
| else: | ||
| run_dict[td.q_id][NULL_RESPONSE_KEY] = 1 | ||
|
|
||
| return Run(run_dict) | ||
|
|
||
|
|
||
| def _eval_router( | ||
| router: SemanticRouter, test_data: List[TestData], qrels: Qrels, eval_metric: str | ||
| ) -> float: | ||
| """Evaluate acceptable metric given run and qrels data""" | ||
| run = _generate_run_router(test_data, router) | ||
| return evaluate(qrels, run, eval_metric, make_comparable=True) | ||
|
|
||
|
|
||
| def _router_random_search( | ||
| route_names: List[str], route_thresholds: dict, search_step=0.10 | ||
| ): | ||
| """Performances random search for many threshold to many route context""" | ||
| score_threshold_values = [] | ||
| for route in route_names: | ||
| score_threshold_values.append( | ||
| np.linspace( | ||
| start=max(route_thresholds[route] - search_step, 0), | ||
| stop=route_thresholds[route] + search_step, | ||
| num=100, | ||
| ) | ||
| ) | ||
|
|
||
| return { | ||
| route: float(random.choice(score_threshold_values[i])) | ||
| for i, route in enumerate(route_names) | ||
| } | ||
|
|
||
|
|
||
| def _random_search_opt_router( | ||
| router: SemanticRouter, | ||
| test_data: List[TestData], | ||
| qrels: Qrels, | ||
| eval_metric: EvalMetric, | ||
| **kwargs: Any, | ||
| ): | ||
| """Performs complete optimization for router cases provide acceptable metric""" | ||
| best_score = _eval_router(router, test_data, qrels, eval_metric.value) | ||
| best_thresholds = router.route_thresholds | ||
|
|
||
| max_iterations = kwargs.get("max_iterations", 20) | ||
|
|
||
| print(f"Starting score {best_score}, starting thresholds {router.route_thresholds}") | ||
| for _ in range(max_iterations): | ||
| route_names = router.route_names | ||
| route_thresholds = router.route_thresholds | ||
| thresholds = _router_random_search( | ||
| route_names=route_names, route_thresholds=route_thresholds | ||
| ) | ||
| router.update_route_thresholds(thresholds) | ||
| score = _eval_router(router, test_data, qrels, eval_metric.value) | ||
| if score > best_score: | ||
| best_score = score | ||
| best_thresholds = thresholds | ||
|
|
||
| print(f"Ending score {best_score}, ending thresholds {router.route_thresholds}") | ||
| router.update_route_thresholds(best_thresholds) | ||
|
|
||
|
|
||
| class RouterThresholdOptimizer(BaseThresholdOptimizer): | ||
| def __init__( | ||
| self, | ||
| router: SemanticRouter, | ||
| test_dict: List[Dict], | ||
| opt_fn: Callable = _random_search_opt_router, | ||
| eval_metric: str = "f1", | ||
| ): | ||
| super().__init__(router, test_dict, opt_fn, eval_metric) | ||
|
|
||
| def optimize(self, **kwargs: Any): | ||
| """Optimize thresholds using the provided optimization function for router case.""" | ||
| qrels = _format_qrels(self.test_data) | ||
| self.opt_fn(self.optimizable, self.test_data, qrels, self.eval_metric, **kwargs) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| from typing import List, Optional | ||
|
|
||
| from pydantic import BaseModel, Field | ||
| from ulid import ULID | ||
|
|
||
|
|
||
| class TestData(BaseModel): | ||
| q_id: str = Field(default_factory=lambda: str(ULID())) | ||
rbs333 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| query: str | ||
| query_match: Optional[str] | ||
| response: List[dict] = [] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| from typing import List | ||
|
|
||
| from ranx import Qrels | ||
|
|
||
| from redisvl.utils.threshold_optimizer.schema import TestData | ||
|
|
||
| NULL_RESPONSE_KEY = "no_match" | ||
|
|
||
|
|
||
| def _format_qrels(test_data: List[TestData]) -> Qrels: | ||
| """Utility function for creating qrels for evaluation with ranx""" | ||
| qrels_dict = {} | ||
|
|
||
| for td in test_data: | ||
| if td.query_match: | ||
| qrels_dict[td.q_id] = {td.query_match: 1} | ||
| else: | ||
| # This is for capturing true negatives from test set | ||
| qrels_dict[td.q_id] = {NULL_RESPONSE_KEY: 1} | ||
|
|
||
| return Qrels(qrels_dict) | ||
|
|
||
|
|
||
| def _validate_test_dict(test_dict: List[dict]) -> List[TestData]: | ||
| """Convert/validate test_dict for use in optimizer""" | ||
| return [TestData(**d) for d in test_dict] |
Uh oh!
There was an error while loading. Please reload this page.