diff --git a/README.md b/README.md index 238a0c958..f872c7f3f 100644 --- a/README.md +++ b/README.md @@ -250,7 +250,7 @@ Options: [env: DENSE_PATH=] --hf-token - Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` is set, the token will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private or gated models, and allows for a more permissive rate limiting + Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` are set, the token will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private or gated models, and allows for a more permissive rate limiting [env: HF_TOKEN=] diff --git a/backends/candle/src/layers/mod.rs b/backends/candle/src/layers/mod.rs index 1849d31b5..8f455b90d 100644 --- a/backends/candle/src/layers/mod.rs +++ b/backends/candle/src/layers/mod.rs @@ -14,4 +14,4 @@ pub use layer_norm::{LayerNorm, LayerNormNoBias}; pub use linear::{HiddenAct, Linear}; #[allow(unused_imports)] pub use rms_norm::RMSNorm; -pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling}; +pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeParameters, RopeScaling}; diff --git a/backends/candle/src/layers/rotary.rs b/backends/candle/src/layers/rotary.rs index 967d11c5f..b23ad0aac 100644 --- a/backends/candle/src/layers/rotary.rs +++ b/backends/candle/src/layers/rotary.rs @@ -1,6 +1,13 @@ use candle::{DType, Device, Result, Tensor, D}; use serde::Deserialize; +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct RopeParameters { + pub rope_theta: f32, + #[allow(unused)] + rope_type: String, +} + #[derive(Debug, Clone, PartialEq, Deserialize)] #[serde(untagged)] pub enum RopeScaling { diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 5396e5027..d44ff5e53 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -538,6 +538,7 @@ impl CandleBackend { rms_norm_eps: config.rms_norm_eps, model_type: config.model_type.clone(), rope_theta: config.rope_theta, + rope_parameters: config.rope_parameters, sliding_window: config.sliding_window, rope_scaling: config.rope_scaling, use_bidirectional_attention: config.use_bidirectional_attention, diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index c1a5d5112..014dbbca2 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -199,9 +199,18 @@ impl FlashGTEModel { Self::inner_load(vb.pp("new"), config) .or_else(|_| Self::inner_load(vb.clone(), config))?; + // NOTE: https://github.com/huggingface/transformers/pull/39847 + let rope_theta = match config.rope_theta { + Some(rope_theta) => rope_theta, + None => match &config.rope_parameters { + Some(rope_parameters) => rope_parameters.rope_theta, + None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"), + }, + }; + let inv_freqs = get_inv_freqs( layers[0].attention.attention_head_size, - config.rope_theta, + rope_theta, vb.device(), config.rope_scaling.as_ref(), )?; diff --git a/backends/candle/src/models/flash_mistral.rs b/backends/candle/src/models/flash_mistral.rs index 09956cdf7..3cc4b503a 100644 --- a/backends/candle/src/models/flash_mistral.rs +++ b/backends/candle/src/models/flash_mistral.rs @@ -268,9 +268,18 @@ impl FlashMistralModel { let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?; + // NOTE: https://github.com/huggingface/transformers/pull/39847 + let rope_theta = match config.rope_theta { + Some(rope_theta) => rope_theta, + None => match &config.rope_parameters { + Some(rope_parameters) => rope_parameters.rope_theta, + None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"), + }, + }; + let inv_freqs = get_inv_freqs( layers[0].attention.attention_head_size, - config.rope_theta, + rope_theta, vb.device(), config.rope_scaling.as_ref(), )?; diff --git a/backends/candle/src/models/flash_qwen2.rs b/backends/candle/src/models/flash_qwen2.rs index f00eb3f14..c133d92d7 100644 --- a/backends/candle/src/models/flash_qwen2.rs +++ b/backends/candle/src/models/flash_qwen2.rs @@ -1,9 +1,11 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{get_cos_sin, get_inv_freqs, index_select, HiddenAct, Linear, RMSNorm}; use crate::models::{Model, Qwen2Config}; + use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; use candle_rotary::apply_rotary_inplace; + use text_embeddings_backend_core::{Batch, ModelType, Pool}; struct Qwen2Attention { @@ -285,9 +287,18 @@ impl FlashQwen2Model { let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?; + // NOTE: https://github.com/huggingface/transformers/pull/39847 + let rope_theta = match config.rope_theta { + Some(rope_theta) => rope_theta, + None => match &config.rope_parameters { + Some(rope_parameters) => rope_parameters.rope_theta, + None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"), + }, + }; + let inv_freqs = get_inv_freqs( layers[0].attention.attention_head_size, - config.rope_theta, + rope_theta, vb.device(), None, )?; diff --git a/backends/candle/src/models/flash_qwen3.rs b/backends/candle/src/models/flash_qwen3.rs index efa96dfa0..29cc5c8b7 100644 --- a/backends/candle/src/models/flash_qwen3.rs +++ b/backends/candle/src/models/flash_qwen3.rs @@ -1,9 +1,11 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{get_cos_sin, get_inv_freqs, index_select, HiddenAct, Linear, RMSNorm}; use crate::models::{Model, Qwen3Config}; + use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; use candle_rotary::apply_rotary_inplace; + use text_embeddings_backend_core::{Batch, ModelType, Pool}; struct Qwen3Attention { @@ -353,9 +355,18 @@ impl FlashQwen3Model { None }; + // NOTE: https://github.com/huggingface/transformers/pull/39847 + let rope_theta = match config.rope_theta { + Some(rope_theta) => rope_theta, + None => match &config.rope_parameters { + Some(rope_parameters) => rope_parameters.rope_theta, + None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"), + }, + }; + let inv_freqs = get_inv_freqs( layers[0].attention.attention_head_size, - config.rope_theta, + rope_theta, vb.device(), None, )?; diff --git a/backends/candle/src/models/gemma3.rs b/backends/candle/src/models/gemma3.rs index 81485e0b1..90c52a00c 100644 --- a/backends/candle/src/models/gemma3.rs +++ b/backends/candle/src/models/gemma3.rs @@ -1,5 +1,6 @@ use crate::layers::{ apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, Linear, + RopeParameters, }; use crate::models::Model; @@ -23,9 +24,10 @@ pub struct Gemma3Config { pub query_pre_attn_scalar: usize, pub rms_norm_eps: f32, pub rope_local_base_freq: f32, - pub rope_theta: f32, + pub rope_theta: Option, + pub rope_parameters: Option, pub sliding_window: Option, - #[serde(rename(deserialize = "_sliding_window_pattern"))] + #[serde(rename = "_sliding_window_pattern")] pub sliding_window_pattern: usize, pub vocab_size: usize, } @@ -653,7 +655,16 @@ impl Gemma3Model { .head_dim .unwrap_or(config.hidden_size / config.num_attention_heads); - let inv_freqs = get_inv_freqs(rotary_dim, config.rope_theta, vb.device(), None)?; + // NOTE: https://github.com/huggingface/transformers/pull/39847 + let rope_theta = match config.rope_theta { + Some(rope_theta) => rope_theta, + None => match &config.rope_parameters { + Some(rope_parameters) => rope_parameters.rope_theta, + None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"), + }, + }; + + let inv_freqs = get_inv_freqs(rotary_dim, rope_theta, vb.device(), None)?; let rotary_cache = get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?; diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index d5cf34120..6387a4df9 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,12 +1,14 @@ use crate::layers::{ apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, LayerNorm, Linear, - RopeScaling, + RopeParameters, RopeScaling, }; use crate::models::{Model, PositionEmbeddingType}; + use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; + use text_embeddings_backend_core::{Batch, ModelType, Pool}; #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -22,7 +24,8 @@ pub struct GTEConfig { pub layer_norm_type: String, pub layer_norm_eps: f32, pub position_embedding_type: PositionEmbeddingType, - pub rope_theta: f32, + pub rope_theta: Option, + pub rope_parameters: Option, pub rope_scaling: Option, #[serde(default)] pub logn_attention_scale: bool, @@ -412,10 +415,19 @@ impl GTEModel { Self::inner_load(vb.pp("new"), config) .or_else(|_| Self::inner_load(vb.clone(), config))?; + // NOTE: https://github.com/huggingface/transformers/pull/39847 + let rope_theta = match config.rope_theta { + Some(rope_theta) => rope_theta, + None => match &config.rope_parameters { + Some(rope_parameters) => rope_parameters.rope_theta, + None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"), + }, + }; + let rotary_dim = encoder.layers[0].attention.attention_head_size; let inv_freqs = get_inv_freqs( rotary_dim, - config.rope_theta, + rope_theta, vb.device(), config.rope_scaling.as_ref(), )?; diff --git a/backends/candle/src/models/llama.rs b/backends/candle/src/models/llama.rs index dc993f808..0b7e26227 100644 --- a/backends/candle/src/models/llama.rs +++ b/backends/candle/src/models/llama.rs @@ -1,4 +1,5 @@ -use crate::layers::{HiddenAct, RopeScaling}; +use crate::layers::{HiddenAct, RopeParameters, RopeScaling}; + use serde::Deserialize; #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -14,7 +15,8 @@ pub struct LlamaConfig { pub initializer_range: f64, pub rms_norm_eps: f32, pub model_type: Option, - pub rope_theta: f32, + pub rope_theta: Option, + pub rope_parameters: Option, pub sliding_window: Option, pub rope_scaling: Option, #[serde(default)] diff --git a/backends/candle/src/models/mistral.rs b/backends/candle/src/models/mistral.rs index 8fa91d017..85a590896 100644 --- a/backends/candle/src/models/mistral.rs +++ b/backends/candle/src/models/mistral.rs @@ -1,4 +1,5 @@ -use crate::layers::{HiddenAct, RopeScaling}; +use crate::layers::{HiddenAct, RopeParameters, RopeScaling}; + use serde::Deserialize; #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -14,7 +15,8 @@ pub struct MistralConfig { pub initializer_range: f64, pub rms_norm_eps: f32, pub model_type: Option, - pub rope_theta: f32, + pub rope_theta: Option, + pub rope_parameters: Option, pub sliding_window: Option, pub rope_scaling: Option, #[serde(default)] diff --git a/backends/candle/src/models/qwen2.rs b/backends/candle/src/models/qwen2.rs index 10cbb76ec..a19727f70 100644 --- a/backends/candle/src/models/qwen2.rs +++ b/backends/candle/src/models/qwen2.rs @@ -1,4 +1,4 @@ -use crate::layers::HiddenAct; +use crate::layers::{HiddenAct, RopeParameters}; use serde::Deserialize; fn default_is_causal() -> bool { @@ -17,7 +17,8 @@ pub struct Qwen2Config { pub hidden_act: HiddenAct, pub max_position_embeddings: usize, pub rms_norm_eps: f32, - pub rope_theta: f32, + pub rope_theta: Option, + pub rope_parameters: Option, pub sliding_window: Option, pub use_sliding_window: bool, #[serde(default = "default_is_causal")] diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index 8c76aa090..f21094b9d 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -1,10 +1,13 @@ use crate::layers::{ apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, Linear, RMSNorm, + RopeParameters, }; use crate::models::Model; + use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; + use text_embeddings_backend_core::{Batch, ModelType, Pool}; #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -20,7 +23,8 @@ pub struct Qwen3Config { pub hidden_act: HiddenAct, pub max_position_embeddings: usize, pub rms_norm_eps: f32, - pub rope_theta: f32, + pub rope_theta: Option, + pub rope_parameters: Option, pub sliding_window: Option, pub use_sliding_window: bool, pub eos_token_id: usize, @@ -454,7 +458,16 @@ impl Qwen3Model { .head_dim .unwrap_or(config.hidden_size / config.num_attention_heads); - let inv_freqs = get_inv_freqs(rotary_dim, config.rope_theta, vb.device(), None)?; + // NOTE: https://github.com/huggingface/transformers/pull/39847 + let rope_theta = match config.rope_theta { + Some(rope_theta) => rope_theta, + None => match &config.rope_parameters { + Some(rope_parameters) => rope_parameters.rope_theta, + None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"), + }, + }; + + let inv_freqs = get_inv_freqs(rotary_dim, rope_theta, vb.device(), None)?; let rotary_cache = get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?; diff --git a/docs/source/en/cli_arguments.md b/docs/source/en/cli_arguments.md index 7d05a7f12..fea608657 100644 --- a/docs/source/en/cli_arguments.md +++ b/docs/source/en/cli_arguments.md @@ -134,7 +134,7 @@ Options: [env: DENSE_PATH=] --hf-token - Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` is set, the token will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private or gated models, and allows for a more permissive rate limiting + Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` are set, the token will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private or gated models, and allows for a more permissive rate limiting [env: HF_TOKEN=] diff --git a/router/src/lib.rs b/router/src/lib.rs index a05508184..21920f578 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -181,7 +181,7 @@ pub async fn run( || &config.model_type == "camembert" || &config.model_type == "roberta" { - config.pad_token_id + 1 + config.pad_token_id.unwrap_or(0) + 1 } else { 0 }; @@ -459,7 +459,7 @@ pub struct ModelConfig { #[serde(alias = "n_positions")] pub max_position_embeddings: usize, #[serde(default)] - pub pad_token_id: usize, + pub pad_token_id: Option, pub id2label: Option>, pub label2id: Option>, pub auto_map: Option>, diff --git a/router/src/main.rs b/router/src/main.rs index baabb5c49..4e14acffa 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -136,7 +136,7 @@ struct Args { #[redact(partial)] hf_api_token: Option, - /// Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` is set, the token + /// Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` are set, the token /// will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private /// or gated models, and allows for a more permissive rate limiting. #[clap(long, env, conflicts_with = "hf_api_token")]