Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 4 additions & 52 deletions python/pyfory/_fory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import enum
import logging
import os
import warnings
from abc import ABC, abstractmethod
from typing import Union, Iterable, TypeVar

Expand All @@ -37,9 +36,6 @@
except ImportError:
np = None

from cloudpickle import Pickler

from pickle import Unpickler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -105,8 +101,6 @@ class Fory:
"serialization_context",
"require_type_registration",
"buffer",
"pickler",
"unpickler",
"_buffer_callback",
"_buffers",
"metastring_resolver",
Expand Down Expand Up @@ -160,17 +154,6 @@ def __init__(
self.type_resolver.initialize()

self.buffer = Buffer.allocate(32)
if not require_type_registration:
warnings.warn(
"Type registration is disabled, unknown types can be deserialized which may be insecure.",
RuntimeWarning,
stacklevel=2,
)
self.pickler = Pickler(self.buffer)
self.unpickler = None
else:
self.pickler = _PicklerStub()
self.unpickler = _UnpicklerStub()
self._buffer_callback = None
self._buffers = None
self._unsupported_callback = None
Expand Down Expand Up @@ -237,9 +220,7 @@ def _serialize(
) -> Union[Buffer, bytes]:
self._buffer_callback = buffer_callback
self._unsupported_callback = unsupported_callback
if buffer is not None:
self.pickler = Pickler(buffer)
else:
if buffer is None:
self.buffer.writer_index = 0
buffer = self.buffer
if self.language == Language.XLANG:
Expand Down Expand Up @@ -493,21 +474,11 @@ def read_buffer_object(self, buffer) -> Buffer:

def handle_unsupported_write(self, buffer, obj):
if self._unsupported_callback is None or self._unsupported_callback(obj):
buffer.write_bool(True)
self.pickler.dump(obj)
else:
buffer.write_bool(False)
raise NotImplementedError(f"{type(obj)} is not supported for write")

def handle_unsupported_read(self, buffer):
in_band = buffer.read_bool()
if in_band:
unpickler = self.unpickler
if unpickler is None:
self.unpickler = unpickler = Unpickler(buffer)
return unpickler.load()
else:
assert self._unsupported_objects is not None
return next(self._unsupported_objects)
assert self._unsupported_objects is not None
return next(self._unsupported_objects)

def write_ref_pyobject(self, buffer, value, typeinfo=None):
if self.ref_resolver.write_ref_or_null(buffer, value):
Expand All @@ -525,7 +496,6 @@ def reset_write(self):
self.type_resolver.reset_write()
self.serialization_context.reset_write()
self.metastring_resolver.reset_write()
self.pickler.clear_memo()
self._buffer_callback = None
self._unsupported_callback = None

Expand All @@ -535,7 +505,6 @@ def reset_read(self):
self.type_resolver.reset_read()
self.serialization_context.reset_read()
self.metastring_resolver.reset_write()
self.unpickler = None
self._buffers = None
self._unsupported_objects = None

Expand All @@ -562,20 +531,3 @@ def throw_depth_limit_exceeded_exception(self):
"1",
"true",
}


class _PicklerStub:
def dump(self, o):
raise ValueError(
f"Type {type(o)} is not registered, "
f"pickle is not allowed when type registration enabled, "
f"Please register the type or pass unsupported_callback"
)

def clear_memo(self):
pass


class _UnpicklerStub:
def load(self):
raise ValueError("pickle is not allowed when type registration enabled, Please register the type or pass unsupported_callback")
99 changes: 51 additions & 48 deletions python/pyfory/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import TypeVar, Union
from enum import Enum

from pyfory._serialization import ENABLE_FORY_CYTHON_SERIALIZATION
from pyfory import ENABLE_FORY_CYTHON_SERIALIZATION
from pyfory import Language
from pyfory.error import TypeUnregisteredError

Expand All @@ -35,9 +35,6 @@
NDArraySerializer,
PyArraySerializer,
DynamicPyArraySerializer,
_PickleStub,
PickleStrongCacheStub,
PickleCacheStub,
NoneSerializer,
BooleanSerializer,
ByteSerializer,
Expand All @@ -56,15 +53,16 @@
SetSerializer,
EnumSerializer,
SliceSerializer,
PickleCacheSerializer,
PickleStrongCacheSerializer,
PickleSerializer,
DataClassSerializer,
DataClassStubSerializer,
StatefulSerializer,
ReduceSerializer,
FunctionSerializer,
ObjectSerializer,
TypeSerializer,
MethodSerializer,
UnsupportedSerializer,
NativeFuncMethodSerializer,
)
from pyfory.meta.metastring import MetaStringEncoder, MetaStringDecoder
from pyfory.meta.meta_compressor import DeflaterMetaCompressor
Expand All @@ -78,6 +76,7 @@
Float64Type,
load_class,
is_struct_type,
record_class_factory,
)
from pyfory._fory import (
DYNAMIC_TYPE_ID,
Expand Down Expand Up @@ -171,6 +170,7 @@ class TypeResolver:
"_meta_shared_typeinfo",
"meta_share",
"serialization_context",
"_internal_py_serializer_map",
)

def __init__(self, fory, meta_share=False):
Expand Down Expand Up @@ -199,35 +199,42 @@ def __init__(self, fory, meta_share=False):
self.typename_decoder = MetaStringDecoder("$", "_")
self.meta_compressor = DeflaterMetaCompressor()
self.meta_share = meta_share
self._internal_py_serializer_map = {}

def initialize(self):
self._initialize_xlang()
self._initialize_common()
if self.fory.language == Language.PYTHON:
self._initialize_py()
else:
self._initialize_xlang()
self.serialization_context = self.fory.serialization_context

def _initialize_py(self):
register = functools.partial(self._register_type, internal=True)
register(
_PickleStub,
type_id=PickleSerializer.PICKLE_TYPE_ID,
serializer=PickleSerializer,
)
register(
PickleStrongCacheStub,
type_id=97,
serializer=PickleStrongCacheSerializer(self.fory),
)
register(
PickleCacheStub,
type_id=98,
serializer=PickleCacheSerializer(self.fory),
)
register(type(None), serializer=NoneSerializer)
register(tuple, serializer=TupleSerializer)
register(slice, serializer=SliceSerializer)
register(np.ndarray, serializer=NDArraySerializer)
register(array.array, serializer=DynamicPyArraySerializer)
self._internal_py_serializer_map = {
ReduceSerializer: (self._stub_cls("__Reduce__"), self._next_type_id()),
TypeSerializer: (self._stub_cls("__Type__"), self._next_type_id()),
MethodSerializer: (self._stub_cls("__Method__"), self._next_type_id()),
NativeFuncMethodSerializer: (self._stub_cls("__NativeFunction__"), self._next_type_id()),
}
for serializer, (stub_cls, type_id) in self._internal_py_serializer_map.items():
register(stub_cls, serializer=serializer, type_id=type_id)

@staticmethod
def _stub_cls(name: str):
return record_class_factory(name, [])

def _initialize_xlang(self):
register = functools.partial(self._register_type, internal=True)
register(array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer)
register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer)

def _initialize_common(self):
register = functools.partial(self._register_type, internal=True)
register(None, type_id=TypeId.NA, serializer=NoneSerializer)
register(bool, type_id=TypeId.BOOL, serializer=BooleanSerializer)
Expand Down Expand Up @@ -258,7 +265,6 @@ def _initialize_xlang(self):
type_id=typeid,
serializer=PyArraySerializer(self.fory, ftype, typeid),
)
register(array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer)
if np:
# overwrite pyarray with same type id.
# if pyarray are needed, one must annotate that value with XXXArrayType
Expand All @@ -274,7 +280,6 @@ def _initialize_xlang(self):
type_id=typeid,
serializer=Numpy1DArraySerializer(self.fory, ftype, dtype),
)
register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer)
register(list, type_id=TypeId.LIST, serializer=ListSerializer)
register(set, type_id=TypeId.SET, serializer=SetSerializer)
register(dict, type_id=TypeId.MAP, serializer=MapSerializer)
Expand Down Expand Up @@ -447,7 +452,7 @@ def __register_type(
self._named_type_to_typeinfo[(namespace, typename)] = typeinfo
self._ns_type_to_typeinfo[(ns_meta_bytes, type_meta_bytes)] = typeinfo
self._types_info[cls] = typeinfo
if type_id > 0 and (self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id)):
if type_id is not None and type_id != 0 and (self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id)):
if type_id not in self._type_id_to_typeinfo or not internal:
self._type_id_to_typeinfo[type_id] = typeinfo
self._types_info[cls] = typeinfo
Expand Down Expand Up @@ -500,12 +505,12 @@ def get_typeinfo(self, cls, create=True):
if self.language == Language.PYTHON:
if isinstance(serializer, EnumSerializer):
type_id = TypeId.NAMED_ENUM
elif type(serializer) is PickleSerializer:
type_id = PickleSerializer.PICKLE_TYPE_ID
elif isinstance(serializer, FunctionSerializer):
type_id = TypeId.NAMED_EXT
elif isinstance(serializer, (ObjectSerializer, StatefulSerializer, ReduceSerializer)):
elif isinstance(serializer, (ObjectSerializer, StatefulSerializer)):
type_id = TypeId.NAMED_EXT
elif self._internal_py_serializer_map.get(type(serializer)) is not None:
type_id = self._internal_py_serializer_map.get(type(serializer))[1]
if not self.require_registration:
if isinstance(serializer, DataClassSerializer):
type_id = TypeId.NAMED_STRUCT
Expand Down Expand Up @@ -552,35 +557,33 @@ def _create_serializer(self, cls):
serializer = DataClassStubSerializer(self.fory, cls, xlang=not self.fory.is_py)
elif issubclass(cls, enum.Enum):
serializer = EnumSerializer(self.fory, cls)
elif ("builtin_function_or_method" in str(cls) or "cython_function_or_method" in str(cls)) and "<locals>" not in str(cls):
serializer = NativeFuncMethodSerializer(self.fory, cls)
elif cls is type(self.initialize):
# Handle bound method objects
serializer = MethodSerializer(self.fory, cls)
elif issubclass(cls, type):
# Handle Python type objects and metaclass such as numpy._DTypeMeta(i.e. np.dtype)
serializer = TypeSerializer(self.fory, cls)
elif cls is array.array:
# Handle array.array objects with DynamicPyArraySerializer
# Note: This will use DynamicPyArraySerializer for all array.array objects
serializer = DynamicPyArraySerializer(self.fory, cls)
elif (hasattr(cls, "__reduce__") and cls.__reduce__ is not object.__reduce__) or (
hasattr(cls, "__reduce_ex__") and cls.__reduce_ex__ is not object.__reduce_ex__
):
# Use ReduceSerializer for objects that have custom __reduce__ or __reduce_ex__ methods
# This has higher precedence than StatefulSerializer and ObjectSerializer
# Only use it for objects with custom reduce methods, not default ones from the object
module_name = getattr(cls, "__module__", "")
if module_name.startswith("pandas.") or module_name == "builtins" or cls.__name__ in ("type", "function", "method"):
# Exclude pandas, built-ins, and certain system types
serializer = PickleSerializer(self.fory, cls)
else:
serializer = ReduceSerializer(self.fory, cls)
serializer = ReduceSerializer(self.fory, cls)
elif hasattr(cls, "__getstate__") and hasattr(cls, "__setstate__"):
# Use StatefulSerializer for objects that support __getstate__ and __setstate__
# But exclude certain types that have incompatible state methods
module_name = getattr(cls, "__module__", "")
if module_name.startswith("pandas."):
# Pandas objects have __getstate__/__setstate__ but use incompatible pickle formats
serializer = PickleSerializer(self.fory, cls)
else:
serializer = StatefulSerializer(self.fory, cls)
elif (
cls is not type
and (hasattr(cls, "__dict__") or hasattr(cls, "__slots__"))
and not (np and (issubclass(cls, np.dtype) or cls is type(np.dtype)))
):
serializer = StatefulSerializer(self.fory, cls)
elif hasattr(cls, "__dict__") or hasattr(cls, "__slots__"):
serializer = ObjectSerializer(self.fory, cls)
else:
serializer = PickleSerializer(self.fory, cls)
# c-extension types will go to here
serializer = UnsupportedSerializer(self.fory, cls)
return serializer

def is_registered_by_name(self, cls):
Expand Down
Loading
Loading