Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ pip install pyarrow==15.0.0 Cython wheel pytest
pip install -v -e .
```

If the last steps fails with an error like `libarrow_python.dylib: No such file or directory`,
you are probably suffering from bazel's aggressive caching; the sought library is longer at the
temporary directory it was the last time bazel ran. To remedy this run

> bazel clean --expunge

### Environment Requirements

- python 3.8+
Expand Down
76 changes: 10 additions & 66 deletions python/pyfory/_fory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@
except ImportError:
np = None

from cloudpickle import Pickler

from pickle import Unpickler

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -104,8 +100,6 @@ class Fory:
"serialization_context",
"require_type_registration",
"buffer",
"pickler",
"unpickler",
"_buffer_callback",
"_buffers",
"metastring_resolver",
Expand Down Expand Up @@ -133,9 +127,7 @@ def __init__(
"""
self.language = language
self.is_py = language == Language.PYTHON
self.require_type_registration = (
_ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration
)
self.require_type_registration = _ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration
self.ref_tracking = ref_tracking
if self.ref_tracking:
self.ref_resolver = MapRefResolver()
Expand All @@ -149,18 +141,6 @@ def __init__(
self.type_resolver.initialize()
self.serialization_context = SerializationContext()
self.buffer = Buffer.allocate(32)
if not require_type_registration:
warnings.warn(
"Type registration is disabled, unknown types can be deserialized "
"which may be insecure.",
RuntimeWarning,
stacklevel=2,
)
self.pickler = Pickler(self.buffer)
self.unpickler = None
else:
self.pickler = _PicklerStub()
self.unpickler = _UnpicklerStub()
self._buffer_callback = None
self._buffers = None
self._unsupported_callback = None
Expand Down Expand Up @@ -214,9 +194,7 @@ def _serialize(
) -> Union[Buffer, bytes]:
self._buffer_callback = buffer_callback
self._unsupported_callback = unsupported_callback
if buffer is not None:
self.pickler = Pickler(buffer)
else:
if buffer is None:
self.buffer.writer_index = 0
buffer = self.buffer
if self.language == Language.XLANG:
Expand Down Expand Up @@ -349,27 +327,18 @@ def _deserialize(
if get_bit(buffer, reader_index, 0):
return None
is_little_endian_ = get_bit(buffer, reader_index, 1)
assert is_little_endian_, (
"Big endian is not supported for now, "
"please ensure peer machine is little endian."
)
assert is_little_endian_, "Big endian is not supported for now, please ensure peer machine is little endian."
is_target_x_lang = get_bit(buffer, reader_index, 2)
if is_target_x_lang:
self._peer_language = Language(buffer.read_int8())
else:
self._peer_language = Language.PYTHON
is_out_of_band_serialization_enabled = get_bit(buffer, reader_index, 3)
if is_out_of_band_serialization_enabled:
assert buffers is not None, (
"buffers shouldn't be null when the serialized stream is "
"produced with buffer_callback not null."
)
assert buffers is not None, "buffers shouldn't be null when the serialized stream is produced with buffer_callback not null."
self._buffers = iter(buffers)
else:
assert buffers is None, (
"buffers should be null when the serialized stream is "
"produced with buffer_callback null."
)
assert buffers is None, "buffers should be null when the serialized stream is produced with buffer_callback null."
if is_target_x_lang:
obj = self.xdeserialize_ref(buffer)
else:
Expand Down Expand Up @@ -442,17 +411,16 @@ def read_buffer_object(self, buffer) -> Buffer:
def handle_unsupported_write(self, buffer, obj):
if self._unsupported_callback is None or self._unsupported_callback(obj):
buffer.write_bool(True)
self.pickler.dump(obj)
# Use native serialization for all objects
self.serialize_ref(buffer, obj)
else:
buffer.write_bool(False)

def handle_unsupported_read(self, buffer):
in_band = buffer.read_bool()
if in_band:
unpickler = self.unpickler
if unpickler is None:
self.unpickler = unpickler = Unpickler(buffer)
return unpickler.load()
# The appropriate serializer will be determined by the type ID
return self.deserialize_ref(buffer)
else:
assert self._unsupported_objects is not None
return next(self._unsupported_objects)
Expand All @@ -473,7 +441,6 @@ def reset_write(self):
self.type_resolver.reset_write()
self.serialization_context.reset()
self.metastring_resolver.reset_write()
self.pickler.clear_memo()
self._buffer_callback = None
self._unsupported_callback = None

Expand All @@ -482,7 +449,6 @@ def reset_read(self):
self.type_resolver.reset_read()
self.serialization_context.reset()
self.metastring_resolver.reset_write()
self.unpickler = None
self._buffers = None
self._unsupported_objects = None

Expand Down Expand Up @@ -521,29 +487,7 @@ def reset(self):
self.objects.clear()


_ENABLE_TYPE_REGISTRATION_FORCIBLY = os.getenv(
"ENABLE_TYPE_REGISTRATION_FORCIBLY", "0"
) in {
_ENABLE_TYPE_REGISTRATION_FORCIBLY = os.getenv("ENABLE_TYPE_REGISTRATION_FORCIBLY", "0") in {
"1",
"true",
}


class _PicklerStub:
def dump(self, o):
raise ValueError(
f"Type {type(o)} is not registered, "
f"pickle is not allowed when type registration enabled, Please register"
f"the type or pass unsupported_callback"
)

def clear_memo(self):
pass


class _UnpicklerStub:
def load(self):
raise ValueError(
"pickle is not allowed when type registration enabled, Please register"
"the type or pass unsupported_callback"
)
6 changes: 4 additions & 2 deletions python/pyfory/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ def get_typeinfo(self, cls, create=True):
type_id = TypeId.NAMED_ENUM
elif type(serializer) is PickleSerializer:
type_id = PickleSerializer.PICKLE_TYPE_ID
elif isinstance(serializer, FunctionSerializer):
type_id = TypeId.NAMED_EXT
elif isinstance(serializer, (ObjectSerializer, StatefulSerializer, ReduceSerializer)):
type_id = TypeId.NAMED_EXT
if not self.require_registration:
Expand All @@ -491,8 +493,8 @@ def _create_serializer(self, cls):
break
else:
if cls is types.FunctionType:
# Use PickleSerializer for function types (including lambdas)
serializer = PickleSerializer(self.fory, cls)
# Use FunctionSerializer for function types (including lambdas)
serializer = FunctionSerializer(self.fory, cls)
elif dataclasses.is_dataclass(cls):
serializer = DataClassSerializer(self.fory, cls)
elif issubclass(cls, enum.Enum):
Expand Down
24 changes: 5 additions & 19 deletions python/pyfory/_serialization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ from typing import TypeVar, Union, Iterable
from pyfory._util import get_bit, set_bit, clear_bit
from pyfory import _fory as fmod
from pyfory._fory import Language
from pyfory._fory import _PicklerStub, _UnpicklerStub, Pickler, Unpickler
from pyfory._fory import _ENABLE_TYPE_REGISTRATION_FORCIBLY
from pyfory.lib import mmh3
from pyfory.meta.metastring import Encoding
Expand Down Expand Up @@ -581,8 +580,6 @@ cdef class Fory:
cdef readonly MetaStringResolver metastring_resolver
cdef readonly SerializationContext serialization_context
cdef Buffer buffer
cdef public object pickler # pickle.Pickler
cdef public object unpickler # Optional[pickle.Unpickler]
cdef object _buffer_callback
cdef object _buffers # iterator
cdef object _unsupported_callback
Expand Down Expand Up @@ -625,11 +622,6 @@ cdef class Fory:
RuntimeWarning,
stacklevel=2,
)
self.pickler = Pickler(self.buffer)
else:
self.pickler = _PicklerStub()
self.unpickler = _UnpicklerStub()
self.unpickler = None
self._buffer_callback = None
self._buffers = None
self._unsupported_callback = None
Expand Down Expand Up @@ -670,9 +662,7 @@ cdef class Fory:
self, obj, Buffer buffer, buffer_callback=None, unsupported_callback=None):
self._buffer_callback = buffer_callback
self._unsupported_callback = unsupported_callback
if buffer is not None:
self.pickler = Pickler(self.buffer)
else:
if buffer is None:
self.buffer.writer_index = 0
buffer = self.buffer
if self.language == Language.XLANG:
Expand Down Expand Up @@ -800,8 +790,6 @@ cdef class Fory:

cpdef inline _deserialize(
self, Buffer buffer, buffers=None, unsupported_objects=None):
if not self.require_type_registration:
self.unpickler = Unpickler(buffer)
if unsupported_objects is not None:
self._unsupported_objects = iter(unsupported_objects)
if self.language == Language.XLANG:
Expand Down Expand Up @@ -931,16 +919,16 @@ cdef class Fory:
cpdef inline handle_unsupported_write(self, Buffer buffer, obj):
if self._unsupported_callback is None or self._unsupported_callback(obj):
buffer.write_bool(True)
self.pickler.dump(obj)
# Use native serialization for all objects
self.serialize_ref(buffer, obj)
else:
buffer.write_bool(False)

cpdef inline handle_unsupported_read(self, Buffer buffer):
cdef c_bool in_band = buffer.read_bool()
if in_band:
if self.unpickler is None:
self.unpickler.buffer = Unpickler(buffer)
return self.unpickler.load()
# The appropriate serializer will be determined by the type ID
return self.deserialize_ref(buffer)
else:
assert self._unsupported_objects is not None
return next(self._unsupported_objects)
Expand Down Expand Up @@ -970,7 +958,6 @@ cdef class Fory:
self.type_resolver.reset_write()
self.metastring_resolver.reset_write()
self.serialization_context.reset()
self.pickler.clear_memo()
self._unsupported_callback = None

cpdef inline reset_read(self):
Expand All @@ -979,7 +966,6 @@ cdef class Fory:
self.metastring_resolver.reset_read()
self.serialization_context.reset()
self._buffers = None
self.unpickler = None
self._unsupported_objects = None

cpdef inline reset(self):
Expand Down
58 changes: 28 additions & 30 deletions python/pyfory/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,11 +692,12 @@ def xread(self, buffer):
return arr

def write(self, buffer, value):
buffer.write_varuint32(PickleSerializer.PICKLE_TYPE_ID)
self.fory.handle_unsupported_write(buffer, value)
# Use the xwrite method for Python arrays
self.xwrite(buffer, value)

def read(self, buffer):
return self.fory.handle_unsupported_read(buffer)
# Use the xread method for Python arrays
return self.xread(buffer)


if np:
Expand Down Expand Up @@ -752,11 +753,12 @@ def xread(self, buffer):
return np.frombuffer(data, dtype=self.dtype)

def write(self, buffer, value):
buffer.write_int8(PickleSerializer.PICKLE_TYPE_ID)
self.fory.handle_unsupported_write(buffer, value)
# Use the xwrite method for NumPy arrays
self.xwrite(buffer, value)

def read(self, buffer):
return self.fory.handle_unsupported_read(buffer)
# Use the xread method for NumPy arrays
return self.xread(buffer)


class NDArraySerializer(Serializer):
Expand All @@ -775,36 +777,28 @@ def xread(self, buffer):
raise NotImplementedError("Multi-dimensional array not supported currently")

def write(self, buffer, value):
buffer.write_int8(PickleSerializer.PICKLE_TYPE_ID)
self.fory.handle_unsupported_write(buffer, value)
# Use the xwrite method for NumPy arrays
self.xwrite(buffer, value)

def read(self, buffer):
return self.fory.handle_unsupported_read(buffer)
# Use the xread method for NumPy arrays
return self.xread(buffer)


class BytesSerializer(CrossLanguageCompatibleSerializer):
def write(self, buffer, value):
self.fory.write_buffer_object(buffer, BytesBufferObject(value))

def read(self, buffer):
fory_buf = self.fory.read_buffer_object(buffer)
return fory_buf.to_pybytes()


class BytesBufferObject(BufferObject):
__slots__ = ("binary",)

def __init__(self, binary: bytes):
self.binary = binary
def xwrite(self, buffer, value):
buffer.write_bytes_and_size(value)

def total_bytes(self) -> int:
return len(self.binary)
def xread(self, buffer):
return buffer.read_bytes_and_size()

def write_to(self, buffer: "Buffer"):
buffer.write_bytes(self.binary)
def write(self, buffer, value):
# Use the xwrite method for bytes
self.xwrite(buffer, value)

def to_buffer(self) -> "Buffer":
return Buffer(self.binary)
def read(self, buffer):
# Use the xread method for bytes
return self.xread(buffer)


class StatefulSerializer(CrossLanguageCompatibleSerializer):
Expand Down Expand Up @@ -1182,10 +1176,14 @@ def xread(self, buffer):
raise NotImplementedError

def write(self, buffer, value):
self.fory.handle_unsupported_write(buffer, value)
# Use standard pickle module instead of cloudpickle
serialized = pickle.dumps(value)
buffer.write_bytes_and_size(serialized)

def read(self, buffer):
return self.fory.handle_unsupported_read(buffer)
# Use standard pickle module instead of cloudpickle
serialized = buffer.read_bytes_and_size()
return pickle.loads(serialized)


class ObjectSerializer(Serializer):
Expand Down
Loading
Loading