Skip to content

Commit b4b5a71

Browse files
author
Eric Buehler
committed
Update loader
1 parent e47e246 commit b4b5a71

File tree

5 files changed

+192
-9
lines changed

5 files changed

+192
-9
lines changed

mistralrs-core/src/models/smollm3.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,6 @@ impl SmolLm3 {
490490
})
491491
}
492492

493-
pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
494-
self.wte.forward(input_ids)
495-
}
496-
497493
pub fn forward(
498494
&self,
499495
input_ids: &Tensor,

mistralrs-core/src/pipeline/loaders/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub use normal_loaders::{
2323
AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
2424
LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata,
2525
NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader,
26-
Qwen3Loader, Qwen3MoELoader, Starcoder2Loader,
26+
Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
2727
};
2828

2929
pub use vision_loaders::{

mistralrs-core/src/pipeline/loaders/normal_loaders.rs

Lines changed: 187 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ pub enum NormalLoaderType {
169169
GLM4,
170170
#[serde(rename = "qwen3moe")]
171171
Qwen3Moe,
172+
#[serde(rename = "smollm3")]
173+
SmolLm3,
172174
}
173175

174176
// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
@@ -190,6 +192,7 @@ impl NormalLoaderType {
190192
"Qwen3ForCausalLM" => Ok(Self::Qwen3),
191193
"Glm4ForCausalLM" => Ok(Self::GLM4),
192194
"Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
195+
"SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
193196
other => anyhow::bail!(
194197
"Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
195198
),
@@ -216,7 +219,8 @@ impl FromStr for NormalLoaderType {
216219
"qwen3" => Ok(Self::Qwen3),
217220
"glm4" => Ok(Self::GLM4),
218221
"qwen3moe" => Ok(Self::Qwen3Moe),
219-
a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`.")),
222+
"smollm3" => Ok(Self::SmolLm3),
223+
a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`, `smollm3`.")),
220224
}
221225
}
222226
}
@@ -239,6 +243,7 @@ impl Display for NormalLoaderType {
239243
Self::Qwen3 => write!(f, "qwen3"),
240244
Self::GLM4 => write!(f, "glm4"),
241245
Self::Qwen3Moe => write!(f, "qwen3moe"),
246+
Self::SmolLm3 => write!(f, "smollm3"),
242247
}
243248
}
244249
}
@@ -290,6 +295,7 @@ impl AutoNormalLoader {
290295
NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
291296
NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
292297
NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
298+
NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
293299
}
294300
}
295301
}
@@ -3526,3 +3532,183 @@ impl DeviceMappedModelLoader for Qwen3MoELoader {
35263532
Ok(Box::new(cfg))
35273533
}
35283534
}
3535+
3536+
// ======================== SmolLm3 loader
3537+
3538+
/// [`NormalLoader`] for a SmolLm3 model.
3539+
///
3540+
/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
3541+
pub struct SmolLm3Loader;
3542+
3543+
impl NormalModelLoader for SmolLm3Loader {
3544+
fn load(
3545+
&self,
3546+
config: &str,
3547+
vb: ShardedVarBuilder,
3548+
normal_loading_metadata: NormalLoadingMetadata,
3549+
attention_mechanism: AttentionImplementation,
3550+
) -> Result<Box<dyn NormalModel + Send + Sync>> {
3551+
let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3552+
3553+
Ok(Box::new(models::smollm3::SmolLm3::new(
3554+
&cfg,
3555+
vb,
3556+
self.is_gptx(config)?,
3557+
normal_loading_metadata,
3558+
attention_mechanism,
3559+
)?))
3560+
}
3561+
fn load_xlora(
3562+
&self,
3563+
_config: &str,
3564+
_vb: ShardedVarBuilder,
3565+
_lora_config: &[((String, String), LoraConfig)],
3566+
_xlora_config: Option<XLoraConfig>,
3567+
_xlora_ordering: Ordering,
3568+
_normal_loading_metadata: NormalLoadingMetadata,
3569+
_preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3570+
) -> Result<Box<dyn NormalModel + Send + Sync>> {
3571+
todo!()
3572+
}
3573+
fn is_gptx(&self, _: &str) -> Result<bool> {
3574+
Ok(true)
3575+
}
3576+
fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3577+
let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3578+
Ok(Box::new(cfg))
3579+
}
3580+
}
3581+
3582+
impl IsqModelLoader for SmolLm3Loader {
3583+
fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3584+
Ok(vec![
3585+
Regex::new(r"lm_head\.(weight|bias)$")?,
3586+
// Attention
3587+
Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3588+
Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3589+
Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3590+
Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3591+
// MLP
3592+
Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3593+
Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3594+
Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3595+
])
3596+
}
3597+
fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3598+
self.isq_layer_regexes(config)
3599+
}
3600+
}
3601+
3602+
impl DeviceMappedModelLoader for SmolLm3Loader {
3603+
fn mapped_max_act_size_elems(
3604+
&self,
3605+
config: &str,
3606+
params: &AutoDeviceMapParams,
3607+
prompt_chunksize: usize,
3608+
) -> Result<usize> {
3609+
let AutoDeviceMapParams::Text {
3610+
max_seq_len: _,
3611+
max_batch_size,
3612+
} = params
3613+
else {
3614+
anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3615+
};
3616+
3617+
let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3618+
3619+
Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3620+
}
3621+
fn non_mapped_max_act_size_elems(
3622+
&self,
3623+
_config: &str,
3624+
_params: &AutoDeviceMapParams,
3625+
) -> Result<usize> {
3626+
Ok(0)
3627+
}
3628+
3629+
fn non_mapped_size_in_bytes(
3630+
&self,
3631+
config: &str,
3632+
dtype: DType,
3633+
weight_pack_factor: usize,
3634+
) -> Result<usize> {
3635+
let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3636+
3637+
let elems = {
3638+
let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3639+
// If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3640+
let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3641+
cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3642+
} else {
3643+
0
3644+
};
3645+
let norm = cfg.hidden_size;
3646+
embed_tokens + lm_head + norm
3647+
};
3648+
Ok(elems * dtype.size_in_bytes())
3649+
}
3650+
3651+
fn layer_sizes_in_bytes(
3652+
&self,
3653+
config: &str,
3654+
dtype: DType,
3655+
weight_pack_factor: usize,
3656+
) -> Result<Vec<usize>> {
3657+
let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
3658+
3659+
let per_layer_elems = {
3660+
let input_layernorm = cfg.hidden_size;
3661+
let post_attention_layernorm = cfg.hidden_size;
3662+
3663+
let size_in = cfg.hidden_size;
3664+
let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3665+
let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3666+
let q_proj = size_in * size_q / weight_pack_factor;
3667+
let k_proj = size_in * size_kv / weight_pack_factor;
3668+
let v_proj = size_in * size_kv / weight_pack_factor;
3669+
let o_proj = size_q * size_in / weight_pack_factor;
3670+
3671+
let h_size = cfg.hidden_size;
3672+
let i_size = cfg.intermediate_size;
3673+
let gate_proj = h_size * i_size / weight_pack_factor;
3674+
let up_proj = h_size * i_size / weight_pack_factor;
3675+
let down_proj = i_size * h_size / weight_pack_factor;
3676+
3677+
input_layernorm
3678+
+ post_attention_layernorm
3679+
+ q_proj
3680+
+ k_proj
3681+
+ v_proj
3682+
+ o_proj
3683+
+ gate_proj
3684+
+ up_proj
3685+
+ down_proj
3686+
};
3687+
Ok(vec![
3688+
per_layer_elems * dtype.size_in_bytes();
3689+
cfg.num_hidden_layers
3690+
])
3691+
}
3692+
3693+
fn num_layers(&self, config: &str) -> Result<usize> {
3694+
let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
3695+
3696+
Ok(cfg.num_hidden_layers)
3697+
}
3698+
fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3699+
let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
3700+
3701+
let cfg = ModelConfigMetadata {
3702+
max_seq_len: cfg.max_position_embeddings,
3703+
num_layers: cfg.num_hidden_layers,
3704+
hidden_size: cfg.hidden_size,
3705+
num_kv_heads: cfg.num_key_value_heads,
3706+
num_attn_heads: cfg.num_attention_heads,
3707+
sliding_window: None,
3708+
k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3709+
v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3710+
};
3711+
3712+
Ok(Box::new(cfg))
3713+
}
3714+
}

mistralrs-core/src/pipeline/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ pub use loaders::{
4343
ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader,
4444
Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName,
4545
QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3Loader, Qwen3MoELoader,
46-
Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
47-
VisionModelLoader,
46+
SmolLm3Loader, Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType,
47+
VisionModel, VisionModelLoader,
4848
};
4949
use mistralrs_quant::IsqType;
5050
pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};

mistralrs-core/src/pipeline/normal.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use super::{
1313
use super::{
1414
AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, Gemma2Loader, GemmaLoader,
1515
LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader,
16-
Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Starcoder2Loader,
16+
Phi3_5MoELoader, Qwen2Loader, Qwen3Loader, Qwen3MoELoader, SmolLm3Loader, Starcoder2Loader,
1717
};
1818
use crate::amoe::AnyMoeExpertType;
1919
use crate::device_map::{self, DeviceMapper};
@@ -224,6 +224,7 @@ impl NormalLoaderBuilder {
224224
Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
225225
Some(NormalLoaderType::GLM4) => Box::new(GLM4Loader),
226226
Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
227+
Some(NormalLoaderType::SmolLm3) => Box::new(SmolLm3Loader),
227228
None => Box::new(AutoNormalLoader),
228229
};
229230
Ok(Box::new(NormalLoader {

0 commit comments

Comments
 (0)