Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/HTTP.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The API consists of the following endpoints. They can be viewed in your browser
To support additional features, we have extended the completion and chat completion request objects. Both have the same keys added:

- `top_k`: `int` | `null`. If non null, it is only relevant if positive.
- `grammar`: `{"type" : "regex" | "lark" | "json_schema" | "llguidance", "value": string}` or `null`. Grammar to use.
- `grammar`: `{"type" : "regex" | "lark" | "json_schema" | "llguidance", "value": string}` or `null`. Grammar to use. This is mutually exclusive to the OpenAI-compatible `response_format`.
- `adapters`: `array of string` | `null`. Adapter names to activate for this request.
- `min_p`: `float` | `null`. If non null, it is only relevant if 1 >= min_p >= 0.

Expand Down
9 changes: 3 additions & 6 deletions examples/server/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,12 @@
"content": "Gimme a sample address.",
}
],
max_tokens=256,
frequency_penalty=1.0,
top_p=0.1,
temperature=0,
extra_body={
"grammar": {
"type": "json_schema",
"value": addr_schema,
}
response_format={
"type": "json_schema",
"json_schema": {"name": "My Schema", "schema": addr_schema},
},
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, Extra, Field

from openai import OpenAI

Expand All @@ -26,29 +26,24 @@ class Airplane(BaseModel):

class Fleet(BaseModel):
fleet_name: str = Field(..., title="Fleet Name", min_length=1, max_length=50)
airplanes: List[Airplane] = Field(..., title="Fleet Airplanes", min_length=1)

airplanes: List[Airplane] = Field(
..., title="Fleet Airplanes", min_length=1, max_length=3
)

fleet_schema = Fleet.model_json_schema()

completion = client.chat.completions.create(
completion = client.beta.chat.completions.parse(
model="mistral",
messages=[
{
"role": "user",
"content": "Can you please make me a fleet of airplanes?",
}
],
max_tokens=256,
frequency_penalty=1.0,
top_p=0.1,
temperature=0,
extra_body={
"grammar": {
"type": "json_schema",
"value": fleet_schema,
}
},
response_format=Fleet,
)

print(completion.choices[0].message.content)
event = completion.choices[0].message.parsed
print(event)
4 changes: 2 additions & 2 deletions mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
Option text_new.map(ToString::to_string)
),
role: "assistant".to_string(),
tool_calls: Some(tool_calls),
tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
},
index: seq.get_response_index(),
finish_reason: is_done.map(|x| x.to_string()),
Expand Down Expand Up @@ -211,7 +211,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
message: crate::ResponseMessage {
content: text_new.map(ToString::to_string),
role: "assistant".to_string(),
tool_calls,
tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
},
logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
};
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ macro_rules! generate_repr {
pub struct ResponseMessage {
pub content: Option<String>,
pub role: String,
pub tool_calls: Vec<ToolCallResponse>,
pub tool_calls: Option<Vec<ToolCallResponse>>,
}

generate_repr!(ResponseMessage);
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ macro_rules! handle_pipeline_forward_error {
message: ResponseMessage {
content: Some(res),
role: "assistant".to_string(),
tool_calls: Vec::new(),
tool_calls: None,
},
logprobs: None,
};
Expand Down
66 changes: 44 additions & 22 deletions mistralrs-server/src/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use std::{
use tokio::sync::mpsc::{channel, Receiver, Sender};

use crate::{
openai::{ChatCompletionRequest, Grammar, MessageInnerContent, StopTokens},
openai::{
ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
ResponseFormat, StopTokens,
},
util,
};
use anyhow::Result;
Expand Down Expand Up @@ -190,27 +193,33 @@ async fn parse_request(
let mut messages = Vec::new();
let mut image_urls = Vec::new();
for message in req_messages {
match message.content.deref() {
Either::Left(content) => {
let content = match message.content.as_deref() {
Some(content) => content.clone(),
None => {
// Handle tool call
let content = match content {
Some(content) => content.to_string(),
None => {
use anyhow::Context;
let calls = message.tool_calls.as_ref()
.context("No content was provided, expected tool calls to be provided.")?
.iter().map(|call| &call.function).collect::<Vec<_>>();

serde_json::to_string(&calls)?
}
};
use anyhow::Context;
let calls = message
.tool_calls
.as_ref()
.context(
"No content was provided, expected tool calls to be provided.",
)?
.iter()
.map(|call| &call.function)
.collect::<Vec<_>>();

Either::Left(serde_json::to_string(&calls)?)
}
};

match &content {
Either::Left(content) => {
let mut message_map: IndexMap<
String,
Either<String, Vec<IndexMap<String, Value>>>,
> = IndexMap::new();
message_map.insert("role".to_string(), Either::Left(message.role));
message_map.insert("content".to_string(), Either::Left(content));
message_map.insert("content".to_string(), Either::Left(content.clone()));
messages.push(message_map);
}
Either::Right(image_messages) => {
Expand Down Expand Up @@ -351,6 +360,25 @@ async fn parse_request(
};

let is_streaming = oairequest.stream.unwrap_or(false);

if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
}

let constraint = match oairequest.grammar {
Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
None => match oairequest.response_format {
Some(ResponseFormat::JsonSchema {
json_schema: JsonSchemaResponseFormat { name: _, schema },
}) => Constraint::JsonSchema(schema),
Some(ResponseFormat::Text) => Constraint::None,
None => Constraint::None,
},
};

Ok((
Request::Normal(NormalRequest {
id: state.next_request_id(),
Expand All @@ -373,13 +401,7 @@ async fn parse_request(
return_logprobs: oairequest.logprobs,
is_streaming,
suffix: None,
constraint: match oairequest.grammar {
Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
None => Constraint::None,
},
constraint,
adapters: oairequest.adapters,
tool_choice: oairequest.tool_choice,
tools: oairequest.tools,
Expand Down
26 changes: 22 additions & 4 deletions mistralrs-server/src/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@ impl Deref for MessageInnerContent {
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
Either<Option<String>, Vec<HashMap<String, MessageInnerContent>>>,
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);

impl Deref for MessageContent {
type Target = Either<Option<String>, Vec<HashMap<String, MessageInnerContent>>>;
type Target = Either<String, Vec<HashMap<String, MessageInnerContent>>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}

#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct FunctionCalled {
pub description: Option<String>,
pub name: String,
#[serde(alias = "arguments")]
pub parameters: String,
Expand All @@ -48,7 +47,7 @@ pub struct ToolCall {

#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct Message {
pub content: MessageContent,
pub content: Option<MessageContent>,
pub role: String,
pub name: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
Expand Down Expand Up @@ -98,6 +97,23 @@ pub enum Grammar {
Lark(String),
}

#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct JsonSchemaResponseFormat {
pub name: String,
pub schema: serde_json::Value,
}

#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_schema")]
JsonSchema {
json_schema: JsonSchemaResponseFormat,
},
}

#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct ChatCompletionRequest {
#[schema(example = json!(vec![Message{content:"Why did the crab cross the road?".to_string(), role:"user".to_string(), name: None}]))]
Expand Down Expand Up @@ -136,6 +152,8 @@ pub struct ChatCompletionRequest {
pub tools: Option<Vec<Tool>>,
#[schema(example = json!(Option::None::<ToolChoice>))]
pub tool_choice: Option<ToolChoice>,
#[schema(example = json!(Option::None::<ResponseFormat>))]
pub response_format: Option<ResponseFormat>,

// mistral.rs additional
#[schema(example = json!(Option::None::<usize>))]
Expand Down
4 changes: 2 additions & 2 deletions mistralrs/examples/tools/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ async fn main() -> Result<()> {

let message = &response.choices[0].message;

if !message.tool_calls.is_empty() {
let called = &message.tool_calls[0];
if let Some(tool_calls) = &message.tool_calls {
let called = &tool_calls[0];
if called.function.name == "get_weather" {
let input: GetWeatherInput = serde_json::from_str(&called.function.arguments)?;
println!("Called tool `get_weather` with arguments {input:?}");
Expand Down
Loading