From 4e13eaca6e6cd6475509ea2aa0daeae32e0e0a73 Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 17 Jan 2024 16:28:13 +0800 Subject: [PATCH] update zhipu api due to new model and api; repair extra invalid generate output; update its unittest --- config/config.yaml | 1 + examples/llm_hello_world.py | 4 + metagpt/actions/write_code_review.py | 6 +- metagpt/config.py | 1 + metagpt/provider/base_llm.py | 4 + metagpt/provider/general_api_requestor.py | 6 +- metagpt/provider/zhipuai/async_sse_client.py | 92 +++++-------------- metagpt/provider/zhipuai/zhipu_model_api.py | 59 ++++-------- metagpt/provider/zhipuai_api.py | 65 ++++--------- metagpt/utils/file_repository.py | 1 + metagpt/utils/repair_llm_raw_output.py | 22 ++++- metagpt/utils/token_counter.py | 3 +- requirements.txt | 2 +- tests/metagpt/provider/test_zhipuai_api.py | 41 +++------ .../provider/zhipuai/test_async_sse_client.py | 10 +- .../provider/zhipuai/test_zhipu_model_api.py | 23 ++--- .../utils/test_repair_llm_raw_output.py | 32 +++++++ 17 files changed, 157 insertions(+), 215 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 6dff55b4e7..f41f7d276f 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -36,6 +36,7 @@ TIMEOUT: 60 # Timeout for llm invocation #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +# ZHIPUAI_API_MODEL: "glm-4" #### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`. #### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY" diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 76be1cc90c..219a303c87 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -23,6 +23,10 @@ async def main(): # streaming mode, much slower await llm.acompletion_text(hello_msg, stream=True) + # check completion if exist to test llm complete functions + if hasattr(llm, "completion"): + logger.info(llm.completion(hello_msg)) + if __name__ == "__main__": asyncio.run(main()) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index a8c9135733..3973d089bb 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -157,9 +157,11 @@ async def run(self, *args, **kwargs) -> CodingContext: cr_prompt = EXAMPLE_AND_INSTRUCTION.format( format_example=format_example, ) + len1 = len(iterative_code) if iterative_code else 0 + len2 = len(self.context.code_doc.content) if self.context.code_doc.content else 0 logger.info( - f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, " - f"{len(self.context.code_doc.content)=}" + f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, " + f"len(self.context.code_doc.content)={len2}" ) result, rewrited_code = await self.write_code_review_and_rewrite( context_prompt, cr_prompt, self.context.code_doc.filename diff --git a/metagpt/config.py b/metagpt/config.py index d633c7d28e..e837b329ba 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -144,6 +144,7 @@ def _update(self): self.openai_api_key = self._get("OPENAI_API_KEY") self.anthropic_api_key = self._get("ANTHROPIC_API_KEY") self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY") + self.zhipuai_api_model = self._get("ZHIPUAI_API_MODEL") self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index d23d162c86..a50cdacd9a 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -89,6 +89,10 @@ def get_choice_text(self, rsp: dict) -> str: """Required to provide the first text of choice""" return rsp.get("choices")[0]["message"]["content"] + def get_choice_delta_text(self, rsp: dict) -> str: + """Required to provide the first text of stream choice""" + return rsp.get("choices")[0]["delta"]["content"] + def get_choice_function(self, rsp: dict) -> dict: """Required to provide the first function of choice :param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls", diff --git a/metagpt/provider/general_api_requestor.py b/metagpt/provider/general_api_requestor.py index cf31fd629b..500cd14267 100644 --- a/metagpt/provider/general_api_requestor.py +++ b/metagpt/provider/general_api_requestor.py @@ -79,10 +79,8 @@ def _interpret_response( async def _interpret_async_response( self, result: aiohttp.ClientResponse, stream: bool ) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]: - if stream and ( - "text/event-stream" in result.headers.get("Content-Type", "") - or "application/x-ndjson" in result.headers.get("Content-Type", "") - ): + content_type = result.headers.get("Content-Type", "") + if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type): # the `Content-Type` of ollama stream resp is "application/x-ndjson" return ( self._interpret_response_line(line, result.status, result.headers, stream=True) diff --git a/metagpt/provider/zhipuai/async_sse_client.py b/metagpt/provider/zhipuai/async_sse_client.py index d7168202a6..054865652e 100644 --- a/metagpt/provider/zhipuai/async_sse_client.py +++ b/metagpt/provider/zhipuai/async_sse_client.py @@ -1,75 +1,31 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : async_sse_client to make keep the use of Event to access response -# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py` +# refs to `zhipuai/core/_sse_client.py` -from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient +import json +from typing import Any, Iterator -class AsyncSSEClient(SSEClient): - async def _aread(self): - data = b"" - async for chunk in self._event_source: - for line in chunk.splitlines(True): - data += line - if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): - yield data - data = b"" - if data: - yield data - - async def async_events(self): - async for chunk in self._aread(): - event = Event() - # Split before decoding so splitlines() only uses \r and \n - for line in chunk.splitlines(): - # Decode the line. - line = line.decode(self._char_enc) - - # Lines starting with a separator are comments and are to be - # ignored. - if not line.strip() or line.startswith(_FIELD_SEPARATOR): - continue - - data = line.split(_FIELD_SEPARATOR, 1) - field = data[0] - - # Ignore unknown fields. - if field not in event.__dict__: - self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field) - continue - - if len(data) > 1: - # From the spec: - # "If value starts with a single U+0020 SPACE character, - # remove it from value." - if data[1].startswith(" "): - value = data[1][1:] - else: - value = data[1] - else: - # If no value is present after the separator, - # assume an empty value. - value = "" +class AsyncSSEClient(object): + def __init__(self, event_source: Iterator[Any]): + self._event_source = event_source - # The data field may come over multiple lines and their values - # are concatenated with each other. - if field == "data": - event.__dict__[field] += value + "\n" - else: - event.__dict__[field] = value - - # Events with no data are not dispatched. - if not event.data: - continue - - # If the data field ends with a newline, remove it. - if event.data.endswith("\n"): - event.data = event.data[0:-1] - - # Empty event names default to 'message' - event.event = event.event or "message" - - # Dispatch the event - self._logger.debug("Dispatching %s...", event) - yield event + async def stream(self) -> dict: + if isinstance(self._event_source, bytes): + raise RuntimeError( + f"Request failed, msg: {self._event_source.decode('utf-8')}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`" + ) + async for chunk in self._event_source: + line = chunk.decode("utf-8") + if line.startswith(":") or not line: + return + + field, _p, value = line.partition(":") + if value.startswith(" "): + value = value[1:] + if field == "data": + if value.startswith("[DONE]"): + break + data = json.loads(value) + yield data diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py index 16d4102d49..a7d49623a3 100644 --- a/metagpt/provider/zhipuai/zhipu_model_api.py +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -4,46 +4,27 @@ import json -import zhipuai -from zhipuai.model_api.api import InvokeType, ModelAPI -from zhipuai.utils.http_client import headers as zhipuai_default_headers +from zhipuai import ZhipuAI +from zhipuai.core._http_client import ZHIPUAI_DEFAULT_TIMEOUT from metagpt.provider.general_api_requestor import GeneralAPIRequestor from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient -class ZhiPuModelAPI(ModelAPI): - @classmethod - def get_header(cls) -> dict: - token = cls._generate_token() - zhipuai_default_headers.update({"Authorization": token}) - return zhipuai_default_headers - - @classmethod - def get_sse_header(cls) -> dict: - token = cls._generate_token() - headers = {"Authorization": token} - return headers - - @classmethod - def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs): +class ZhiPuModelAPI(ZhipuAI): + def split_zhipu_api_url(self): # use this method to prevent zhipu api upgrading to different version. # and follow the GeneralAPIRequestor implemented based on openai sdk - zhipu_api_url = cls._build_api_url(kwargs, invoke_type) - """ - example: - zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method} - """ + zhipu_api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" arr = zhipu_api_url.split("/api/") - # ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke") + # ("https://open.bigmodel.cn/api" , "/paas/v4/chat/completions") return f"{arr[0]}/api", f"/{arr[1]}" - @classmethod - async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs): + async def arequest(self, stream: bool, method: str, headers: dict, kwargs): # TODO to make the async request to be more generic for models in http mode. assert method in ["post", "get"] - base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs) + base_url, url = self.split_zhipu_api_url() requester = GeneralAPIRequestor(base_url=base_url) result, _, api_key = await requester.arequest( method=method, @@ -51,25 +32,23 @@ async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, head headers=headers, stream=stream, params=kwargs, - request_timeout=zhipuai.api_timeout_seconds, + request_timeout=ZHIPUAI_DEFAULT_TIMEOUT.read, ) return result - @classmethod - async def ainvoke(cls, **kwargs) -> dict: + async def acreate(self, **kwargs) -> dict: """async invoke different from raw method `async_invoke` which get the final result by task_id""" - headers = cls.get_header() - resp = await cls.arequest( - invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs - ) + headers = self._default_headers + resp = await self.arequest(stream=False, method="post", headers=headers, kwargs=kwargs) resp = resp.decode("utf-8") resp = json.loads(resp) + if "error" in resp: + raise RuntimeError( + f"Request failed, msg: {resp}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`" + ) return resp - @classmethod - async def asse_invoke(cls, **kwargs) -> AsyncSSEClient: + async def acreate_stream(self, **kwargs) -> AsyncSSEClient: """async sse_invoke""" - headers = cls.get_sse_header() - return AsyncSSEClient( - await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs) - ) + headers = self._default_headers + return AsyncSSEClient(await self.arequest(stream=True, method="post", headers=headers, kwargs=kwargs)) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index e1ccf0de59..a6f77477a8 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -2,11 +2,9 @@ # -*- coding: utf-8 -*- # @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk -import json from enum import Enum import openai -import zhipuai from requests import ConnectionError from tenacity import ( after_log, @@ -15,6 +13,7 @@ stop_after_attempt, wait_random_exponential, ) +from zhipuai.types.chat.chat_completion import Completion from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import log_llm_stream, logger @@ -35,26 +34,25 @@ class ZhiPuEvent(Enum): class ZhiPuAILLM(BaseLLM): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` - From now, there is only one model named `chatglm_turbo` + From now, support glm-3-turbo、glm-4, and also system_prompt. """ def __init__(self): self.__init_zhipuai(CONFIG) - self.llm = ZhiPuModelAPI - self.model = "chatglm_turbo" # so far only one model, just use it - self.use_system_prompt: bool = False # zhipuai has no system prompt when use api + self.llm = ZhiPuModelAPI(api_key=self.api_key) def __init_zhipuai(self, config: CONFIG): assert config.zhipuai_api_key - zhipuai.api_key = config.zhipuai_api_key + self.api_key = config.zhipuai_api_key + self.model = config.zhipuai_api_model # so far, it support glm-3-turbo、glm-4 # due to use openai sdk, set the api_key but it will't be used. # openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used. if config.openai_proxy: # FIXME: openai v1.x sdk has no proxy support openai.proxy = config.openai_proxy - def _const_kwargs(self, messages: list[dict]) -> dict: - kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3} + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3} return kwargs def _update_costs(self, usage: dict): @@ -67,21 +65,15 @@ def _update_costs(self, usage: dict): except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") - def get_choice_text(self, resp: dict) -> str: - """get the first text of choice from llm response""" - assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1] - assert assist_msg["role"] == "assistant" - return assist_msg.get("content") - def completion(self, messages: list[dict], timeout=3) -> dict: - resp = self.llm.invoke(**self._const_kwargs(messages)) - usage = resp.get("data").get("usage") + resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages)) + usage = resp.usage.model_dump() self._update_costs(usage) - return resp + return resp.model_dump() async def _achat_completion(self, messages: list[dict], timeout=3) -> dict: - resp = await self.llm.ainvoke(**self._const_kwargs(messages)) - usage = resp.get("data").get("usage") + resp = await self.llm.acreate(**self._const_kwargs(messages)) + usage = resp.get("usage", {}) self._update_costs(usage) return resp @@ -89,35 +81,18 @@ async def acompletion(self, messages: list[dict], timeout=3) -> dict: return await self._achat_completion(messages, timeout=timeout) async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: - response = await self.llm.asse_invoke(**self._const_kwargs(messages)) + response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True)) collected_content = [] usage = {} - async for event in response.async_events(): - if event.event == ZhiPuEvent.ADD.value: - content = event.data + async for chunk in response.stream(): + finish_reason = chunk.get("choices")[0].get("finish_reason") + if finish_reason == "stop": + usage = chunk.get("usage", {}) + else: + content = self.get_choice_delta_text(chunk) collected_content.append(content) log_llm_stream(content) - elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value: - content = event.data - logger.error(f"event error: {content}", end="") - elif event.event == ZhiPuEvent.FINISH.value: - """ - event.meta - { - "task_status":"SUCCESS", - "usage":{ - "completion_tokens":351, - "prompt_tokens":595, - "total_tokens":946 - }, - "task_id":"xx", - "request_id":"xxx" - } - """ - meta = json.loads(event.meta) - usage = meta.get("usage") - else: - print(f"zhipuapi else event: {event.data}", end="") + log_llm_stream("\n") self._update_costs(usage) diff --git a/metagpt/utils/file_repository.py b/metagpt/utils/file_repository.py index 0ddca414dd..11315e5952 100644 --- a/metagpt/utils/file_repository.py +++ b/metagpt/utils/file_repository.py @@ -55,6 +55,7 @@ async def save(self, filename: Path | str, content, dependencies: List[str] = No """ pathname = self.workdir / filename pathname.parent.mkdir(parents=True, exist_ok=True) + content = content if content else "" # avoid `argument must be str, not None` to make it continue async with aiofiles.open(str(pathname), mode="w") as writer: await writer.write(content) logger.info(f"save to: {str(pathname)}") diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index a96c3dce0d..b71def1360 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -120,6 +120,15 @@ def repair_json_format(output: str) -> str: elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" + # remove `#` in output json str, usually appeared in `glm-4` + arr = output.split("\n") + new_arr = [] + for line in arr: + idx = line.find("#") + if idx >= 0: + line = line[:idx] + new_arr.append(line) + output = "\n".join(new_arr) return output @@ -168,15 +177,17 @@ def repair_invalid_json(output: str, error: str) -> str: example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765) example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266) """ - pattern = r"line ([0-9]+)" + pattern = r"line ([0-9]+) column ([0-9]+)" matches = re.findall(pattern, error, re.DOTALL) if len(matches) > 0: - line_no = int(matches[0]) - 1 + line_no = int(matches[0][0]) - 1 + col_no = int(matches[0][1]) - 1 # due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'` output = output.replace('"""', '"').replace("'''", '"') arr = output.split("\n") + rline = arr[line_no] # raw line line = arr[line_no].strip() # different general problems if line.endswith("],"): @@ -187,9 +198,12 @@ def repair_invalid_json(output: str, error: str) -> str: new_line = line.replace("}", "") elif line.endswith("},") and output.endswith("},"): new_line = line[:-1] - elif '",' not in line and "," not in line: + elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: + # problem, `"""` or `'''` without `,` + new_line = f",{line}" + elif '",' not in line and "," not in line and '"' not in line: new_line = f'{line}",' - elif "," not in line: + elif not line.endswith(","): # problem, miss char `,` at the end. new_line = f"{line}," elif "," in line and len(line) == 1: diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index a1b74a0746..885eb37d71 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -27,7 +27,8 @@ "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, - "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens + "glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, } diff --git a/requirements.txt b/requirements.txt index 0a54236f00..93ad653dc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,7 +50,7 @@ aioredis~=2.0.1 # Used by metagpt/utils/redis.py websocket-client==1.6.2 aiofiles==23.2.1 gitpython==3.1.40 -zhipuai==1.0.7 +zhipuai==2.0.1 socksio~=1.0.0 gitignore-parser==0.1.9 # connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index ab240260cf..8f06fc7172 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -3,7 +3,6 @@ # @Desc : the unittest of ZhiPuAILLM import pytest -from zhipuai.utils.sse_client import Event from metagpt.config import CONFIG from metagpt.provider.zhipuai_api import ZhiPuAILLM @@ -15,35 +14,16 @@ resp_content = "I'm chatglm-turbo" default_resp = { - "code": 200, - "data": { - "choices": [{"role": "assistant", "content": resp_content}], - "usage": {"prompt_tokens": 20, "completion_tokens": 20}, - }, + "choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}], + "usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41}, } -def mock_zhipuai_invoke(**kwargs) -> dict: - return default_resp - - -async def mock_zhipuai_ainvoke(**kwargs) -> dict: - return default_resp - - -async def mock_zhipuai_asse_invoke(**kwargs): +async def mock_zhipuai_acreate_stream(self, **kwargs): class MockResponse(object): async def _aread(self): class Iterator(object): - events = [ - Event(id="xxx", event="add", data=resp_content, retry=0), - Event( - id="xxx", - event="finish", - data="", - meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}', - ), - ] + events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}] async def __aiter__(self): for event in self.events: @@ -52,23 +32,26 @@ async def __aiter__(self): async for chunk in Iterator(): yield chunk - async def async_events(self): + async def stream(self): async for chunk in self._aread(): yield chunk return MockResponse() +async def mock_zhipuai_acreate(self, **kwargs) -> dict: + return default_resp + + @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke) - mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke) - mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream) zhipu_gpt = ZhiPuAILLM() resp = await zhipu_gpt.acompletion(messages) - assert resp["data"]["choices"][0]["content"] == resp_content + assert resp["choices"][0]["message"]["content"] == resp_content resp = await zhipu_gpt.aask(prompt_msg, stream=False) assert resp == resp_content diff --git a/tests/metagpt/provider/zhipuai/test_async_sse_client.py b/tests/metagpt/provider/zhipuai/test_async_sse_client.py index 2649f595be..31b2d3d648 100644 --- a/tests/metagpt/provider/zhipuai/test_async_sse_client.py +++ b/tests/metagpt/provider/zhipuai/test_async_sse_client.py @@ -11,16 +11,16 @@ async def test_async_sse_client(): class Iterator(object): async def __aiter__(self): - yield b"data: test_value" + yield b'data: {"test_key": "test_value"}' async_sse_client = AsyncSSEClient(event_source=Iterator()) - async for event in async_sse_client.async_events(): - assert event.data, "test_value" + async for chunk in async_sse_client.stream(): + assert "test_value" in chunk.values() class InvalidIterator(object): async def __aiter__(self): yield b"invalid: test_value" async_sse_client = AsyncSSEClient(event_source=InvalidIterator()) - async for event in async_sse_client.async_events(): - assert not event + async for chunk in async_sse_client.stream(): + assert not chunk diff --git a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py index 1f0a42fa6a..15673c51c7 100644 --- a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py +++ b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py @@ -6,15 +6,13 @@ import pytest import zhipuai -from zhipuai.model_api.api import InvokeType -from zhipuai.utils.http_client import headers as zhipuai_default_headers from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI api_key = "xxx.xxx" zhipuai.api_key = api_key -default_resp = b'{"result": "test response"}' +default_resp = b'{"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": "test response", "role": "assistant"}}]}' async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]: @@ -23,22 +21,15 @@ async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]: @pytest.mark.asyncio async def test_zhipu_model_api(mocker): - header = ZhiPuModelAPI.get_header() - zhipuai_default_headers.update({"Authorization": api_key}) - assert header == zhipuai_default_headers - - sse_header = ZhiPuModelAPI.get_sse_header() - assert len(sse_header["Authorization"]) == 191 - - url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"}) + url_prefix, url_suffix = ZhiPuModelAPI(api_key=api_key).split_zhipu_api_url() assert url_prefix == "https://open.bigmodel.cn/api" - assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke" + assert url_suffix == "/paas/v4/chat/completions" mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest) - result = await ZhiPuModelAPI.arequest( - InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"} + result = await ZhiPuModelAPI(api_key=api_key).arequest( + stream=False, method="get", headers={}, kwargs={"model": "glm-3-turbo"} ) assert result == default_resp - result = await ZhiPuModelAPI.ainvoke() - assert result["result"] == "test response" + result = await ZhiPuModelAPI(api_key=api_key).acreate() + assert result["choices"][0]["message"]["content"] == "test response" diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index 1970c64430..1f809a0811 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -128,6 +128,19 @@ def test_repair_json_format(): output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) assert output == target_output + raw_output = """ +{ + "Language": "en_us", # define language + "Programming Language": "Python" +} +""" + target_output = """{ + "Language": "en_us", + "Programming Language": "Python" +}""" + output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) + assert output == target_output + def test_repair_invalid_json(): from metagpt.utils.repair_llm_raw_output import repair_invalid_json @@ -204,6 +217,25 @@ def test_retry_parse_json_text(): output = retry_parse_json_text(output=invalid_json_text) assert output == target_json + invalid_json_text = '''{ + "Data structures and interfaces": """ + class UI: + - game_engine: GameEngine + + __init__(engine: GameEngine) -> None + + display_board() -> None + + display_score() -> None + + prompt_move() -> str + + reset_game() -> None + """ + "Anything UNCLEAR": "no" +}''' + target_json = { + "Data structures and interfaces": "\n class UI:\n - game_engine: GameEngine\n + __init__(engine: GameEngine) -> None\n + display_board() -> None\n + display_score() -> None\n + prompt_move() -> str\n + reset_game() -> None\n ", + "Anything UNCLEAR": "no", + } + output = retry_parse_json_text(output=invalid_json_text) + assert output == target_json + def test_extract_content_from_output(): """