Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions mistralrs/examples/uqff/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use anyhow::Result;
use mistralrs::{
IsqType, PagedAttentionMetaBuilder, RequestBuilder, TextMessageRole, TextMessages,
UqffTextModelBuilder,
};

#[tokio::main]
async fn main() -> Result<()> {
let model = UqffTextModelBuilder::new(
"EricB/Phi-3.5-mini-instruct-UQFF",
"phi3.5-mini-instruct-q8_0.uqff".into(),
)
.into_inner()
.with_isq(IsqType::Q8_0)
.with_logging()
.with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?
.build()
.await?;

let messages = TextMessages::new()
.add_message(
TextMessageRole::System,
"You are an AI agent with a specialty in programming.",
)
.add_message(
TextMessageRole::User,
"Hello! How are you? Please write generic binary search function in Rust.",
);

let response = model.send_chat_request(messages).await?;

println!("{}", response.choices[0].message.content.as_ref().unwrap());
dbg!(
response.usage.avg_prompt_tok_per_sec,
response.usage.avg_compl_tok_per_sec
);

// Next example: Return some logprobs with the `RequestBuilder`, which enables higher configurability.
let request = RequestBuilder::new().return_logprobs(true).add_message(
TextMessageRole::User,
"Please write a mathematical equation where a few numbers are added.",
);

let response = model.send_chat_request(request).await?;

println!(
"Logprobs: {:?}",
&response.choices[0]
.logprobs
.as_ref()
.unwrap()
.content
.as_ref()
.unwrap()[0..3]
);

Ok(())
}
43 changes: 43 additions & 0 deletions mistralrs/examples/uqff_vision/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use anyhow::Result;
use mistralrs::{
IsqType, TextMessageRole, UqffVisionModelBuilder, VisionLoaderType, VisionMessages,
};

#[tokio::main]
async fn main() -> Result<()> {
let model = UqffVisionModelBuilder::new(
"EricB/Phi-3.5-vision-instruct-UQFF",
VisionLoaderType::Phi3V,
"phi3.5-vision-instruct-q8_0.uqff".into(),
)
.into_inner()
.with_isq(IsqType::Q4K)
.with_logging()
.build()
.await?;

let bytes = match reqwest::blocking::get(
"https://cdn.britannica.com/45/5645-050-B9EC0205/head-treasure-flower-disk-flowers-inflorescence-ray.jpg",
) {
Ok(http_resp) => http_resp.bytes()?.to_vec(),
Err(e) => anyhow::bail!(e),
};
let image = image::load_from_memory(&bytes)?;

let messages = VisionMessages::new().add_image_message(
TextMessageRole::User,
"What is depicted here? Please describe the scene in detail.",
image,
&model,
)?;

let response = model.send_chat_request(messages).await?;

println!("{}", response.choices[0].message.content.as_ref().unwrap());
dbg!(
response.usage.avg_prompt_tok_per_sec,
response.usage.avg_compl_tok_per_sec
);

Ok(())
}
6 changes: 4 additions & 2 deletions mistralrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ pub mod v0_4_api {
};
pub use super::model::{best_device, Model};
pub use super::speculative::TextSpeculativeBuilder;
pub use super::text_model::{PagedAttentionMetaBuilder, TextModelBuilder};
pub use super::vision_model::VisionModelBuilder;
pub use super::text_model::{
PagedAttentionMetaBuilder, TextModelBuilder, UqffTextModelBuilder,
};
pub use super::vision_model::{UqffVisionModelBuilder, VisionModelBuilder};
pub use super::xlora_model::XLoraModelBuilder;
}

Expand Down
54 changes: 53 additions & 1 deletion mistralrs/src/text_model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use mistralrs_core::*;
use std::{num::NonZeroUsize, path::PathBuf};
use std::{
num::NonZeroUsize,
ops::{Deref, DerefMut},
path::PathBuf,
};

use crate::{best_device, Model};

Expand Down Expand Up @@ -318,3 +322,51 @@ impl TextModelBuilder {
Ok(Model::new(runner.build()))
}
}

#[derive(Clone)]
/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
/// This wraps and implements `DerefMut` for the TextModelBuilder, so users should take care to not call UQFF-related methods.
pub struct UqffTextModelBuilder(TextModelBuilder);

impl UqffTextModelBuilder {
/// A few defaults are applied here:
/// - MoQE ISQ organization
/// - Token source is from the cache (.cache/huggingface/token)
/// - Maximum number of sequences running is 32
/// - Number of sequences to hold in prefix cache is 16.
/// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
pub fn new(model_id: impl ToString, uqff_file: PathBuf) -> Self {
let mut inner = TextModelBuilder::new(model_id);
inner = inner.from_uqff(uqff_file);
Self(inner)
}

pub async fn build(self) -> anyhow::Result<Model> {
self.0.build().await
}

/// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
pub fn into_inner(self) -> TextModelBuilder {
self.0
}
}

impl Deref for UqffTextModelBuilder {
type Target = TextModelBuilder;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl DerefMut for UqffTextModelBuilder {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl From<UqffTextModelBuilder> for TextModelBuilder {
fn from(value: UqffTextModelBuilder) -> Self {
value.0
}
}
52 changes: 51 additions & 1 deletion mistralrs/src/vision_model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use mistralrs_core::*;
use std::{num::NonZeroUsize, path::PathBuf};
use std::{
num::NonZeroUsize,
ops::{Deref, DerefMut},
path::PathBuf,
};

use crate::{best_device, Model};

Expand Down Expand Up @@ -215,3 +219,49 @@ impl VisionModelBuilder {
Ok(Model::new(runner.build()))
}
}

#[derive(Clone)]
/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
/// This wraps and implements `DerefMut` for the VisionModelBuilder, so users should take care to not call UQFF-related methods.
pub struct UqffVisionModelBuilder(VisionModelBuilder);

impl UqffVisionModelBuilder {
/// A few defaults are applied here:
/// - Token source is from the cache (.cache/huggingface/token)
/// - Maximum number of sequences running is 32
/// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
pub fn new(model_id: impl ToString, loader_type: VisionLoaderType, uqff_file: PathBuf) -> Self {
let mut inner = VisionModelBuilder::new(model_id, loader_type);
inner = inner.from_uqff(uqff_file);
Self(inner)
}

pub async fn build(self) -> anyhow::Result<Model> {
self.0.build().await
}

/// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
pub fn into_inner(self) -> VisionModelBuilder {
self.0
}
}

impl Deref for UqffVisionModelBuilder {
type Target = VisionModelBuilder;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl DerefMut for UqffVisionModelBuilder {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl From<UqffVisionModelBuilder> for VisionModelBuilder {
fn from(value: UqffVisionModelBuilder) -> Self {
value.0
}
}
Loading