Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions pipeless/src/stages/inference/onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub struct OnnxSessionParams {
execution_mode: Option<String>, // Parallel or sequential exeuction mode or onnx
inter_threads: Option<i16>, // If execution mode is Parallel (and nodes can be run in parallel), this sets the maximum number of threads to use to run them in parallel.
intra_threads: Option<i16>, // Number of threads to parallelize the execution within nodes
custom_op_lib_path: Option<String>, // Path to a custom op library
/*ir_version: Option<u32>,
opset_version: Option<u32>,
image_shape_format: Option<Vec<String>>,
Expand All @@ -20,13 +21,15 @@ impl OnnxSessionParams {
pub fn new(
stage_name: &str,
execution_provider: &str, execution_mode: Option<&str>,
inter_threads: Option<i16>, intra_threads: Option<i16>
inter_threads: Option<i16>, intra_threads: Option<i16>,
custom_op_lib_path: Option<&str>,
) -> Self {
Self {
stage_name: stage_name.to_string(),
execution_provider: execution_provider.to_string(),
execution_mode: execution_mode.map(|m| m.to_string()),
inter_threads, intra_threads,
custom_op_lib_path: custom_op_lib_path.map(|p| p.to_string()),
}
}
}
Expand Down Expand Up @@ -82,25 +85,41 @@ impl OnnxSession {
}
}

if let Some(lib_path) = onnx_params.custom_op_lib_path {
log::info!("Loading custom operations lib from: {}", lib_path);
session_builder = session_builder.with_custom_op_lib(&lib_path).unwrap();
}

let session = session_builder.with_model_from_file(model_file_path).unwrap();

// Run a first test inference that usually takes more time.
// This avoids to add an initial delay to the stream when it arrives, making the session ready
let input0_shape: Vec<usize> = session.inputs[0].dimensions()
.map(std::option::Option::unwrap)
.collect();
// Assuming the conventional input format: batch, channels, height, witdh
let batch_shift = if input0_shape.len() > 3 { 1 } else { 0 };
let width = input0_shape[2 + batch_shift];
let height = input0_shape[1 + batch_shift];
let channels = input0_shape[0 + batch_shift];
let test_image = ndarray::Array3::<u8>::zeros((channels, height, width)).into_dyn();
let cow_array = ndarray::CowArray::from(test_image);
let ort_input_value = ort::Value::from_array(
session.allocator(),
&cow_array
).unwrap();
let _ = session.run(vec![ort_input_value]);
let input0_shape: Vec<Option<usize>> = session.inputs[0].dimensions().map(|x| x).collect();
if input0_shape.len() > 2 {
// Assuming the conventional input format: batch, channels, height, witdh
let batch_shift = if input0_shape.len() > 3 { 1 } else { 0 };
let width = input0_shape[2 + batch_shift];
let height = input0_shape[1 + batch_shift];
let channels = input0_shape[0 + batch_shift];
if let (Some(width), Some(height), Some(channels)) = (width, height, channels) {
let test_image = ndarray::Array3::<u8>::zeros((channels, height, width)).into_dyn();
let cow_array = ndarray::CowArray::from(test_image);
let ort_input_value = ort::Value::from_array(
session.allocator(),
&cow_array
).unwrap();
let _ = session.run(vec![ort_input_value]);
} else {
warn!(
"Could not run an inference test because the model input shape was not properly recognized. Obtained: width: {:?}, height: {:?}, channels: {:?}",
width.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string()), // Print the number on the option or "None"
height.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string()),
channels.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string())
);
}
} else {
warn!("Could not run an inference test because the model input shape does not contain all the image dimensions");
}

Ok(Self { session })
} else {
Expand All @@ -112,14 +131,16 @@ impl OnnxSession {

impl super::session::SessionTrait for OnnxSession {
fn infer(&self, mut frame: pipeless::data::Frame) -> pipeless::data::Frame {
// TODO: automatically resize and traspose the input image to the expected by the model

// FIXME: we are forcing users to provide float32 arrays which will produce the inference to fail if the model expects uint values.

let input_data = frame.get_inference_input().to_owned();
if input_data.len() == 0 {
warn!("No inference input data was provided. Did you forget to add it at your pre-process hook?");
return frame;
}

// TODO: automatically resize and traspose the input image to the expected by the model

let input_vec = input_data.view().insert_axis(ndarray::Axis(0)).into_dyn(); // Batch image with batch size 1
let cow_array = ndarray::CowArray::from(input_vec);
let ort_input_value_result = ort::Value::from_array(
Expand Down
4 changes: 3 additions & 1 deletion pipeless/src/stages/inference/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ impl SessionParams {
warn!("'execution_mode' must be set to 'Parallel' for 'inter_threads' to take effect");
}
let intra_threads = data["intra_threads"].as_i64();
let custom_op_lib_path = data["custom_op_lib_path"].as_str();
SessionParams::Onnx(
OnnxSessionParams::new(
stage_name,
execution_provider, execution_mode,
inter_threads.map(|t| t as i16),
intra_threads.map(|t| t as i16)
intra_threads.map(|t| t as i16),
custom_op_lib_path,
))
},
super::runtime::InferenceRuntime::Openvino => unimplemented!(),
Expand Down