Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 61 additions & 12 deletions python/paddle/fluid/dygraph/dygraph_to_static/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import six
import sys
import traceback
import linecache

from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map

Expand All @@ -29,6 +30,9 @@
DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR"
DEFAULT_DISABLE_NEW_ERROR = 0

SOURCE_CODE_RANGE = 5
BLANK_COUNT_BEFORE_FILE_STR = 4


def attach_error_data(error, in_runtime=False):
"""
Expand All @@ -40,6 +44,7 @@ def attach_error_data(error, in_runtime=False):
Returns:
An error attached data about original source code information and traceback.
"""

e_type, e_value, e_traceback = sys.exc_info()
tb = traceback.extract_tb(e_traceback)[1:]

Expand Down Expand Up @@ -82,12 +87,49 @@ def __init__(self, location, function_name, source_code):
def formated_message(self):
# self.source_code may be empty in some functions.
# For example, decorator generated function
return ' File "{}", line {}, in {}\n\t{}'.format(
return ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n\t{}'.format(
self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip()
if isinstance(self.source_code, str) else self.source_code)


class TraceBackFrameRange(OriginInfo):
"""
Traceback frame information.
"""

def __init__(self, location, function_name):
self.location = location
self.function_name = function_name
self.source_code = []
blank_count = []
begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2))

for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE):
line = linecache.getline(self.location.filepath, i)
line_lstrip = line.strip()
self.source_code.append(line_lstrip)
blank_count.append(len(line) - len(line_lstrip))

if i == self.location.lineno:
hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE'
self.source_code.append(hint_msg)
blank_count.append(blank_count[-1])
linecache.clearcache()

min_black_count = min(blank_count)
for i in range(len(self.source_code)):
self.source_code[i] = ' ' * (blank_count[i] - min_black_count +
BLANK_COUNT_BEFORE_FILE_STR * 2
) + self.source_code[i]

def formated_message(self):
msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format(
self.location.filepath, self.location.lineno, self.function_name)
# add empty line after range code
return msg + '\n'.join(self.source_code) + '\n'


class ErrorData(object):
"""
Error data attached to an exception which is raised in un-transformed code.
Expand Down Expand Up @@ -128,26 +170,34 @@ def create_message(self):
return '\n'.join(message_lines)

# Step2: Optimizes stack information with source code information of dygraph from user.
for filepath, lineno, funcname, code in self.origin_traceback:
whether_source_range = True
for filepath, lineno, funcname, code in self.origin_traceback[::-1]:
loc = Location(filepath, lineno)

dygraph_func_info = self.origin_info_map.get(loc.line_location,
None)
if dygraph_func_info:
# TODO(liym27): more information to prompt users that this is the original information.
# Replaces trace stack information about transformed static code with original dygraph code.
traceback_frame = self.origin_info_map[loc.line_location]
else:
traceback_frame = TraceBackFrame(loc, funcname, code)

message_lines.append(traceback_frame.formated_message())
if whether_source_range:
traceback_frame = TraceBackFrameRange(
dygraph_func_info.location,
dygraph_func_info.function_name)
whether_source_range = False
else:
traceback_frame = TraceBackFrame(
dygraph_func_info.location,
dygraph_func_info.function_name,
dygraph_func_info.source_code)
# Two elements already exist in message_lines: "In transformed code:" and "", so insert in index 2
message_lines.insert(2, traceback_frame.formated_message())

# Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
# NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length
# is gather than 1, for example, the error_type is IndentationError.
format_exception = traceback.format_exception_only(self.error_type,
self.error_value)
error_message = [" " * 4 + line for line in format_exception]
error_message = [
" " * BLANK_COUNT_BEFORE_FILE_STR + line
for line in format_exception
]
message_lines.extend(error_message)

return '\n'.join(message_lines)
Expand Down Expand Up @@ -175,7 +225,6 @@ def _simplify_error_value(self):
self.error_value = self.error_type(error_value_str)

def raise_new_exception(self):

# Raises the origin error if disable dygraph2static error module,
if int(os.getenv(DISABLE_ERROR_ENV_NAME, DEFAULT_DISABLE_NEW_ERROR)):
raise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ def set_message(self):
['File "{}", line 35, in func_error_in_compile_time'.format(self.filepath),
'inner_func()',
'File "{}", line 28, in inner_func'.format(self.filepath),
'def inner_func():',
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
'<--- HERE',
'return',
]

def set_func_call(self):
Expand All @@ -242,7 +245,11 @@ def set_message(self):
self.expected_message = \
[
'File "{}", line 46, in func_error_in_compile_time_2'.format(self.filepath),
'x = fluid.layers.reshape(x, shape=[1, 2])'
'def func_error_in_compile_time_2(x):',
'x = fluid.dygraph.to_variable(x)',
'x = fluid.layers.reshape(x, shape=[1, 2])',
'<--- HERE',
'return x'
]


Expand All @@ -261,7 +268,10 @@ def set_exception_type(self):
def set_message(self):
self.expected_message = \
['File "{}", line 91, in forward'.format(self.filepath),
'@paddle.jit.to_static',
'def forward(self):',
'self.test_func()',
'<--- HERE'
]

def set_func_call(self):
Expand Down Expand Up @@ -318,7 +328,12 @@ def set_exception_type(self):
def set_message(self):
self.expected_message = \
['File "{}", line 80, in forward'.format(self.filepath),
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
'def forward(self, x):',
'y = self._linear(x)',
'z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
'<--- HERE',
'out = fluid.layers.mean(y[z])',
'return out'
]

def set_func_call(self):
Expand All @@ -329,7 +344,7 @@ def test_error(self):
self._test_raise_new_exception()


# Situation 4: NotImplementedError
# # Situation 4: NotImplementedError
class TestErrorInOther(unittest.TestCase):
def test(self):
paddle.disable_static()
Expand Down