Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions mem0/vector_stores/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import uuid
from typing import Dict, List, Mapping, Optional
from urllib.parse import urlparse

from pydantic import BaseModel
from urllib.parse import urlparse

try:
import weaviate
Expand All @@ -13,7 +13,7 @@
)

import weaviate.classes.config as wvcc
from weaviate.classes.init import Auth, AdditionalConfig, Timeout
from weaviate.classes.init import AdditionalConfig, Auth, Timeout
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.util import get_valid_uuid

Expand Down Expand Up @@ -48,10 +48,10 @@ def __init__(
auth_config (dict, optional): Authentication configuration for Weaviate. Defaults to None.
additional_headers (dict, optional): Additional headers for requests. Defaults to None.
"""
if "localhost" in cluster_url:
if "localhost" in cluster_url:
self.client = weaviate.connect_to_local(headers=additional_headers)
elif auth_client_secret:
self.client = weaviate.connect_to_wcs(
elif auth_client_secret:
self.client = weaviate.connect_to_weaviate_cloud(
cluster_url=cluster_url,
auth_credentials=Auth.api_key(auth_client_secret),
headers=additional_headers,
Expand All @@ -76,7 +76,7 @@ def __init__(
grpc_secure,
headers=additional_headers,
skip_init_checks=True,
additional_config=AdditionalConfig(timeout=Timeout(init=2.0))
additional_config=AdditionalConfig(timeout=Timeout(init=2.0)),
)

self.collection_name = collection_name
Expand Down Expand Up @@ -208,12 +208,16 @@ def search(
del payload[id_field]

payload["id"] = str(obj.uuid).split("'")[0] # Include the id in the payload
if obj.metadata.distance is not None:
score = 1 - obj.metadata.distance # Convert distance to similarity score
elif obj.metadata.score is not None:
score = obj.metadata.score
else:
score = 1.0 # Default score if none provided
results.append(
OutputData(
id=str(obj.uuid),
score=1
if obj.metadata.distance is None
else 1 - obj.metadata.distance, # Convert distance to score
score=score,
payload=payload,
)
)
Expand Down