|
28 | 28 | CollectionRerankOptions, |
29 | 29 | VectorServiceOptions, |
30 | 30 | ) |
31 | | -from langchain_community.vectorstores.utils import maximal_marginal_relevance |
32 | 31 | from langchain_core.runnables.utils import gather_with_concurrency |
33 | 32 | from langchain_core.vectorstores import VectorStore |
34 | 33 | from typing_extensions import override |
|
56 | 55 | _DefaultVSDocumentCodec, |
57 | 56 | ) |
58 | 57 |
|
| 58 | +is_simd_available: bool = False |
| 59 | +try: |
| 60 | + import simsimd as simd |
| 61 | + |
| 62 | + is_simd_available = True |
| 63 | +except ImportError: |
| 64 | + pass |
| 65 | + |
59 | 66 | if TYPE_CHECKING: |
60 | 67 | from collections.abc import AsyncIterable, Awaitable, Iterable, Sequence |
61 | 68 |
|
@@ -310,6 +317,75 @@ def _describe_error(_errd: Exception) -> list[str]: |
310 | 317 | return err_msg |
311 | 318 |
|
312 | 319 |
|
| 320 | +_Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] |
| 321 | + |
| 322 | + |
| 323 | +def _cosine_similarity(x: _Matrix, y: _Matrix) -> np.ndarray: |
| 324 | + """Row-wise cosine similarity between two equal-width matrices.""" |
| 325 | + if len(x) == 0 or len(y) == 0: |
| 326 | + return np.array([]) |
| 327 | + |
| 328 | + x = np.array(x) |
| 329 | + y = np.array(y) |
| 330 | + if x.shape[1] != y.shape[1]: |
| 331 | + msg = ( |
| 332 | + f"Number of columns in X and Y must be the same. X has shape {x.shape} " |
| 333 | + f"and Y has shape {y.shape}." |
| 334 | + ) |
| 335 | + raise ValueError(msg) |
| 336 | + |
| 337 | + if is_simd_available: |
| 338 | + x = np.array(x, dtype=np.float32) |
| 339 | + y = np.array(y, dtype=np.float32) |
| 340 | + return 1 - np.array(simd.cdist(x, y, metric="cosine")) |
| 341 | + |
| 342 | + logger.debug( |
| 343 | + "Unable to use simsimd, defaulting to NumPy implementation. If you want " |
| 344 | + "to use simsimd please install with `pip install simsimd`." |
| 345 | + ) |
| 346 | + x_norm = np.linalg.norm(x, axis=1) |
| 347 | + y_norm = np.linalg.norm(y, axis=1) |
| 348 | + # Ignore divide by zero errors run time warnings as those are handled below. |
| 349 | + with np.errstate(divide="ignore", invalid="ignore"): |
| 350 | + similarity: np.ndarray = np.dot(x, y.T) / np.outer(x_norm, y_norm) |
| 351 | + similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 |
| 352 | + return similarity |
| 353 | + |
| 354 | + |
| 355 | +def _maximal_marginal_relevance( |
| 356 | + query_embedding: np.ndarray, |
| 357 | + embedding_list: list[list[float]], |
| 358 | + lambda_mult: float = 0.5, |
| 359 | + k: int = 4, |
| 360 | +) -> list[int]: |
| 361 | + """Calculate maximal marginal relevance.""" |
| 362 | + if min(k, len(embedding_list)) <= 0: |
| 363 | + return [] |
| 364 | + if query_embedding.ndim == 1: |
| 365 | + query_embedding = np.expand_dims(query_embedding, axis=0) |
| 366 | + similarity_to_query = _cosine_similarity(query_embedding, embedding_list)[0] |
| 367 | + most_similar = int(np.argmax(similarity_to_query)) |
| 368 | + idxs = [most_similar] |
| 369 | + selected = np.array([embedding_list[most_similar]]) |
| 370 | + while len(idxs) < min(k, len(embedding_list)): |
| 371 | + best_score = -np.inf |
| 372 | + idx_to_add = -1 |
| 373 | + similarity_to_selected = _cosine_similarity(embedding_list, selected) |
| 374 | + for i, query_score in enumerate(similarity_to_query): |
| 375 | + if i in idxs: |
| 376 | + continue |
| 377 | + redundant_score = max(similarity_to_selected[i]) |
| 378 | + equation_score = ( |
| 379 | + lambda_mult * query_score - (1 - lambda_mult) * redundant_score |
| 380 | + ) |
| 381 | + if equation_score > best_score: |
| 382 | + best_score = equation_score |
| 383 | + idx_to_add = i |
| 384 | + idxs.append(idx_to_add) |
| 385 | + selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) |
| 386 | + return idxs |
| 387 | + |
| 388 | + |
313 | 389 | class AstraDBVectorStoreError(Exception): |
314 | 390 | """An exception during vector-store activities. |
315 | 391 |
|
@@ -3308,9 +3384,9 @@ def _get_mmr_hits( |
3308 | 3384 | lambda_mult: float, |
3309 | 3385 | prefetch_hit_pairs: list[tuple[Document, list[float]]], |
3310 | 3386 | ) -> list[Document]: |
3311 | | - mmr_chosen_indices = maximal_marginal_relevance( |
| 3387 | + mmr_chosen_indices = _maximal_marginal_relevance( |
3312 | 3388 | np.array(embedding, dtype=np.float32), |
3313 | | - [hit_pair[1] for hit_pair in prefetch_hit_pairs], |
| 3389 | + [embedding for _, embedding in prefetch_hit_pairs], |
3314 | 3390 | k=k, |
3315 | 3391 | lambda_mult=lambda_mult, |
3316 | 3392 | ) |
|
0 commit comments