From af4a30927957e05ec643242b506210051394e63c Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Tue, 10 Dec 2019 11:45:50 -0800 Subject: [PATCH] symbolic shape inference: fix warnings in GPT-2 model And revise nuphar perf test on BERT squad --- .../providers/nuphar/scripts/symbolic_shape_infer.py | 10 ++++++++-- .../test/python/onnxruntime_test_python_nuphar.py | 5 ++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py index 3342977caa28c..f2536f8c91ff4 100644 --- a/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py +++ b/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py @@ -263,7 +263,13 @@ def _broadcast_shapes(self, shape1, shape2): else: new_dim = self._merge_symbols([dim1, dim2]) if not new_dim: - print('unsupported broadcast between ' + str(dim1) + ' ' + str(dim2)) + # warning about unsupported broadcast when not auto merge + # note that auto merge has the risk of incorrectly merge symbols while one of them being 1 + # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b' + if self.auto_merge_: + self._add_suggested_merge([dim1, dim2], apply=True) + else: + print('unsupported broadcast between ' + str(dim1) + ' ' + str(dim2)) new_shape = [new_dim] + new_shape return new_shape @@ -625,9 +631,9 @@ def _infer_ConstantOfShape(self, node): sympy_shape = self._get_int_values(node)[0] vi = self.known_vi_[node.output[0]] if sympy_shape is not None: - self._update_computed_dims(sympy_shape) if type(sympy_shape) != list: sympy_shape = [sympy_shape] + self._update_computed_dims(sympy_shape) else: # create new dynamic shape sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node,0), node) diff --git a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py index c33226e632cd7..b8c66367b94b3 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py +++ b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py @@ -108,10 +108,9 @@ def test_bert_squad(self): onnx_test_runner = os.path.join(cwd, 'onnx_test_runner') subprocess.run([onnx_test_runner, '-e', 'nuphar', '-n', 'download_sample_10', cwd], check=True, cwd=cwd) - # run onnxruntime_perf_test + # run onnxruntime_perf_test, note that nuphar currently is not integrated with ORT thread pool, so set -x 1 to avoid thread confliction with OpenMP onnxruntime_perf_test = os.path.join(cwd, 'onnxruntime_perf_test') - subprocess.run([onnxruntime_perf_test, '-e', 'nuphar', '-t', '20', bert_squad_model, '1.txt'], check=True, cwd=cwd) - subprocess.run([onnxruntime_perf_test, '-e', 'cpu', '-o', '99', '-t', '20', bert_squad_model, '1.txt'], check=True, cwd=cwd) + subprocess.run([onnxruntime_perf_test, '-e', 'nuphar', '-x', '1', '-t', '20', bert_squad_model, '1.txt'], check=True, cwd=cwd) def test_rnn_benchmark(self):