diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 908725e36..6f957f20a 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -65,6 +65,7 @@ except ImportError: from StringIO import StringIO PY3 = False + else: types.ClassType = type from pickle import _Pickler as Pickler @@ -205,6 +206,8 @@ def _walk_global_ops(code): if op in GLOBAL_OPS: yield op, oparg + HAVE_ENUM = False + else: def _walk_global_ops(code): """ @@ -216,6 +219,9 @@ def _walk_global_ops(code): if op in GLOBAL_OPS: yield op, instr.arg + HAVE_ENUM = True + import enum + class CloudPickler(Pickler): @@ -865,6 +871,51 @@ def save_weakset(self, obj): dispatch[weakref.WeakSet] = save_weakset + if HAVE_ENUM: + + def save_enum_instance(self, obj): + """Save an **instance** of enum.Enum. + + This is what gets called to save enum members. + """ + self.save_reduce(getattr, (type(obj), obj.name)) + + def save_enum_subclass(self, obj): + """Save a **subclass** of enum.Enum. + + This is what gets called to save the Enum class itself. + """ + if obj is enum.Enum: + self.save_global(obj) + return + + # EnumMeta uses a custom dictionary subclass during class + # construction to keep track of member order. The EnumMeta + # constructor assumes that it will get an instance of the custom + # subclass, so we need to re-create what EnumMeta would have + # received when creating this object. + clsdict = enum._EnumDict() + for member in obj: + clsdict[member.name] = member.value + + self.save_reduce(type(obj), (obj.__name__, obj.__bases__, clsdict)) + + def save_enum_classdict(self, obj): + """ + Save a dictionary to use as the third argument to EnumMeta. + + This is called to save the custom subclass of dict used by + EnumMeta.__prepare__. + """ + # Get the entries in _member_names order so that we preserve the + # order of the enum members. + items = [(name, obj[name]) for name in obj._member_names] + self.save_reduce(type(obj), (), None, None, items) + + dispatch[enum.Enum] = save_enum_instance + dispatch[enum.EnumMeta] = save_enum_subclass + dispatch[enum._EnumDict] = save_enum_classdict + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 22f656a77..f1c5842a0 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -40,6 +40,18 @@ except ImportError: tornado = None + +try: + import enum + + class SomeEnum(enum.Enum): + a = 'a' + b = 'b' + c = 'c' + +except ImportError: + SomeEnum = None + import cloudpickle from cloudpickle.cloudpickle import _find_module, _make_empty_cell, cell_set @@ -652,6 +664,40 @@ def __init__(self, x): self.assertEqual(set(weakset), set([depickled1, depickled2])) + def _check_enum(self, original): + if sys.version_info < (3,): + zip_longest = itertools.izip_longest + else: + zip_longest = itertools.zip_longest + + depickled = pickle_depickle(original) + for orig_member, depickled_member in zip_longest(original, depickled): + self.assertEqual(orig_member.name, depickled_member.name) + self.assertEqual(orig_member.value, depickled_member.value) + self.assertIsInstance(depickled_member, depickled) + + enum_then_members = [original] + list(original.__members__.values()) + enum_ = enum_then_members[0] + members = enum_then_members[1:] + + for before, after in itertools.zip_longest(enum_, members): + self.assertIs(before, after) + + @pytest.mark.skipif(SomeEnum is None, reason="enum module doesn't exist") + def test_global_enum(self): + self._check_enum(SomeEnum) + + @pytest.mark.skipif(SomeEnum is None, reason="enum module doesn't exist") + def test_dynamic_enum(self): + + class DynamicEnum(enum.Enum): + a = 'a' + b = 'b' + c = 'c' + d = 'd' + + self._check_enum(DynamicEnum) + if __name__ == '__main__': unittest.main()