Skip to content

Commit 9c09485

Browse files
authored
Merge pull request #128 from pipeless-ai/custom_op_lib
feat(onnxruntime): Support custom op library
2 parents 5abb981 + 643fec2 commit 9c09485

2 files changed

Lines changed: 42 additions & 19 deletions

File tree

pipeless/src/stages/inference/onnx.rs

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub struct OnnxSessionParams {
99
execution_mode: Option<String>, // Parallel or sequential exeuction mode or onnx
1010
inter_threads: Option<i16>, // If execution mode is Parallel (and nodes can be run in parallel), this sets the maximum number of threads to use to run them in parallel.
1111
intra_threads: Option<i16>, // Number of threads to parallelize the execution within nodes
12+
custom_op_lib_path: Option<String>, // Path to a custom op library
1213
/*ir_version: Option<u32>,
1314
opset_version: Option<u32>,
1415
image_shape_format: Option<Vec<String>>,
@@ -20,13 +21,15 @@ impl OnnxSessionParams {
2021
pub fn new(
2122
stage_name: &str,
2223
execution_provider: &str, execution_mode: Option<&str>,
23-
inter_threads: Option<i16>, intra_threads: Option<i16>
24+
inter_threads: Option<i16>, intra_threads: Option<i16>,
25+
custom_op_lib_path: Option<&str>,
2426
) -> Self {
2527
Self {
2628
stage_name: stage_name.to_string(),
2729
execution_provider: execution_provider.to_string(),
2830
execution_mode: execution_mode.map(|m| m.to_string()),
2931
inter_threads, intra_threads,
32+
custom_op_lib_path: custom_op_lib_path.map(|p| p.to_string()),
3033
}
3134
}
3235
}
@@ -82,25 +85,41 @@ impl OnnxSession {
8285
}
8386
}
8487

88+
if let Some(lib_path) = onnx_params.custom_op_lib_path {
89+
log::info!("Loading custom operations lib from: {}", lib_path);
90+
session_builder = session_builder.with_custom_op_lib(&lib_path).unwrap();
91+
}
92+
8593
let session = session_builder.with_model_from_file(model_file_path).unwrap();
8694

8795
// Run a first test inference that usually takes more time.
8896
// This avoids to add an initial delay to the stream when it arrives, making the session ready
89-
let input0_shape: Vec<usize> = session.inputs[0].dimensions()
90-
.map(std::option::Option::unwrap)
91-
.collect();
92-
// Assuming the conventional input format: batch, channels, height, witdh
93-
let batch_shift = if input0_shape.len() > 3 { 1 } else { 0 };
94-
let width = input0_shape[2 + batch_shift];
95-
let height = input0_shape[1 + batch_shift];
96-
let channels = input0_shape[0 + batch_shift];
97-
let test_image = ndarray::Array3::<u8>::zeros((channels, height, width)).into_dyn();
98-
let cow_array = ndarray::CowArray::from(test_image);
99-
let ort_input_value = ort::Value::from_array(
100-
session.allocator(),
101-
&cow_array
102-
).unwrap();
103-
let _ = session.run(vec![ort_input_value]);
97+
let input0_shape: Vec<Option<usize>> = session.inputs[0].dimensions().map(|x| x).collect();
98+
if input0_shape.len() > 2 {
99+
// Assuming the conventional input format: batch, channels, height, witdh
100+
let batch_shift = if input0_shape.len() > 3 { 1 } else { 0 };
101+
let width = input0_shape[2 + batch_shift];
102+
let height = input0_shape[1 + batch_shift];
103+
let channels = input0_shape[0 + batch_shift];
104+
if let (Some(width), Some(height), Some(channels)) = (width, height, channels) {
105+
let test_image = ndarray::Array3::<u8>::zeros((channels, height, width)).into_dyn();
106+
let cow_array = ndarray::CowArray::from(test_image);
107+
let ort_input_value = ort::Value::from_array(
108+
session.allocator(),
109+
&cow_array
110+
).unwrap();
111+
let _ = session.run(vec![ort_input_value]);
112+
} else {
113+
warn!(
114+
"Could not run an inference test because the model input shape was not properly recognized. Obtained: width: {:?}, height: {:?}, channels: {:?}",
115+
width.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string()), // Print the number on the option or "None"
116+
height.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string()),
117+
channels.map(|num| num.to_string()).unwrap_or_else(|| "None".to_string())
118+
);
119+
}
120+
} else {
121+
warn!("Could not run an inference test because the model input shape does not contain all the image dimensions");
122+
}
104123

105124
Ok(Self { session })
106125
} else {
@@ -112,14 +131,16 @@ impl OnnxSession {
112131

113132
impl super::session::SessionTrait for OnnxSession {
114133
fn infer(&self, mut frame: pipeless::data::Frame) -> pipeless::data::Frame {
134+
// TODO: automatically resize and traspose the input image to the expected by the model
135+
136+
// FIXME: we are forcing users to provide float32 arrays which will produce the inference to fail if the model expects uint values.
137+
115138
let input_data = frame.get_inference_input().to_owned();
116139
if input_data.len() == 0 {
117140
warn!("No inference input data was provided. Did you forget to add it at your pre-process hook?");
118141
return frame;
119142
}
120143

121-
// TODO: automatically resize and traspose the input image to the expected by the model
122-
123144
let input_vec = input_data.view().insert_axis(ndarray::Axis(0)).into_dyn(); // Batch image with batch size 1
124145
let cow_array = ndarray::CowArray::from(input_vec);
125146
let ort_input_value_result = ort::Value::from_array(

pipeless/src/stages/inference/session.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@ impl SessionParams {
4242
warn!("'execution_mode' must be set to 'Parallel' for 'inter_threads' to take effect");
4343
}
4444
let intra_threads = data["intra_threads"].as_i64();
45+
let custom_op_lib_path = data["custom_op_lib_path"].as_str();
4546
SessionParams::Onnx(
4647
OnnxSessionParams::new(
4748
stage_name,
4849
execution_provider, execution_mode,
4950
inter_threads.map(|t| t as i16),
50-
intra_threads.map(|t| t as i16)
51+
intra_threads.map(|t| t as i16),
52+
custom_op_lib_path,
5153
))
5254
},
5355
super::runtime::InferenceRuntime::Openvino => unimplemented!(),

0 commit comments

Comments
 (0)