Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ nomic-v2-moe = ["dep:candle-core", "dep:candle-nn", "hf-hub"]
image-models = ["dep:image"]
mkl = ["qwen3", "nomic-v2-moe", "dep:intel-mkl-src", "candle-nn/mkl", "candle-core/mkl"]
accelerate = ["qwen3", "nomic-v2-moe", "dep:accelerate-src", "candle-core/accelerate", "candle-nn/accelerate"]
directml = ["ort/directml"]
cuda = ["qwen3", "nomic-v2-moe", "candle-core/cuda", "candle-nn/cuda"]
cudnn = ["qwen3", "nomic-v2-moe", "candle-core/cudnn", "candle-nn/cudnn", "cuda"]
metal = ["qwen3", "nomic-v2-moe", "candle-core/metal", "candle-nn/metal"]
Expand Down
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,29 @@ println!("Rerank result: {:?}", results);

Alternatively, local model files can be used for inference via the `try_new_from_user_defined(...)` methods of respective structs.

### DirectML (Windows)

To run models on a GPU via DirectML on Windows, enable the `directml` feature:

```toml
[dependencies]
fastembed = { version = "5", features = ["directml"] }
```

Then pass a DirectML execution provider when initializing a model:

```rust
use fastembed::{TextEmbedding, InitOptions, EmbeddingModel};
use ort::ep::DirectML;

let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::AllMiniLML6V2)
.with_execution_providers(vec![DirectML::default().into()]),
)?;
```

When DirectML is detected, fastembed automatically disables memory pattern optimization and parallel execution on the ONNX Runtime session, as required by the DirectML execution provider.

## LICENSE

[Apache 2.0](https://github.com/Anush008/fastembed-rs/blob/main/LICENSE)
40 changes: 34 additions & 6 deletions src/image_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,25 @@ impl ImageEmbedding {
.get(&model_file_name)
.context(format!("Failed to retrieve {}", model_file_name))?;

let session = Session::builder()?
#[cfg(feature = "directml")]
let has_directml = execution_providers
.iter()
.any(|ep| ep.downcast_ref::<ort::ep::DirectML>().is_some());
#[cfg(not(feature = "directml"))]
let has_directml = false;

let mut builder = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?
.commit_from_file(model_file_reference)?;
.with_intra_threads(threads)?;

if has_directml {
builder = builder
.with_memory_pattern(false)?
.with_parallel_execution(false)?;
}

let session = builder.commit_from_file(model_file_reference)?;

Ok(Self::new(preprocessor, session))
}
Expand All @@ -83,11 +97,25 @@ impl ImageEmbedding {

let preprocessor = Compose::from_bytes(model.preprocessor_file)?;

let session = Session::builder()?
#[cfg(feature = "directml")]
let has_directml = execution_providers
.iter()
.any(|ep| ep.downcast_ref::<ort::ep::DirectML>().is_some());
#[cfg(not(feature = "directml"))]
let has_directml = false;

let mut builder = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?
.commit_from_memory(&model.onnx_file)?;
.with_intra_threads(threads)?;

if has_directml {
builder = builder
.with_memory_pattern(false)?
.with_parallel_execution(false)?;
}

let session = builder.commit_from_memory(&model.onnx_file)?;

Ok(Self::new(preprocessor, session))
}
Expand Down
33 changes: 30 additions & 3 deletions src/text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,25 @@ impl TextEmbedding {
// prioritise loading pooling config if available, if not (thanks qdrant!), look for it in hardcoded
let post_processing = TextEmbedding::get_default_pooling_method(&model_name);

let session = Session::builder()?
#[cfg(feature = "directml")]
let has_directml = execution_providers
.iter()
.any(|ep| ep.downcast_ref::<ort::ep::DirectML>().is_some());
#[cfg(not(feature = "directml"))]
let has_directml = false;

let mut builder = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?
.commit_from_file(model_file_reference)?;
.with_intra_threads(threads)?;

if has_directml {
builder = builder
.with_memory_pattern(false)?
.with_parallel_execution(false)?;
}

let session = builder.commit_from_file(model_file_reference)?;

let tokenizer = load_tokenizer_hf_hub(model_repo, max_length)?;
Ok(Self::new(
Expand All @@ -100,12 +114,25 @@ impl TextEmbedding {

let threads = available_parallelism()?.get();

#[cfg(feature = "directml")]
let has_directml = execution_providers
.iter()
.any(|ep| ep.downcast_ref::<ort::ep::DirectML>().is_some());
#[cfg(not(feature = "directml"))]
let has_directml = false;

let session = {
let mut session_builder = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?;

if has_directml {
session_builder = session_builder
.with_memory_pattern(false)?
.with_parallel_execution(false)?;
}

for external_initializer_file in model.external_initializers {
session_builder = session_builder.with_external_initializer_file_in_memory(
external_initializer_file.file_name,
Expand Down
Loading