Skip to content
Merged
25 changes: 25 additions & 0 deletions integration/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
)
from weaviate.types import UUID, UUIDS

import weaviate.classes as wvc

UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9")
UUID2 = uuid.UUID("8ad0d33c-8db1-4437-87f3-72161ca2a51a")
UUID3 = uuid.UUID("83d99755-9deb-4b16-8431-d1dff4ab0a75")
Expand Down Expand Up @@ -1754,3 +1756,26 @@ def test_none_query_hybrid_bm25(collection_factory: CollectionFactory) -> None:
bm25_objs = collection.query.bm25(query=None, return_metadata=MetadataQuery.full()).objects
assert len(bm25_objs) == 3
assert all(obj.metadata.score is not None and obj.metadata.score == 0.0 for obj in bm25_objs)


def test_bm25_operators(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)

if collection._connection._weaviate_version.is_lower_than(1, 31, 0):
pytest.skip("bm25 operators are only supported in versions higher than 1.31.0")

uuid1 = collection.data.insert({"name": "banana one"})
uuid2 = collection.data.insert({"name": "banana two"})
uuid3 = collection.data.insert({"name": "banana three"})
uuid4 = collection.data.insert({"name": "banana four"})

objs = collection.query.bm25(
"banana two",
operator=wvc.query.BM25Operator.or_(minimum_match=1),
)
assert len(objs.objects) == 4
assert objs.objects[0].uuid == uuid2
assert sorted(obj.uuid for obj in objs.objects[1:]) == sorted([uuid1, uuid3, uuid4])
25 changes: 25 additions & 0 deletions integration/test_collection_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,28 @@ def test_aggregate_max_vector_distance(collection_factory: CollectionFactory) ->
return_metrics=[wvc.aggregate.Metrics("name").text(count=True)],
)
assert res.total_count == 2


def test_hybrid_bm25_operators(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)

if collection._connection._weaviate_version.is_lower_than(1, 31, 0):
pytest.skip("bm25 operators are only supported in versions higher than 1.31.0")

uuid1 = collection.data.insert({"name": "banana one"}, vector=[1, 0, 0, 0])
uuid2 = collection.data.insert({"name": "banana two"}, vector=[0, 1, 0, 0])
uuid3 = collection.data.insert({"name": "banana three"}, vector=[0, 1, 0, 0])
uuid4 = collection.data.insert({"name": "banana four"}, vector=[1, 0, 0, 0])

objs = collection.query.hybrid(
"banana two",
vector=None,
alpha=0.0,
bm25_operator=wvc.query.BM25Operator.or_(minimum_match=1),
)
assert len(objs.objects) == 4
assert objs.objects[0].uuid == uuid2
assert sorted(obj.uuid for obj in objs.objects[1:]) == sorted([uuid1, uuid3, uuid4])
4 changes: 4 additions & 0 deletions weaviate/classes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
Sort,
TargetVectors,
)
from weaviate.collections.classes.grpc import (
BM25OperatorFactory as BM25Operator,
)
from weaviate.collections.classes.types import GeoCoordinate

__all__ = [
Expand All @@ -24,6 +27,7 @@
"GroupBy",
"HybridFusion",
"HybridVector",
"BM25Operator",
"MetadataQuery",
"Metrics",
"Move",
Expand Down
2 changes: 1 addition & 1 deletion weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
)
from weaviate.collections.classes.config_vector_index import (
VectorFilterStrategy,
_MuveraConfigCreate,
_EncodingConfigCreate,
_MultiVectorConfigCreate,
_MuveraConfigCreate,
_QuantizerConfigCreate,
_VectorIndexConfigCreate,
_VectorIndexConfigDynamicCreate,
Expand Down
51 changes: 49 additions & 2 deletions weaviate/collections/classes/grpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
ClassVar,
Any,
Dict,
Generic,
List,
Expand All @@ -15,7 +15,7 @@
)

from pydantic import ConfigDict, Field
from typing_extensions import TypeGuard, TypeVar
from typing_extensions import ClassVar, TypeGuard, TypeVar

from weaviate.collections.classes.types import _WeaviateInput
from weaviate.exceptions import WeaviateInvalidInputError
Expand Down Expand Up @@ -242,6 +242,53 @@ class Rerank(_WeaviateInput):
query: Optional[str] = Field(default=None)


@dataclass
class BM25OperatorOptions:
# replace with ClassVar[base_search_pb2.SearchOperatorOptions.Operator] once python 3.10 is removed
operator: ClassVar[Any]


@dataclass
class BM25OperatorOr(BM25OperatorOptions):
"""Define the 'Or' operator for keyword queries."""

operator = base_search_pb2.SearchOperatorOptions.OPERATOR_OR
minimum_should_match: Optional[int]


@dataclass
class BM25OperatorAnd(BM25OperatorOptions):
"""Define the 'And' operator for keyword queries."""

operator = base_search_pb2.SearchOperatorOptions.OPERATOR_AND


class BM25OperatorFactory:
"""Define how the BM25 query's token matching should be performed."""

def __init__(self) -> None:
raise TypeError("BM25Operator cannot be instantiated. Use the static methods to create.")

@staticmethod
def or_(minimum_match: int) -> BM25OperatorOptions:
"""Use the 'Or' operator for keyword queries, where at least a minimum number of tokens must match.

Note that the query is tokenized using the respective tokenization method of each property.

Args:
minimum_match: The minimum number of keyword tokens (excluding stopwords) that must match for an object to be considered a match.
"""
return BM25OperatorOr(minimum_should_match=minimum_match)

@staticmethod
def and_() -> BM25OperatorOptions:
"""Use the 'And' operator for keyword queries, where all query tokens must match.

Note that the query is tokenized using the respective tokenization method of each property.
"""
return BM25OperatorAnd()


OneDimensionalVectorType = Sequence[NUMBER]
"""Represents a one-dimensional vector, e.g. one produced by `text2vec-jinaai`"""
TwoDimensionalVectorType = Sequence[Sequence[NUMBER]]
Expand Down
1 change: 1 addition & 0 deletions weaviate/collections/grpc/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def hybrid(
alpha,
vector,
properties,
None, # no keyword operator for hybrid search
None,
distance,
target_vector,
Expand Down
16 changes: 15 additions & 1 deletion weaviate/collections/grpc/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
REFERENCES,
HybridFusion,
HybridVectorType,
BM25OperatorOptions,
BM25OperatorOr,
Move,
NearVectorInputType,
QueryNested,
Expand Down Expand Up @@ -155,6 +157,7 @@ def hybrid(
distance: Optional[NUMBER] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
bm25_operator: Optional[BM25OperatorOptions] = None,
autocut: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Optional[_GroupBy] = None,
Expand All @@ -181,6 +184,7 @@ def hybrid(
alpha,
vector,
properties,
bm25_operator,
fusion_type,
distance,
target_vector,
Expand All @@ -194,6 +198,7 @@ def bm25(
properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
autocut: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Optional[_GroupBy] = None,
Expand Down Expand Up @@ -224,7 +229,16 @@ def bm25(
autocut=autocut,
bm25=(
base_search_pb2.BM25(
query=query, properties=properties if properties is not None else []
query=query,
properties=properties if properties is not None else [],
search_operator=base_search_pb2.SearchOperatorOptions(
operator=operator.operator,
minimum_or_tokens_match=operator.minimum_should_match
if isinstance(operator, BM25OperatorOr)
else None,
)
if operator is not None
else None,
)
if query is not None
else None
Expand Down
11 changes: 11 additions & 0 deletions weaviate/collections/grpc/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from weaviate.collections.classes.grpc import (
HybridFusion,
HybridVectorType,
BM25OperatorOptions,
BM25OperatorOr,
Move,
NearVectorInputType,
OneDimensionalVectorType,
Expand Down Expand Up @@ -579,6 +581,7 @@ def _parse_hybrid(
alpha: Optional[float],
vector: Optional[HybridVectorType],
properties: Optional[List[str]],
keyword_operator: Optional[BM25OperatorOptions],
fusion_type: Optional[HybridFusion],
distance: Optional[NUMBER],
target_vector: Optional[TargetVectorJoinType],
Expand Down Expand Up @@ -724,6 +727,14 @@ def _parse_hybrid(
vector_bytes=vector_bytes,
vector_distance=distance,
vectors=vectors,
bm25_search_operator=base_search_pb2.SearchOperatorOptions(
operator=keyword_operator.operator,
minimum_or_tokens_match=keyword_operator.minimum_should_match
if isinstance(keyword_operator, BM25OperatorOr)
else None,
)
if keyword_operator is not None
else None,
)
if query is not None or vector is not None
else None
Expand Down
22 changes: 21 additions & 1 deletion weaviate/collections/queries/bm25/generate/async_.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Generic, List, Literal, Optional, Type, Union, overload

from weaviate.collections.classes.filters import _Filters
from weaviate.collections.classes.grpc import METADATA, PROPERTIES, REFERENCES, GroupBy, Rerank
from weaviate.collections.classes.grpc import (
METADATA,
PROPERTIES,
REFERENCES,
GroupBy,
BM25OperatorOptions,
Rerank,
)
from weaviate.collections.classes.internal import (
CrossReferences,
GenerativeGroupByReturn,
Expand Down Expand Up @@ -34,6 +41,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -55,6 +63,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -76,6 +85,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -97,6 +107,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -118,6 +129,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -139,6 +151,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -160,6 +173,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -181,6 +195,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -202,6 +217,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -223,6 +239,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -244,6 +261,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -265,6 +283,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -286,6 +305,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[BM25OperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Optional[GroupBy] = None,
Expand Down
Loading