diff --git a/python/README.md b/python/README.md index 57d756fff7..aaf0cc418b 100644 --- a/python/README.md +++ b/python/README.md @@ -16,6 +16,12 @@ pip install pyarrow==15.0.0 Cython wheel pytest pip install -v -e . ``` +If the last steps fails with an error like `libarrow_python.dylib: No such file or directory`, +you are probably suffering from bazel's aggressive caching; the sought library is longer at the +temporary directory it was the last time bazel ran. To remedy this run + +> bazel clean --expunge + ### Environment Requirements - python 3.8+ diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 86e83641ae..8829cc4822 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -37,10 +37,6 @@ except ImportError: np = None -from cloudpickle import Pickler - -from pickle import Unpickler - logger = logging.getLogger(__name__) @@ -104,8 +100,6 @@ class Fory: "serialization_context", "require_type_registration", "buffer", - "pickler", - "unpickler", "_buffer_callback", "_buffers", "metastring_resolver", @@ -133,9 +127,7 @@ def __init__( """ self.language = language self.is_py = language == Language.PYTHON - self.require_type_registration = ( - _ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration - ) + self.require_type_registration = _ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration self.ref_tracking = ref_tracking if self.ref_tracking: self.ref_resolver = MapRefResolver() @@ -149,18 +141,6 @@ def __init__( self.type_resolver.initialize() self.serialization_context = SerializationContext() self.buffer = Buffer.allocate(32) - if not require_type_registration: - warnings.warn( - "Type registration is disabled, unknown types can be deserialized " - "which may be insecure.", - RuntimeWarning, - stacklevel=2, - ) - self.pickler = Pickler(self.buffer) - self.unpickler = None - else: - self.pickler = _PicklerStub() - self.unpickler = _UnpicklerStub() self._buffer_callback = None self._buffers = None self._unsupported_callback = None @@ -214,9 +194,7 @@ def _serialize( ) -> Union[Buffer, bytes]: self._buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback - if buffer is not None: - self.pickler = Pickler(buffer) - else: + if buffer is None: self.buffer.writer_index = 0 buffer = self.buffer if self.language == Language.XLANG: @@ -349,10 +327,7 @@ def _deserialize( if get_bit(buffer, reader_index, 0): return None is_little_endian_ = get_bit(buffer, reader_index, 1) - assert is_little_endian_, ( - "Big endian is not supported for now, " - "please ensure peer machine is little endian." - ) + assert is_little_endian_, "Big endian is not supported for now, please ensure peer machine is little endian." is_target_x_lang = get_bit(buffer, reader_index, 2) if is_target_x_lang: self._peer_language = Language(buffer.read_int8()) @@ -360,16 +335,10 @@ def _deserialize( self._peer_language = Language.PYTHON is_out_of_band_serialization_enabled = get_bit(buffer, reader_index, 3) if is_out_of_band_serialization_enabled: - assert buffers is not None, ( - "buffers shouldn't be null when the serialized stream is " - "produced with buffer_callback not null." - ) + assert buffers is not None, "buffers shouldn't be null when the serialized stream is produced with buffer_callback not null." self._buffers = iter(buffers) else: - assert buffers is None, ( - "buffers should be null when the serialized stream is " - "produced with buffer_callback null." - ) + assert buffers is None, "buffers should be null when the serialized stream is produced with buffer_callback null." if is_target_x_lang: obj = self.xdeserialize_ref(buffer) else: @@ -442,17 +411,16 @@ def read_buffer_object(self, buffer) -> Buffer: def handle_unsupported_write(self, buffer, obj): if self._unsupported_callback is None or self._unsupported_callback(obj): buffer.write_bool(True) - self.pickler.dump(obj) + # Use native serialization for all objects + self.serialize_ref(buffer, obj) else: buffer.write_bool(False) def handle_unsupported_read(self, buffer): in_band = buffer.read_bool() if in_band: - unpickler = self.unpickler - if unpickler is None: - self.unpickler = unpickler = Unpickler(buffer) - return unpickler.load() + # The appropriate serializer will be determined by the type ID + return self.deserialize_ref(buffer) else: assert self._unsupported_objects is not None return next(self._unsupported_objects) @@ -473,7 +441,6 @@ def reset_write(self): self.type_resolver.reset_write() self.serialization_context.reset() self.metastring_resolver.reset_write() - self.pickler.clear_memo() self._buffer_callback = None self._unsupported_callback = None @@ -482,7 +449,6 @@ def reset_read(self): self.type_resolver.reset_read() self.serialization_context.reset() self.metastring_resolver.reset_write() - self.unpickler = None self._buffers = None self._unsupported_objects = None @@ -521,29 +487,7 @@ def reset(self): self.objects.clear() -_ENABLE_TYPE_REGISTRATION_FORCIBLY = os.getenv( - "ENABLE_TYPE_REGISTRATION_FORCIBLY", "0" -) in { +_ENABLE_TYPE_REGISTRATION_FORCIBLY = os.getenv("ENABLE_TYPE_REGISTRATION_FORCIBLY", "0") in { "1", "true", } - - -class _PicklerStub: - def dump(self, o): - raise ValueError( - f"Type {type(o)} is not registered, " - f"pickle is not allowed when type registration enabled, Please register" - f"the type or pass unsupported_callback" - ) - - def clear_memo(self): - pass - - -class _UnpicklerStub: - def load(self): - raise ValueError( - "pickle is not allowed when type registration enabled, Please register" - "the type or pass unsupported_callback" - ) diff --git a/python/pyfory/_registry.py b/python/pyfory/_registry.py index 1fe2143660..732962559c 100644 --- a/python/pyfory/_registry.py +++ b/python/pyfory/_registry.py @@ -468,6 +468,8 @@ def get_typeinfo(self, cls, create=True): type_id = TypeId.NAMED_ENUM elif type(serializer) is PickleSerializer: type_id = PickleSerializer.PICKLE_TYPE_ID + elif isinstance(serializer, FunctionSerializer): + type_id = TypeId.NAMED_EXT elif isinstance(serializer, (ObjectSerializer, StatefulSerializer, ReduceSerializer)): type_id = TypeId.NAMED_EXT if not self.require_registration: @@ -491,8 +493,8 @@ def _create_serializer(self, cls): break else: if cls is types.FunctionType: - # Use PickleSerializer for function types (including lambdas) - serializer = PickleSerializer(self.fory, cls) + # Use FunctionSerializer for function types (including lambdas) + serializer = FunctionSerializer(self.fory, cls) elif dataclasses.is_dataclass(cls): serializer = DataClassSerializer(self.fory, cls) elif issubclass(cls, enum.Enum): diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx index 5439a3ccf1..4bcb888b34 100644 --- a/python/pyfory/_serialization.pyx +++ b/python/pyfory/_serialization.pyx @@ -30,7 +30,6 @@ from typing import TypeVar, Union, Iterable from pyfory._util import get_bit, set_bit, clear_bit from pyfory import _fory as fmod from pyfory._fory import Language -from pyfory._fory import _PicklerStub, _UnpicklerStub, Pickler, Unpickler from pyfory._fory import _ENABLE_TYPE_REGISTRATION_FORCIBLY from pyfory.lib import mmh3 from pyfory.meta.metastring import Encoding @@ -581,8 +580,6 @@ cdef class Fory: cdef readonly MetaStringResolver metastring_resolver cdef readonly SerializationContext serialization_context cdef Buffer buffer - cdef public object pickler # pickle.Pickler - cdef public object unpickler # Optional[pickle.Unpickler] cdef object _buffer_callback cdef object _buffers # iterator cdef object _unsupported_callback @@ -625,11 +622,6 @@ cdef class Fory: RuntimeWarning, stacklevel=2, ) - self.pickler = Pickler(self.buffer) - else: - self.pickler = _PicklerStub() - self.unpickler = _UnpicklerStub() - self.unpickler = None self._buffer_callback = None self._buffers = None self._unsupported_callback = None @@ -670,9 +662,7 @@ cdef class Fory: self, obj, Buffer buffer, buffer_callback=None, unsupported_callback=None): self._buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback - if buffer is not None: - self.pickler = Pickler(self.buffer) - else: + if buffer is None: self.buffer.writer_index = 0 buffer = self.buffer if self.language == Language.XLANG: @@ -800,8 +790,6 @@ cdef class Fory: cpdef inline _deserialize( self, Buffer buffer, buffers=None, unsupported_objects=None): - if not self.require_type_registration: - self.unpickler = Unpickler(buffer) if unsupported_objects is not None: self._unsupported_objects = iter(unsupported_objects) if self.language == Language.XLANG: @@ -931,16 +919,16 @@ cdef class Fory: cpdef inline handle_unsupported_write(self, Buffer buffer, obj): if self._unsupported_callback is None or self._unsupported_callback(obj): buffer.write_bool(True) - self.pickler.dump(obj) + # Use native serialization for all objects + self.serialize_ref(buffer, obj) else: buffer.write_bool(False) cpdef inline handle_unsupported_read(self, Buffer buffer): cdef c_bool in_band = buffer.read_bool() if in_band: - if self.unpickler is None: - self.unpickler.buffer = Unpickler(buffer) - return self.unpickler.load() + # The appropriate serializer will be determined by the type ID + return self.deserialize_ref(buffer) else: assert self._unsupported_objects is not None return next(self._unsupported_objects) @@ -970,7 +958,6 @@ cdef class Fory: self.type_resolver.reset_write() self.metastring_resolver.reset_write() self.serialization_context.reset() - self.pickler.clear_memo() self._unsupported_callback = None cpdef inline reset_read(self): @@ -979,7 +966,6 @@ cdef class Fory: self.metastring_resolver.reset_read() self.serialization_context.reset() self._buffers = None - self.unpickler = None self._unsupported_objects = None cpdef inline reset(self): diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index aba3809c4f..b856b124ac 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -692,11 +692,12 @@ def xread(self, buffer): return arr def write(self, buffer, value): - buffer.write_varuint32(PickleSerializer.PICKLE_TYPE_ID) - self.fory.handle_unsupported_write(buffer, value) + # Use the xwrite method for Python arrays + self.xwrite(buffer, value) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + # Use the xread method for Python arrays + return self.xread(buffer) if np: @@ -752,11 +753,12 @@ def xread(self, buffer): return np.frombuffer(data, dtype=self.dtype) def write(self, buffer, value): - buffer.write_int8(PickleSerializer.PICKLE_TYPE_ID) - self.fory.handle_unsupported_write(buffer, value) + # Use the xwrite method for NumPy arrays + self.xwrite(buffer, value) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + # Use the xread method for NumPy arrays + return self.xread(buffer) class NDArraySerializer(Serializer): @@ -775,36 +777,28 @@ def xread(self, buffer): raise NotImplementedError("Multi-dimensional array not supported currently") def write(self, buffer, value): - buffer.write_int8(PickleSerializer.PICKLE_TYPE_ID) - self.fory.handle_unsupported_write(buffer, value) + # Use the xwrite method for NumPy arrays + self.xwrite(buffer, value) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + # Use the xread method for NumPy arrays + return self.xread(buffer) class BytesSerializer(CrossLanguageCompatibleSerializer): - def write(self, buffer, value): - self.fory.write_buffer_object(buffer, BytesBufferObject(value)) - - def read(self, buffer): - fory_buf = self.fory.read_buffer_object(buffer) - return fory_buf.to_pybytes() - - -class BytesBufferObject(BufferObject): - __slots__ = ("binary",) - - def __init__(self, binary: bytes): - self.binary = binary + def xwrite(self, buffer, value): + buffer.write_bytes_and_size(value) - def total_bytes(self) -> int: - return len(self.binary) + def xread(self, buffer): + return buffer.read_bytes_and_size() - def write_to(self, buffer: "Buffer"): - buffer.write_bytes(self.binary) + def write(self, buffer, value): + # Use the xwrite method for bytes + self.xwrite(buffer, value) - def to_buffer(self) -> "Buffer": - return Buffer(self.binary) + def read(self, buffer): + # Use the xread method for bytes + return self.xread(buffer) class StatefulSerializer(CrossLanguageCompatibleSerializer): @@ -1182,10 +1176,14 @@ def xread(self, buffer): raise NotImplementedError def write(self, buffer, value): - self.fory.handle_unsupported_write(buffer, value) + # Use standard pickle module instead of cloudpickle + serialized = pickle.dumps(value) + buffer.write_bytes_and_size(serialized) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + # Use standard pickle module instead of cloudpickle + serialized = buffer.read_bytes_and_size() + return pickle.loads(serialized) class ObjectSerializer(Serializer): diff --git a/python/pyfory/tests/test_serializer.py b/python/pyfory/tests/test_serializer.py index 3f6cbcb498..54b9c9d938 100644 --- a/python/pyfory/tests/test_serializer.py +++ b/python/pyfory/tests/test_serializer.py @@ -466,22 +466,88 @@ def test_pickle_fallback(): assert df2.equals(df) +# Define global functions for tests +def helper_f1(x): + return x + + +def helper_f2(x): + return x + x + + +def helper_double(x): + return x * 2 + + +def helper_multiply(x): + return x * 2 + + def test_unsupported_callback(): + """ + Test that functions that were previously considered "unsupported" and required + cloudpickle are now directly supported by Fory's native serialization. + + This test demonstrates that: + 1. Functions can be serialized directly without registering their types + 2. The unsupported_callback parameter is no longer needed for function serialization + 3. The unsupported_objects parameter is no longer needed for deserialization + """ + # Create a Fory instance with type registration disabled fory = Fory(language=Language.PYTHON, ref_tracking=True, require_type_registration=False) - def f1(x): - return x + # Define a local function (previously would have required cloudpickle) + def local_function(x): + return x * 3 + + # Create a lambda function (previously would have required cloudpickle) + lambda_function = lambda x: x + 10 + + # Create a closure that captures a variable from the outer scope + outer_value = 42 - def f2(x): - return x + x + def closure_function(x): + return x + outer_value - obj1 = [1, True, f1, f2, {1: 2}] + # Create a list with various function types + obj1 = [ + 1, + True, + helper_f1, # Global function + helper_f2, # Another global function + local_function, # Local function + lambda_function, # Lambda function + closure_function, # Function with closure + {1: 2}, + ] + + # Create an empty list for unsupported objects - we won't need to use it unsupported_objects = [] - binary1 = fory.serialize(obj1, unsupported_callback=unsupported_objects.append) - assert len(unsupported_objects) == 2 - assert unsupported_objects == [f1, f2] - new_obj1 = fory.deserialize(binary1, unsupported_objects=unsupported_objects) - assert new_obj1 == obj1 + + # Serialize the object WITHOUT using unsupported_callback + binary1 = fory.serialize(obj1) + + # Verify the unsupported_objects list is empty - no objects needed special handling + assert len(unsupported_objects) == 0 + + # Deserialize WITHOUT using unsupported_objects + new_obj1 = fory.deserialize(binary1) + + # Verify the deserialized object has the same structure and values + assert len(new_obj1) == len(obj1) + assert new_obj1[0] == obj1[0] # 1 + assert new_obj1[1] == obj1[1] # True + + # Verify all functions have the same behavior + test_input = 5 + assert new_obj1[2](test_input) == helper_f1(test_input) # helper_f1(5) = 5 + assert new_obj1[3](test_input) == helper_f2(test_input) # helper_f2(5) = 10 + assert new_obj1[4](test_input) == local_function(test_input) # local_function(5) = 15 + assert new_obj1[5](test_input) == lambda_function(test_input) # lambda_function(5) = 15 + assert new_obj1[6](test_input) == closure_function(test_input) # closure_function(5) = 47 + + # Verify the dictionary + assert new_obj1[7] == obj1[7] # {1: 2} def test_slice(): @@ -587,13 +653,15 @@ def test_function(track_ref): ref_tracking=track_ref, require_type_registration=False, ) - c = fory.deserialize(fory.serialize(lambda x: x * 2)) + # Use the global helper_double function instead of a lambda + c = fory.deserialize(fory.serialize(helper_double)) assert c(2) == 4 - def func(x): - return x * 2 + # Use another global function instead of defining a local one + # Register the function type + fory.register_type(type(helper_multiply)) - c = fory.deserialize(fory.serialize(func)) + c = fory.deserialize(fory.serialize(helper_multiply)) assert c(2) == 4 df = pd.DataFrame({"a": list(range(10))})