Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
except ImportError:
from StringIO import StringIO
PY3 = False

else:
types.ClassType = type
from pickle import _Pickler as Pickler
Expand Down Expand Up @@ -205,6 +206,8 @@ def _walk_global_ops(code):
if op in GLOBAL_OPS:
yield op, oparg

HAVE_ENUM = False

else:
def _walk_global_ops(code):
"""
Expand All @@ -216,6 +219,9 @@ def _walk_global_ops(code):
if op in GLOBAL_OPS:
yield op, instr.arg

HAVE_ENUM = True
import enum


class CloudPickler(Pickler):

Expand Down Expand Up @@ -865,6 +871,51 @@ def save_weakset(self, obj):

dispatch[weakref.WeakSet] = save_weakset

if HAVE_ENUM:

def save_enum_instance(self, obj):
"""Save an **instance** of enum.Enum.

This is what gets called to save enum members.
"""
self.save_reduce(getattr, (type(obj), obj.name))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only line that isn't covered by the test suite currently.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch. It's easy enough to add coverage for this case if we decide we want this feature, though I'm not totally convinced it's a good idea because idiomatic usage of Enum involves lots of is comparisons, which will break silently if a user unpickles the same Enum multiple times.


def save_enum_subclass(self, obj):
"""Save a **subclass** of enum.Enum.

This is what gets called to save the Enum class itself.
"""
if obj is enum.Enum:
self.save_global(obj)
return

# EnumMeta uses a custom dictionary subclass during class
# construction to keep track of member order. The EnumMeta
# constructor assumes that it will get an instance of the custom
# subclass, so we need to re-create what EnumMeta would have
# received when creating this object.
clsdict = enum._EnumDict()
for member in obj:
clsdict[member.name] = member.value

self.save_reduce(type(obj), (obj.__name__, obj.__bases__, clsdict))

def save_enum_classdict(self, obj):
"""
Save a dictionary to use as the third argument to EnumMeta.

This is called to save the custom subclass of dict used by
EnumMeta.__prepare__.
"""
# Get the entries in _member_names order so that we preserve the
# order of the enum members.
items = [(name, obj[name]) for name in obj._member_names]
self.save_reduce(type(obj), (), None, None, items)

dispatch[enum.Enum] = save_enum_instance
dispatch[enum.EnumMeta] = save_enum_subclass
dispatch[enum._EnumDict] = save_enum_classdict

"""Special functions for Add-on libraries"""
def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
Expand Down
46 changes: 46 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@
except ImportError:
tornado = None


try:
import enum

class SomeEnum(enum.Enum):
a = 'a'
b = 'b'
c = 'c'

except ImportError:
SomeEnum = None

import cloudpickle
from cloudpickle.cloudpickle import _find_module, _make_empty_cell, cell_set

Expand Down Expand Up @@ -652,6 +664,40 @@ def __init__(self, x):

self.assertEqual(set(weakset), set([depickled1, depickled2]))

def _check_enum(self, original):
if sys.version_info < (3,):
zip_longest = itertools.izip_longest
else:
zip_longest = itertools.zip_longest

depickled = pickle_depickle(original)
for orig_member, depickled_member in zip_longest(original, depickled):
self.assertEqual(orig_member.name, depickled_member.name)
self.assertEqual(orig_member.value, depickled_member.value)
self.assertIsInstance(depickled_member, depickled)

enum_then_members = [original] + list(original.__members__.values())
enum_ = enum_then_members[0]
members = enum_then_members[1:]

for before, after in itertools.zip_longest(enum_, members):
self.assertIs(before, after)

@pytest.mark.skipif(SomeEnum is None, reason="enum module doesn't exist")
def test_global_enum(self):
self._check_enum(SomeEnum)

@pytest.mark.skipif(SomeEnum is None, reason="enum module doesn't exist")
def test_dynamic_enum(self):

class DynamicEnum(enum.Enum):
a = 'a'
b = 'b'
c = 'c'
d = 'd'

self._check_enum(DynamicEnum)


if __name__ == '__main__':
unittest.main()