diff --git a/paddle/common/enforce.cc b/paddle/common/enforce.cc index 0719035db4c494..6dd4f0372e2b37 100644 --- a/paddle/common/enforce.cc +++ b/paddle/common/enforce.cc @@ -64,10 +64,11 @@ int GetCallStackLevel() { return FLAGS_call_stack_level; } std::string SimplifyErrorTypeFormat(const std::string& str) { std::ostringstream sout; size_t type_end_pos = str.find(':', 0); - if (str.substr(type_end_pos - 5, type_end_pos) == "Error:") { + if (type_end_pos != str.npos && type_end_pos >= 5 && + str.substr(type_end_pos - 5, 6) == "Error:") { // Remove "Error:", add "()" // Examples: - // InvalidArgumentError: xxx -> (InvalidArgument): xxx + // InvalidArgumentError: xxx -> (InvalidArgument) xxx sout << "(" << str.substr(0, type_end_pos - 5) << ")" << str.substr(type_end_pos + 1); } else { diff --git a/test/legacy_test/test_cpp_error_msg.py b/test/legacy_test/test_cpp_error_msg.py new file mode 100644 index 00000000000000..164ab16187c1c9 --- /dev/null +++ b/test/legacy_test/test_cpp_error_msg.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + + +class TestCppErrorMsg(unittest.TestCase): + def setUp(self) -> None: + paddle.base.set_flags({'FLAGS_call_stack_level': 1}) + + def test_invalid_argument(self): + with self.assertRaises(ValueError) as em: + input_value = paddle.to_tensor([1, 2, 3, 4, 5]) + paddle.bincount(input_value, minlength=-1) + # InvalidArgumentError: xxx -> (InvalidArgument) xxx + self.assertEqual( + str(em.exception).startswith("(InvalidArgument)"), True + ) + + +if __name__ == "__main__": + unittest.main()