Skip to content

Commit b7bad59

Browse files
committed
feat(python): drop-in replacement for pickle serialization (#2629)
<!-- **Thanks for contributing to Apache Fory™.** **If this is your first time opening a PR on fory, you can refer to [CONTRIBUTING.md](https://github.com/apache/fory/blob/main/CONTRIBUTING.md).** Contribution Checklist - The **Apache Fory™** community has requirements on the naming of pr titles. You can also find instructions in [CONTRIBUTING.md](https://github.com/apache/fory/blob/main/CONTRIBUTING.md). - Apache Fory™ has a strong focus on performance. If the PR you submit will have an impact on performance, please benchmark it first and provide the benchmark result here. --> Implement serialization for any pickleable objects, so that pyfory can be used to replace pickle for smaller size and faster speed. <!-- Describe the details of this PR. --> Closes #2417 <!-- If any user-facing interface changes, please [open an issue](https://github.com/apache/fory/issues/new/choose) describing the need to do so and update the document if necessary. Delete section if not applicable. --> - [ ] Does this PR introduce any public API change? - [ ] Does this PR introduce any binary protocol compatibility change? <!-- When the PR has an impact on performance (if you don't know whether the PR will have an impact on performance, you can submit the PR first, and if it will have impact on performance, the code reviewer will explain it), be sure to attach a benchmark data here. Delete section if not applicable. -->
1 parent 6c87ed5 commit b7bad59

8 files changed

Lines changed: 294 additions & 364 deletions

File tree

python/pyfory/_fory.py

Lines changed: 23 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import enum
1919
import logging
2020
import os
21-
import warnings
2221
from abc import ABC, abstractmethod
2322
from typing import Union, Iterable, TypeVar
2423

@@ -37,9 +36,6 @@
3736
except ImportError:
3837
np = None
3938

40-
from cloudpickle import Pickler
41-
42-
from pickle import Unpickler
4339

4440
logger = logging.getLogger(__name__)
4541

@@ -101,11 +97,8 @@ class Fory:
10197
"ref_tracking",
10298
"ref_resolver",
10399
"type_resolver",
104-
"serialization_context",
105100
"require_type_registration",
106101
"buffer",
107-
"pickler",
108-
"unpickler",
109102
"_buffer_callback",
110103
"_buffers",
111104
"metastring_resolver",
@@ -115,7 +108,6 @@ class Fory:
115108
"max_depth",
116109
"depth",
117110
)
118-
serialization_context: "SerializationContext"
119111

120112
def __init__(
121113
self,
@@ -152,19 +144,7 @@ def __init__(
152144
self.metastring_resolver = MetaStringResolver()
153145
self.type_resolver = TypeResolver(self)
154146
self.type_resolver.initialize()
155-
self.serialization_context = SerializationContext()
156147
self.buffer = Buffer.allocate(32)
157-
if not require_type_registration:
158-
warnings.warn(
159-
"Type registration is disabled, unknown types can be deserialized which may be insecure.",
160-
RuntimeWarning,
161-
stacklevel=2,
162-
)
163-
self.pickler = Pickler(self.buffer)
164-
self.unpickler = None
165-
else:
166-
self.pickler = _PicklerStub()
167-
self.unpickler = _UnpicklerStub()
168148
self._buffer_callback = None
169149
self._buffers = None
170150
self._unsupported_callback = None
@@ -231,9 +211,7 @@ def _serialize(
231211
) -> Union[Buffer, bytes]:
232212
self._buffer_callback = buffer_callback
233213
self._unsupported_callback = unsupported_callback
234-
if buffer is not None:
235-
self.pickler = Pickler(buffer)
236-
else:
214+
if buffer is None:
237215
self.buffer.writer_index = 0
238216
buffer = self.buffer
239217
if self.language == Language.XLANG:
@@ -463,21 +441,11 @@ def read_buffer_object(self, buffer) -> Buffer:
463441

464442
def handle_unsupported_write(self, buffer, obj):
465443
if self._unsupported_callback is None or self._unsupported_callback(obj):
466-
buffer.write_bool(True)
467-
self.pickler.dump(obj)
468-
else:
469-
buffer.write_bool(False)
444+
raise NotImplementedError(f"{type(obj)} is not supported for write")
470445

471446
def handle_unsupported_read(self, buffer):
472-
in_band = buffer.read_bool()
473-
if in_band:
474-
unpickler = self.unpickler
475-
if unpickler is None:
476-
self.unpickler = unpickler = Unpickler(buffer)
477-
return unpickler.load()
478-
else:
479-
assert self._unsupported_objects is not None
480-
return next(self._unsupported_objects)
447+
assert self._unsupported_objects is not None
448+
return next(self._unsupported_objects)
481449

482450
def write_ref_pyobject(self, buffer, value, typeinfo=None):
483451
if self.ref_resolver.write_ref_or_null(buffer, value):
@@ -490,6 +458,25 @@ def write_ref_pyobject(self, buffer, value, typeinfo=None):
490458
def read_ref_pyobject(self, buffer):
491459
return self.deserialize_ref(buffer)
492460

461+
def reset_write(self):
462+
self.ref_resolver.reset_write()
463+
self.type_resolver.reset_write()
464+
self.metastring_resolver.reset_write()
465+
self._buffer_callback = None
466+
self._unsupported_callback = None
467+
468+
def reset_read(self):
469+
self.depth = 0
470+
self.ref_resolver.reset_read()
471+
self.type_resolver.reset_read()
472+
self.metastring_resolver.reset_write()
473+
self._buffers = None
474+
self._unsupported_objects = None
475+
476+
def reset(self):
477+
self.reset_write()
478+
self.reset_read()
479+
493480
def inc_depth(self):
494481
self.depth += 1
495482
if self.depth > self.max_depth:
@@ -507,19 +494,15 @@ def throw_depth_limit_exceeded_exception(self):
507494
def reset_write(self):
508495
self.ref_resolver.reset_write()
509496
self.type_resolver.reset_write()
510-
self.serialization_context.reset()
511497
self.metastring_resolver.reset_write()
512-
self.pickler.clear_memo()
513498
self._buffer_callback = None
514499
self._unsupported_callback = None
515500

516501
def reset_read(self):
517502
self.depth = 0
518503
self.ref_resolver.reset_read()
519504
self.type_resolver.reset_read()
520-
self.serialization_context.reset()
521505
self.metastring_resolver.reset_write()
522-
self.unpickler = None
523506
self._buffers = None
524507
self._unsupported_objects = None
525508

@@ -562,20 +545,3 @@ def reset(self):
562545
"1",
563546
"true",
564547
}
565-
566-
567-
class _PicklerStub:
568-
def dump(self, o):
569-
raise ValueError(
570-
f"Type {type(o)} is not registered, "
571-
f"pickle is not allowed when type registration enabled, "
572-
f"Please register the type or pass unsupported_callback"
573-
)
574-
575-
def clear_memo(self):
576-
pass
577-
578-
579-
class _UnpicklerStub:
580-
def load(self):
581-
raise ValueError("pickle is not allowed when type registration enabled, Please register the type or pass unsupported_callback")

python/pyfory/_registry.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from typing import TypeVar, Union
2626
from enum import Enum
2727

28-
from pyfory._serialization import ENABLE_FORY_CYTHON_SERIALIZATION
28+
from pyfory import ENABLE_FORY_CYTHON_SERIALIZATION
2929
from pyfory import Language
3030
from pyfory.error import TypeUnregisteredError
3131

@@ -35,9 +35,6 @@
3535
NDArraySerializer,
3636
PyArraySerializer,
3737
DynamicPyArraySerializer,
38-
_PickleStub,
39-
PickleStrongCacheStub,
40-
PickleCacheStub,
4138
NoneSerializer,
4239
BooleanSerializer,
4340
ByteSerializer,
@@ -56,14 +53,15 @@
5653
SetSerializer,
5754
EnumSerializer,
5855
SliceSerializer,
59-
PickleCacheSerializer,
60-
PickleStrongCacheSerializer,
61-
PickleSerializer,
6256
DataClassSerializer,
6357
StatefulSerializer,
6458
ReduceSerializer,
6559
FunctionSerializer,
6660
ObjectSerializer,
61+
TypeSerializer,
62+
MethodSerializer,
63+
UnsupportedSerializer,
64+
NativeFuncMethodSerializer,
6765
)
6866
from pyfory.meta.metastring import MetaStringEncoder, MetaStringDecoder
6967
from pyfory.type import (
@@ -75,6 +73,7 @@
7573
Float32Type,
7674
Float64Type,
7775
load_class,
76+
record_class_factory,
7877
)
7978
from pyfory._fory import (
8079
DYNAMIC_TYPE_ID,
@@ -158,9 +157,10 @@ class TypeResolver:
158157
"metastring_resolver",
159158
"language",
160159
"_type_id_to_typeinfo",
160+
"_internal_py_serializer_map",
161161
)
162162

163-
def __init__(self, fory):
163+
def __init__(self, fory, meta_share=False):
164164
self.fory = fory
165165
self.metastring_resolver = fory.metastring_resolver
166166
self.language = fory.language
@@ -182,34 +182,41 @@ def __init__(self, fory):
182182
self.namespace_decoder = MetaStringDecoder(".", "_")
183183
self.typename_encoder = MetaStringEncoder("$", "_")
184184
self.typename_decoder = MetaStringDecoder("$", "_")
185+
self._internal_py_serializer_map = {}
185186

186187
def initialize(self):
187-
self._initialize_xlang()
188+
self._initialize_common()
188189
if self.fory.language == Language.PYTHON:
189190
self._initialize_py()
191+
else:
192+
self._initialize_xlang()
190193

191194
def _initialize_py(self):
192195
register = functools.partial(self._register_type, internal=True)
193-
register(
194-
_PickleStub,
195-
type_id=PickleSerializer.PICKLE_TYPE_ID,
196-
serializer=PickleSerializer,
197-
)
198-
register(
199-
PickleStrongCacheStub,
200-
type_id=97,
201-
serializer=PickleStrongCacheSerializer(self.fory),
202-
)
203-
register(
204-
PickleCacheStub,
205-
type_id=98,
206-
serializer=PickleCacheSerializer(self.fory),
207-
)
208196
register(type(None), serializer=NoneSerializer)
209197
register(tuple, serializer=TupleSerializer)
210198
register(slice, serializer=SliceSerializer)
199+
register(np.ndarray, serializer=NDArraySerializer)
200+
register(array.array, serializer=DynamicPyArraySerializer)
201+
self._internal_py_serializer_map = {
202+
ReduceSerializer: (self._stub_cls("__Reduce__"), self._next_type_id()),
203+
TypeSerializer: (self._stub_cls("__Type__"), self._next_type_id()),
204+
MethodSerializer: (self._stub_cls("__Method__"), self._next_type_id()),
205+
NativeFuncMethodSerializer: (self._stub_cls("__NativeFunction__"), self._next_type_id()),
206+
}
207+
for serializer, (stub_cls, type_id) in self._internal_py_serializer_map.items():
208+
register(stub_cls, serializer=serializer, type_id=type_id)
209+
210+
@staticmethod
211+
def _stub_cls(name: str):
212+
return record_class_factory(name, [])
211213

212214
def _initialize_xlang(self):
215+
register = functools.partial(self._register_type, internal=True)
216+
register(array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer)
217+
register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer)
218+
219+
def _initialize_common(self):
213220
register = functools.partial(self._register_type, internal=True)
214221
register(None, type_id=TypeId.NA, serializer=NoneSerializer)
215222
register(bool, type_id=TypeId.BOOL, serializer=BooleanSerializer)
@@ -240,7 +247,6 @@ def _initialize_xlang(self):
240247
type_id=typeid,
241248
serializer=PyArraySerializer(self.fory, ftype, typeid),
242249
)
243-
register(array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer)
244250
if np:
245251
# overwrite pyarray with same type id.
246252
# if pyarray are needed, one must annotate that value with XXXArrayType
@@ -256,7 +262,6 @@ def _initialize_xlang(self):
256262
type_id=typeid,
257263
serializer=Numpy1DArraySerializer(self.fory, ftype, dtype),
258264
)
259-
register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer)
260265
register(list, type_id=TypeId.LIST, serializer=ListSerializer)
261266
register(set, type_id=TypeId.SET, serializer=SetSerializer)
262267
register(dict, type_id=TypeId.MAP, serializer=MapSerializer)
@@ -416,7 +421,7 @@ def __register_type(
416421
self._named_type_to_typeinfo[(namespace, typename)] = typeinfo
417422
self._ns_type_to_typeinfo[(ns_meta_bytes, type_meta_bytes)] = typeinfo
418423
self._types_info[cls] = typeinfo
419-
if type_id > 0 and (self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id)):
424+
if type_id is not None and type_id != 0 and (self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id)):
420425
if type_id not in self._type_id_to_typeinfo or not internal:
421426
self._type_id_to_typeinfo[type_id] = typeinfo
422427
self._types_info[cls] = typeinfo
@@ -469,12 +474,12 @@ def get_typeinfo(self, cls, create=True):
469474
if self.language == Language.PYTHON:
470475
if isinstance(serializer, EnumSerializer):
471476
type_id = TypeId.NAMED_ENUM
472-
elif type(serializer) is PickleSerializer:
473-
type_id = PickleSerializer.PICKLE_TYPE_ID
474477
elif isinstance(serializer, FunctionSerializer):
475478
type_id = TypeId.NAMED_EXT
476-
elif isinstance(serializer, (ObjectSerializer, StatefulSerializer, ReduceSerializer)):
479+
elif isinstance(serializer, (ObjectSerializer, StatefulSerializer)):
477480
type_id = TypeId.NAMED_EXT
481+
elif self._internal_py_serializer_map.get(type(serializer)) is not None:
482+
type_id = self._internal_py_serializer_map.get(type(serializer))[1]
478483
if not self.require_registration:
479484
if isinstance(serializer, DataClassSerializer):
480485
type_id = TypeId.NAMED_STRUCT
@@ -502,35 +507,33 @@ def _create_serializer(self, cls):
502507
serializer = DataClassSerializer(self.fory, cls)
503508
elif issubclass(cls, enum.Enum):
504509
serializer = EnumSerializer(self.fory, cls)
510+
elif ("builtin_function_or_method" in str(cls) or "cython_function_or_method" in str(cls)) and "<locals>" not in str(cls):
511+
serializer = NativeFuncMethodSerializer(self.fory, cls)
512+
elif cls is type(self.initialize):
513+
# Handle bound method objects
514+
serializer = MethodSerializer(self.fory, cls)
515+
elif issubclass(cls, type):
516+
# Handle Python type objects and metaclass such as numpy._DTypeMeta(i.e. np.dtype)
517+
serializer = TypeSerializer(self.fory, cls)
518+
elif cls is array.array:
519+
# Handle array.array objects with DynamicPyArraySerializer
520+
# Note: This will use DynamicPyArraySerializer for all array.array objects
521+
serializer = DynamicPyArraySerializer(self.fory, cls)
505522
elif (hasattr(cls, "__reduce__") and cls.__reduce__ is not object.__reduce__) or (
506523
hasattr(cls, "__reduce_ex__") and cls.__reduce_ex__ is not object.__reduce_ex__
507524
):
508525
# Use ReduceSerializer for objects that have custom __reduce__ or __reduce_ex__ methods
509526
# This has higher precedence than StatefulSerializer and ObjectSerializer
510527
# Only use it for objects with custom reduce methods, not default ones from the object
511-
module_name = getattr(cls, "__module__", "")
512-
if module_name.startswith("pandas.") or module_name == "builtins" or cls.__name__ in ("type", "function", "method"):
513-
# Exclude pandas, built-ins, and certain system types
514-
serializer = PickleSerializer(self.fory, cls)
515-
else:
516-
serializer = ReduceSerializer(self.fory, cls)
528+
serializer = ReduceSerializer(self.fory, cls)
517529
elif hasattr(cls, "__getstate__") and hasattr(cls, "__setstate__"):
518530
# Use StatefulSerializer for objects that support __getstate__ and __setstate__
519-
# But exclude certain types that have incompatible state methods
520-
module_name = getattr(cls, "__module__", "")
521-
if module_name.startswith("pandas."):
522-
# Pandas objects have __getstate__/__setstate__ but use incompatible pickle formats
523-
serializer = PickleSerializer(self.fory, cls)
524-
else:
525-
serializer = StatefulSerializer(self.fory, cls)
526-
elif (
527-
cls is not type
528-
and (hasattr(cls, "__dict__") or hasattr(cls, "__slots__"))
529-
and not (np and (issubclass(cls, np.dtype) or cls is type(np.dtype)))
530-
):
531+
serializer = StatefulSerializer(self.fory, cls)
532+
elif hasattr(cls, "__dict__") or hasattr(cls, "__slots__"):
531533
serializer = ObjectSerializer(self.fory, cls)
532534
else:
533-
serializer = PickleSerializer(self.fory, cls)
535+
# c-extension types will go to here
536+
serializer = UnsupportedSerializer(self.fory, cls)
534537
return serializer
535538

536539
def _load_metabytes_to_typeinfo(self, ns_metabytes, type_metabytes):

0 commit comments

Comments
 (0)