diff --git a/onnxruntime/core/providers/nuphar/runtime/compute_ctx.h b/onnxruntime/core/providers/nuphar/runtime/compute_ctx.h index ce7de45ca6581..bde9a9108ef16 100644 --- a/onnxruntime/core/providers/nuphar/runtime/compute_ctx.h +++ b/onnxruntime/core/providers/nuphar/runtime/compute_ctx.h @@ -178,18 +178,18 @@ class KernelComputeCtx { } // UpdateRealizedDims is used to sync realize dim - // Note insert_exclusive_axis is introduced to adjusted shape. + // Note insert_inclusive_axis is introduced to adjusted shape. // It is commonly used in Scan or other subgraphs // when Tensors' shapes in a subgraph are sliced from the main grahp. - // Using the sliced axis as insert_exclusive_axis can find the correct shape dim in the main graph + // Using the sliced axis as insert_inclusive_axis can find the correct shape dim in the main graph inline void UpdateRealizedDims( const std::vector>& symbols, std::vector& realized_output_shape, - size_t insert_exclusive_axis = 65535 /*minimal maximum of size_t*/) { + size_t insert_inclusive_axis = 65535 /*minimal maximum of size_t*/) { for (const auto& s_pair : symbols) { size_t dim = s_pair.first; size_t adjusted_dim = dim; - if (dim > insert_exclusive_axis) { + if (dim >= insert_inclusive_axis) { adjusted_dim = dim + 1; } diff --git a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py index c33226e632cd7..797514afec9de 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py +++ b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py @@ -7,7 +7,7 @@ from onnx import numpy_helper import onnxruntime as onnxrt import os -from onnxruntime.nuphar.rnn_benchmark import perf_test +from onnxruntime.nuphar.rnn_benchmark import perf_test, generate_model from pathlib import Path import shutil import sys @@ -131,6 +131,51 @@ def test_rnn_benchmark(self): min_duration_seconds=1) + def test_batch_scan(self): + input_dim = 3 + hidden_dim = 5 + bidirectional = False + layers = 3 + + lstm_model_name = 'test_batch_rnn_lstm.onnx' + # create an LSTM model for generating baseline data + generate_model('lstm', input_dim, hidden_dim, bidirectional, layers, lstm_model_name, batch_one=False, has_seq_len=True) + + seq_len = 8 + batch_size = 2 + # prepare input + data_input = (np.random.rand(seq_len, batch_size, input_dim) * 2 - 1).astype(np.float32) + data_seq_len = np.random.randint(1, seq_len, size=(batch_size,), dtype=np.int32) + + # run lstm as baseline + sess = onnxrt.InferenceSession(lstm_model_name) + first_lstm_data_output = sess.run([], {'input':data_input[:,0:1,:], 'seq_len':data_seq_len[0:1]}) + + lstm_data_output = [] + lstm_data_output = first_lstm_data_output + + for b in range(1, batch_size): + lstm_data_output = lstm_data_output + sess.run([], {'input':data_input[:,b:(b+1),:], 'seq_len':data_seq_len[b:(b+1)]}) + lstm_data_output = np.concatenate(lstm_data_output, axis=1) + + # generate a batch scan model + scan_model_name = 'test_batch_rnn_scan.onnx' + subprocess.run([sys.executable, '-m', 'onnxruntime.nuphar.model_editor', '--input', lstm_model_name, '--output', scan_model_name, '--mode', 'to_scan'], check=True) + + # run scan_batch with batch size 1 + sess = onnxrt.InferenceSession(scan_model_name) + scan_batch_data_output = sess.run([], {'input':data_input[:,0:1,:], 'seq_len':data_seq_len[0:1]}) + assert np.allclose(first_lstm_data_output, scan_batch_data_output) + + # run scan_batch with batch size 2 + scan_batch_data_output = sess.run([], {'input':data_input, 'seq_len':data_seq_len}) + assert np.allclose(lstm_data_output, scan_batch_data_output) + + # run scan_batch with batch size 1 again + scan_batch_data_output = sess.run([], {'input':data_input[:,0:1,:], 'seq_len':data_seq_len[0:1]}) + assert np.allclose(first_lstm_data_output, scan_batch_data_output) + + def test_symbolic_shape_infer(self): cwd = os.getcwd() test_model_dir = os.path.join(cwd, '..', 'models')