Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ TIMEOUT: 60 # Timeout for llm invocation

#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
# ZHIPUAI_API_MODEL: "glm-4"

#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`.
#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY"
Expand Down
4 changes: 4 additions & 0 deletions examples/llm_hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ async def main():
# streaming mode, much slower
await llm.acompletion_text(hello_msg, stream=True)

# check completion if exist to test llm complete functions
if hasattr(llm, "completion"):
logger.info(llm.completion(hello_msg))


if __name__ == "__main__":
asyncio.run(main())
6 changes: 4 additions & 2 deletions metagpt/actions/write_code_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ async def run(self, *args, **kwargs) -> CodingContext:
cr_prompt = EXAMPLE_AND_INSTRUCTION.format(
format_example=format_example,
)
len1 = len(iterative_code) if iterative_code else 0
len2 = len(self.context.code_doc.content) if self.context.code_doc.content else 0
logger.info(
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, "
f"{len(self.context.code_doc.content)=}"
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, "
f"len(self.context.code_doc.content)={len2}"
)
result, rewrited_code = await self.write_code_review_and_rewrite(
context_prompt, cr_prompt, self.context.code_doc.filename
Expand Down
1 change: 1 addition & 0 deletions metagpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def _update(self):
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("ANTHROPIC_API_KEY")
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
self.zhipuai_api_model = self._get("ZHIPUAI_API_MODEL")
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
Expand Down
4 changes: 4 additions & 0 deletions metagpt/provider/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def get_choice_text(self, rsp: dict) -> str:
"""Required to provide the first text of choice"""
return rsp.get("choices")[0]["message"]["content"]

def get_choice_delta_text(self, rsp: dict) -> str:
"""Required to provide the first text of stream choice"""
return rsp.get("choices")[0]["delta"]["content"]

def get_choice_function(self, rsp: dict) -> dict:
"""Required to provide the first function of choice
:param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls",
Expand Down
6 changes: 2 additions & 4 deletions metagpt/provider/general_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,8 @@ def _interpret_response(
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
if stream and (
"text/event-stream" in result.headers.get("Content-Type", "")
or "application/x-ndjson" in result.headers.get("Content-Type", "")
):
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)
Expand Down
92 changes: 24 additions & 68 deletions metagpt/provider/zhipuai/async_sse_client.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : async_sse_client to make keep the use of Event to access response
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
# refs to `zhipuai/core/_sse_client.py`

from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient
import json
from typing import Any, Iterator


class AsyncSSEClient(SSEClient):
async def _aread(self):
data = b""
async for chunk in self._event_source:
for line in chunk.splitlines(True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data

async def async_events(self):
async for chunk in self._aread():
event = Event()
# Split before decoding so splitlines() only uses \r and \n
for line in chunk.splitlines():
# Decode the line.
line = line.decode(self._char_enc)

# Lines starting with a separator are comments and are to be
# ignored.
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
continue

data = line.split(_FIELD_SEPARATOR, 1)
field = data[0]

# Ignore unknown fields.
if field not in event.__dict__:
self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field)
continue

if len(data) > 1:
# From the spec:
# "If value starts with a single U+0020 SPACE character,
# remove it from value."
if data[1].startswith(" "):
value = data[1][1:]
else:
value = data[1]
else:
# If no value is present after the separator,
# assume an empty value.
value = ""
class AsyncSSEClient(object):
def __init__(self, event_source: Iterator[Any]):
self._event_source = event_source

# The data field may come over multiple lines and their values
# are concatenated with each other.
if field == "data":
event.__dict__[field] += value + "\n"
else:
event.__dict__[field] = value

# Events with no data are not dispatched.
if not event.data:
continue

# If the data field ends with a newline, remove it.
if event.data.endswith("\n"):
event.data = event.data[0:-1]

# Empty event names default to 'message'
event.event = event.event or "message"

# Dispatch the event
self._logger.debug("Dispatching %s...", event)
yield event
async def stream(self) -> dict:
if isinstance(self._event_source, bytes):
raise RuntimeError(
f"Request failed, msg: {self._event_source.decode('utf-8')}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
)
async for chunk in self._event_source:
line = chunk.decode("utf-8")
if line.startswith(":") or not line:
return

field, _p, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if field == "data":
if value.startswith("[DONE]"):
break
data = json.loads(value)
yield data
59 changes: 19 additions & 40 deletions metagpt/provider/zhipuai/zhipu_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,72 +4,51 @@

import json

import zhipuai
from zhipuai.model_api.api import InvokeType, ModelAPI
from zhipuai.utils.http_client import headers as zhipuai_default_headers
from zhipuai import ZhipuAI
from zhipuai.core._http_client import ZHIPUAI_DEFAULT_TIMEOUT

from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient


class ZhiPuModelAPI(ModelAPI):
@classmethod
def get_header(cls) -> dict:
token = cls._generate_token()
zhipuai_default_headers.update({"Authorization": token})
return zhipuai_default_headers

@classmethod
def get_sse_header(cls) -> dict:
token = cls._generate_token()
headers = {"Authorization": token}
return headers

@classmethod
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
class ZhiPuModelAPI(ZhipuAI):
def split_zhipu_api_url(self):
# use this method to prevent zhipu api upgrading to different version.
# and follow the GeneralAPIRequestor implemented based on openai sdk
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
"""
example:
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
"""
zhipu_api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
arr = zhipu_api_url.split("/api/")
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
# ("https://open.bigmodel.cn/api" , "/paas/v4/chat/completions")
return f"{arr[0]}/api", f"/{arr[1]}"

@classmethod
async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs):
async def arequest(self, stream: bool, method: str, headers: dict, kwargs):
# TODO to make the async request to be more generic for models in http mode.
assert method in ["post", "get"]

base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
base_url, url = self.split_zhipu_api_url()
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
method=method,
url=url,
headers=headers,
stream=stream,
params=kwargs,
request_timeout=zhipuai.api_timeout_seconds,
request_timeout=ZHIPUAI_DEFAULT_TIMEOUT.read,
)
return result

@classmethod
async def ainvoke(cls, **kwargs) -> dict:
async def acreate(self, **kwargs) -> dict:
"""async invoke different from raw method `async_invoke` which get the final result by task_id"""
headers = cls.get_header()
resp = await cls.arequest(
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
)
headers = self._default_headers
resp = await self.arequest(stream=False, method="post", headers=headers, kwargs=kwargs)
resp = resp.decode("utf-8")
resp = json.loads(resp)
if "error" in resp:
raise RuntimeError(
f"Request failed, msg: {resp}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
)
return resp

@classmethod
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
async def acreate_stream(self, **kwargs) -> AsyncSSEClient:
"""async sse_invoke"""
headers = cls.get_sse_header()
return AsyncSSEClient(
await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs)
)
headers = self._default_headers
return AsyncSSEClient(await self.arequest(stream=True, method="post", headers=headers, kwargs=kwargs))
65 changes: 20 additions & 45 deletions metagpt/provider/zhipuai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
# -*- coding: utf-8 -*-
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk

import json
from enum import Enum

import openai
import zhipuai
from requests import ConnectionError
from tenacity import (
after_log,
Expand All @@ -15,6 +13,7 @@
stop_after_attempt,
wait_random_exponential,
)
from zhipuai.types.chat.chat_completion import Completion

from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import log_llm_stream, logger
Expand All @@ -35,26 +34,25 @@ class ZhiPuEvent(Enum):
class ZhiPuAILLM(BaseLLM):
"""
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
From now, there is only one model named `chatglm_turbo`
From now, support glm-3-turbo、glm-4, and also system_prompt.
"""

def __init__(self):
self.__init_zhipuai(CONFIG)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
self.llm = ZhiPuModelAPI(api_key=self.api_key)

def __init_zhipuai(self, config: CONFIG):
assert config.zhipuai_api_key
zhipuai.api_key = config.zhipuai_api_key
self.api_key = config.zhipuai_api_key
self.model = config.zhipuai_api_model # so far, it support glm-3-turbo、glm-4
# due to use openai sdk, set the api_key but it will't be used.
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
if config.openai_proxy:
# FIXME: openai v1.x sdk has no proxy support
openai.proxy = config.openai_proxy

def _const_kwargs(self, messages: list[dict]) -> dict:
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs

def _update_costs(self, usage: dict):
Expand All @@ -67,57 +65,34 @@ def _update_costs(self, usage: dict):
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")

def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")

def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.invoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()
self._update_costs(usage)
return resp
return resp.model_dump()

async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
resp = await self.llm.acreate(**self._const_kwargs(messages))
usage = resp.get("usage", {})
self._update_costs(usage)
return resp

async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)

async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
async for event in response.async_events():
if event.event == ZhiPuEvent.ADD.value:
content = event.data
async for chunk in response.stream():
finish_reason = chunk.get("choices")[0].get("finish_reason")
if finish_reason == "stop":
usage = chunk.get("usage", {})
else:
content = self.get_choice_delta_text(chunk)
collected_content.append(content)
log_llm_stream(content)
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
content = event.data
logger.error(f"event error: {content}", end="")
elif event.event == ZhiPuEvent.FINISH.value:
"""
event.meta
{
"task_status":"SUCCESS",
"usage":{
"completion_tokens":351,
"prompt_tokens":595,
"total_tokens":946
},
"task_id":"xx",
"request_id":"xxx"
}
"""
meta = json.loads(event.meta)
usage = meta.get("usage")
else:
print(f"zhipuapi else event: {event.data}", end="")

log_llm_stream("\n")

self._update_costs(usage)
Expand Down
1 change: 1 addition & 0 deletions metagpt/utils/file_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def save(self, filename: Path | str, content, dependencies: List[str] = No
"""
pathname = self.workdir / filename
pathname.parent.mkdir(parents=True, exist_ok=True)
content = content if content else "" # avoid `argument must be str, not None` to make it continue
async with aiofiles.open(str(pathname), mode="w") as writer:
await writer.write(content)
logger.info(f"save to: {str(pathname)}")
Expand Down
Loading