diff --git a/dill/source.py b/dill/source.py index 4b538fa6..40c41e75 100644 --- a/dill/source.py +++ b/dill/source.py @@ -26,7 +26,7 @@ import re from inspect import (getblock, getfile, getmodule, getsourcefile, indentsize, isbuiltin, isclass, iscode, isframe, isfunction, ismethod, - ismodule, istraceback) + ismodule, istraceback, currentframe) from tokenize import TokenError from ._dill import IS_IPYTHON @@ -236,6 +236,30 @@ def findsource(object): if isclass(object): name = object.__name__ pat = re.compile(r'^(\s*)class\s*' + name + r'\b') + + # find the first frame that inside sourcefile + frame = currentframe() + while frame and frame.f_code.co_filename != sourcefile: + frame = frame.f_back + + # Starting from the found frame, search upward level by level. + while frame and frame.f_code.co_filename == sourcefile: + lineno = frame.f_lineno if hasattr(frame, 'f_lineno') else None + start_lineno = lineno - 1 if lineno is not None else len(lines) - 1 + candidates = [] + for i in range(start_lineno, -1, -1): + match = pat.match(lines[i]) + if match: + # if it's at toplevel, it's already the best one + if lines[i][0] == 'c': + return lines, i + candidates.append((match.group(1), -i)) + if candidates: + candidates.sort() + return lines, -candidates[0][1] + # If no match is found in the current frame, move up to the previous frame. + frame = frame.f_back + # make some effort to find the best matching class definition: # use the one with the least indentation, which is the one # that's most probably not inside a function definition. diff --git a/dill/tests/test_source.py b/dill/tests/test_source.py index 12b4519d..964414a1 100644 --- a/dill/tests/test_source.py +++ b/dill/tests/test_source.py @@ -35,6 +35,9 @@ class Bar: pass _bar = Bar() +def _wrap_getsource(obj): + return getsource(obj) + # inspect.getsourcelines # dill.source.getblocks def test_getsource(): assert getsource(f) == 'f = lambda x: x**2\n' @@ -53,6 +56,19 @@ def test_getsource(): assert getsource(Foo) == 'class Foo(object):\n def bar(self, x):\n return x*x+x\n' #XXX: add getsource for _foo, _bar +def test_getsource_redefine(): + class Foobar: + def bar(self,x): + return x*x+x + assert getsource(Foobar) == ' class Foobar:\n def bar(self,x):\n return x*x+x\n' + assert _wrap_getsource(Foobar) == ' class Foobar:\n def bar(self,x):\n return x*x+x\n' + + class Foobar: + def bar(self,x): + return x*x+x+1 + assert getsource(Foobar) == ' class Foobar:\n def bar(self,x):\n return x*x+x+1\n' + assert _wrap_getsource(Foobar) == ' class Foobar:\n def bar(self,x):\n return x*x+x+1\n' + # test itself def test_itself(): assert getimport(getimport)=='from dill.source import getimport\n' @@ -163,6 +179,7 @@ def test_foo(): if __name__ == '__main__': test_getsource() + test_getsource_redefine() test_itself() test_builtin() test_imported()