diff --git a/CHANGES.md b/CHANGES.md index abbcd0dfe..1b5d66ee0 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,10 @@ +0.4.3 +===== + +- Fixed a regression: `AttributeError` when loading pickles that hold a + reference to a dynamically defined class from the `__main__` module. + ([issue #131]( https://github.com/cloudpipe/cloudpickle/issues/131)). + 0.4.2 ===== diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 45cffe5ba..f429a58d1 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -620,12 +620,16 @@ def save_global(self, obj, name=None, pack=struct.pack): The name of this method is somewhat misleading: all types get dispatched here. """ + if obj.__module__ == "__main__": + return self.save_dynamic_class(obj) + try: return Pickler.save_global(self, obj, name=name) except Exception: if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": if obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) + return self.save_reduce( + _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) typ = type(obj) if typ is not obj and isinstance(obj, (type, types.ClassType)): diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 5b1bbcb19..2a73786c4 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -44,6 +44,7 @@ from cloudpickle.cloudpickle import _find_module, _make_empty_cell, cell_set from .testutils import subprocess_pickle_echo +from .testutils import assert_run_python_script HAVE_WEAKSET = hasattr(weakref, 'WeakSet') @@ -736,6 +737,62 @@ def test_builtin_type__new__(self): for t in list, tuple, set, frozenset, dict, object: self.assertTrue(pickle_depickle(t.__new__) is t.__new__) + def test_interactively_defined_function(self): + # Check that callables defined in the __main__ module of a Python + # script (or jupyter kernel) can be pickled / unpickled / executed. + code = """\ + from testutils import subprocess_pickle_echo + + CONSTANT = 42 + + class Foo(object): + + def method(self, x): + return x + + foo = Foo() + + def f0(x): + return x ** 2 + + def f1(): + return Foo + + def f2(x): + return Foo().method(x) + + def f3(): + return Foo().method(CONSTANT) + + def f4(x): + return foo.method(x) + + cloned = subprocess_pickle_echo(lambda x: x**2) + assert cloned(3) == 9 + + cloned = subprocess_pickle_echo(f0) + assert cloned(3) == 9 + + cloned = subprocess_pickle_echo(Foo) + assert cloned().method(2) == Foo().method(2) + + cloned = subprocess_pickle_echo(Foo()) + assert cloned.method(2) == Foo().method(2) + + cloned = subprocess_pickle_echo(f1) + assert cloned()().method('a') == f1()().method('a') + + cloned = subprocess_pickle_echo(f2) + assert cloned(2) == f2(2) + + cloned = subprocess_pickle_echo(f3) + assert cloned() == f3() + + cloned = subprocess_pickle_echo(f4) + assert cloned(2) == f4(2) + """ + assert_run_python_script(textwrap.dedent(code)) + @pytest.mark.skipif(sys.version_info >= (3, 0), reason="hardcoded pickle bytes for 2.7") def test_function_pickle_compat_0_4_0(self): diff --git a/tests/testutils.py b/tests/testutils.py index 110e2f78d..6a50f9732 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -1,7 +1,8 @@ import sys import os -from subprocess import Popen -from subprocess import PIPE +import os.path as op +import tempfile +from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError from cloudpickle import dumps from pickle import loads @@ -28,9 +29,13 @@ def subprocess_pickle_echo(input_data): """ pickled_input_data = dumps(input_data) - cmd = [sys.executable, __file__] - cwd = os.getcwd() - proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd) + cmd = [sys.executable, __file__] # run then pickle_echo() in __main__ + cloudpickle_repo_folder = op.normpath( + op.join(op.dirname(__file__), '..')) + cwd = cloudpickle_repo_folder + pythonpath = "{src}/tests:{src}".format(src=cloudpickle_repo_folder) + env = {'PYTHONPATH': pythonpath} + proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env) try: comm_kwargs = {} if timeout_supported: @@ -68,5 +73,42 @@ def pickle_echo(stream_in=None, stream_out=None): stream_out.close() +def assert_run_python_script(source_code, timeout=5): + """Utility to help check pickleability of objects defined in __main__ + + The script provided in the source code should return 0 and not print + anything on stderr or stdout. + """ + fd, source_file = tempfile.mkstemp(suffix='_src_test_cloudpickle.py') + os.close(fd) + try: + with open(source_file, 'wb') as f: + f.write(source_code.encode('utf-8')) + cmd = [sys.executable, source_file] + cloudpickle_repo_folder = op.normpath( + op.join(op.dirname(__file__), '..')) + pythonpath = "{src}/tests:{src}".format(src=cloudpickle_repo_folder) + kwargs = { + 'cwd': cloudpickle_repo_folder, + 'stderr': STDOUT, + 'env': {'PYTHONPATH': pythonpath}, + } + if timeout_supported: + kwargs['timeout'] = timeout + try: + try: + out = check_output(cmd, **kwargs) + except CalledProcessError as e: + raise RuntimeError(u"script errored with output:\n%s" + % e.output.decode('utf-8')) + if out != b"": + raise AssertionError(out.decode('utf-8')) + except TimeoutExpired as e: + raise RuntimeError(u"script timeout, output so far:\n%s" + % e.output.decode('utf-8')) + finally: + os.unlink(source_file) + + if __name__ == '__main__': pickle_echo()