Skip to content

Commit 45babd6

Browse files
author
KeDengMS
authored
symbolic shape inference: fix warnings in GPT-2 model (#2608)
And revise nuphar perf test on BERT squad
1 parent bc89ecc commit 45babd6

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,13 @@ def _broadcast_shapes(self, shape1, shape2):
263263
else:
264264
new_dim = self._merge_symbols([dim1, dim2])
265265
if not new_dim:
266-
print('unsupported broadcast between ' + str(dim1) + ' ' + str(dim2))
266+
# warning about unsupported broadcast when not auto merge
267+
# note that auto merge has the risk of incorrectly merge symbols while one of them being 1
268+
# for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b'
269+
if self.auto_merge_:
270+
self._add_suggested_merge([dim1, dim2], apply=True)
271+
else:
272+
print('unsupported broadcast between ' + str(dim1) + ' ' + str(dim2))
267273
new_shape = [new_dim] + new_shape
268274
return new_shape
269275

@@ -625,9 +631,9 @@ def _infer_ConstantOfShape(self, node):
625631
sympy_shape = self._get_int_values(node)[0]
626632
vi = self.known_vi_[node.output[0]]
627633
if sympy_shape is not None:
628-
self._update_computed_dims(sympy_shape)
629634
if type(sympy_shape) != list:
630635
sympy_shape = [sympy_shape]
636+
self._update_computed_dims(sympy_shape)
631637
else:
632638
# create new dynamic shape
633639
sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node,0), node)

onnxruntime/test/python/onnxruntime_test_python_nuphar.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,9 @@ def test_bert_squad(self):
108108
onnx_test_runner = os.path.join(cwd, 'onnx_test_runner')
109109
subprocess.run([onnx_test_runner, '-e', 'nuphar', '-n', 'download_sample_10', cwd], check=True, cwd=cwd)
110110

111-
# run onnxruntime_perf_test
111+
# run onnxruntime_perf_test, note that nuphar currently is not integrated with ORT thread pool, so set -x 1 to avoid thread confliction with OpenMP
112112
onnxruntime_perf_test = os.path.join(cwd, 'onnxruntime_perf_test')
113-
subprocess.run([onnxruntime_perf_test, '-e', 'nuphar', '-t', '20', bert_squad_model, '1.txt'], check=True, cwd=cwd)
114-
subprocess.run([onnxruntime_perf_test, '-e', 'cpu', '-o', '99', '-t', '20', bert_squad_model, '1.txt'], check=True, cwd=cwd)
113+
subprocess.run([onnxruntime_perf_test, '-e', 'nuphar', '-x', '1', '-t', '20', bert_squad_model, '1.txt'], check=True, cwd=cwd)
115114

116115

117116
def test_rnn_benchmark(self):

0 commit comments

Comments
 (0)