diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 1c91021e9..b3893175c 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -50,6 +50,7 @@ import opcode import operator import pickle +import platform import struct import sys import traceback @@ -92,6 +93,7 @@ string_types = (str,) PY3 = True PY2 = False + from importlib._bootstrap import _find_spec def _ensure_tracking(class_def): @@ -123,6 +125,69 @@ def _getattribute(obj, name): return getattr(obj, name, None), None +def _whichmodule(obj, name): + """Find the module an object belongs to. + + This function differs from ``pickle.whichmodule`` in two ways: + - it does not mangle the cases where obj's module is __main__ and obj was + not found in any module. + - Errors arising during module introspection are ignored, as those errors + are considered unwanted side effects. + """ + module_name = getattr(obj, '__module__', None) + if module_name is not None: + return module_name + # Protect the iteration by using a list copy of sys.modules against dynamic + # modules that trigger imports of other modules upon calls to getattr. + for module_name, module in list(sys.modules.items()): + if module_name == '__main__' or module is None: + continue + try: + if _getattribute(module, name)[0] is obj: + return module_name + except Exception: + pass + return None + + +def _is_global(obj, name=None): + """Determine if obj can be pickled as attribute of a file-backed module""" + if name is None: + name = getattr(obj, '__qualname__', None) + if name is None: + name = getattr(obj, '__name__', None) + + module_name = _whichmodule(obj, name) + + if module_name is None: + # In this case, obj.__module__ is None AND obj was not found in any + # imported module. obj is thus treated as dynamic. + return False + + if module_name == "__main__": + return False + + module = sys.modules.get(module_name, None) + if module is None: + # The main reason why obj's module would not be imported is that this + # module has been dynamically created, using for example + # types.ModuleType. The other possibility is that module was removed + # from sys.modules after obj was created/imported. But this case is not + # supported, as the standard pickle does not support it either. + return False + + # module has been added to sys.modules, but it can still be dynamic. + if _is_dynamic(module): + return False + + try: + obj2, parent = _getattribute(module, name) + except AttributeError: + # obj was not found inside the module it points to + return False + return obj2 is obj + + def _make_cell_set_template_code(): """Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF @@ -236,10 +301,6 @@ def cell_set(cell, value): EXTENDED_ARG = dis.EXTENDED_ARG -def islambda(func): - return getattr(func, '__name__') == '' - - _BUILTIN_TYPE_NAMES = {} for k, v in types.__dict__.items(): if type(v) is type: @@ -392,61 +453,9 @@ def save_function(self, obj, name=None): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ - write = self.write - - if name is None: - name = getattr(obj, '__qualname__', None) - if name is None: - name = getattr(obj, '__name__', None) - try: - # whichmodule() could fail, see - # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling - modname = pickle.whichmodule(obj, name) - except Exception: - modname = None - # print('which gives %s %s %s' % (modname, obj, name)) - try: - themodule = sys.modules[modname] - except KeyError: - # eval'd items such as namedtuple give invalid items for their function __module__ - modname = '__main__' - - if modname == '__main__': - themodule = None - - try: - lookedup_by_name, _ = _getattribute(themodule, name) - except Exception: - lookedup_by_name = None - - if themodule: - if lookedup_by_name is obj: - return self.save_global(obj, name) - - # if func is lambda, def'ed at prompt, is in main, or is nested, then - # we'll pickle the actual function object rather than simply saving a - # reference (as is done in default pickler), via save_function_tuple. - if (islambda(obj) - or getattr(obj.__code__, 'co_filename', None) == '' - or themodule is None): - self.save_function_tuple(obj) - return - else: - # func is nested - if lookedup_by_name is None or lookedup_by_name is not obj: - self.save_function_tuple(obj) - return - - if obj.__dict__: - # essentially save_reduce, but workaround needed to avoid recursion - self.save(_restore_attr) - write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') - self.memoize(obj) - self.save(obj.__dict__) - write(pickle.TUPLE + pickle.REDUCE) - else: - write(pickle.GLOBAL + modname + '\n' + name + '\n') - self.memoize(obj) + if not _is_global(obj, name=name): + return self.save_function_tuple(obj) + return Pickler.save_global(self, obj, name=name) dispatch[types.FunctionType] = save_function @@ -801,23 +810,15 @@ def save_global(self, obj, name=None, pack=struct.pack): return self.save_reduce(type, (Ellipsis,), obj=obj) elif obj is type(NotImplemented): return self.save_reduce(type, (NotImplemented,), obj=obj) - - if obj.__module__ == "__main__": - return self.save_dynamic_class(obj) - - try: - return Pickler.save_global(self, obj, name=name) - except Exception: - if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": - if obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce( - _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - - typ = type(obj) - if typ is not obj and isinstance(obj, (type, types.ClassType)): - return self.save_dynamic_class(obj) - - raise + elif obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce( + _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) + elif name is not None: + Pickler.save_global(self, obj, name=name) + elif not _is_global(obj, name=name): + self.save_dynamic_class(obj) + else: + Pickler.save_global(self, obj, name=name) dispatch[type] = save_global dispatch[types.ClassType] = save_global @@ -1085,13 +1086,6 @@ def dynamic_subimport(name, vars): return mod -# restores function attributes -def _restore_attr(obj, attr): - for key, val in attr.items(): - setattr(obj, key, val) - return obj - - def _gen_ellipsis(): return Ellipsis @@ -1298,7 +1292,29 @@ def _is_dynamic(module): return False if hasattr(module, '__spec__'): - return module.__spec__ is None + if module.__spec__ is not None: + return False + + # In PyPy, Some built-in modules such as _codecs can have their + # __spec__ attribute set to None despite being imported. For such + # modules, the ``_find_spec`` utility of the standard library is used. + parent_name = module.__name__.rpartition('.')[0] + if parent_name: # pragma: no cover + # This code handles the case where an imported package (and not + # module) remains with __spec__ set to None. It is however untested + # as no package in the PyPy stdlib has __spec__ set to None after + # it is imported. + try: + parent = sys.modules[parent_name] + except KeyError: + msg = "parent {!r} not in sys.modules" + raise ImportError(msg.format(parent_name)) + else: + pkgpath = parent.__path__ + else: + pkgpath = None + return _find_spec(module.__name__, pkgpath, module) is None + else: # Backward compat for Python 2 import imp diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 7f7d7dfd8..77fde239b 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -478,6 +478,15 @@ def method(self, x): mod1, mod2 = pickle_depickle([mod, mod]) self.assertEqual(id(mod1), id(mod2)) + # Ensure proper pickling of mod's functions when module "looks" like a + # file-backed module even though it is not: + try: + sys.modules['mod'] = mod + depickled_f = pickle_depickle(mod.f, protocol=self.protocol) + self.assertEqual(mod.f(5), depickled_f(5)) + finally: + sys.modules.pop('mod', None) + def test_module_locals_behavior(self): # Makes sure that a local function defined in another module is # correctly serialized. This notably checks that the globals are @@ -621,6 +630,10 @@ def test_is_dynamic_module(self): dynamic_module = types.ModuleType('dynamic_module') assert _is_dynamic(dynamic_module) + if platform.python_implementation() == 'PyPy': + import _codecs + assert not _is_dynamic(_codecs) + def test_Ellipsis(self): self.assertEqual(Ellipsis, pickle_depickle(Ellipsis, protocol=self.protocol)) @@ -1023,7 +1036,7 @@ def __init__(self, x): self.assertEqual(set(weakset), {depickled1, depickled2}) def test_faulty_module(self): - for module_name in ['_faulty_module', '_missing_module', None]: + for module_name in ['_missing_module', None]: class FaultyModule(object): def __getattr__(self, name): # This throws an exception while looking up within @@ -1794,6 +1807,15 @@ def f(a, /, b=1): """.format(protocol=self.protocol) assert_run_python_script(textwrap.dedent(code)) + def test___reduce___returns_string(self): + # Non regression test for objects with a __reduce__ method returning a + # string, meaning "save by attribute using save_global" + from .mypkg import some_singleton + assert some_singleton.__reduce__() == "some_singleton" + depickled_singleton = pickle_depickle( + some_singleton, protocol=self.protocol) + assert depickled_singleton is some_singleton + class Protocol2CloudPickleTest(CloudPickleTest): protocol = 2 diff --git a/tests/mypkg/__init__.py b/tests/mypkg/__init__.py index fe3cc6b1d..60d5b8d28 100644 --- a/tests/mypkg/__init__.py +++ b/tests/mypkg/__init__.py @@ -4,3 +4,12 @@ def package_function(): """Function living inside a package, not a simple module""" return "hello from a package!" + + +class _SingletonClass(object): + def __reduce__(self): + # This reducer is only valid for the top level "some_singleton" object. + return "some_singleton" + + +some_singleton = _SingletonClass()