diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index cb2883f80..030d44a3f 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -56,6 +56,7 @@ import types import weakref + if sys.version < '3': from pickle import Pickler try: @@ -69,6 +70,92 @@ from io import BytesIO as StringIO PY3 = True + +def _make_cell_set_template_code(): + """Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF + + Notes + ----- + In Python 3, we could use an easier function: + + .. code-block:: python + + def f(): + cell = None + + def _stub(value): + nonlocal cell + cell = value + + return _stub + + _cell_set_template_code = f() + + This function is _only_ a LOAD_FAST(arg); STORE_DEREF, but that is + invalid syntax on Python 2. If we use this function we also don't need + to do the weird freevars/cellvars swap below + """ + def inner(value): + lambda: cell # make ``cell`` a closure so that we get a STORE_DEREF + cell = value + + co = inner.__code__ + + # NOTE: we are marking the cell variable as a free variable intentionally + # so that we simulate an inner function instead of the outer function. This + # is what gives us the ``nonlocal`` behavior in a Python 2 compatible way. + if not PY3: + return types.CodeType( + co.co_argcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # this is the trickery + (), + ) + else: + return types.CodeType( + co.co_argcount, + co.co_kwonlyargcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # this is the trickery + (), + ) + + +_cell_set_template_code = _make_cell_set_template_code() + + +def cell_set(cell, value): + """Set the value of a closure cell. + """ + return types.FunctionType( + _cell_set_template_code, + {}, + '_cell_set_inner', + (), + (cell,), + )(value) + + #relevant opcodes STORE_GLOBAL = opcode.opmap['STORE_GLOBAL'] DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL'] @@ -305,7 +392,6 @@ def _save_subimports(self, code, top_level_dependencies): # then discards the reference to it self.write(pickle.POP) - def save_function_tuple(self, func): """ Pickles an actual func object. @@ -326,16 +412,23 @@ def save_function_tuple(self, func): save = self.save write = self.write - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + code, f_globals, defaults, closure_values, dct, base_globals = self.extract_func_data(func) save(_fill_function) # skeleton function updater write(pickle.MARK) # beginning of tuple that _fill_function expects - self._save_subimports(code, set(f_globals.values()) | set(closure)) + self._save_subimports( + code, + itertools.chain(f_globals.values(), closure_values or ()), + ) # create a skeleton function object and memoize it save(_make_skel_func) - save((code, closure, base_globals)) + save(( + code, + len(closure_values) if closure_values is not None else -1, + base_globals, + )) write(pickle.REDUCE) self.memoize(func) @@ -343,6 +436,7 @@ def save_function_tuple(self, func): save(f_globals) save(defaults) save(dct) + save(closure_values) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple @@ -380,7 +474,7 @@ def extract_code_globals(cls, co): def extract_func_data(self, func): """ Turn the function into a tuple of data necessary to recreate it: - code, globals, defaults, closure, dict + code, globals, defaults, closure_values, dict """ code = func.__code__ @@ -397,7 +491,10 @@ def extract_func_data(self, func): defaults = func.__defaults__ # process closure - closure = [c.cell_contents for c in func.__closure__] if func.__closure__ else [] + closure = ( + [c.cell_contents for c in func.__closure__] + if func.__closure__ is not None else None + ) # save the dict dct = func.__dict__ @@ -799,38 +896,46 @@ def _gen_ellipsis(): def _gen_not_implemented(): return NotImplemented -def _fill_function(func, globals, defaults, dict): +def _fill_function(func, globals, defaults, dict, closure_values): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). - """ + """ func.__globals__.update(globals) func.__defaults__ = defaults func.__dict__ = dict - return func + cells = func.__closure__ + if cells is not None: + for cell, value in zip(cells, closure_values): + cell_set(cell, value) + return func -def _make_cell(value): - return (lambda: value).__closure__[0] +def _make_empty_cell(): + if False: + # trick the compiler into creating an empty cell in our lambda + cell = None + raise AssertionError('this route should not be executed') -def _reconstruct_closure(values): - return tuple([_make_cell(v) for v in values]) + return (lambda: cell).__closure__[0] -def _make_skel_func(code, closures, base_globals = None): +def _make_skel_func(code, cell_count, base_globals=None): """ Creates a skeleton function object that contains just the provided code and the correct number of cells in func_closure. All other func attributes (e.g. func_globals) are empty. """ - closure = _reconstruct_closure(closures) if closures else None - if base_globals is None: base_globals = {} base_globals['__builtins__'] = __builtins__ - return types.FunctionType(code, base_globals, - None, None, closure) + closure = ( + tuple(_make_empty_cell() for _ in range(cell_count)) + if cell_count >= 0 else + None + ) + return types.FunctionType(code, base_globals, None, None, closure) def _find_module(mod_name): diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 29deb8cdf..19f1faf1f 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -38,7 +38,7 @@ from io import BytesIO import cloudpickle -from cloudpickle.cloudpickle import _find_module +from cloudpickle.cloudpickle import _find_module, _make_empty_cell, cell_set from .testutils import subprocess_pickle_echo @@ -133,6 +133,52 @@ def test_nested_lambdas(self): f2 = lambda x: f1(x) // b self.assertEqual(pickle_depickle(f2)(1), 1) + def test_recursive_closure(self): + def f1(): + def g(): + return g + return g + + def f2(base): + def g(n): + return base if n <= 1 else n * g(n - 1) + return g + + g1 = pickle_depickle(f1()) + self.assertEqual(g1(), g1) + + g2 = pickle_depickle(f2(2)) + self.assertEqual(g2(5), 240) + + def test_closure_none_is_preserved(self): + def f(): + """a function with no closure cells + """ + + self.assertTrue( + f.__closure__ is None, + msg='f actually has closure cells!', + ) + + g = pickle_depickle(f) + + self.assertTrue( + g.__closure__ is None, + msg='g now has closure cells even though f does not', + ) + + def test_unhashable_closure(self): + def f(): + s = set((1, 2)) # mutable set is unhashable + + def g(): + return len(s) + + return g + + g = pickle_depickle(f()) + self.assertEqual(g(), 2) + @pytest.mark.skipif(sys.version_info >= (3, 4) and sys.version_info < (3, 4, 3), reason="subprocess has a bug in 3.4.0 to 3.4.2") @@ -448,6 +494,19 @@ def example(): "'))()") assert not subprocess.call([sys.executable, '-c', command]) + def test_cell_manipulation(self): + cell = _make_empty_cell() + + with pytest.raises(ValueError): + cell.cell_contents + + ob = object() + cell_set(cell, ob) + self.assertTrue( + cell.cell_contents is ob, + msg='cell contents not set correctly', + ) + if __name__ == '__main__': unittest.main()