diff --git a/Cargo.lock b/Cargo.lock index 1ef02ed58..0bba377aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -378,6 +378,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + [[package]] name = "bindgen" version = "0.69.5" @@ -932,6 +938,16 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.4.0" @@ -1753,7 +1769,7 @@ dependencies = [ "serde_json", "thiserror 2.0.12", "tokio", - "ureq", + "ureq 2.12.1", "windows-sys 0.59.0", ] @@ -1874,7 +1890,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "smallvec", + "smallvec 1.15.0", "tokio", "want", ] @@ -2027,7 +2043,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.15.0", "utf16_iter", "utf8_iter", "write16", @@ -2102,7 +2118,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.15.0", "utf8_iter", ] @@ -2855,7 +2871,7 @@ dependencies = [ "tar", "thiserror 1.0.69", "toml", - "ureq", + "ureq 2.12.1", "url", "uuid", "walkdir", @@ -3136,27 +3152,28 @@ dependencies = [ [[package]] name = "ort" -version = "2.0.0-rc.9" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52afb44b6b0cffa9bf45e4d37e5a4935b0334a51570658e279e9e3e6cf324aa5" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" dependencies = [ "half", "ndarray", "ort-sys", + "smallvec 2.0.0-alpha.10", "tracing", ] [[package]] name = "ort-sys" -version = "2.0.0-rc.9" +version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c41d7757331aef2d04b9cb09b45583a59217628beaf91895b7e76187b6e8c088" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" dependencies = [ "flate2", "pkg-config", "sha2", "tar", - "ureq", + "ureq 3.0.12", ] [[package]] @@ -3190,7 +3207,7 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.15.0", "windows-targets 0.52.6", ] @@ -3200,6 +3217,15 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -4251,6 +4277,12 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + [[package]] name = "socket2" version = "0.5.9" @@ -4507,6 +4539,7 @@ dependencies = [ "nohash-hasher", "num_cpus", "ort", + "ort-sys", "serde", "serde_json", "text-embeddings-backend-core", @@ -5101,7 +5134,7 @@ dependencies = [ "once_cell", "opentelemetry 0.22.0", "opentelemetry_sdk 0.22.1", - "smallvec", + "smallvec 1.15.0", "tracing", "tracing-core", "tracing-log 0.2.0", @@ -5119,7 +5152,7 @@ dependencies = [ "once_cell", "opentelemetry 0.23.0", "opentelemetry_sdk 0.23.0", - "smallvec", + "smallvec 1.15.0", "tracing", "tracing-core", "tracing-log 0.2.0", @@ -5162,7 +5195,7 @@ dependencies = [ "serde", "serde_json", "sharded-slab", - "smallvec", + "smallvec 1.15.0", "thread_local", "tracing", "tracing-core", @@ -5248,7 +5281,7 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" dependencies = [ - "smallvec", + "smallvec 1.15.0", ] [[package]] @@ -5300,6 +5333,37 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "ureq" +version = "3.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f0fde9bc91026e381155f8c67cb354bcd35260b2f4a29bcc84639f762760c39" +dependencies = [ + "base64 0.22.1", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pemfile", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs 0.26.11", +] + +[[package]] +name = "ureq-proto" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59db78ad1923f2b1be62b6da81fe80b173605ca0d57f85da2e005382adf693f7" +dependencies = [ + "base64 0.22.1", + "http 1.3.1", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.4" @@ -5317,6 +5381,12 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf16_iter" version = "1.0.5" @@ -5583,6 +5653,24 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-root-certs" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75c7f0ef91146ebfb530314f5f1d24528d7f0767efbfd31dce919275413e393e" +dependencies = [ + "webpki-root-certs 1.0.2", +] + +[[package]] +name = "webpki-root-certs" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4ffd8df1c57e87c325000a3d6ef93db75279dc3a231125aac571650f22b12a" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webpki-roots" version = "0.26.8" diff --git a/backends/ort/Cargo.toml b/backends/ort/Cargo.toml index 40ccc5605..e0fb2f524 100644 --- a/backends/ort/Cargo.toml +++ b/backends/ort/Cargo.toml @@ -10,7 +10,8 @@ anyhow = { workspace = true } nohash-hasher = { workspace = true } ndarray = "0.16.1" num_cpus = { workspace = true } -ort = { version = "2.0.0-rc.8", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] } +ort = { version = "2.0.0-rc.10", default-features = false, features = ["std", "download-binaries", "half", "onednn", "ndarray"] } +ort-sys = { version = "=2.0.0-rc.10", default-features = false } # https://github.com/pykeio/ort/issues/399 text-embeddings-backend-core = { path = "../core" } tracing = { workspace = true } thiserror = { workspace = true } diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs index add5b33dc..2d6384e38 100644 --- a/backends/ort/src/lib.rs +++ b/backends/ort/src/lib.rs @@ -1,17 +1,33 @@ use ndarray::{s, Axis}; use nohash_hasher::BuildNoHashHasher; use ort::session::{builder::GraphOptimizationLevel, Session}; +use serde::Deserialize; use std::collections::HashMap; use std::ops::{Div, Mul}; use std::path::Path; +use std::sync::Mutex; use text_embeddings_backend_core::{ Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, }; +#[derive(Debug, Clone, Deserialize)] +pub struct PastKeyValuesConfig { + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, +} + pub struct OrtBackend { - session: Session, + session: Mutex, + + token_type_ids: bool, + // NOTE: required since the key can either be `token_type_ids` or `input_type` + token_type_ids_key: String, + position_ids: bool, + past_key_values: bool, + past_key_values_config: Option, + pool: Pool, - type_id_name: Option, } impl OrtBackend { @@ -20,27 +36,24 @@ impl OrtBackend { dtype: String, model_type: ModelType, ) -> Result { - // Check dtype if dtype != "float32" { return Err(BackendError::Start(format!( - "DType {dtype} is not supported" + "Dtype {dtype} is not supported for `ort`, only float32." ))); }; - // Check model type let pool = match model_type { ModelType::Classifier => Pool::Cls, ModelType::Embedding(pool) => match pool { Pool::Splade => { return Err(BackendError::Start(format!( - "Pooling {pool} is not supported for this backend. Use `candle` backend instead." + "Pooling {pool} is not supported for `ort`, use `candle` instead." ))); } - pool => pool, + _ => pool, }, }; - // Get model path let onnx_path = { let default_path = model_path.join("model.onnx"); match default_path.exists() { @@ -49,7 +62,6 @@ impl OrtBackend { } }; - // Start onnx session let session = Session::builder() .s()? .with_intra_threads(num_cpus::get()) @@ -59,19 +71,58 @@ impl OrtBackend { .commit_from_file(onnx_path) .s()?; - // Check if the model requires type tokens - let mut type_id_name = None; + let mut token_type_ids = false; + let mut token_type_ids_key = String::from("token_type_ids"); + let mut position_ids = false; + let mut past_key_values = false; + for input in &session.inputs { - if &input.name == "token_type_ids" || &input.name == "input_type" { - type_id_name = Some(input.name.clone()); - break; + match input.name.as_str() { + "token_type_ids" | "input_type" => { + token_type_ids = true; + token_type_ids_key = String::from("token_type_ids"); + } + "position_ids" => { + position_ids = true; + } + name if name.starts_with("past_key_values.") => { + past_key_values = true; + } + // NOTE: no need to handle `inputs_ids` and `attention_mask` since those are always + // required + _ => {} } } + let past_key_values_config = match past_key_values { + true => { + let path = model_path.join("config.json"); + if !path.exists() { + return Err(BackendError::Start(format!( + "config.json not found at {:?}", + path + ))); + } + let content = std::fs::read_to_string(path).map_err(|e| { + BackendError::Start(format!("Failed to read config.json: {}", e)) + })?; + Some( + serde_json::from_str::(&content).map_err(|e| { + BackendError::Start(format!("Failed to parse config.json: {}", e)) + })?, + ) + } + false => None, + }; + Ok(Self { - session, + session: Mutex::new(session), + token_type_ids, + token_type_ids_key, + position_ids, + past_key_values, + past_key_values_config, pool, - type_id_name, }) } } @@ -96,14 +147,15 @@ impl Backend for OrtBackend { // Whether a least one of the request in the batch is padded let mut masking = false; - let (input_ids, type_ids, input_lengths, attention_mask) = { + let (input_ids, token_type_ids, input_lengths, attention_mask, position_ids) = { let elems = batch_size * max_length; if batch_size > 1 { // Prepare padded batch let mut input_ids = Vec::with_capacity(elems); - let mut type_ids = Vec::with_capacity(elems); + let mut token_type_ids = Vec::with_capacity(elems); let mut attention_mask = Vec::with_capacity(elems); + let mut position_ids = Vec::with_capacity(elems); let mut input_lengths = Vec::with_capacity(batch_size); for i in 0..batch_size { @@ -113,10 +165,11 @@ impl Backend for OrtBackend { input_lengths.push(seq_length as f32); // Copy values - for j in start..end { + for (pos, j) in (start..end).enumerate() { input_ids.push(batch.input_ids[j] as i64); - type_ids.push(batch.token_type_ids[j] as i64); + token_type_ids.push(batch.token_type_ids[j] as i64); attention_mask.push(1_i64); + position_ids.push(pos as i64); } // Add padding if needed @@ -124,22 +177,31 @@ impl Backend for OrtBackend { if padding > 0 { // Set bool to use attention mask masking = true; - for _ in 0..padding { + for pad_pos in 0..padding { input_ids.push(0); - type_ids.push(0); + token_type_ids.push(0); attention_mask.push(0_i64); + position_ids.push((seq_length + pad_pos) as i64); } } } - (input_ids, type_ids, input_lengths, attention_mask) + ( + input_ids, + token_type_ids, + input_lengths, + attention_mask, + position_ids, + ) } else { let attention_mask = vec![1_i64; elems]; + let position_ids: Vec = (0..max_length as i64).collect(); ( batch.input_ids.into_iter().map(|v| v as i64).collect(), batch.token_type_ids.into_iter().map(|v| v as i64).collect(), vec![batch.max_length as f32], attention_mask, + position_ids, ) } }; @@ -148,24 +210,63 @@ impl Backend for OrtBackend { let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?; let attention_mask = ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?; + let position_ids = + ndarray::Array2::from_shape_vec((batch_size, max_length), position_ids).e()?; let input_lengths = ndarray::Array1::from_vec(input_lengths); - // Create onnx inputs - let inputs = match self.type_id_name.as_ref() { - Some(type_id_name) => { - // Add type ids to inputs - let type_ids = - ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?; - ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()? + let inputs = { + let mut inputs = ort::inputs![ + "input_ids" => ort::value::Tensor::from_array(input_ids).e()?, + "attention_mask" => ort::value::Tensor::from_array(attention_mask.clone()).e()?, + ]; + + if self.token_type_ids { + let token_type_ids_tensor = + ndarray::Array2::from_shape_vec((batch_size, max_length), token_type_ids) + .e()?; + let token_type_ids_value = + ort::value::Tensor::from_array(token_type_ids_tensor).e()?; + inputs.push(( + self.token_type_ids_key.clone().into(), + token_type_ids_value.into(), + )); + } + + if self.position_ids { + let position_ids_value = ort::value::Tensor::from_array(position_ids).e()?; + inputs.push(("position_ids".into(), position_ids_value.into())); } - None => { - ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone()] - .e()? + + if self.past_key_values { + let config = self.past_key_values_config.as_ref().unwrap(); + let head_size = config.hidden_size / config.num_key_value_heads; + + for i in 0..config.num_hidden_layers { + let key_shape = (batch_size, config.num_key_value_heads, 0, head_size); + let value_shape = (batch_size, config.num_key_value_heads, 0, head_size); + + let empty_key = ndarray::Array4::::zeros(key_shape); + let empty_value = ndarray::Array4::::zeros(value_shape); + + let key_value = ort::value::Tensor::from_array(empty_key).e()?; + let value_value = ort::value::Tensor::from_array(empty_value).e()?; + inputs.push(( + format!("past_key_values.{}.key", i).into(), + key_value.into(), + )); + inputs.push(( + format!("past_key_values.{}.value", i).into(), + value_value.into(), + )); + } } + + inputs }; // Run model - let outputs = self.session.run(inputs).e()?; + let mut session = self.session.lock().unwrap(); + let outputs = session.run(inputs).e()?; // Get last_hidden_state ndarray let outputs = outputs @@ -173,9 +274,9 @@ impl Backend for OrtBackend { .or(outputs.get("token_embeddings")) .ok_or(BackendError::Inference(format!( "Unknown output keys: {:?}", - self.session.outputs + outputs )))? - .try_extract_tensor::() + .try_extract_array::() .e()? .to_owned(); @@ -202,7 +303,6 @@ impl Backend for OrtBackend { }; let pooled_embeddings = match self.pool { - // CLS pooling Pool::Cls => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(), Pool::LastToken => { let axis_len = outputs.len_of(Axis(1)); @@ -211,7 +311,6 @@ impl Backend for OrtBackend { .into_owned() .into_dyn() } - // Mean pooling Pool::Mean => { if masking { let mut attention_mask = attention_mask; @@ -302,14 +401,15 @@ impl Backend for OrtBackend { let batch_size = batch.len(); let max_length = batch.max_length as usize; - let (input_ids, type_ids, attention_mask) = { + let (input_ids, token_type_ids, attention_mask, position_ids) = { let elems = batch_size * max_length; if batch_size > 1 { // Prepare padded batch let mut input_ids = Vec::with_capacity(elems); - let mut type_ids = Vec::with_capacity(elems); + let mut token_type_ids = Vec::with_capacity(elems); let mut attention_mask = Vec::with_capacity(elems); + let mut position_ids = Vec::with_capacity(elems); for i in 0..batch_size { let start = batch.cumulative_seq_lengths[i] as usize; @@ -317,30 +417,34 @@ impl Backend for OrtBackend { let seq_length = (end - start) as u32; // Copy values - for j in start..end { + for (pos, j) in (start..end).enumerate() { input_ids.push(batch.input_ids[j] as i64); - type_ids.push(batch.token_type_ids[j] as i64); + token_type_ids.push(batch.token_type_ids[j] as i64); attention_mask.push(1_i64); + position_ids.push(pos as i64); } // Add padding if needed let padding = batch.max_length - seq_length; if padding > 0 { - for _ in 0..padding { + for pad_pos in 0..padding { input_ids.push(0); - type_ids.push(0); + token_type_ids.push(0); attention_mask.push(0_i64); + position_ids.push((seq_length + pad_pos) as i64); } } } - (input_ids, type_ids, attention_mask) + (input_ids, token_type_ids, attention_mask, position_ids) } else { let attention_mask = vec![1_i64; elems]; + let position_ids: Vec = (0..max_length as i64).collect(); ( batch.input_ids.into_iter().map(|v| v as i64).collect(), batch.token_type_ids.into_iter().map(|v| v as i64).collect(), attention_mask, + position_ids, ) } }; @@ -349,29 +453,65 @@ impl Backend for OrtBackend { let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?; let attention_mask = ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?; + let position_ids = + ndarray::Array2::from_shape_vec((batch_size, max_length), position_ids).e()?; + + let inputs = { + let mut inputs = ort::inputs![ + "input_ids" => ort::value::Tensor::from_array(input_ids).e()?, + "attention_mask" => ort::value::Tensor::from_array(attention_mask.clone()).e()?, + ]; + + if self.token_type_ids { + let token_type_ids_tensor = + ndarray::Array2::from_shape_vec((batch_size, max_length), token_type_ids) + .e()?; + let token_type_ids_value = + ort::value::Tensor::from_array(token_type_ids_tensor).e()?; + inputs.push(( + self.token_type_ids_key.clone().into(), + token_type_ids_value.into(), + )); + } - // Create onnx inputs - let inputs = match self.type_id_name.as_ref() { - Some(type_id_name) => { - // Add type ids to inputs - let type_ids = - ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?; - ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()? + if self.position_ids { + let position_ids_value = ort::value::Tensor::from_array(position_ids).e()?; + inputs.push(("position_ids".into(), position_ids_value.into())); } - None => { - ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone()] - .e()? + + if self.past_key_values { + let config = self.past_key_values_config.as_ref().unwrap(); + let head_size = config.hidden_size / config.num_key_value_heads; + + for i in 0..config.num_hidden_layers { + let key_shape = (batch_size, config.num_key_value_heads, 0, head_size); + let value_shape = (batch_size, config.num_key_value_heads, 0, head_size); + + let empty_key = ndarray::Array4::::zeros(key_shape); + let empty_value = ndarray::Array4::::zeros(value_shape); + + let key_value = ort::value::Tensor::from_array(empty_key).e()?; + let value_value = ort::value::Tensor::from_array(empty_value).e()?; + inputs.push(( + format!("past_key_values.{}.key", i).into(), + key_value.into(), + )); + inputs.push(( + format!("past_key_values.{}.value", i).into(), + value_value.into(), + )); + } } + + inputs }; // Run model - let outputs = self.session.run(inputs).e()?; + let mut session = self.session.lock().unwrap(); + let outputs = session.run(inputs).e()?; // Get last_hidden_state ndarray - let outputs = outputs["logits"] - .try_extract_tensor::() - .e()? - .to_owned(); + let outputs = outputs["logits"].try_extract_array::().e()?.to_owned(); let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());