Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ Options:
[env: DENSE_PATH=]

--hf-token <HF_TOKEN>
Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` is set, the token will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private or gated models, and allows for a more permissive rate limiting
Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` are set, the token will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private or gated models, and allows for a more permissive rate limiting

[env: HF_TOKEN=]

Expand Down
2 changes: 1 addition & 1 deletion backends/candle/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ pub use layer_norm::{LayerNorm, LayerNormNoBias};
pub use linear::{HiddenAct, Linear};
#[allow(unused_imports)]
pub use rms_norm::RMSNorm;
pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling};
pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeParameters, RopeScaling};
7 changes: 7 additions & 0 deletions backends/candle/src/layers/rotary.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
use candle::{DType, Device, Result, Tensor, D};
use serde::Deserialize;

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct RopeParameters {
pub rope_theta: f32,
#[allow(unused)]
rope_type: String,
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
#[serde(untagged)]
pub enum RopeScaling {
Expand Down
1 change: 1 addition & 0 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ impl CandleBackend {
rms_norm_eps: config.rms_norm_eps,
model_type: config.model_type.clone(),
rope_theta: config.rope_theta,
rope_parameters: config.rope_parameters,
sliding_window: config.sliding_window,
rope_scaling: config.rope_scaling,
use_bidirectional_attention: config.use_bidirectional_attention,
Expand Down
11 changes: 10 additions & 1 deletion backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,18 @@ impl FlashGTEModel {
Self::inner_load(vb.pp("new"), config)
.or_else(|_| Self::inner_load(vb.clone(), config))?;

// NOTE: https://github.com/huggingface/transformers/pull/39847
let rope_theta = match config.rope_theta {
Some(rope_theta) => rope_theta,
None => match &config.rope_parameters {
Some(rope_parameters) => rope_parameters.rope_theta,
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"),
},
};

let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
rope_theta,
vb.device(),
config.rope_scaling.as_ref(),
)?;
Expand Down
11 changes: 10 additions & 1 deletion backends/candle/src/models/flash_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,18 @@ impl FlashMistralModel {

let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;

// NOTE: https://github.com/huggingface/transformers/pull/39847
let rope_theta = match config.rope_theta {
Some(rope_theta) => rope_theta,
None => match &config.rope_parameters {
Some(rope_parameters) => rope_parameters.rope_theta,
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"),
},
};

let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
rope_theta,
vb.device(),
config.rope_scaling.as_ref(),
)?;
Expand Down
13 changes: 12 additions & 1 deletion backends/candle/src/models/flash_qwen2.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{get_cos_sin, get_inv_freqs, index_select, HiddenAct, Linear, RMSNorm};
use crate::models::{Model, Qwen2Config};

use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;

use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct Qwen2Attention {
Expand Down Expand Up @@ -285,9 +287,18 @@ impl FlashQwen2Model {

let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;

// NOTE: https://github.com/huggingface/transformers/pull/39847
let rope_theta = match config.rope_theta {
Some(rope_theta) => rope_theta,
None => match &config.rope_parameters {
Some(rope_parameters) => rope_parameters.rope_theta,
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"),
},
};

let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
rope_theta,
vb.device(),
None,
)?;
Expand Down
13 changes: 12 additions & 1 deletion backends/candle/src/models/flash_qwen3.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{get_cos_sin, get_inv_freqs, index_select, HiddenAct, Linear, RMSNorm};
use crate::models::{Model, Qwen3Config};

use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;

use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct Qwen3Attention {
Expand Down Expand Up @@ -353,9 +355,18 @@ impl FlashQwen3Model {
None
};

// NOTE: https://github.com/huggingface/transformers/pull/39847
let rope_theta = match config.rope_theta {
Some(rope_theta) => rope_theta,
None => match &config.rope_parameters {
Some(rope_parameters) => rope_parameters.rope_theta,
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"),
},
};

let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
rope_theta,
vb.device(),
None,
)?;
Expand Down
17 changes: 14 additions & 3 deletions backends/candle/src/models/gemma3.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::layers::{
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, Linear,
RopeParameters,
};
use crate::models::Model;

Expand All @@ -23,9 +24,10 @@ pub struct Gemma3Config {
pub query_pre_attn_scalar: usize,
pub rms_norm_eps: f32,
pub rope_local_base_freq: f32,
pub rope_theta: f32,
pub rope_theta: Option<f32>,
pub rope_parameters: Option<RopeParameters>,
pub sliding_window: Option<usize>,
#[serde(rename(deserialize = "_sliding_window_pattern"))]
#[serde(rename = "_sliding_window_pattern")]
pub sliding_window_pattern: usize,
pub vocab_size: usize,
}
Expand Down Expand Up @@ -653,7 +655,16 @@ impl Gemma3Model {
.head_dim
.unwrap_or(config.hidden_size / config.num_attention_heads);

let inv_freqs = get_inv_freqs(rotary_dim, config.rope_theta, vb.device(), None)?;
// NOTE: https://github.com/huggingface/transformers/pull/39847
let rope_theta = match config.rope_theta {
Some(rope_theta) => rope_theta,
None => match &config.rope_parameters {
Some(rope_parameters) => rope_parameters.rope_theta,
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"),
},
};

let inv_freqs = get_inv_freqs(rotary_dim, rope_theta, vb.device(), None)?;
let rotary_cache =
get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?;

Expand Down
18 changes: 15 additions & 3 deletions backends/candle/src/models/gte.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::layers::{
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, LayerNorm, Linear,
RopeScaling,
RopeParameters, RopeScaling,
};
use crate::models::{Model, PositionEmbeddingType};

use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;

use text_embeddings_backend_core::{Batch, ModelType, Pool};

#[derive(Debug, Clone, PartialEq, Deserialize)]
Expand All @@ -22,7 +24,8 @@ pub struct GTEConfig {
pub layer_norm_type: String,
pub layer_norm_eps: f32,
pub position_embedding_type: PositionEmbeddingType,
pub rope_theta: f32,
pub rope_theta: Option<f32>,
pub rope_parameters: Option<RopeParameters>,
pub rope_scaling: Option<RopeScaling>,
#[serde(default)]
pub logn_attention_scale: bool,
Expand Down Expand Up @@ -412,10 +415,19 @@ impl GTEModel {
Self::inner_load(vb.pp("new"), config)
.or_else(|_| Self::inner_load(vb.clone(), config))?;

// NOTE: https://github.com/huggingface/transformers/pull/39847
let rope_theta = match config.rope_theta {
Some(rope_theta) => rope_theta,
None => match &config.rope_parameters {
Some(rope_parameters) => rope_parameters.rope_theta,
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"),
},
};

let rotary_dim = encoder.layers[0].attention.attention_head_size;
let inv_freqs = get_inv_freqs(
rotary_dim,
config.rope_theta,
rope_theta,
vb.device(),
config.rope_scaling.as_ref(),
)?;
Expand Down
6 changes: 4 additions & 2 deletions backends/candle/src/models/llama.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::layers::{HiddenAct, RopeScaling};
use crate::layers::{HiddenAct, RopeParameters, RopeScaling};

use serde::Deserialize;

#[derive(Debug, Clone, PartialEq, Deserialize)]
Expand All @@ -14,7 +15,8 @@ pub struct LlamaConfig {
pub initializer_range: f64,
pub rms_norm_eps: f32,
pub model_type: Option<String>,
pub rope_theta: f32,
pub rope_theta: Option<f32>,
pub rope_parameters: Option<RopeParameters>,
pub sliding_window: Option<usize>,
pub rope_scaling: Option<RopeScaling>,
#[serde(default)]
Expand Down
6 changes: 4 additions & 2 deletions backends/candle/src/models/mistral.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::layers::{HiddenAct, RopeScaling};
use crate::layers::{HiddenAct, RopeParameters, RopeScaling};

use serde::Deserialize;

#[derive(Debug, Clone, PartialEq, Deserialize)]
Expand All @@ -14,7 +15,8 @@ pub struct MistralConfig {
pub initializer_range: f64,
pub rms_norm_eps: f32,
pub model_type: Option<String>,
pub rope_theta: f32,
pub rope_theta: Option<f32>,
pub rope_parameters: Option<RopeParameters>,
pub sliding_window: Option<usize>,
pub rope_scaling: Option<RopeScaling>,
#[serde(default)]
Expand Down
5 changes: 3 additions & 2 deletions backends/candle/src/models/qwen2.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::layers::HiddenAct;
use crate::layers::{HiddenAct, RopeParameters};
use serde::Deserialize;

fn default_is_causal() -> bool {
Expand All @@ -17,7 +17,8 @@ pub struct Qwen2Config {
pub hidden_act: HiddenAct,
pub max_position_embeddings: usize,
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub rope_theta: Option<f32>,
pub rope_parameters: Option<RopeParameters>,
pub sliding_window: Option<usize>,
pub use_sliding_window: bool,
#[serde(default = "default_is_causal")]
Expand Down
17 changes: 15 additions & 2 deletions backends/candle/src/models/qwen3.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::layers::{
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, Linear, RMSNorm,
RopeParameters,
};
use crate::models::Model;

use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;

use text_embeddings_backend_core::{Batch, ModelType, Pool};

#[derive(Debug, Clone, PartialEq, Deserialize)]
Expand All @@ -20,7 +23,8 @@ pub struct Qwen3Config {
pub hidden_act: HiddenAct,
pub max_position_embeddings: usize,
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub rope_theta: Option<f32>,
pub rope_parameters: Option<RopeParameters>,
pub sliding_window: Option<usize>,
pub use_sliding_window: bool,
pub eos_token_id: usize,
Expand Down Expand Up @@ -454,7 +458,16 @@ impl Qwen3Model {
.head_dim
.unwrap_or(config.hidden_size / config.num_attention_heads);

let inv_freqs = get_inv_freqs(rotary_dim, config.rope_theta, vb.device(), None)?;
// NOTE: https://github.com/huggingface/transformers/pull/39847
let rope_theta = match config.rope_theta {
Some(rope_theta) => rope_theta,
None => match &config.rope_parameters {
Some(rope_parameters) => rope_parameters.rope_theta,
None => candle::bail!("Neither `rope_theta` nor `rope_parameters.rope_theta` are defined in the `config.json`"),
},
};

let inv_freqs = get_inv_freqs(rotary_dim, rope_theta, vb.device(), None)?;

let rotary_cache =
get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?;
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/cli_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Options:
[env: DENSE_PATH=]

--hf-token <HF_TOKEN>
Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` is set, the token will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private or gated models, and allows for a more permissive rate limiting
Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` are set, the token will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private or gated models, and allows for a more permissive rate limiting

[env: HF_TOKEN=]

Expand Down
4 changes: 2 additions & 2 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ pub async fn run(
|| &config.model_type == "camembert"
|| &config.model_type == "roberta"
{
config.pad_token_id + 1
config.pad_token_id.unwrap_or(0) + 1
} else {
0
};
Expand Down Expand Up @@ -459,7 +459,7 @@ pub struct ModelConfig {
#[serde(alias = "n_positions")]
pub max_position_embeddings: usize,
#[serde(default)]
pub pad_token_id: usize,
pub pad_token_id: Option<usize>,
pub id2label: Option<HashMap<String, String>>,
pub label2id: Option<HashMap<String, usize>>,
pub auto_map: Option<HashMap<String, String>>,
Expand Down
2 changes: 1 addition & 1 deletion router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ struct Args {
#[redact(partial)]
hf_api_token: Option<String>,

/// Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` is set, the token
/// Your Hugging Face Hub token. If neither `--hf-token` nor `HF_TOKEN` are set, the token
/// will be read from the `$HF_HOME/token` path, if it exists. This ensures access to private
/// or gated models, and allows for a more permissive rate limiting.
#[clap(long, env, conflicts_with = "hf_api_token")]
Expand Down
Loading