Skip to content
18 changes: 17 additions & 1 deletion python/ray/cloudpickle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
from __future__ import absolute_import
import sys

from ray.cloudpickle.cloudpickle import *
# TODO(suquark): This is a temporary flag for
# the new serialization implementation.
# Remove it when the old one is deprecated.
USE_NEW_SERIALIZER = False

if USE_NEW_SERIALIZER and sys.version_info[:2] >= (3, 8):
from ray.cloudpickle.cloudpickle_fast import *
FAST_CLOUDPICKLE_USED = True
else:
try:
import pickle5
except ImportError:
# We need pickle5 backport support for the new serializer.
USE_NEW_SERIALIZER = False
from ray.cloudpickle.cloudpickle import *
FAST_CLOUDPICKLE_USED = False

__version__ = '1.2.2.dev0'
30 changes: 21 additions & 9 deletions python/ray/cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import logging
import opcode
import operator
import pickle
import platform
import struct
import sys
Expand All @@ -59,6 +58,14 @@
import uuid
import threading

PICKLE5_ENABLED = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should only be needed for cloudpickle_fast, right? Let's focus on supporting that codepath, so it doesn't break other things.

Copy link
Member Author

@suquark suquark Sep 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we we have pickle5 enabled (python3 && python<3.8 && pickle5 installed), then we will try to use the old cloudpickle


try:
import pickle5 as pickle
from pickle5 import PickleBuffer # export PickleBuffer
PICKLE5_ENABLED = True
except ImportError:
import pickle

try:
from enum import Enum
Expand Down Expand Up @@ -95,7 +102,10 @@
PY2 = True
else:
types.ClassType = type
from pickle import _Pickler as Pickler
if PICKLE5_ENABLED:
from pickle5 import _Pickler as Pickler
else:
from pickle import _Pickler as Pickler
from io import BytesIO as StringIO
string_types = (str,)
PY3 = True
Expand Down Expand Up @@ -466,13 +476,15 @@ def _extract_class_dict(cls):


class CloudPickler(Pickler):

dispatch = Pickler.dispatch.copy()

def __init__(self, file, protocol=None):
def __init__(self, file, protocol=None, buffer_callback=None):
if protocol is None:
protocol = DEFAULT_PROTOCOL
Pickler.__init__(self, file, protocol=protocol)
if PICKLE5_ENABLED:
Pickler.__init__(self, file, protocol=protocol, buffer_callback=buffer_callback)
else:
Pickler.__init__(self, file, protocol=protocol)
# map ids to dictionary. used to ensure that functions can share global env
self.globals_ref = {}

Expand Down Expand Up @@ -1094,7 +1106,7 @@ def _rebuild_tornado_coroutine(func):

# Shorthands for legacy support

def dump(obj, file, protocol=None):
def dump(obj, file, protocol=None, buffer_callback=None):
"""Serialize obj as bytes streamed into file

protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
Expand All @@ -1104,10 +1116,10 @@ def dump(obj, file, protocol=None):
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
CloudPickler(file, protocol=protocol).dump(obj)
CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback).dump(obj)


def dumps(obj, protocol=None):
def dumps(obj, protocol=None, buffer_callback=None):
"""Serialize obj as a string of bytes allocated in memory

protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
Expand All @@ -1119,7 +1131,7 @@ def dumps(obj, protocol=None):
"""
file = StringIO()
try:
cp = CloudPickler(file, protocol=protocol)
cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback)
cp.dump(obj)
return file.getvalue()
finally:
Expand Down
Loading