diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index e8f4223b2..fb40342e8 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -238,7 +238,7 @@ def save_function(self, obj, name=None): # a builtin_function_or_method which comes in as an attribute of some # object (e.g., object.__new__, itertools.chain.from_iterable) will end # up with modname "__main__" and so end up here. But these functions - # have no __code__ attribute in CPython, so the handling for + # have no __code__ attribute in CPython, so the handling for # user-defined functions below will fail. # So we pickle them here using save_reduce; have to do it differently # for different python versions. @@ -282,6 +282,27 @@ def save_function(self, obj, name=None): self.memoize(obj) dispatch[types.FunctionType] = save_function + def _save_subimports(self, code, top_level_dependencies): + """ + Ensure de-pickler imports any package child-modules that + are needed by the function + """ + # check if any known dependency is an imported package + for x in top_level_dependencies: + if isinstance(x, types.ModuleType) and x.__package__: + # check if the package has any currently loaded sub-imports + prefix = x.__name__ + '.' + for name, module in sys.modules.items(): + if name.startswith(prefix): + # check whether the function can address the sub-module + tokens = set(name[len(prefix):].split('.')) + if not tokens - set(code.co_names): + # ensure unpickler executes this import + self.save(module) + # then discards the reference to it + self.write(pickle.POP) + + def save_function_tuple(self, func): """ Pickles an actual func object. @@ -307,6 +328,8 @@ def save_function_tuple(self, func): save(_fill_function) # skeleton function updater write(pickle.MARK) # beginning of tuple that _fill_function expects + self._save_subimports(code, set(f_globals.values()) | set(closure)) + # create a skeleton function object and memoize it save(_make_skel_func) save((code, closure, base_globals)) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index a1dec1151..c540d5d4c 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -9,6 +9,8 @@ import itertools import platform import textwrap +import base64 +import subprocess try: # try importing numpy and scipy. These are not hard dependencies and @@ -360,6 +362,92 @@ def f(): self.assertTrue(f2 is f3) self.assertEqual(f2(), res) + def test_submodule(self): + # Function that refers (by attribute) to a sub-module of a package. + + # Choose any module NOT imported by __init__ of its parent package + # examples in standard library include: + # - http.cookies, unittest.mock, curses.textpad, xml.etree.ElementTree + + global xml # imitate performing this import at top of file + import xml.etree.ElementTree + def example(): + x = xml.etree.ElementTree.Comment # potential AttributeError + + s = cloudpickle.dumps(example) + + # refresh the environment, i.e., unimport the dependency + del xml + for item in list(sys.modules): + if item.split('.')[0] == 'xml': + del sys.modules[item] + + # deserialise + f = pickle.loads(s) + f() # perform test for error + + def test_submodule_closure(self): + # Same as test_submodule except the package is not a global + def scope(): + import xml.etree.ElementTree + def example(): + x = xml.etree.ElementTree.Comment # potential AttributeError + return example + example = scope() + + s = cloudpickle.dumps(example) + + # refresh the environment (unimport dependency) + for item in list(sys.modules): + if item.split('.')[0] == 'xml': + del sys.modules[item] + + f = cloudpickle.loads(s) + f() # test + + def test_multiprocess(self): + # running a function pickled by another process (a la dask.distributed) + def scope(): + import curses.textpad + def example(): + x = xml.etree.ElementTree.Comment + x = curses.textpad.Textbox + return example + global xml + import xml.etree.ElementTree + example = scope() + + s = cloudpickle.dumps(example) + + # choose "subprocess" rather than "multiprocessing" because the latter + # library uses fork to preserve the parent environment. + command = ("import pickle, base64; " + "pickle.loads(base64.b32decode('" + + base64.b32encode(s).decode('ascii') + + "'))()") + assert not subprocess.call([sys.executable, '-c', command]) + + def test_import(self): + # like test_multiprocess except subpackage modules referenced directly + # (unlike test_submodule) + global etree + def scope(): + import curses.textpad as foobar + def example(): + x = etree.Comment + x = foobar.Textbox + return example + example = scope() + import xml.etree.ElementTree as etree + + s = cloudpickle.dumps(example) + + command = ("import pickle, base64; " + "pickle.loads(base64.b32decode('" + + base64.b32encode(s).decode('ascii') + + "'))()") + assert not subprocess.call([sys.executable, '-c', command]) + if __name__ == '__main__': unittest.main()