diff --git a/mem0/vector_stores/qdrant.py b/mem0/vector_stores/qdrant.py index 59ee9a92c1..6e06ffc24c 100644 --- a/mem0/vector_stores/qdrant.py +++ b/mem0/vector_stores/qdrant.py @@ -8,6 +8,7 @@ FieldCondition, Filter, MatchValue, + MatchAny, PointIdsList, PointStruct, Range, @@ -150,14 +151,54 @@ def _create_filter(self, filters: dict) -> Filter: """ if not filters: return None - - conditions = [] + + must_conditions = [] + should_conditions = [] + must_not_conditions = [] + for key, value in filters.items(): - if isinstance(value, dict) and "gte" in value and "lte" in value: - conditions.append(FieldCondition(key=key, range=Range(gte=value["gte"], lte=value["lte"]))) + if key == "AND": + for item in value: + sub_filter = self._create_filter(item) + if sub_filter: + must_conditions.append(sub_filter) + elif key == "OR": + for item in value: + sub_filter = self._create_filter(item) + if sub_filter: + should_conditions.append(sub_filter) + elif key == "NOT": + for item in value: + sub_filter = self._create_filter(item) + if sub_filter: + must_not_conditions.append(sub_filter) + elif isinstance(value, dict): + # Handle operators + range_ops = {k: v for k, v in value.items() if k in ["gt", "gte", "lt", "lte"]} + other_ops = {k: v for k, v in value.items() if k not in ["gt", "gte", "lt", "lte"]} + + if range_ops: + must_conditions.append(FieldCondition(key=key, range=Range(**range_ops))) + + for operator, operand in other_ops.items(): + if operator == "eq": + must_conditions.append(FieldCondition(key=key, match=MatchValue(value=operand))) + elif operator == "ne": + must_not_conditions.append(FieldCondition(key=key, match=MatchValue(value=operand))) + elif operator == "in": + must_conditions.append(FieldCondition(key=key, match=MatchAny(any=operand))) + elif operator == "nin": + must_not_conditions.append(FieldCondition(key=key, match=MatchAny(any=operand))) + else: + logger.warning(f"Unsupported operator: {operator}") else: - conditions.append(FieldCondition(key=key, match=MatchValue(value=value))) - return Filter(must=conditions) if conditions else None + must_conditions.append(FieldCondition(key=key, match=MatchValue(value=value))) + + return Filter( + must=must_conditions if must_conditions else None, + should=should_conditions if should_conditions else None, + must_not=must_not_conditions if must_not_conditions else None + ) def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list: """ diff --git a/pyproject.toml b/pyproject.toml index 279c81c4d2..fdd6b15820 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,8 @@ test = [ "pytest>=8.2.2", "pytest-mock>=3.14.0", "pytest-asyncio>=0.23.7", + "sentence-transformers>=3.0.0", + "qdrant-client>=1.9.0", ] dev = [ "ruff>=0.6.5", diff --git a/tests/test_integration_local.py b/tests/test_integration_local.py new file mode 100644 index 0000000000..b20738a769 --- /dev/null +++ b/tests/test_integration_local.py @@ -0,0 +1,120 @@ +import pytest +import os +from mem0 import Memory +from mem0.configs.base import MemoryConfig + +from unittest.mock import patch, MagicMock + +# Skip test if dependencies are missing +try: + import sentence_transformers + import qdrant_client + DEPENDENCIES_INSTALLED = True +except ImportError: + DEPENDENCIES_INSTALLED = False + +@pytest.fixture +def local_memory(): + """ + Fixture for a fully local Memory instance: + - Embedding: HuggingFace (sentence-transformers/all-MiniLM-L6-v2) + - Vector Store: Qdrant (:memory:) + - LLM: Mocked (to avoid API key requirements) + """ + with patch("mem0.memory.main.LlmFactory.create") as mock_llm: + # Configuration + config = MemoryConfig( + embedder={ + "provider": "huggingface", + "config": { + "model": "sentence-transformers/all-MiniLM-L6-v2" + } + }, + vector_store={ + "provider": "qdrant", + "config": { + "collection_name": "test_integration_local", + "path": ":memory:", + "embedding_model_dims": 384, + } + }, + history_db_path=":memory:" + ) + + # Setup Mock LLM behavior if needed (for infer=True) + # For this basic integration test, we might use infer=False or expect the mock to be called + mock_llm_instance = MagicMock() + mock_llm.return_value = mock_llm_instance + + yield Memory(config=config) + +@pytest.mark.skipif(not DEPENDENCIES_INSTALLED, reason="sentence-transformers or qdrant-client not installed") +def test_full_lifecycle(local_memory): + """ + Test the full lifecycle of a memory: + 1. Add + 2. Search + 3. Get + 4. Update + 5. Delete + """ + user_id = "test_user_integration" + memory_text = "I am writing an integration test for mem0." + + # 1. Add + print("\n[TEST] Adding memory...") + add_response = local_memory.add( + messages=[{"role": "user", "content": memory_text}], + user_id=user_id, + infer=False + ) + assert len(add_response["results"]) > 0 + memory_id = add_response["results"][0]["id"] + print(f"[TEST] Added memory ID: {memory_id}") + + # 2. Search + print("[TEST] Searching memory...") + search_response = local_memory.search("integration test", user_id=user_id) + assert len(search_response["results"]) > 0 + found_memory = search_response["results"][0] + assert found_memory["memory"] == memory_text + print("[TEST] Search successful.") + + # 3. Get + print("[TEST] Getting memory by ID...") + get_response = local_memory.get(memory_id) + assert get_response["memory"] == memory_text + print("[TEST] Get successful.") + + # 4. Update + print("[TEST] Updating memory...") + new_text = "I have updated this memory locally." + update_response = local_memory.update(memory_id, data=new_text) + assert "updated" in update_response["message"].lower() + + # Verify update + get_updated = local_memory.get(memory_id) + assert get_updated["memory"] == new_text + print("[TEST] Update successful.") + + # 5. History + print("[TEST] Checking history...") + history = local_memory.history(memory_id) + assert len(history) >= 2 # Add + Update + assert history[0]["event"] == "ADD" + assert history[-1]["event"] == "UPDATE" + print("[TEST] History check successful.") + + # 6. Delete + print("[TEST] Deleting memory...") + delete_response = local_memory.delete(memory_id) + assert "deleted" in delete_response["message"].lower() + + # Verify deletion + get_deleted = local_memory.get(memory_id) + # Qdrant behavior: might return None or raise error depending on implementation + # Based on base implementation, it usually returns None or error, but let's check basic get_all + all_memories = local_memory.get_all(user_id=user_id) + # Validating it's gone from the user's list + assert not any(m["id"] == memory_id for m in all_memories.get("results", [])) + print("[TEST] Delete successful.") diff --git a/tests/test_qdrant_filters.py b/tests/test_qdrant_filters.py new file mode 100644 index 0000000000..f84cc18087 --- /dev/null +++ b/tests/test_qdrant_filters.py @@ -0,0 +1,97 @@ +import pytest +from mem0.vector_stores.qdrant import Qdrant +from qdrant_client.http import models as rest + +class MockQdrant(Qdrant): + def __init__(self): + # Bypass init to test _create_filter directly + pass + +@pytest.fixture +def qdrant_store(): + return MockQdrant() + +def test_simple_kv_filter(qdrant_store): + filters = {"user_id": "alice"} + q_filter = qdrant_store._create_filter(filters) + + assert isinstance(q_filter, rest.Filter) + assert len(q_filter.must) == 1 + assert q_filter.must[0].key == "user_id" + assert q_filter.must[0].match.value == "alice" + +def test_operator_eq(qdrant_store): + filters = {"user_id": {"eq": "alice"}} + q_filter = qdrant_store._create_filter(filters) + + assert len(q_filter.must) == 1 + assert q_filter.must[0].key == "user_id" + assert q_filter.must[0].match.value == "alice" + +def test_operator_ne(qdrant_store): + filters = {"status": {"ne": "deleted"}} + q_filter = qdrant_store._create_filter(filters) + + # "ne" usually maps to must_not match + assert q_filter.must_not is not None + assert len(q_filter.must_not) == 1 + assert q_filter.must_not[0].key == "status" + assert q_filter.must_not[0].match.value == "deleted" + +def test_operator_range(qdrant_store): + filters = {"age": {"gt": 18, "lte": 30}} + q_filter = qdrant_store._create_filter(filters) + + assert len(q_filter.must) == 1 + range_cond = qdrant_store._create_filter({"age": {"gt": 18}}).must[0].range + # Note: Structure might vary depending on implementation (one Range object vs multiple) + # Assuming standard behavior: separate conditions or unified range + # Let's verify at least one range condition exists + cond = q_filter.must[0] + assert cond.key == "age" + assert cond.range.gt == 18 + assert cond.range.lte == 30 + +def test_operator_in(qdrant_store): + filters = {"tags": {"in": ["ai", "python"]}} + q_filter = qdrant_store._create_filter(filters) + + assert len(q_filter.must) == 1 + assert q_filter.must[0].key == "tags" + assert q_filter.must[0].match.any == ["ai", "python"] + +def test_operator_nin(qdrant_store): + filters = {"tags": {"nin": ["spam", "ads"]}} + q_filter = qdrant_store._create_filter(filters) + + assert len(q_filter.must_not) == 1 + assert q_filter.must_not[0].key == "tags" + assert q_filter.must_not[0].match.any == ["spam", "ads"] + +def test_logical_or(qdrant_store): + # OR: [{"role": "admin"}, {"status": "active"}] + # Mem0 defined format for OR might be implicit list or strict key "$or" or "OR" + # Based on memory/main.py logic, it might modify keys. + # But usually vector store receives filters directly. + # Assuming Mem0 convention for explicit OR is key "OR" or similar? + # Wait, looking at memory/main.py _process_metadata_filters: + # it maps standard operators but structure is passed. + # Let's test standard Qdrant/Mongo style if applicable, or just assume input is what's passed. + + # According to `mem0/vector_stores/qdrant.py` current impl, it iterates `.items()`. + # Let's implement support for a special key "OR" that takes a list of conditions. + filters = { + "OR": [ + {"role": "admin"}, + {"role": "editor"} + ] + } + q_filter = qdrant_store._create_filter(filters) + + assert q_filter.should is not None + assert len(q_filter.should) == 2 + # Since _create_filter returns a Filter object, the list contains nested Filters + # Each nested Filter wraps the condition in its 'must' list + assert isinstance(q_filter.should[0], rest.Filter) + assert q_filter.should[0].must[0].key == "role" + assert q_filter.should[0].must[0].match.value == "admin"