diff --git a/CHANGES.md b/CHANGES.md index 2e1b6f00..17e621c1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -17,6 +17,10 @@ dev _is_parametrized_type_hint to limit false positives. ([PR #409](https://github.com/cloudpipe/cloudpickle/pull/409)) +- Support pickling / depickling of OrderedDict KeysView, ValuesView, and + ItemsView, following similar strategy for vanilla Python dictionaries. + ([PR #423](https://github.com/cloudpipe/cloudpickle/pull/423)) + - Suppressed a source of non-determinism when pickling dynamically defined functions and handles the deprecation of co_lnotab in Python 3.10+. ([PR #428](https://github.com/cloudpipe/cloudpickle/pull/428)) diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 07dc7617..763e9d6f 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -55,6 +55,7 @@ import warnings from .compat import pickle +from collections import OrderedDict from typing import Generic, Union, Tuple, Callable from pickle import _getattribute from importlib._bootstrap import _find_spec @@ -855,13 +856,22 @@ def _get_bases(typ): return getattr(typ, bases_attr) -def _make_dict_keys(obj): - return dict.fromkeys(obj).keys() +def _make_dict_keys(obj, is_ordered=False): + if is_ordered: + return OrderedDict.fromkeys(obj).keys() + else: + return dict.fromkeys(obj).keys() -def _make_dict_values(obj): - return {i: _ for i, _ in enumerate(obj)}.values() +def _make_dict_values(obj, is_ordered=False): + if is_ordered: + return OrderedDict((i, _) for i, _ in enumerate(obj)).values() + else: + return {i: _ for i, _ in enumerate(obj)}.values() -def _make_dict_items(obj): - return obj.items() +def _make_dict_items(obj, is_ordered=False): + if is_ordered: + return OrderedDict(obj).items() + else: + return obj.items() diff --git a/cloudpickle/cloudpickle_fast.py b/cloudpickle/cloudpickle_fast.py index c46914c6..10ceef1b 100644 --- a/cloudpickle/cloudpickle_fast.py +++ b/cloudpickle/cloudpickle_fast.py @@ -23,7 +23,7 @@ import typing from enum import Enum -from collections import ChainMap +from collections import ChainMap, OrderedDict from .compat import pickle, Pickler from .cloudpickle import ( @@ -437,6 +437,24 @@ def _dict_items_reduce(obj): return _make_dict_items, (dict(obj), ) +def _odict_keys_reduce(obj): + # Safer not to ship the full dict as sending the rest might + # be unintended and could potentially cause leaking of + # sensitive information + return _make_dict_keys, (list(obj), True) + + +def _odict_values_reduce(obj): + # Safer not to ship the full dict as sending the rest might + # be unintended and could potentially cause leaking of + # sensitive information + return _make_dict_values, (list(obj), True) + + +def _odict_items_reduce(obj): + return _make_dict_items, (dict(obj), True) + + # COLLECTIONS OF OBJECTS STATE SETTERS # ------------------------------------ # state setters are called at unpickling time, once the object is created and @@ -513,6 +531,9 @@ class CloudPickler(Pickler): _dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce _dispatch_table[_collections_abc.dict_values] = _dict_values_reduce _dispatch_table[_collections_abc.dict_items] = _dict_items_reduce + _dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce + _dispatch_table[type(OrderedDict().values())] = _odict_values_reduce + _dispatch_table[type(OrderedDict().items())] = _odict_items_reduce dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index e2d84012..baca23cc 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -229,6 +229,24 @@ def test_dict_items(self): self.assertEqual(results, items) assert isinstance(results, _collections_abc.dict_items) + def test_odict_keys(self): + keys = collections.OrderedDict([("a", 1), ("b", 2)]).keys() + results = pickle_depickle(keys) + self.assertEqual(results, keys) + assert type(keys) == type(results) + + def test_odict_values(self): + values = collections.OrderedDict([("a", 1), ("b", 2)]).values() + results = pickle_depickle(values) + self.assertEqual(list(results), list(values)) + assert type(values) == type(results) + + def test_odict_items(self): + items = collections.OrderedDict([("a", 1), ("b", 2)]).items() + results = pickle_depickle(items) + self.assertEqual(results, items) + assert type(items) == type(results) + def test_sliced_and_non_contiguous_memoryview(self): buffer_obj = memoryview(b"Hello!" * 3)[2:15:2] self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),