Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
9e06476
Deps
EricLBuehler Jun 6, 2025
7bb5a64
Add conformer
EricLBuehler Jun 6, 2025
ce5aff6
Nemo loading
EricLBuehler Jun 6, 2025
408aae1
Position embeds
EricLBuehler Jun 6, 2025
a9e46ed
Load t5 attn bias
EricLBuehler Jun 6, 2025
0a36028
Attn and feed forward
EricLBuehler Jun 7, 2025
4cd9196
Add conv module and glu pointwise
EricLBuehler Jun 7, 2025
8244e07
Implement relative attn bias
EricLBuehler Jun 7, 2025
3853cfe
Add the forward methods
EricLBuehler Jun 7, 2025
d94134e
Add encoder embedding
EricLBuehler Jun 7, 2025
f645bf9
Fix oproj
EricLBuehler Jun 7, 2025
46d7ea4
Some loading
EricLBuehler Jun 7, 2025
c9ac339
Conformer loads!
EricLBuehler Jun 7, 2025
3907feb
Fully loading speech stack
EricLBuehler Jun 7, 2025
8e10e52
Merger
EricLBuehler Jun 7, 2025
d7ce884
Dont need that
EricLBuehler Jun 7, 2025
06f2cfe
First pass at audio processing
EricLBuehler Jun 8, 2025
6123f82
Read samples
EricLBuehler Jun 8, 2025
2ee2742
Optional
EricLBuehler Jun 8, 2025
2f4e8b3
Small loading fix
EricLBuehler Jun 8, 2025
d6f4e99
Runs but not correct yet
EricLBuehler Jun 8, 2025
8010d36
Improved audio processing?
EricLBuehler Jun 8, 2025
865f0c3
Works with this
EricLBuehler Jun 8, 2025
a0f9bfc
Fix t5 attn bias
EricLBuehler Jun 8, 2025
34ca7d4
It works!
EricLBuehler Jun 8, 2025
2d84c8c
Comment
EricLBuehler Jun 8, 2025
b8291e3
Use some other crates
EricLBuehler Jun 8, 2025
5083095
Clippy
EricLBuehler Jun 8, 2025
99bd576
Allow bf16 on metal
EricLBuehler Jun 8, 2025
04bbd4e
Add prefix_audio
EricLBuehler Jun 8, 2025
4ea5b25
Remove unused
EricLBuehler Jun 8, 2025
8250b33
Typo
EricLBuehler Jun 8, 2025
558eb99
User specified
EricLBuehler Jun 8, 2025
c85b087
Add audio url parsing
EricLBuehler Jun 8, 2025
6fb358e
AudioProjectionMode -> InputMode
EricLBuehler Jun 8, 2025
9acf1c4
Audio prefix caching
EricLBuehler Jun 8, 2025
f810079
Fix bug in audio prefix caching
EricLBuehler Jun 8, 2025
bc87555
Support both at the same time!
EricLBuehler Jun 8, 2025
e546c5b
Tweak logging
EricLBuehler Jun 8, 2025
b81ca71
Support stereo
EricLBuehler Jun 8, 2025
ff67bf8
Add mistralrs-audio
EricLBuehler Jun 9, 2025
167f1b5
Support batching
EricLBuehler Jun 9, 2025
f8f2a31
Add server and rust api example
EricLBuehler Jun 9, 2025
6480ac1
Add python api
EricLBuehler Jun 9, 2025
8316dd6
Fix add_multimodal_message
EricLBuehler Jun 9, 2025
53f3e39
Fix unfold for conformer
EricLBuehler Jun 9, 2025
38f7a8e
Streaming example
EricLBuehler Jun 9, 2025
29b146c
Add web chat support
EricLBuehler Jun 9, 2025
dd9031b
Add modalities registry
EricLBuehler Jun 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ extend-ignore-identifiers-re = [
"thr",
"nd",
"uneeded",
"tese"
"tese",
"seperable",
"Seperable",
]

[files]
Expand Down
77 changes: 77 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion examples/python/custom_tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ToolChoice,
)


def local_search(query: str):
results = []
for root, _, files in os.walk("."):
Expand Down Expand Up @@ -35,6 +36,7 @@ def tool_cb(name: str, args: dict) -> str:
return json.dumps(local_search(args.get("query", "")))
return ""


schema = json.dumps(
{
"type": "function",
Expand All @@ -51,7 +53,9 @@ def tool_cb(name: str, args: dict) -> str:
)

runner = Runner(
which=Which.Plain(model_id="NousResearch/Hermes-3-Llama-3.1-8B", arch=Architecture.Llama),
which=Which.Plain(
model_id="NousResearch/Hermes-3-Llama-3.1-8B", arch=Architecture.Llama
),
tool_callbacks={"local_search": tool_cb},
)

Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ ahash.workspace = true
num-traits.workspace = true
libc.workspace = true
bm25.workspace = true
rubato = "0.16.2"
rustfft = "6.3.0"
hound = "3.5.1"
apodize = "1.0.0"

[features]
pyo3_macros = ["pyo3"]
Expand Down
13 changes: 13 additions & 0 deletions mistralrs-core/src/engine/add_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,21 @@ impl Engine {
ref images,
messages: _,
enable_thinking: _,
audios: _,
} => Some(images.clone()),
_ => None,
};

let audios = match request.messages {
RequestMessage::VisionChat {
images: _,
messages: _,
enable_thinking: _,
ref audios,
} => Some(audios.clone()),
_ => None,
};

let matcher = Arc::new(handle_seq_error!(
ToolCallingMatcher::new(request.tool_choice.unwrap_or(ToolChoice::Auto),),
request.response
Expand Down Expand Up @@ -157,6 +168,7 @@ impl Engine {
}
| RequestMessage::VisionChat {
images: _,
audios: _,
messages,
enable_thinking,
} => {
Expand Down Expand Up @@ -497,6 +509,7 @@ impl Engine {
None
},
images.clone(),
audios.clone(),
block_size,
Some(matcher.clone()),
image_generation_format,
Expand Down
60 changes: 59 additions & 1 deletion mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use candle_core::{
Context, DType, Device, IndexOp, Result, Tensor, D,
};
use candle_nn::{
Conv2d, Conv2dConfig, Embedding, GroupNorm, LayerNorm, LayerNormConfig, Linear, Module,
BatchNorm, BatchNormConfig, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Embedding, GroupNorm,
LayerNorm, LayerNormConfig, Linear, Module,
};
use float8::F8E4M3;
use half::{bf16, f16};
Expand Down Expand Up @@ -67,6 +68,34 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
}
}

pub fn batch_norm<C: Into<BatchNormConfig>>(
num_features: usize,
config: C,
vb: ShardedVarBuilder,
) -> Result<BatchNorm> {
let config = config.into();
if config.eps < 0. {
candle_core::bail!("batch-norm eps cannot be negative {}", config.eps)
}
let running_mean = vb.get(num_features, "running_mean")?;
let running_var = vb.get(num_features, "running_var")?;

if config.affine {
let weight = vb.get(num_features, "weight")?;
let bias = vb.get(num_features, "bias")?;
BatchNorm::new(
num_features,
running_mean,
running_var,
weight,
bias,
config.eps,
)
} else {
BatchNorm::new_no_bias(num_features, running_mean, running_var, config.eps)
}
}

pub fn group_norm(
num_groups: usize,
num_channels: usize,
Expand Down Expand Up @@ -117,6 +146,35 @@ pub fn conv2d_no_bias(
Ok(Conv2d::new(ws, None, cfg))
}

pub fn conv1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv1dConfig,
vb: ShardedVarBuilder,
) -> Result<Conv1d> {
let ws = vb.get(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
)?;
let bs = vb.get(out_channels, "bias")?;
Ok(Conv1d::new(ws, Some(bs), cfg))
}

pub fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv1dConfig,
vb: ShardedVarBuilder,
) -> Result<Conv1d> {
let ws = vb.get(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
)?;
Ok(Conv1d::new(ws, None, cfg))
}

pub fn linear(in_dim: usize, out_dim: usize, vb: ShardedVarBuilder) -> Result<Linear> {
let ws = vb.get((out_dim, in_dim), "weight")?;
let bs = vb.get(out_dim, "bias")?;
Expand Down
18 changes: 10 additions & 8 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,18 @@ pub use pipeline::{
DiffusionLoaderBuilder, DiffusionLoaderType, GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig,
GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, Idefics2Loader,
IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths,
LoraAdapterPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader,
NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
SpeechLoader, SpeechPipeline, Starcoder2Loader, TokenSource, VisionLoader, VisionLoaderBuilder,
VisionLoaderType, VisionPromptPrefixer, VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
LoraAdapterPaths, MistralLoader, MixtralLoader, ModelKind, ModelPaths,
MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
UQFF_MULTI_FILE_DELIMITER,
};
pub use request::{
ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
LlguidanceGrammar, MessageContent, NormalRequest, Request, RequestMessage, SearchContextSize,
TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
ApproximateUserLocation, AudioInput, Constraint, DetokenizationRequest,
ImageGenerationResponseFormat, LlguidanceGrammar, MessageContent, NormalRequest, Request,
RequestMessage, SearchContextSize, TokenizationRequest, WebSearchOptions,
WebSearchUserLocation,
};
pub use response::*;
pub use sampler::{
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ fn new_dummy_seq(
None,
None,
images,
None,
None, // TODO incorrect for PagedAttention
None,
None,
Expand Down
Loading
Loading