Skip to content
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
0.4.3
=====

- Fixed a regression: `AttributeError` when loading pickles that hold a
reference to a dynamically defined class from the `__main__` module.
([issue #131]( https://github.com/cloudpipe/cloudpickle/issues/131)).

0.4.2
=====

Expand Down
6 changes: 5 additions & 1 deletion cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,16 @@ def save_global(self, obj, name=None, pack=struct.pack):
The name of this method is somewhat misleading: all types get
dispatched here.
"""
if obj.__module__ == "__main__":
return self.save_dynamic_class(obj)

try:
return Pickler.save_global(self, obj, name=name)
except Exception:
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
if obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
return self.save_reduce(
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)

typ = type(obj)
if typ is not obj and isinstance(obj, (type, types.ClassType)):
Expand Down
57 changes: 57 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from cloudpickle.cloudpickle import _find_module, _make_empty_cell, cell_set

from .testutils import subprocess_pickle_echo
from .testutils import assert_run_python_script


HAVE_WEAKSET = hasattr(weakref, 'WeakSet')
Expand Down Expand Up @@ -736,6 +737,62 @@ def test_builtin_type__new__(self):
for t in list, tuple, set, frozenset, dict, object:
self.assertTrue(pickle_depickle(t.__new__) is t.__new__)

def test_interactively_defined_function(self):
# Check that callables defined in the __main__ module of a Python
# script (or jupyter kernel) can be pickled / unpickled / executed.
code = """\
from testutils import subprocess_pickle_echo

CONSTANT = 42

class Foo(object):

def method(self, x):
return x

foo = Foo()

def f0(x):
return x ** 2

def f1():
return Foo

def f2(x):
return Foo().method(x)

def f3():
return Foo().method(CONSTANT)

def f4(x):
return foo.method(x)

cloned = subprocess_pickle_echo(lambda x: x**2)
assert cloned(3) == 9

cloned = subprocess_pickle_echo(f0)
assert cloned(3) == 9

cloned = subprocess_pickle_echo(Foo)
assert cloned().method(2) == Foo().method(2)

cloned = subprocess_pickle_echo(Foo())
assert cloned.method(2) == Foo().method(2)

cloned = subprocess_pickle_echo(f1)
assert cloned()().method('a') == f1()().method('a')

cloned = subprocess_pickle_echo(f2)
assert cloned(2) == f2(2)

cloned = subprocess_pickle_echo(f3)
assert cloned() == f3()

cloned = subprocess_pickle_echo(f4)
assert cloned(2) == f4(2)
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed protocol comparing to the original

assert_run_python_script(textwrap.dedent(code))

@pytest.mark.skipif(sys.version_info >= (3, 0),
reason="hardcoded pickle bytes for 2.7")
def test_function_pickle_compat_0_4_0(self):
Expand Down
52 changes: 47 additions & 5 deletions tests/testutils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys
import os
from subprocess import Popen
from subprocess import PIPE
import os.path as op
import tempfile
from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError

from cloudpickle import dumps
from pickle import loads
Expand All @@ -28,9 +29,13 @@ def subprocess_pickle_echo(input_data):

"""
pickled_input_data = dumps(input_data)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

cmd = [sys.executable, __file__]
cwd = os.getcwd()
proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd)
cmd = [sys.executable, __file__] # run then pickle_echo() in __main__
cloudpickle_repo_folder = op.normpath(
op.join(op.dirname(__file__), '..'))
cwd = cloudpickle_repo_folder
pythonpath = "{src}/tests:{src}".format(src=cloudpickle_repo_folder)
env = {'PYTHONPATH': pythonpath}
proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env)
try:
comm_kwargs = {}
if timeout_supported:
Expand Down Expand Up @@ -68,5 +73,42 @@ def pickle_echo(stream_in=None, stream_out=None):
stream_out.close()


def assert_run_python_script(source_code, timeout=5):
"""Utility to help check pickleability of objects defined in __main__

The script provided in the source code should return 0 and not print
anything on stderr or stdout.
"""
fd, source_file = tempfile.mkstemp(suffix='_src_test_cloudpickle.py')
os.close(fd)
try:
with open(source_file, 'wb') as f:
f.write(source_code.encode('utf-8'))
cmd = [sys.executable, source_file]
cloudpickle_repo_folder = op.normpath(
op.join(op.dirname(__file__), '..'))
pythonpath = "{src}/tests:{src}".format(src=cloudpickle_repo_folder)
kwargs = {
'cwd': cloudpickle_repo_folder,
'stderr': STDOUT,
'env': {'PYTHONPATH': pythonpath},
}
if timeout_supported:
kwargs['timeout'] = timeout
try:
try:
out = check_output(cmd, **kwargs)
except CalledProcessError as e:
raise RuntimeError(u"script errored with output:\n%s"
% e.output.decode('utf-8'))
if out != b"":
raise AssertionError(out.decode('utf-8'))
except TimeoutExpired as e:
raise RuntimeError(u"script timeout, output so far:\n%s"
% e.output.decode('utf-8'))
finally:
os.unlink(source_file)


if __name__ == '__main__':
pickle_echo()