Skip to content

Commit 66536d0

Browse files
committed
Added kwargs support for vicinity
1 parent 956bd52 commit 66536d0

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

semhash/index.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
import numpy as np
46
from vicinity import Backend
57
from vicinity.backends import AbstractBackend, get_backend_class
@@ -27,17 +29,21 @@ def __init__(self, vectors: np.ndarray, items: list[DictItem], backend: Abstract
2729
self.vectors = vectors
2830

2931
@classmethod
30-
def from_vectors_and_items(cls, vectors: np.ndarray, items: list[DictItem], backend_type: Backend) -> Index:
32+
def from_vectors_and_items(
33+
cls, vectors: np.ndarray, items: list[DictItem], backend_type: Backend | str, **kwargs: Any
34+
) -> Index:
3135
"""
3236
Load the index from vectors and items.
3337
3438
:param vectors: The vectors of the items.
3539
:param items: The items in the index.
3640
:param backend_type: The type of backend to use.
41+
:param **kwargs: Additional arguments to pass to the backend.
3742
:return: The index.
3843
"""
3944
backend_class = get_backend_class(backend_type)
40-
backend = backend_class.from_vectors(vectors)
45+
arguments = backend_class.argument_class(**kwargs)
46+
backend = backend_class.from_vectors(vectors, **arguments.dict())
4147

4248
return cls(vectors, items, backend)
4349

semhash/semhash.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections import defaultdict
44
from math import ceil
5-
from typing import Generic, Literal, Sequence
5+
from typing import Any, Generic, Literal, Sequence
66

77
import numpy as np
88
from frozendict import frozendict
@@ -102,7 +102,8 @@ def from_records(
102102
columns: Sequence[str] | None = None,
103103
use_ann: bool = True,
104104
model: Encoder | None = None,
105-
ann_backend: Backend = Backend.USEARCH,
105+
ann_backend: Backend | str = Backend.USEARCH,
106+
**kwargs: Any,
106107
) -> SemHash:
107108
"""
108109
Initialize a SemHash instance from records.
@@ -114,6 +115,7 @@ def from_records(
114115
:param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True.
115116
:param model: (Optional) An Encoder model. If None, the default model is used (minishlab/potion-base-8M).
116117
:param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH.
118+
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
117119
:return: A SemHash instance with a fitted vicinity index.
118120
:raises ValueError: If columns are not provided for dictionary records.
119121
"""
@@ -158,6 +160,7 @@ def from_records(
158160
vectors=embeddings,
159161
items=items,
160162
backend_type=backend,
163+
**kwargs,
161164
)
162165

163166
return cls(index=index, columns=columns, model=model, was_string=was_string)

0 commit comments

Comments
 (0)