@@ -9,6 +9,7 @@ pub struct OnnxSessionParams {
99 execution_mode : Option < String > , // Parallel or sequential exeuction mode or onnx
1010 inter_threads : Option < i16 > , // If execution mode is Parallel (and nodes can be run in parallel), this sets the maximum number of threads to use to run them in parallel.
1111 intra_threads : Option < i16 > , // Number of threads to parallelize the execution within nodes
12+ custom_op_lib_path : Option < String > , // Path to a custom op library
1213 /*ir_version: Option<u32>,
1314 opset_version: Option<u32>,
1415 image_shape_format: Option<Vec<String>>,
@@ -20,13 +21,15 @@ impl OnnxSessionParams {
2021 pub fn new (
2122 stage_name : & str ,
2223 execution_provider : & str , execution_mode : Option < & str > ,
23- inter_threads : Option < i16 > , intra_threads : Option < i16 >
24+ inter_threads : Option < i16 > , intra_threads : Option < i16 > ,
25+ custom_op_lib_path : Option < & str > ,
2426 ) -> Self {
2527 Self {
2628 stage_name : stage_name. to_string ( ) ,
2729 execution_provider : execution_provider. to_string ( ) ,
2830 execution_mode : execution_mode. map ( |m| m. to_string ( ) ) ,
2931 inter_threads, intra_threads,
32+ custom_op_lib_path : custom_op_lib_path. map ( |p| p. to_string ( ) ) ,
3033 }
3134 }
3235}
@@ -82,25 +85,41 @@ impl OnnxSession {
8285 }
8386 }
8487
88+ if let Some ( lib_path) = onnx_params. custom_op_lib_path {
89+ log:: info!( "Loading custom operations lib from: {}" , lib_path) ;
90+ session_builder = session_builder. with_custom_op_lib ( & lib_path) . unwrap ( ) ;
91+ }
92+
8593 let session = session_builder. with_model_from_file ( model_file_path) . unwrap ( ) ;
8694
8795 // Run a first test inference that usually takes more time.
8896 // This avoids to add an initial delay to the stream when it arrives, making the session ready
89- let input0_shape: Vec < usize > = session. inputs [ 0 ] . dimensions ( )
90- . map ( std:: option:: Option :: unwrap)
91- . collect ( ) ;
92- // Assuming the conventional input format: batch, channels, height, witdh
93- let batch_shift = if input0_shape. len ( ) > 3 { 1 } else { 0 } ;
94- let width = input0_shape[ 2 + batch_shift] ;
95- let height = input0_shape[ 1 + batch_shift] ;
96- let channels = input0_shape[ 0 + batch_shift] ;
97- let test_image = ndarray:: Array3 :: < u8 > :: zeros ( ( channels, height, width) ) . into_dyn ( ) ;
98- let cow_array = ndarray:: CowArray :: from ( test_image) ;
99- let ort_input_value = ort:: Value :: from_array (
100- session. allocator ( ) ,
101- & cow_array
102- ) . unwrap ( ) ;
103- let _ = session. run ( vec ! [ ort_input_value] ) ;
97+ let input0_shape: Vec < Option < usize > > = session. inputs [ 0 ] . dimensions ( ) . map ( |x| x) . collect ( ) ;
98+ if input0_shape. len ( ) > 2 {
99+ // Assuming the conventional input format: batch, channels, height, witdh
100+ let batch_shift = if input0_shape. len ( ) > 3 { 1 } else { 0 } ;
101+ let width = input0_shape[ 2 + batch_shift] ;
102+ let height = input0_shape[ 1 + batch_shift] ;
103+ let channels = input0_shape[ 0 + batch_shift] ;
104+ if let ( Some ( width) , Some ( height) , Some ( channels) ) = ( width, height, channels) {
105+ let test_image = ndarray:: Array3 :: < u8 > :: zeros ( ( channels, height, width) ) . into_dyn ( ) ;
106+ let cow_array = ndarray:: CowArray :: from ( test_image) ;
107+ let ort_input_value = ort:: Value :: from_array (
108+ session. allocator ( ) ,
109+ & cow_array
110+ ) . unwrap ( ) ;
111+ let _ = session. run ( vec ! [ ort_input_value] ) ;
112+ } else {
113+ warn ! (
114+ "Could not run an inference test because the model input shape was not properly recognized. Obtained: width: {:?}, height: {:?}, channels: {:?}" ,
115+ width. map( |num| num. to_string( ) ) . unwrap_or_else( || "None" . to_string( ) ) , // Print the number on the option or "None"
116+ height. map( |num| num. to_string( ) ) . unwrap_or_else( || "None" . to_string( ) ) ,
117+ channels. map( |num| num. to_string( ) ) . unwrap_or_else( || "None" . to_string( ) )
118+ ) ;
119+ }
120+ } else {
121+ warn ! ( "Could not run an inference test because the model input shape does not contain all the image dimensions" ) ;
122+ }
104123
105124 Ok ( Self { session } )
106125 } else {
@@ -112,14 +131,16 @@ impl OnnxSession {
112131
113132impl super :: session:: SessionTrait for OnnxSession {
114133 fn infer ( & self , mut frame : pipeless:: data:: Frame ) -> pipeless:: data:: Frame {
134+ // TODO: automatically resize and traspose the input image to the expected by the model
135+
136+ // FIXME: we are forcing users to provide float32 arrays which will produce the inference to fail if the model expects uint values.
137+
115138 let input_data = frame. get_inference_input ( ) . to_owned ( ) ;
116139 if input_data. len ( ) == 0 {
117140 warn ! ( "No inference input data was provided. Did you forget to add it at your pre-process hook?" ) ;
118141 return frame;
119142 }
120143
121- // TODO: automatically resize and traspose the input image to the expected by the model
122-
123144 let input_vec = input_data. view ( ) . insert_axis ( ndarray:: Axis ( 0 ) ) . into_dyn ( ) ; // Batch image with batch size 1
124145 let cow_array = ndarray:: CowArray :: from ( input_vec) ;
125146 let ort_input_value_result = ort:: Value :: from_array (
0 commit comments