diff --git a/tests/memory/test_message_validation.py b/tests/memory/test_message_validation.py new file mode 100644 index 0000000000..772a653e68 --- /dev/null +++ b/tests/memory/test_message_validation.py @@ -0,0 +1,125 @@ +"""Essential message validation tests for Memory and AsyncMemory classes.""" + +import pytest +from unittest.mock import Mock, patch +from mem0.exceptions import ValidationError as Mem0ValidationError +from mem0.memory.main import Memory, AsyncMemory +from mem0.configs.base import MemoryConfig + + +class TestMessageValidation: + """Test message validation for Memory and AsyncMemory classes.""" + + @pytest.fixture + def memory_instance(self): + """Create Memory instance with mocked dependencies.""" + config = MemoryConfig() + with patch('mem0.memory.main.EmbedderFactory') as mock_embedder, \ + patch('mem0.memory.main.VectorStoreFactory') as mock_vector, \ + patch('mem0.memory.main.LlmFactory') as mock_llm, \ + patch('mem0.memory.main.SQLiteManager') as mock_db, \ + patch('mem0.memory.main.GraphStoreFactory') as mock_graph, \ + patch('mem0.memory.main.VectorStoreFactory') as mock_telemetry: + + mock_embedder.create.return_value = Mock() + mock_vector.create.return_value = Mock() + mock_llm.create.return_value = Mock() + mock_db.return_value = Mock() + mock_graph.create.return_value = None + mock_telemetry.create.return_value = Mock() + + yield Memory(config) + + @pytest.fixture + def async_memory_instance(self): + """Create AsyncMemory instance with mocked dependencies.""" + config = MemoryConfig() + with patch('mem0.memory.main.EmbedderFactory') as mock_embedder, \ + patch('mem0.memory.main.VectorStoreFactory') as mock_vector, \ + patch('mem0.memory.main.LlmFactory') as mock_llm, \ + patch('mem0.memory.main.SQLiteManager') as mock_db, \ + patch('mem0.memory.main.GraphStoreFactory') as mock_graph, \ + patch('mem0.memory.main.VectorStoreFactory') as mock_telemetry: + + mock_embedder.create.return_value = Mock() + mock_vector.create.return_value = Mock() + mock_llm.create.return_value = Mock() + mock_db.return_value = Mock() + mock_graph.create.return_value = None + mock_telemetry.create.return_value = Mock() + + yield AsyncMemory(config) + + # Message Type Validation + def test_valid_message_types(self, memory_instance): + """Test valid message types.""" + # String message + result = memory_instance.add("Hello", user_id="test") + assert result is not None + + # Dict message + result = memory_instance.add({"role": "user", "content": "Hello"}, user_id="test") + assert result is not None + + # List message + result = memory_instance.add([{"role": "user", "content": "Hello"}], user_id="test") + assert result is not None + + def test_invalid_message_types(self, memory_instance): + """Test invalid message types raise ValidationError.""" + # None message + with pytest.raises(Mem0ValidationError) as exc_info: + memory_instance.add(None, user_id="test") + assert exc_info.value.error_code == "VALIDATION_003" + + # Integer message + with pytest.raises(Mem0ValidationError) as exc_info: + memory_instance.add(123, user_id="test") + assert exc_info.value.error_code == "VALIDATION_003" + + # List of strings - this will fail in parse_vision_messages, not validation + with pytest.raises(TypeError): + memory_instance.add(["hello", "world"], user_id="test") + + # Session ID Validation + def test_no_session_ids(self, memory_instance): + """Test missing session IDs raises ValidationError.""" + with pytest.raises(Mem0ValidationError) as exc_info: + memory_instance.add("Hello") + assert exc_info.value.error_code == "VALIDATION_001" + + def test_valid_session_ids(self, memory_instance): + """Test valid session IDs.""" + # user_id + result = memory_instance.add("Hello", user_id="test") + assert result is not None + + # agent_id + result = memory_instance.add("Hello", agent_id="test") + assert result is not None + + # run_id + result = memory_instance.add("Hello", run_id="test") + assert result is not None + + # Memory Type Validation + def test_invalid_memory_type(self, memory_instance): + """Test invalid memory_type raises ValidationError.""" + with pytest.raises(Mem0ValidationError) as exc_info: + memory_instance.add("Hello", user_id="test", memory_type="invalid") + assert exc_info.value.error_code == "VALIDATION_002" + + # AsyncMemory Tests + @pytest.mark.asyncio + async def test_async_invalid_message_type(self, async_memory_instance): + """Test AsyncMemory with invalid message type.""" + with pytest.raises(Mem0ValidationError) as exc_info: + await async_memory_instance.add(123, user_id="test") + assert exc_info.value.error_code == "VALIDATION_003" + + @pytest.mark.asyncio + async def test_async_no_session_ids(self, async_memory_instance): + """Test AsyncMemory with no session IDs.""" + with pytest.raises(Mem0ValidationError) as exc_info: + await async_memory_instance.add("Hello") + assert exc_info.value.error_code == "VALIDATION_001" diff --git a/tests/memory/test_parse_messages.py b/tests/memory/test_parse_messages.py new file mode 100644 index 0000000000..47934eb8a1 --- /dev/null +++ b/tests/memory/test_parse_messages.py @@ -0,0 +1,96 @@ +"""Essential tests for message parsing utilities.""" + +import pytest +from unittest.mock import patch +from mem0.memory.utils import parse_messages, parse_vision_messages + + +class TestParseMessages: + """Test the parse_messages utility function.""" + + def test_parse_basic_messages(self): + """Test parsing basic message types.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + {"role": "system", "content": "You are helpful"} + ] + result = parse_messages(messages) + + assert "user: Hello" in result + assert "assistant: Hi there" in result + assert "system: You are helpful" in result + + def test_parse_empty_messages(self): + """Test parsing empty messages list.""" + result = parse_messages([]) + assert result == "" + + def test_parse_messages_with_missing_fields(self): + """Test parsing messages with missing fields.""" + # Test missing role - should raise KeyError + with pytest.raises(KeyError): + parse_messages([{"content": "Missing role"}]) + + # Test missing content - should raise KeyError + with pytest.raises(KeyError): + parse_messages([{"role": "user"}]) + + # Test invalid role - should be processed but not included + messages = [{"role": "invalid_role", "content": "Invalid role"}] + result = parse_messages(messages) + assert result == "" + + def test_parse_messages_with_special_content(self): + """Test parsing messages with special content.""" + messages = [ + {"role": "user", "content": "Hello δΈ–η•Œ! πŸš€"}, + {"role": "user", "content": "```python\ndef hello():\n pass\n```"} + ] + result = parse_messages(messages) + assert "Hello δΈ–η•Œ! πŸš€" in result + assert "```python" in result + + +class TestParseVisionMessages: + """Test the parse_vision_messages utility function.""" + + def test_parse_regular_messages(self): + """Test parsing regular text messages.""" + messages = [{"role": "user", "content": "Hello"}] + result = parse_vision_messages(messages) + assert result == messages + + def test_parse_vision_messages_with_image(self): + """Test parsing vision messages with image URL.""" + messages = [{ + "role": "user", + "content": { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg"} + } + }] + + with patch('mem0.memory.utils.get_image_description') as mock_get_desc: + mock_get_desc.return_value = "Image description" + result = parse_vision_messages(messages) + + assert result[0]["content"] == "Image description" + + def test_parse_vision_messages_download_error(self): + """Test parsing vision messages when image download fails.""" + messages = [{ + "role": "user", + "content": { + "type": "image_url", + "image_url": {"url": "https://invalid.com/image.jpg"} + } + }] + + with patch('mem0.memory.utils.get_image_description') as mock_get_desc: + mock_get_desc.side_effect = Exception("Download failed") + + with pytest.raises(Exception) as exc_info: + parse_vision_messages(messages) + + assert "Error while downloading" in str(exc_info.value)