Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions tests/memory/test_message_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Essential message validation tests for Memory and AsyncMemory classes."""

import pytest
from unittest.mock import Mock, patch
from mem0.exceptions import ValidationError as Mem0ValidationError
from mem0.memory.main import Memory, AsyncMemory
from mem0.configs.base import MemoryConfig


class TestMessageValidation:
"""Test message validation for Memory and AsyncMemory classes."""

@pytest.fixture
def memory_instance(self):
"""Create Memory instance with mocked dependencies."""
config = MemoryConfig()
with patch('mem0.memory.main.EmbedderFactory') as mock_embedder, \
patch('mem0.memory.main.VectorStoreFactory') as mock_vector, \
patch('mem0.memory.main.LlmFactory') as mock_llm, \
patch('mem0.memory.main.SQLiteManager') as mock_db, \
patch('mem0.memory.main.GraphStoreFactory') as mock_graph, \
patch('mem0.memory.main.VectorStoreFactory') as mock_telemetry:

mock_embedder.create.return_value = Mock()
mock_vector.create.return_value = Mock()
mock_llm.create.return_value = Mock()
mock_db.return_value = Mock()
mock_graph.create.return_value = None
mock_telemetry.create.return_value = Mock()

yield Memory(config)

@pytest.fixture
def async_memory_instance(self):
"""Create AsyncMemory instance with mocked dependencies."""
config = MemoryConfig()
with patch('mem0.memory.main.EmbedderFactory') as mock_embedder, \
patch('mem0.memory.main.VectorStoreFactory') as mock_vector, \
patch('mem0.memory.main.LlmFactory') as mock_llm, \
patch('mem0.memory.main.SQLiteManager') as mock_db, \
patch('mem0.memory.main.GraphStoreFactory') as mock_graph, \
patch('mem0.memory.main.VectorStoreFactory') as mock_telemetry:

mock_embedder.create.return_value = Mock()
mock_vector.create.return_value = Mock()
mock_llm.create.return_value = Mock()
mock_db.return_value = Mock()
mock_graph.create.return_value = None
mock_telemetry.create.return_value = Mock()

yield AsyncMemory(config)

# Message Type Validation
def test_valid_message_types(self, memory_instance):
"""Test valid message types."""
# String message
result = memory_instance.add("Hello", user_id="test")
assert result is not None

# Dict message
result = memory_instance.add({"role": "user", "content": "Hello"}, user_id="test")
assert result is not None

# List message
result = memory_instance.add([{"role": "user", "content": "Hello"}], user_id="test")
assert result is not None

def test_invalid_message_types(self, memory_instance):
"""Test invalid message types raise ValidationError."""
# None message
with pytest.raises(Mem0ValidationError) as exc_info:
memory_instance.add(None, user_id="test")
assert exc_info.value.error_code == "VALIDATION_003"

# Integer message
with pytest.raises(Mem0ValidationError) as exc_info:
memory_instance.add(123, user_id="test")
assert exc_info.value.error_code == "VALIDATION_003"

# List of strings - this will fail in parse_vision_messages, not validation
with pytest.raises(TypeError):
memory_instance.add(["hello", "world"], user_id="test")

# Session ID Validation
def test_no_session_ids(self, memory_instance):
"""Test missing session IDs raises ValidationError."""
with pytest.raises(Mem0ValidationError) as exc_info:
memory_instance.add("Hello")
assert exc_info.value.error_code == "VALIDATION_001"

def test_valid_session_ids(self, memory_instance):
"""Test valid session IDs."""
# user_id
result = memory_instance.add("Hello", user_id="test")
assert result is not None

# agent_id
result = memory_instance.add("Hello", agent_id="test")
assert result is not None

# run_id
result = memory_instance.add("Hello", run_id="test")
assert result is not None

# Memory Type Validation
def test_invalid_memory_type(self, memory_instance):
"""Test invalid memory_type raises ValidationError."""
with pytest.raises(Mem0ValidationError) as exc_info:
memory_instance.add("Hello", user_id="test", memory_type="invalid")
assert exc_info.value.error_code == "VALIDATION_002"

# AsyncMemory Tests
@pytest.mark.asyncio
async def test_async_invalid_message_type(self, async_memory_instance):
"""Test AsyncMemory with invalid message type."""
with pytest.raises(Mem0ValidationError) as exc_info:
await async_memory_instance.add(123, user_id="test")
assert exc_info.value.error_code == "VALIDATION_003"

@pytest.mark.asyncio
async def test_async_no_session_ids(self, async_memory_instance):
"""Test AsyncMemory with no session IDs."""
with pytest.raises(Mem0ValidationError) as exc_info:
await async_memory_instance.add("Hello")
assert exc_info.value.error_code == "VALIDATION_001"
96 changes: 96 additions & 0 deletions tests/memory/test_parse_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Essential tests for message parsing utilities."""

import pytest
from unittest.mock import patch
from mem0.memory.utils import parse_messages, parse_vision_messages


class TestParseMessages:
"""Test the parse_messages utility function."""

def test_parse_basic_messages(self):
"""Test parsing basic message types."""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "system", "content": "You are helpful"}
]
result = parse_messages(messages)

assert "user: Hello" in result
assert "assistant: Hi there" in result
assert "system: You are helpful" in result

def test_parse_empty_messages(self):
"""Test parsing empty messages list."""
result = parse_messages([])
assert result == ""

def test_parse_messages_with_missing_fields(self):
"""Test parsing messages with missing fields."""
# Test missing role - should raise KeyError
with pytest.raises(KeyError):
parse_messages([{"content": "Missing role"}])

# Test missing content - should raise KeyError
with pytest.raises(KeyError):
parse_messages([{"role": "user"}])

# Test invalid role - should be processed but not included
messages = [{"role": "invalid_role", "content": "Invalid role"}]
result = parse_messages(messages)
assert result == ""

def test_parse_messages_with_special_content(self):
"""Test parsing messages with special content."""
messages = [
{"role": "user", "content": "Hello 世界! 🚀"},
{"role": "user", "content": "```python\ndef hello():\n pass\n```"}
]
result = parse_messages(messages)
assert "Hello 世界! 🚀" in result
assert "```python" in result


class TestParseVisionMessages:
"""Test the parse_vision_messages utility function."""

def test_parse_regular_messages(self):
"""Test parsing regular text messages."""
messages = [{"role": "user", "content": "Hello"}]
result = parse_vision_messages(messages)
assert result == messages

def test_parse_vision_messages_with_image(self):
"""Test parsing vision messages with image URL."""
messages = [{
"role": "user",
"content": {
"type": "image_url",
"image_url": {"url": "https://example.com/image.jpg"}
}
}]

with patch('mem0.memory.utils.get_image_description') as mock_get_desc:
mock_get_desc.return_value = "Image description"
result = parse_vision_messages(messages)

assert result[0]["content"] == "Image description"

def test_parse_vision_messages_download_error(self):
"""Test parsing vision messages when image download fails."""
messages = [{
"role": "user",
"content": {
"type": "image_url",
"image_url": {"url": "https://invalid.com/image.jpg"}
}
}]

with patch('mem0.memory.utils.get_image_description') as mock_get_desc:
mock_get_desc.side_effect = Exception("Download failed")

with pytest.raises(Exception) as exc_info:
parse_vision_messages(messages)

assert "Error while downloading" in str(exc_info.value)
Loading