diff --git a/CHANGES.md b/CHANGES.md index 73bd21cf9..0ca5a92b4 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -19,6 +19,10 @@ over the module list ([PR #322](https://github.com/cloudpipe/cloudpickle/pull/322)). +- Add support for out-of-band pickling (Python 3.8 and later). + https://docs.python.org/3/library/pickle.html#example + ([issue #308](https://github.com/cloudpipe/cloudpickle/pull/308)) + 1.2.2 ===== diff --git a/cloudpickle/cloudpickle_fast.py b/cloudpickle/cloudpickle_fast.py index 4f52a4e95..1b6f6a38c 100644 --- a/cloudpickle/cloudpickle_fast.py +++ b/cloudpickle/cloudpickle_fast.py @@ -34,7 +34,7 @@ # Shorthands similar to pickle.dump/pickle.dumps -def dump(obj, file, protocol=None): +def dump(obj, file, protocol=None, buffer_callback=None): """Serialize obj as bytes streamed into file protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to @@ -44,10 +44,10 @@ def dump(obj, file, protocol=None): Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure compatibility with older versions of Python. """ - CloudPickler(file, protocol=protocol).dump(obj) + CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback).dump(obj) -def dumps(obj, protocol=None): +def dumps(obj, protocol=None, buffer_callback=None): """Serialize obj as a string of bytes allocated in memory protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to @@ -58,7 +58,7 @@ def dumps(obj, protocol=None): compatibility with older versions of Python. """ with io.BytesIO() as file: - cp = CloudPickler(file, protocol=protocol) + cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback) cp.dump(obj) return file.getvalue() @@ -421,10 +421,10 @@ class CloudPickler(Pickler): dispatch[types.MappingProxyType] = _mappingproxy_reduce dispatch[weakref.WeakSet] = _weakset_reduce - def __init__(self, file, protocol=None): + def __init__(self, file, protocol=None, buffer_callback=None): if protocol is None: protocol = DEFAULT_PROTOCOL - Pickler.__init__(self, file, protocol=protocol) + Pickler.__init__(self, file, protocol=protocol, buffer_callback=buffer_callback) # map functions __globals__ attribute ids, to ensure that functions # sharing the same global namespace at pickling time also share their # global namespace at unpickling time. diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index ee9d55dbb..dd425ec2e 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -2052,6 +2052,22 @@ def __getattr__(self, name): with pytest.raises(pickle.PicklingError, match='recursion'): cloudpickle.dumps(a) + def test_out_of_band_buffers(self): + if self.protocol < 5: + pytest.skip("Need Pickle Protocol 5 or later") + np = pytest.importorskip("numpy") + + class LocallyDefinedClass: + data = np.zeros(10) + + data_instance = LocallyDefinedClass() + buffers = [] + pickle_bytes = cloudpickle.dumps(data_instance, protocol=self.protocol, + buffer_callback=buffers.append) + assert len(buffers) == 1 + reconstructed = pickle.loads(pickle_bytes, buffers=buffers) + np.testing.assert_allclose(reconstructed.data, data_instance.data) + class Protocol2CloudPickleTest(CloudPickleTest):