Skip to content

Commit 230192f

Browse files
authored
Merge pull request #685 from garylin2099/llm_mock
Reduce test time with a global LLM mock
2 parents 136b3f5 + bd4a35f commit 230192f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+289
-217
lines changed

metagpt/actions/invoice_ocr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ async def _unzip(file_path: Path) -> Path:
8888
async def _ocr(invoice_file_path: Path):
8989
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
9090
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
91+
for result in ocr_result[0]:
92+
result[1] = (result[1][0], round(result[1][1], 2)) # round long confidence scores to reduce token costs
9193
return ocr_result
9294

9395
async def run(self, file_path: Path, *args, **kwargs) -> list:

tests/conftest.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,57 +12,20 @@
1212
import os
1313
import re
1414
import uuid
15-
from typing import Optional
1615

1716
import pytest
1817

1918
from metagpt.config import CONFIG, Config
2019
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
2120
from metagpt.llm import LLM
2221
from metagpt.logs import logger
23-
from metagpt.provider.openai_api import OpenAILLM
2422
from metagpt.utils.git_repository import GitRepository
23+
from tests.mock.mock_llm import MockLLM
2524

26-
27-
class MockLLM(OpenAILLM):
28-
rsp_cache: dict = {}
29-
30-
async def original_aask(
31-
self,
32-
msg: str,
33-
system_msgs: Optional[list[str]] = None,
34-
format_msgs: Optional[list[dict[str, str]]] = None,
35-
timeout=3,
36-
stream=True,
37-
):
38-
"""A copy of metagpt.provider.base_llm.BaseLLM.aask, we can't use super().aask because it will be mocked"""
39-
if system_msgs:
40-
message = self._system_msgs(system_msgs)
41-
else:
42-
message = [self._default_system_msg()] if self.use_system_prompt else []
43-
if format_msgs:
44-
message.extend(format_msgs)
45-
message.append(self._user_msg(msg))
46-
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
47-
return rsp
48-
49-
async def aask(
50-
self,
51-
msg: str,
52-
system_msgs: Optional[list[str]] = None,
53-
format_msgs: Optional[list[dict[str, str]]] = None,
54-
timeout=3,
55-
stream=True,
56-
) -> str:
57-
if msg not in self.rsp_cache:
58-
# Call the original unmocked method
59-
rsp = await self.original_aask(msg, system_msgs, format_msgs, timeout, stream)
60-
logger.info(f"Added '{rsp[:20]}' ... to response cache")
61-
self.rsp_cache[msg] = rsp
62-
return rsp
63-
else:
64-
logger.info("Use response cache")
65-
return self.rsp_cache[msg]
25+
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
26+
ALLOW_OPENAI_API_CALL = os.environ.get(
27+
"ALLOW_OPENAI_API_CALL", True
28+
) # NOTE: should change to default False once mock is complete
6629

6730

6831
@pytest.fixture(scope="session")
@@ -76,16 +39,37 @@ def rsp_cache():
7639
else:
7740
rsp_cache_json = {}
7841
yield rsp_cache_json
79-
with open(new_rsp_cache_file_path, "w") as f2:
42+
with open(rsp_cache_file_path, "w") as f2:
8043
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
44+
with open(new_rsp_cache_file_path, "w") as f2:
45+
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)
8146

8247

83-
@pytest.fixture(scope="function")
84-
def llm_mock(rsp_cache, mocker):
85-
llm = MockLLM()
48+
# Hook to capture the test result
49+
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
50+
def pytest_runtest_makereport(item, call):
51+
outcome = yield
52+
rep = outcome.get_result()
53+
if rep.when == "call":
54+
item.test_outcome = rep
55+
56+
57+
@pytest.fixture(scope="function", autouse=True)
58+
def llm_mock(rsp_cache, mocker, request):
59+
llm = MockLLM(allow_open_api_call=ALLOW_OPENAI_API_CALL)
8660
llm.rsp_cache = rsp_cache
8761
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
62+
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)
8863
yield mocker
64+
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
65+
if llm.rsp_candidates:
66+
for rsp_candidate in llm.rsp_candidates:
67+
cand_key = list(rsp_candidate.keys())[0]
68+
cand_value = list(rsp_candidate.values())[0]
69+
if cand_key not in llm.rsp_cache:
70+
logger.info(f"Added '{cand_key[:100]} ... -> {cand_value[:20]} ...' to response cache")
71+
llm.rsp_cache.update(rsp_candidate)
72+
RSP_CACHE_NEW.update(rsp_candidate)
8973

9074

9175
class Context:
@@ -173,6 +157,13 @@ def init_config():
173157
Config()
174158

175159

160+
@pytest.fixture(scope="function")
161+
def new_filename(mocker):
162+
# NOTE: Mock new filename to make reproducible llm aask, should consider changing after implementing requirement segmentation
163+
mocker.patch("metagpt.utils.file_repository.FileRepository.new_filename", lambda: "20240101")
164+
yield mocker
165+
166+
176167
@pytest.fixture
177168
def aiohttp_mocker(mocker):
178169
class MockAioResponse:

tests/data/rsp_cache.json

Lines changed: 117 additions & 67 deletions
Large diffs are not rendered by default.

tests/metagpt/actions/test_debug_error.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def test_player_calculate_score_with_multiple_aces(self):
117117

118118

119119
@pytest.mark.asyncio
120-
@pytest.mark.usefixtures("llm_mock")
121120
async def test_debug_error():
122121
CONFIG.src_workspace = CONFIG.git_repo.workdir / uuid.uuid4().hex
123122
ctx = RunCodeContext(

tests/metagpt/actions/test_design_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818

1919
@pytest.mark.asyncio
20-
@pytest.mark.usefixtures("llm_mock")
2120
async def test_design_api():
2221
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", PRD_SAMPLE]
2322
for prd in inputs:

tests/metagpt/actions/test_design_api_review.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212

1313
@pytest.mark.asyncio
14-
@pytest.mark.usefixtures("llm_mock")
1514
async def test_design_api_review():
1615
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
1716
api_design = """

tests/metagpt/actions/test_generate_questions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121

2222
@pytest.mark.asyncio
23-
@pytest.mark.usefixtures("llm_mock")
2423
async def test_generate_questions():
2524
action = GenerateQuestions()
2625
rsp = await action.run(context)

tests/metagpt/actions/test_invoice_ocr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ async def test_generate_table(invoice_path: Path, expected_result: dict):
5454
("invoice_path", "query", "expected_result"),
5555
[(Path("invoices/invoice-1.pdf"), "Invoicing date", "2023年02月03日")],
5656
)
57-
@pytest.mark.usefixtures("llm_mock")
5857
async def test_reply_question(invoice_path: Path, query: dict, expected_result: str):
5958
invoice_path = TEST_DATA_PATH / invoice_path
6059
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path))

tests/metagpt/actions/test_prepare_interview.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313

1414
@pytest.mark.asyncio
15-
@pytest.mark.usefixtures("llm_mock")
1615
async def test_prepare_interview():
1716
action = PrepareInterview()
1817
rsp = await action.run("I just graduated and hope to find a job as a Python engineer")

tests/metagpt/actions/test_project_management.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919

2020
@pytest.mark.asyncio
21-
@pytest.mark.usefixtures("llm_mock")
2221
async def test_design_api():
2322
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
2423
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)

0 commit comments

Comments
 (0)