Skip to content

Commit 0dad8d3

Browse files
committed
incorporate my feature into the remove_code_blocks function. And add the test to test_vllm.py instead of a entire separate test.
1 parent 343b435 commit 0dad8d3

File tree

4 files changed

+138
-162
lines changed

4 files changed

+138
-162
lines changed

mem0/memory/main.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
parse_vision_messages,
3232
process_telemetry_filters,
3333
remove_code_blocks,
34-
remove_thinking_tags,
3534
)
3635
from mem0.utils.factory import (
3736
EmbedderFactory,
@@ -363,8 +362,6 @@ def _add_to_vector_store(self, messages, metadata, filters, infer):
363362

364363
try:
365364
response = remove_code_blocks(response)
366-
if '</think>' in response:
367-
response=remove_thinking_tags(response)
368365
new_retrieved_facts = json.loads(response)["facts"]
369366
except Exception as e:
370367
logger.error(f"Error in new_retrieved_facts: {e}")
@@ -419,8 +416,6 @@ def _add_to_vector_store(self, messages, metadata, filters, infer):
419416
new_memories_with_actions = {}
420417
else:
421418
response = remove_code_blocks(response)
422-
if '</think>' in response:
423-
response=remove_thinking_tags(response)
424419
new_memories_with_actions = json.loads(response)
425420
except Exception as e:
426421
logger.error(f"Invalid JSON response: {e}")
@@ -894,8 +889,7 @@ def _create_procedural_memory(self, messages, metadata=None, prompt=None):
894889

895890
try:
896891
procedural_memory = self.llm.generate_response(messages=parsed_messages)
897-
if '</think>' in procedural_memory:
898-
procedural_memory=remove_thinking_tags(procedural_memory)
892+
procedural_memory = remove_code_blocks(procedural_memory)
899893
except Exception as e:
900894
logger.error(f"Error generating procedural memory summary: {e}")
901895
raise
@@ -1220,8 +1214,6 @@ async def _add_to_vector_store(
12201214
)
12211215
try:
12221216
response = remove_code_blocks(response)
1223-
if '</think>' in response:
1224-
response=remove_thinking_tags(response)
12251217
new_retrieved_facts = json.loads(response)["facts"]
12261218
except Exception as e:
12271219
logger.error(f"Error in new_retrieved_facts: {e}")
@@ -1270,8 +1262,6 @@ async def process_fact_for_search(new_mem_content):
12701262
messages=[{"role": "user", "content": function_calling_prompt}],
12711263
response_format={"type": "json_object"},
12721264
)
1273-
if '</think>' in response:
1274-
response=remove_thinking_tags(response)
12751265
except Exception as e:
12761266
logger.error(f"Error in new memory actions response: {e}")
12771267
response = ""
@@ -1795,8 +1785,8 @@ async def _create_procedural_memory(self, messages, metadata=None, llm=None, pro
17951785
procedural_memory = response.content
17961786
else:
17971787
procedural_memory = await asyncio.to_thread(self.llm.generate_response, messages=parsed_messages)
1798-
if '</think>' in procedural_memory:
1799-
procedural_memory=remove_thinking_tags
1788+
procedural_memory = remove_code_blocks(procedural_memory)
1789+
18001790
except Exception as e:
18011791
logger.error(f"Error generating procedural memory summary: {e}")
18021792
raise

mem0/memory/utils.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def remove_code_blocks(content: str) -> str:
4343
"""
4444
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
4545
match = re.match(pattern, content.strip())
46-
return match.group(1).strip() if match else content.strip()
46+
match_res=match.group(1).strip() if match else content.strip()
47+
return re.sub(r"<think>.*?</think>", "", match_res, flags=re.DOTALL).strip()
48+
4749

4850

4951
def extract_json(text):
@@ -183,15 +185,3 @@ def sanitize_relationship_for_cypher(relationship) -> str:
183185

184186
return re.sub(r"_+", "_", sanitized).strip("_")
185187

186-
187-
def remove_thinking_tags(text: str) -> str:
188-
"""
189-
Removes <think>...</think> tags from the input text.
190-
191-
Args:
192-
text (str): The input text potentially containing <think>...</think> tags.
193-
194-
Returns:
195-
str: The text with all <think></think>tags and their content removed.
196-
"""
197-
return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()

tests/llms/test_vllm.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from unittest.mock import Mock, patch
1+
from unittest.mock import MagicMock, Mock, patch
22

33
import pytest
44

5+
from mem0 import AsyncMemory, Memory
56
from mem0.configs.llms.base import BaseLlmConfig
67
from mem0.llms.vllm import VllmLLM
8+
from mem0.memory.utils import remove_code_blocks
79

810

911
@pytest.fixture
@@ -84,3 +86,132 @@ def test_generate_response_with_tools(mock_vllm_client):
8486
assert len(response["tool_calls"]) == 1
8587
assert response["tool_calls"][0]["name"] == "add_memory"
8688
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
89+
90+
91+
92+
def test_remove_thinking_tags():
93+
# with thinking tag
94+
assert remove_code_blocks('<think>abc</think>{"facts":["test fact"]}') == '{"facts":["test fact"]}'
95+
96+
# thinking content with multiple lines
97+
assert remove_code_blocks('<think>\nabc\n</think>\n{"facts":["test fact"]}') == '{"facts":["test fact"]}'
98+
99+
# more than one thinking tag(rare in practice)
100+
assert remove_code_blocks('<think>A</think><think>B</think>{"facts":["test fact"]}') == '{"facts":["test fact"]}'
101+
102+
# no tag
103+
assert remove_code_blocks('{"facts":["test fact"]}') == '{"facts":["test fact"]}'
104+
105+
106+
107+
108+
def create_mocked_memory():
109+
"""Create a fully mocked Memory instance for testing."""
110+
with patch('mem0.utils.factory.LlmFactory.create') as mock_llm_factory, \
111+
patch('mem0.utils.factory.EmbedderFactory.create') as mock_embedder_factory, \
112+
patch('mem0.utils.factory.VectorStoreFactory.create') as mock_vector_factory, \
113+
patch('mem0.memory.storage.SQLiteManager') as mock_sqlite:
114+
115+
mock_llm = MagicMock()
116+
mock_llm_factory.return_value = mock_llm
117+
118+
mock_embedder = MagicMock()
119+
mock_embedder.embed.return_value = [0.1, 0.2, 0.3]
120+
mock_embedder_factory.return_value = mock_embedder
121+
122+
mock_vector_store = MagicMock()
123+
mock_vector_store.search.return_value = []
124+
mock_vector_store.add.return_value = None
125+
mock_vector_factory.return_value = mock_vector_store
126+
127+
mock_sqlite.return_value = MagicMock()
128+
129+
memory = Memory()
130+
memory.api_version = "v1.0"
131+
return memory, mock_llm, mock_vector_store
132+
133+
134+
def create_mocked_async_memory():
135+
"""Create a fully mocked AsyncMemory instance for testing."""
136+
with patch('mem0.utils.factory.LlmFactory.create') as mock_llm_factory, \
137+
patch('mem0.utils.factory.EmbedderFactory.create') as mock_embedder_factory, \
138+
patch('mem0.utils.factory.VectorStoreFactory.create') as mock_vector_factory, \
139+
patch('mem0.memory.storage.SQLiteManager') as mock_sqlite:
140+
141+
mock_llm = MagicMock()
142+
mock_llm_factory.return_value = mock_llm
143+
144+
mock_embedder = MagicMock()
145+
mock_embedder.embed.return_value = [0.1, 0.2, 0.3]
146+
mock_embedder_factory.return_value = mock_embedder
147+
148+
mock_vector_store = MagicMock()
149+
mock_vector_store.search.return_value = []
150+
mock_vector_store.add.return_value = None
151+
mock_vector_factory.return_value = mock_vector_store
152+
153+
mock_sqlite.return_value = MagicMock()
154+
155+
memory = AsyncMemory()
156+
memory.api_version = "v1.0"
157+
return memory, mock_llm, mock_vector_store
158+
159+
160+
def test_thinking_tags_in_add_to_vector_store():
161+
"""Test thinking tags handling in Memory._add_to_vector_store (sync)."""
162+
memory, mock_llm, mock_vector_store = create_mocked_memory()
163+
164+
# Mock LLM responses for both phases
165+
mock_llm.generate_response.side_effect = [
166+
' <think>Sync fact extraction</think> \n{"facts": ["User loves sci-fi"]}',
167+
' <think>Sync memory actions</think> \n{"memory": [{"text": "Loves sci-fi", "event": "ADD"}]}'
168+
]
169+
170+
mock_vector_store.search.return_value = []
171+
172+
result = memory._add_to_vector_store(
173+
messages=[{"role": "user", "content": "I love sci-fi movies"}],
174+
metadata={},
175+
filters={},
176+
infer=True
177+
)
178+
179+
assert len(result) == 1
180+
assert result[0]["memory"] == "Loves sci-fi"
181+
assert result[0]["event"] == "ADD"
182+
183+
184+
185+
@pytest.mark.asyncio
186+
async def test_async_thinking_tags_in_add_to_vector_store():
187+
"""Test thinking tags handling in AsyncMemory._add_to_vector_store."""
188+
memory, mock_llm, mock_vector_store = create_mocked_async_memory()
189+
190+
# Directly mock llm.generate_response instead of via asyncio.to_thread
191+
mock_llm.generate_response.side_effect = [
192+
' <think>Async fact extraction</think> \n{"facts": ["User loves sci-fi"]}',
193+
' <think>Async memory actions</think> \n{"memory": [{"text": "Loves sci-fi", "event": "ADD"}]}'
194+
]
195+
196+
# Mock asyncio.to_thread to call the function directly (bypass threading)
197+
async def mock_to_thread(func, *args, **kwargs):
198+
if func == mock_llm.generate_response:
199+
return func(*args, **kwargs)
200+
elif hasattr(func, '__name__') and 'embed' in func.__name__:
201+
return [0.1, 0.2, 0.3]
202+
elif hasattr(func, '__name__') and 'search' in func.__name__:
203+
return []
204+
else:
205+
return func(*args, **kwargs)
206+
207+
with patch('mem0.memory.main.asyncio.to_thread', side_effect=mock_to_thread):
208+
result = await memory._add_to_vector_store(
209+
messages=[{"role": "user", "content": "I love sci-fi movies"}],
210+
metadata={},
211+
effective_filters={},
212+
infer=True
213+
)
214+
215+
assert len(result) == 1
216+
assert result[0]["memory"] == "Loves sci-fi"
217+
assert result[0]["event"] == "ADD"

tests/memory/test_thinking_tag.py

Lines changed: 0 additions & 135 deletions
This file was deleted.

0 commit comments

Comments
 (0)