Skip to content
Merged
25 changes: 25 additions & 0 deletions integration/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
)
from weaviate.types import UUID, UUIDS

import weaviate.classes as wvc

UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9")
UUID2 = uuid.UUID("8ad0d33c-8db1-4437-87f3-72161ca2a51a")
UUID3 = uuid.UUID("83d99755-9deb-4b16-8431-d1dff4ab0a75")
Expand Down Expand Up @@ -1754,3 +1756,26 @@ def test_none_query_hybrid_bm25(collection_factory: CollectionFactory) -> None:
bm25_objs = collection.query.bm25(query=None, return_metadata=MetadataQuery.full()).objects
assert len(bm25_objs) == 3
assert all(obj.metadata.score is not None and obj.metadata.score == 0.0 for obj in bm25_objs)


def test_bm25_operators(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)

if collection._connection._weaviate_version.is_lower_than(1, 31, 0):
pytest.skip("bm25 operators are only supported in versions higher than 1.31.0")

uuid1 = collection.data.insert({"name": "banana one"})
uuid2 = collection.data.insert({"name": "banana two"})
uuid3 = collection.data.insert({"name": "banana three"})
uuid4 = collection.data.insert({"name": "banana four"})

objs = collection.query.bm25(
"banana two",
operator=wvc.query.KeywordOperatorFactory.Or(minimum_match=1),
)
assert len(objs.objects) == 4
assert objs.objects[0].uuid == uuid2
assert sorted(obj.uuid for obj in objs.objects[1:]) == sorted([uuid1, uuid3, uuid4])
25 changes: 25 additions & 0 deletions integration/test_collection_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,28 @@ def test_aggregate_max_vector_distance(collection_factory: CollectionFactory) ->
return_metrics=[wvc.aggregate.Metrics("name").text(count=True)],
)
assert res.total_count == 2


def test_hybrid_bm25_operators(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)

if collection._connection._weaviate_version.is_lower_than(1, 31, 0):
pytest.skip("bm25 operators are only supported in versions higher than 1.31.0")

uuid1 = collection.data.insert({"name": "banana one"}, vector=[1, 0, 0, 0])
uuid2 = collection.data.insert({"name": "banana two"}, vector=[0, 1, 0, 0])
uuid3 = collection.data.insert({"name": "banana three"}, vector=[0, 1, 0, 0])
uuid4 = collection.data.insert({"name": "banana four"}, vector=[1, 0, 0, 0])

objs = collection.query.hybrid(
"banana two",
vector=None,
alpha=0.0,
keyword_operator=wvc.query.KeywordOperatorFactory.Or(minimum_match=1),
)
assert len(objs.objects) == 4
assert objs.objects[0].uuid == uuid2
assert sorted(obj.uuid for obj in objs.objects[1:]) == sorted([uuid1, uuid3, uuid4])
2 changes: 2 additions & 0 deletions weaviate/classes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
GroupBy,
HybridFusion,
HybridVector,
KeywordOperatorFactory,
MetadataQuery,
Move,
NearMediaType,
Expand All @@ -24,6 +25,7 @@
"GroupBy",
"HybridFusion",
"HybridVector",
"KeywordOperatorFactory",
"MetadataQuery",
"Metrics",
"Move",
Expand Down
2 changes: 1 addition & 1 deletion weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
)
from weaviate.collections.classes.config_vector_index import (
VectorFilterStrategy,
_MuveraConfigCreate,
_EncodingConfigCreate,
_MultiVectorConfigCreate,
_MuveraConfigCreate,
_QuantizerConfigCreate,
_VectorIndexConfigCreate,
_VectorIndexConfigDynamicCreate,
Expand Down
38 changes: 36 additions & 2 deletions weaviate/collections/classes/grpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
ClassVar,
Any,
Dict,
Generic,
List,
Expand All @@ -15,7 +15,7 @@
)

from pydantic import ConfigDict, Field
from typing_extensions import TypeGuard, TypeVar
from typing_extensions import ClassVar, TypeGuard, TypeVar

from weaviate.collections.classes.types import _WeaviateInput
from weaviate.exceptions import WeaviateInvalidInputError
Expand Down Expand Up @@ -242,6 +242,40 @@ class Rerank(_WeaviateInput):
query: Optional[str] = Field(default=None)


@dataclass
class KeywordOperatorOptions:
# replace with ClassVar[base_search_pb2.SearchOperatorOptions.Operator] once python 3.10 is removed
operator: ClassVar[Any]


@dataclass
class KeywordOperatorOr(KeywordOperatorOptions):
operator = base_search_pb2.SearchOperatorOptions.OPERATOR_OR
minimum_should_match: Optional[int]


@dataclass
class KeywordOperatorAnd(KeywordOperatorOptions):
operator = base_search_pb2.SearchOperatorOptions.OPERATOR_AND


class KeywordOperatorFactory(_WeaviateInput):
"""Define how the query's rerank operation should be performed."""

def __init__(self) -> None:
raise TypeError("KeywordOperator cannot be instantiated. Use the static methods to create.")

@staticmethod
def Or(minimum_match: int) -> KeywordOperatorOptions:
"""Use the 'Or' operator for keyword queries."""
return KeywordOperatorOr(minimum_should_match=minimum_match)

@staticmethod
def And() -> KeywordOperatorOptions:
"""Use the 'And' operator for keyword queries."""
return KeywordOperatorAnd()


OneDimensionalVectorType = Sequence[NUMBER]
"""Represents a one-dimensional vector, e.g. one produced by `text2vec-jinaai`"""
TwoDimensionalVectorType = Sequence[Sequence[NUMBER]]
Expand Down
1 change: 1 addition & 0 deletions weaviate/collections/grpc/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def hybrid(
alpha,
vector,
properties,
None, # no keyword operator for hybrid search
None,
distance,
target_vector,
Expand Down
16 changes: 15 additions & 1 deletion weaviate/collections/grpc/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
REFERENCES,
HybridFusion,
HybridVectorType,
KeywordOperatorOptions,
KeywordOperatorOr,
Move,
NearVectorInputType,
QueryNested,
Expand Down Expand Up @@ -155,6 +157,7 @@ def hybrid(
distance: Optional[NUMBER] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
keyword_operator: Optional[KeywordOperatorOptions] = None,
autocut: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Optional[_GroupBy] = None,
Expand All @@ -181,6 +184,7 @@ def hybrid(
alpha,
vector,
properties,
keyword_operator,
fusion_type,
distance,
target_vector,
Expand All @@ -194,6 +198,7 @@ def bm25(
properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
autocut: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Optional[_GroupBy] = None,
Expand Down Expand Up @@ -224,7 +229,16 @@ def bm25(
autocut=autocut,
bm25=(
base_search_pb2.BM25(
query=query, properties=properties if properties is not None else []
query=query,
properties=properties if properties is not None else [],
search_operator=base_search_pb2.SearchOperatorOptions(
operator=operator.operator,
minimum_or_tokens_match=operator.minimum_should_match
if isinstance(operator, KeywordOperatorOr)
else None,
)
if operator is not None
else None,
)
if query is not None
else None
Expand Down
11 changes: 11 additions & 0 deletions weaviate/collections/grpc/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from weaviate.collections.classes.grpc import (
HybridFusion,
HybridVectorType,
KeywordOperatorOptions,
KeywordOperatorOr,
Move,
NearVectorInputType,
OneDimensionalVectorType,
Expand Down Expand Up @@ -579,6 +581,7 @@ def _parse_hybrid(
alpha: Optional[float],
vector: Optional[HybridVectorType],
properties: Optional[List[str]],
keyword_operator: Optional[KeywordOperatorOptions],
fusion_type: Optional[HybridFusion],
distance: Optional[NUMBER],
target_vector: Optional[TargetVectorJoinType],
Expand Down Expand Up @@ -724,6 +727,14 @@ def _parse_hybrid(
vector_bytes=vector_bytes,
vector_distance=distance,
vectors=vectors,
bm25_search_operator=base_search_pb2.SearchOperatorOptions(
operator=keyword_operator.operator,
minimum_or_tokens_match=keyword_operator.minimum_should_match
if isinstance(keyword_operator, KeywordOperatorOr)
else None,
)
if keyword_operator is not None
else None,
)
if query is not None or vector is not None
else None
Expand Down
22 changes: 21 additions & 1 deletion weaviate/collections/queries/bm25/generate/async_.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Generic, List, Literal, Optional, Type, Union, overload

from weaviate.collections.classes.filters import _Filters
from weaviate.collections.classes.grpc import METADATA, PROPERTIES, REFERENCES, GroupBy, Rerank
from weaviate.collections.classes.grpc import (
METADATA,
PROPERTIES,
REFERENCES,
GroupBy,
KeywordOperatorOptions,
Rerank,
)
from weaviate.collections.classes.internal import (
CrossReferences,
GenerativeGroupByReturn,
Expand Down Expand Up @@ -34,6 +41,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -55,6 +63,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -76,6 +85,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -97,6 +107,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -118,6 +129,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -139,6 +151,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Literal[None] = None,
Expand All @@ -160,6 +173,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -181,6 +195,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -202,6 +217,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -223,6 +239,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -244,6 +261,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -265,6 +283,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: GroupBy,
Expand All @@ -286,6 +305,7 @@ class _BM25GenerateAsync(
query_properties: Optional[List[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
operator: Optional[KeywordOperatorOptions] = None,
auto_limit: Optional[int] = None,
filters: Optional[_Filters] = None,
group_by: Optional[GroupBy] = None,
Expand Down
Loading