Skip to content

Commit 51ce6f1

Browse files
authored
Bug fix of thinking llm in vllm (#3510)
1 parent 346d89d commit 51ce6f1

File tree

3 files changed

+122
-2
lines changed

3 files changed

+122
-2
lines changed

mem0/memory/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,7 @@ def _create_procedural_memory(self, messages, metadata=None, prompt=None):
889889

890890
try:
891891
procedural_memory = self.llm.generate_response(messages=parsed_messages)
892+
procedural_memory = remove_code_blocks(procedural_memory)
892893
except Exception as e:
893894
logger.error(f"Error generating procedural memory summary: {e}")
894895
raise
@@ -1784,6 +1785,8 @@ async def _create_procedural_memory(self, messages, metadata=None, llm=None, pro
17841785
procedural_memory = response.content
17851786
else:
17861787
procedural_memory = await asyncio.to_thread(self.llm.generate_response, messages=parsed_messages)
1788+
procedural_memory = remove_code_blocks(procedural_memory)
1789+
17871790
except Exception as e:
17881791
logger.error(f"Error generating procedural memory summary: {e}")
17891792
raise

mem0/memory/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def remove_code_blocks(content: str) -> str:
4343
"""
4444
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
4545
match = re.match(pattern, content.strip())
46-
return match.group(1).strip() if match else content.strip()
46+
match_res=match.group(1).strip() if match else content.strip()
47+
return re.sub(r"<think>.*?</think>", "", match_res, flags=re.DOTALL).strip()
48+
4749

4850

4951
def extract_json(text):
@@ -182,3 +184,4 @@ def sanitize_relationship_for_cypher(relationship) -> str:
182184
sanitized = sanitized.replace(old, new)
183185

184186
return re.sub(r"_+", "_", sanitized).strip("_")
187+

tests/llms/test_vllm.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from unittest.mock import Mock, patch
1+
from unittest.mock import MagicMock, Mock, patch
22

33
import pytest
44

5+
from mem0 import AsyncMemory, Memory
56
from mem0.configs.llms.base import BaseLlmConfig
67
from mem0.llms.vllm import VllmLLM
78

@@ -84,3 +85,116 @@ def test_generate_response_with_tools(mock_vllm_client):
8485
assert len(response["tool_calls"]) == 1
8586
assert response["tool_calls"][0]["name"] == "add_memory"
8687
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
88+
89+
90+
91+
def create_mocked_memory():
92+
"""Create a fully mocked Memory instance for testing."""
93+
with patch('mem0.utils.factory.LlmFactory.create') as mock_llm_factory, \
94+
patch('mem0.utils.factory.EmbedderFactory.create') as mock_embedder_factory, \
95+
patch('mem0.utils.factory.VectorStoreFactory.create') as mock_vector_factory, \
96+
patch('mem0.memory.storage.SQLiteManager') as mock_sqlite:
97+
98+
mock_llm = MagicMock()
99+
mock_llm_factory.return_value = mock_llm
100+
101+
mock_embedder = MagicMock()
102+
mock_embedder.embed.return_value = [0.1, 0.2, 0.3]
103+
mock_embedder_factory.return_value = mock_embedder
104+
105+
mock_vector_store = MagicMock()
106+
mock_vector_store.search.return_value = []
107+
mock_vector_store.add.return_value = None
108+
mock_vector_factory.return_value = mock_vector_store
109+
110+
mock_sqlite.return_value = MagicMock()
111+
112+
memory = Memory()
113+
memory.api_version = "v1.0"
114+
return memory, mock_llm, mock_vector_store
115+
116+
117+
def create_mocked_async_memory():
118+
"""Create a fully mocked AsyncMemory instance for testing."""
119+
with patch('mem0.utils.factory.LlmFactory.create') as mock_llm_factory, \
120+
patch('mem0.utils.factory.EmbedderFactory.create') as mock_embedder_factory, \
121+
patch('mem0.utils.factory.VectorStoreFactory.create') as mock_vector_factory, \
122+
patch('mem0.memory.storage.SQLiteManager') as mock_sqlite:
123+
124+
mock_llm = MagicMock()
125+
mock_llm_factory.return_value = mock_llm
126+
127+
mock_embedder = MagicMock()
128+
mock_embedder.embed.return_value = [0.1, 0.2, 0.3]
129+
mock_embedder_factory.return_value = mock_embedder
130+
131+
mock_vector_store = MagicMock()
132+
mock_vector_store.search.return_value = []
133+
mock_vector_store.add.return_value = None
134+
mock_vector_factory.return_value = mock_vector_store
135+
136+
mock_sqlite.return_value = MagicMock()
137+
138+
memory = AsyncMemory()
139+
memory.api_version = "v1.0"
140+
return memory, mock_llm, mock_vector_store
141+
142+
143+
def test_thinking_tags_sync():
144+
"""Test thinking tags handling in Memory._add_to_vector_store (sync)."""
145+
memory, mock_llm, mock_vector_store = create_mocked_memory()
146+
147+
# Mock LLM responses for both phases
148+
mock_llm.generate_response.side_effect = [
149+
' <think>Sync fact extraction</think> \n{"facts": ["User loves sci-fi"]}',
150+
' <think>Sync memory actions</think> \n{"memory": [{"text": "Loves sci-fi", "event": "ADD"}]}'
151+
]
152+
153+
mock_vector_store.search.return_value = []
154+
155+
result = memory._add_to_vector_store(
156+
messages=[{"role": "user", "content": "I love sci-fi movies"}],
157+
metadata={},
158+
filters={},
159+
infer=True
160+
)
161+
162+
assert len(result) == 1
163+
assert result[0]["memory"] == "Loves sci-fi"
164+
assert result[0]["event"] == "ADD"
165+
166+
167+
168+
@pytest.mark.asyncio
169+
async def test_async_thinking_tags_async():
170+
"""Test thinking tags handling in AsyncMemory._add_to_vector_store."""
171+
memory, mock_llm, mock_vector_store = create_mocked_async_memory()
172+
173+
# Directly mock llm.generate_response instead of via asyncio.to_thread
174+
mock_llm.generate_response.side_effect = [
175+
' <think>Async fact extraction</think> \n{"facts": ["User loves sci-fi"]}',
176+
' <think>Async memory actions</think> \n{"memory": [{"text": "Loves sci-fi", "event": "ADD"}]}'
177+
]
178+
179+
# Mock asyncio.to_thread to call the function directly (bypass threading)
180+
async def mock_to_thread(func, *args, **kwargs):
181+
if func == mock_llm.generate_response:
182+
return func(*args, **kwargs)
183+
elif hasattr(func, '__name__') and 'embed' in func.__name__:
184+
return [0.1, 0.2, 0.3]
185+
elif hasattr(func, '__name__') and 'search' in func.__name__:
186+
return []
187+
else:
188+
return func(*args, **kwargs)
189+
190+
with patch('mem0.memory.main.asyncio.to_thread', side_effect=mock_to_thread):
191+
result = await memory._add_to_vector_store(
192+
messages=[{"role": "user", "content": "I love sci-fi movies"}],
193+
metadata={},
194+
effective_filters={},
195+
infer=True
196+
)
197+
198+
assert len(result) == 1
199+
assert result[0]["memory"] == "Loves sci-fi"
200+
assert result[0]["event"] == "ADD"

0 commit comments

Comments
 (0)