Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbgpt/core/interface/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class ModelRequest:
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""

stop: Optional[str] = None
stop: Optional[Union[str, List[str]]] = None
"""The stop condition of the model inference."""
stop_token_ids: Optional[List[int]] = None
"""The stop token ids of the model inference."""
Expand Down
13 changes: 13 additions & 0 deletions dbgpt/model/cluster/apiserver/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, code: int, message: str):
class APISettings(BaseModel):
api_keys: Optional[List[str]] = None
embedding_bach_size: int = 4
ignore_stop_exceeds_error: bool = False


api_settings = APISettings()
Expand Down Expand Up @@ -146,6 +147,15 @@ def check_requests(request) -> Optional[JSONResponse]:
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.stop} is not valid under any of the given schemas - 'stop'",
)
if request.stop and isinstance(request.stop, list) and len(request.stop) > 4:
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop
if not api_settings.ignore_stop_exceeds_error:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"Invalid 'stop': array too long. Expected an array with maximum length 4, but got an array with length {len(request.stop)} instead.",
)
else:
request.stop = request.stop[:4]

return None

Expand Down Expand Up @@ -581,6 +591,7 @@ def initialize_apiserver(
port: int = None,
api_keys: List[str] = None,
embedding_batch_size: Optional[int] = None,
ignore_stop_exceeds_error: bool = False,
):
import os

Expand Down Expand Up @@ -614,6 +625,7 @@ def initialize_apiserver(

if embedding_batch_size:
api_settings.embedding_bach_size = embedding_batch_size
api_settings.ignore_stop_exceeds_error = ignore_stop_exceeds_error

app.include_router(router, prefix="/api", tags=["APIServer"])

Expand Down Expand Up @@ -664,6 +676,7 @@ def run_apiserver():
port=apiserver_params.port,
api_keys=api_keys,
embedding_batch_size=apiserver_params.embedding_batch_size,
ignore_stop_exceeds_error=apiserver_params.ignore_stop_exceeds_error,
)


Expand Down
4 changes: 2 additions & 2 deletions dbgpt/model/cluster/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from dbgpt._private.pydantic import BaseModel
from dbgpt.core.interface.message import ModelMessage
Expand All @@ -15,7 +15,7 @@ class PromptRequest(BaseModel):
prompt: str = None
temperature: float = None
max_new_tokens: int = None
stop: str = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: List[int] = []
context_len: int = None
echo: bool = True
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ class ModelAPIServerParameters(BaseServerParameters):
embedding_batch_size: Optional[int] = field(
default=None, metadata={"help": "Embedding batch size"}
)
ignore_stop_exceeds_error: Optional[bool] = field(
default=False, metadata={"help": "Ignore exceeds stop words error"}
)

log_file: Optional[str] = field(
default="dbgpt_model_apiserver.log",
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def chatgpt_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
async for r in client.generate_stream(request):
yield r
Expand Down Expand Up @@ -188,6 +189,8 @@ def _build_request(
payload["temperature"] = request.temperature
if request.max_new_tokens:
payload["max_tokens"] = request.max_new_tokens
if request.stop:
payload["stop"] = request.stop
return payload

async def generate(
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async def deepseek_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
async for r in client.generate_stream(request):
yield r
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def gemini_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
yield r
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/moonshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def moonshot_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
async for r in client.generate_stream(request):
yield r
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def spark_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
yield r
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/model/proxy/llms/tongyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def tongyi_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
yield r
Expand Down Expand Up @@ -96,6 +97,7 @@ def sync_generate_stream(
top_p=0.8,
stream=True,
result_format="message",
stop=request.stop,
)
for r in res:
if r:
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/yi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def yi_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
async for r in client.generate_stream(request):
yield r
Expand Down
1 change: 1 addition & 0 deletions dbgpt/model/proxy/llms/zhipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def zhipu_generate_stream(
temperature=params.get("temperature"),
context=context,
max_new_tokens=params.get("max_new_tokens"),
stop=params.get("stop"),
)
for r in client.sync_generate_stream(request):
yield r
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/rag/knowledge/docx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _load(self) -> List[Document]:
documents = self._loader.load()
else:
docs = []
_SerializedRelationships.load_from_xml = load_from_xml_v2 # type: ignore
_SerializedRelationships.load_from_xml = load_from_xml_v2 # type: ignore
doc = docx.Document(self._path)
content = []

Expand Down