diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index 471a6069e8..5d34076265 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -2064,20 +2064,63 @@ public void resetRead() {} public void resetWrite() {} + private static final GenericType OBJECT_GENERIC_TYPE = GenericType.build(Object.class); + @CodegenInvoke public GenericType getGenericTypeInStruct(Class cls, String genericTypeStr) { Map map = - extRegistry.classGenericTypes.computeIfAbsent(cls, k -> new HashMap<>()); - GenericType genericType = map.get(genericTypeStr); - if (genericType == null) { - for (Field field : ReflectionUtils.getFields(cls, true)) { - Type type = field.getGenericType(); - TypeRef typeRef = TypeRef.of(type); - genericType = buildGenericType(typeRef); - map.put(type.getTypeName(), genericType); - } + extRegistry.classGenericTypes.computeIfAbsent(cls, this::buildGenericMap); + return map.getOrDefault(genericTypeStr, OBJECT_GENERIC_TYPE); + } + + /** + * Build a map of nested generic type name to generic type for all fields in the class. + * + * @param cls the class to build the map of nested generic type name to generic type for all + * fields in the class + * @return a map of nested generic type name to generic type for all fields in the class + */ + protected Map buildGenericMap(Class cls) { + Map map = new HashMap<>(); + Map map2 = new HashMap<>(); + for (Field field : ReflectionUtils.getFields(cls, true)) { + Type type = field.getGenericType(); + GenericType genericType = buildGenericType(type); + buildGenericMap(map, genericType); + TypeRef typeRef = TypeRef.of(type); + buildGenericMap(map2, typeRef); + } + map.putAll(map2); + return map; + } + + private void buildGenericMap(Map map, TypeRef typeRef) { + if (map.containsKey(typeRef.getType().getTypeName())) { + return; + } + map.put(typeRef.getType().getTypeName(), buildGenericType(typeRef)); + Class rawType = typeRef.getRawType(); + if (TypeUtils.isMap(rawType)) { + Tuple2, TypeRef> kvTypes = TypeUtils.getMapKeyValueType(typeRef); + buildGenericMap(map, kvTypes.f0); + buildGenericMap(map, kvTypes.f1); + } else if (TypeUtils.isCollection(rawType)) { + TypeRef elementType = TypeUtils.getElementType(typeRef); + buildGenericMap(map, elementType); + } else if (rawType.isArray()) { + TypeRef arrayComponent = TypeUtils.getArrayComponent(typeRef); + buildGenericMap(map, arrayComponent); + } + } + + private void buildGenericMap(Map map, GenericType genericType) { + if (map.containsKey(genericType.getType().getTypeName())) { + return; + } + map.put(genericType.getType().getTypeName(), genericType); + for (GenericType t : genericType.getTypeParameters()) { + buildGenericMap(map, t); } - return genericType; } @Override diff --git a/java/fory-core/src/main/java/org/apache/fory/type/GenericType.java b/java/fory-core/src/main/java/org/apache/fory/type/GenericType.java index f5e7333f4f..36c9a3eb7e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/GenericType.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/GenericType.java @@ -166,6 +166,10 @@ public TypeRef getTypeRef() { return typeRef; } + public Type getType() { + return typeRef.getType(); + } + public Class getCls() { return cls; } diff --git a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java index 2095d2000d..03fc675d70 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java @@ -414,6 +414,14 @@ public static Class getComponentIfArray(Class type) { return type; } + public static TypeRef getArrayComponent(TypeRef type) { + if (type.getType() instanceof GenericArrayType) { + Type componentType = ((GenericArrayType) (type.getType())).getGenericComponentType(); + return TypeRef.of(componentType); + } + return TypeRef.of(getArrayComponentInfo(type.getRawType()).f0); + } + public static Class getArrayComponent(Class type) { return getArrayComponentInfo(type).f0; } diff --git a/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java b/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java index 2bb24606e8..a286de792d 100644 --- a/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java +++ b/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java @@ -97,6 +97,11 @@ public static Object[][] referenceTrackingConfig() { return new Object[][] {{false}, {true}}; } + @DataProvider + public static Object[][] compatible() { + return new Object[][] {{false}, {true}}; + } + @DataProvider public static Object[][] trackingRefFory() { return new Object[][] { diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java index b19b76e261..4e00a79536 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java @@ -1048,4 +1048,30 @@ public void testChunkArrayGeneric() { State state2 = (State) fory2.deserialize(bytes); Assert.assertEquals(state2.map.get("bar"), new String[] {"bar"}); } + + @Data + public static class OuterClass { + private Map f1 = new HashMap<>(); + private TestEnum f2; + } + + @Data + public static class InnerClass { + int f1; + } + + @Test(dataProvider = "compatible") + public void testNestedMapGenericCodegen(boolean compatible) { + Fory fory = + builder() + .withCodegen(true) + .withCompatibleMode( + compatible ? CompatibleMode.COMPATIBLE : CompatibleMode.SCHEMA_CONSISTENT) + .requireClassRegistration(false) + .build(); + + OuterClass value = new OuterClass(); + value.f1.put("aaa", null); + serDeCheck(fory, value); + } } diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 60562f9b44..3900c2bf28 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -18,7 +18,6 @@ import enum import logging import os -import warnings from abc import ABC, abstractmethod from typing import Union, Iterable, TypeVar @@ -37,9 +36,6 @@ except ImportError: np = None -from cloudpickle import Pickler - -from pickle import Unpickler logger = logging.getLogger(__name__) @@ -101,11 +97,8 @@ class Fory: "ref_tracking", "ref_resolver", "type_resolver", - "serialization_context", "require_type_registration", "buffer", - "pickler", - "unpickler", "_buffer_callback", "_buffers", "metastring_resolver", @@ -115,7 +108,6 @@ class Fory: "max_depth", "depth", ) - serialization_context: "SerializationContext" def __init__( self, @@ -152,19 +144,7 @@ def __init__( self.metastring_resolver = MetaStringResolver() self.type_resolver = TypeResolver(self) self.type_resolver.initialize() - self.serialization_context = SerializationContext() self.buffer = Buffer.allocate(32) - if not require_type_registration: - warnings.warn( - "Type registration is disabled, unknown types can be deserialized which may be insecure.", - RuntimeWarning, - stacklevel=2, - ) - self.pickler = Pickler(self.buffer) - self.unpickler = None - else: - self.pickler = _PicklerStub() - self.unpickler = _UnpicklerStub() self._buffer_callback = None self._buffers = None self._unsupported_callback = None @@ -231,9 +211,7 @@ def _serialize( ) -> Union[Buffer, bytes]: self._buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback - if buffer is not None: - self.pickler = Pickler(buffer) - else: + if buffer is None: self.buffer.writer_index = 0 buffer = self.buffer if self.language == Language.XLANG: @@ -463,21 +441,11 @@ def read_buffer_object(self, buffer) -> Buffer: def handle_unsupported_write(self, buffer, obj): if self._unsupported_callback is None or self._unsupported_callback(obj): - buffer.write_bool(True) - self.pickler.dump(obj) - else: - buffer.write_bool(False) + raise NotImplementedError(f"{type(obj)} is not supported for write") def handle_unsupported_read(self, buffer): - in_band = buffer.read_bool() - if in_band: - unpickler = self.unpickler - if unpickler is None: - self.unpickler = unpickler = Unpickler(buffer) - return unpickler.load() - else: - assert self._unsupported_objects is not None - return next(self._unsupported_objects) + assert self._unsupported_objects is not None + return next(self._unsupported_objects) def write_ref_pyobject(self, buffer, value, typeinfo=None): if self.ref_resolver.write_ref_or_null(buffer, value): @@ -507,9 +475,7 @@ def throw_depth_limit_exceeded_exception(self): def reset_write(self): self.ref_resolver.reset_write() self.type_resolver.reset_write() - self.serialization_context.reset() self.metastring_resolver.reset_write() - self.pickler.clear_memo() self._buffer_callback = None self._unsupported_callback = None @@ -517,9 +483,7 @@ def reset_read(self): self.depth = 0 self.ref_resolver.reset_read() self.type_resolver.reset_read() - self.serialization_context.reset() self.metastring_resolver.reset_write() - self.unpickler = None self._buffers = None self._unsupported_objects = None @@ -562,20 +526,3 @@ def reset(self): "1", "true", } - - -class _PicklerStub: - def dump(self, o): - raise ValueError( - f"Type {type(o)} is not registered, " - f"pickle is not allowed when type registration enabled, " - f"Please register the type or pass unsupported_callback" - ) - - def clear_memo(self): - pass - - -class _UnpicklerStub: - def load(self): - raise ValueError("pickle is not allowed when type registration enabled, Please register the type or pass unsupported_callback") diff --git a/python/pyfory/_registry.py b/python/pyfory/_registry.py index f72baac63c..f107f664e3 100644 --- a/python/pyfory/_registry.py +++ b/python/pyfory/_registry.py @@ -25,7 +25,7 @@ from typing import TypeVar, Union from enum import Enum -from pyfory._serialization import ENABLE_FORY_CYTHON_SERIALIZATION +from pyfory import ENABLE_FORY_CYTHON_SERIALIZATION from pyfory import Language from pyfory.error import TypeUnregisteredError @@ -35,9 +35,6 @@ NDArraySerializer, PyArraySerializer, DynamicPyArraySerializer, - _PickleStub, - PickleStrongCacheStub, - PickleCacheStub, NoneSerializer, BooleanSerializer, ByteSerializer, @@ -56,14 +53,15 @@ SetSerializer, EnumSerializer, SliceSerializer, - PickleCacheSerializer, - PickleStrongCacheSerializer, - PickleSerializer, DataClassSerializer, StatefulSerializer, ReduceSerializer, FunctionSerializer, ObjectSerializer, + TypeSerializer, + MethodSerializer, + UnsupportedSerializer, + NativeFuncMethodSerializer, ) from pyfory.meta.metastring import MetaStringEncoder, MetaStringDecoder from pyfory.type import ( @@ -75,6 +73,7 @@ Float32Type, Float64Type, load_class, + record_class_factory, ) from pyfory._fory import ( DYNAMIC_TYPE_ID, @@ -158,9 +157,10 @@ class TypeResolver: "metastring_resolver", "language", "_type_id_to_typeinfo", + "_internal_py_serializer_map", ) - def __init__(self, fory): + def __init__(self, fory, meta_share=False): self.fory = fory self.metastring_resolver = fory.metastring_resolver self.language = fory.language @@ -182,34 +182,43 @@ def __init__(self, fory): self.namespace_decoder = MetaStringDecoder(".", "_") self.typename_encoder = MetaStringEncoder("$", "_") self.typename_decoder = MetaStringDecoder("$", "_") + self._internal_py_serializer_map = {} def initialize(self): - self._initialize_xlang() + self._initialize_common() if self.fory.language == Language.PYTHON: self._initialize_py() + else: + self._initialize_xlang() def _initialize_py(self): register = functools.partial(self._register_type, internal=True) - register( - _PickleStub, - type_id=PickleSerializer.PICKLE_TYPE_ID, - serializer=PickleSerializer, - ) - register( - PickleStrongCacheStub, - type_id=97, - serializer=PickleStrongCacheSerializer(self.fory), - ) - register( - PickleCacheStub, - type_id=98, - serializer=PickleCacheSerializer(self.fory), - ) register(type(None), serializer=NoneSerializer) register(tuple, serializer=TupleSerializer) register(slice, serializer=SliceSerializer) + register(np.ndarray, serializer=NDArraySerializer) + register(array.array, serializer=DynamicPyArraySerializer) + if not self.require_registration: + self._internal_py_serializer_map = { + ReduceSerializer: (self._stub_cls("__Reduce__"), self._next_type_id()), + TypeSerializer: (self._stub_cls("__Type__"), self._next_type_id()), + MethodSerializer: (self._stub_cls("__Method__"), self._next_type_id()), + FunctionSerializer: (self._stub_cls("__Function__"), self._next_type_id()), + NativeFuncMethodSerializer: (self._stub_cls("__NativeFunction__"), self._next_type_id()), + } + for serializer, (stub_cls, type_id) in self._internal_py_serializer_map.items(): + register(stub_cls, serializer=serializer, type_id=type_id) + + @staticmethod + def _stub_cls(name: str): + return record_class_factory(name, []) def _initialize_xlang(self): + register = functools.partial(self._register_type, internal=True) + register(array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer) + register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer) + + def _initialize_common(self): register = functools.partial(self._register_type, internal=True) register(None, type_id=TypeId.NA, serializer=NoneSerializer) register(bool, type_id=TypeId.BOOL, serializer=BooleanSerializer) @@ -240,7 +249,6 @@ def _initialize_xlang(self): type_id=typeid, serializer=PyArraySerializer(self.fory, ftype, typeid), ) - register(array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer) if np: # overwrite pyarray with same type id. # if pyarray are needed, one must annotate that value with XXXArrayType @@ -256,7 +264,6 @@ def _initialize_xlang(self): type_id=typeid, serializer=Numpy1DArraySerializer(self.fory, ftype, dtype), ) - register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer) register(list, type_id=TypeId.LIST, serializer=ListSerializer) register(set, type_id=TypeId.SET, serializer=SetSerializer) register(dict, type_id=TypeId.MAP, serializer=MapSerializer) @@ -348,10 +355,6 @@ def _register_xtype( if issubclass(cls, enum.Enum): serializer = EnumSerializer(self.fory, cls) type_id = TypeId.NAMED_ENUM if type_id is None else ((type_id << 8) + TypeId.ENUM) - elif cls is types.FunctionType: - # Use FunctionSerializer for function types (including lambdas) - serializer = FunctionSerializer(self.fory, cls) - type_id = TypeId.NAMED_EXT if type_id is None else ((type_id << 8) + TypeId.EXT) else: serializer = DataClassSerializer(self.fory, cls, xlang=True) type_id = TypeId.NAMED_STRUCT if type_id is None else ((type_id << 8) + TypeId.STRUCT) @@ -416,7 +419,7 @@ def __register_type( self._named_type_to_typeinfo[(namespace, typename)] = typeinfo self._ns_type_to_typeinfo[(ns_meta_bytes, type_meta_bytes)] = typeinfo self._types_info[cls] = typeinfo - if type_id > 0 and (self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id)): + if type_id is not None and type_id != 0 and (self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id)): if type_id not in self._type_id_to_typeinfo or not internal: self._type_id_to_typeinfo[type_id] = typeinfo self._types_info[cls] = typeinfo @@ -461,7 +464,7 @@ def get_typeinfo(self, cls, create=True): return type_info elif not create: return None - if self.language != Language.PYTHON or (self.require_registration and not issubclass(cls, Enum)): + if self.require_registration and not issubclass(cls, Enum): raise TypeUnregisteredError(f"{cls} not registered") logger.info("Type %s not registered", cls) serializer = self._create_serializer(cls) @@ -469,12 +472,10 @@ def get_typeinfo(self, cls, create=True): if self.language == Language.PYTHON: if isinstance(serializer, EnumSerializer): type_id = TypeId.NAMED_ENUM - elif type(serializer) is PickleSerializer: - type_id = PickleSerializer.PICKLE_TYPE_ID - elif isinstance(serializer, FunctionSerializer): - type_id = TypeId.NAMED_EXT - elif isinstance(serializer, (ObjectSerializer, StatefulSerializer, ReduceSerializer)): + elif isinstance(serializer, (ObjectSerializer, StatefulSerializer)): type_id = TypeId.NAMED_EXT + elif self._internal_py_serializer_map.get(type(serializer)) is not None: + type_id = self._internal_py_serializer_map.get(type(serializer))[1] if not self.require_registration: if isinstance(serializer, DataClassSerializer): type_id = TypeId.NAMED_STRUCT @@ -502,35 +503,33 @@ def _create_serializer(self, cls): serializer = DataClassSerializer(self.fory, cls) elif issubclass(cls, enum.Enum): serializer = EnumSerializer(self.fory, cls) + elif ("builtin_function_or_method" in str(cls) or "cython_function_or_method" in str(cls)) and "" not in str(cls): + serializer = NativeFuncMethodSerializer(self.fory, cls) + elif cls is type(self.initialize): + # Handle bound method objects + serializer = MethodSerializer(self.fory, cls) + elif issubclass(cls, type): + # Handle Python type objects and metaclass such as numpy._DTypeMeta(i.e. np.dtype) + serializer = TypeSerializer(self.fory, cls) + elif cls is array.array: + # Handle array.array objects with DynamicPyArraySerializer + # Note: This will use DynamicPyArraySerializer for all array.array objects + serializer = DynamicPyArraySerializer(self.fory, cls) elif (hasattr(cls, "__reduce__") and cls.__reduce__ is not object.__reduce__) or ( hasattr(cls, "__reduce_ex__") and cls.__reduce_ex__ is not object.__reduce_ex__ ): # Use ReduceSerializer for objects that have custom __reduce__ or __reduce_ex__ methods # This has higher precedence than StatefulSerializer and ObjectSerializer # Only use it for objects with custom reduce methods, not default ones from the object - module_name = getattr(cls, "__module__", "") - if module_name.startswith("pandas.") or module_name == "builtins" or cls.__name__ in ("type", "function", "method"): - # Exclude pandas, built-ins, and certain system types - serializer = PickleSerializer(self.fory, cls) - else: - serializer = ReduceSerializer(self.fory, cls) + serializer = ReduceSerializer(self.fory, cls) elif hasattr(cls, "__getstate__") and hasattr(cls, "__setstate__"): # Use StatefulSerializer for objects that support __getstate__ and __setstate__ - # But exclude certain types that have incompatible state methods - module_name = getattr(cls, "__module__", "") - if module_name.startswith("pandas."): - # Pandas objects have __getstate__/__setstate__ but use incompatible pickle formats - serializer = PickleSerializer(self.fory, cls) - else: - serializer = StatefulSerializer(self.fory, cls) - elif ( - cls is not type - and (hasattr(cls, "__dict__") or hasattr(cls, "__slots__")) - and not (np and (issubclass(cls, np.dtype) or cls is type(np.dtype))) - ): + serializer = StatefulSerializer(self.fory, cls) + elif hasattr(cls, "__dict__") or hasattr(cls, "__slots__"): serializer = ObjectSerializer(self.fory, cls) else: - serializer = PickleSerializer(self.fory, cls) + # c-extension types will go to here + serializer = UnsupportedSerializer(self.fory, cls) return serializer def _load_metabytes_to_typeinfo(self, ns_metabytes, type_metabytes): diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx index 833fbed246..3360610dd3 100644 --- a/python/pyfory/_serialization.pyx +++ b/python/pyfory/_serialization.pyx @@ -30,7 +30,6 @@ from typing import TypeVar, Union, Iterable from pyfory._util import get_bit, set_bit, clear_bit from pyfory import _fory as fmod from pyfory._fory import Language -from pyfory._fory import _PicklerStub, _UnpicklerStub, Pickler, Unpickler from pyfory._fory import _ENABLE_TYPE_REGISTRATION_FORCIBLY from pyfory.lib import mmh3 from pyfory.meta.metastring import Encoding @@ -590,10 +589,7 @@ cdef class Fory: cdef readonly MapRefResolver ref_resolver cdef readonly TypeResolver type_resolver cdef readonly MetaStringResolver metastring_resolver - cdef readonly SerializationContext serialization_context cdef Buffer buffer - cdef public object pickler # pickle.Pickler - cdef public object unpickler # Optional[pickle.Unpickler] cdef object _buffer_callback cdef object _buffers # iterator cdef object _unsupported_callback @@ -634,20 +630,7 @@ cdef class Fory: self.metastring_resolver = MetaStringResolver() self.type_resolver = TypeResolver(self) self.type_resolver.initialize() - self.serialization_context = SerializationContext() self.buffer = Buffer.allocate(32) - if not require_type_registration: - warnings.warn( - "Type registration is disabled, unknown types can be deserialized " - "which may be insecure.", - RuntimeWarning, - stacklevel=2, - ) - self.pickler = Pickler(self.buffer) - else: - self.pickler = _PicklerStub() - self.unpickler = _UnpicklerStub() - self.unpickler = None self._buffer_callback = None self._buffers = None self._unsupported_callback = None @@ -702,9 +685,7 @@ cdef class Fory: self, obj, Buffer buffer, buffer_callback=None, unsupported_callback=None): self._buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback - if buffer is not None: - self.pickler = Pickler(self.buffer) - else: + if buffer is None: self.buffer.writer_index = 0 buffer = self.buffer if self.language == Language.XLANG: @@ -832,8 +813,6 @@ cdef class Fory: cpdef inline _deserialize( self, Buffer buffer, buffers=None, unsupported_objects=None): - if not self.require_type_registration: - self.unpickler = Unpickler(buffer) if unsupported_objects is not None: self._unsupported_objects = iter(unsupported_objects) if self.language == Language.XLANG: @@ -939,7 +918,6 @@ cdef class Fory: self, Buffer buffer, Serializer serializer=None): if serializer is None: serializer = self.type_resolver.read_typeinfo(buffer).serializer - self.depth += 1 self.inc_depth() o = serializer.xread(buffer) self.depth -= 1 @@ -983,22 +961,13 @@ cdef class Fory: buffer.reader_index += size return buf - cpdef inline handle_unsupported_write(self, Buffer buffer, obj): + cpdef handle_unsupported_write(self, buffer, obj): if self._unsupported_callback is None or self._unsupported_callback(obj): - buffer.write_bool(True) - self.pickler.dump(obj) - else: - buffer.write_bool(False) + raise NotImplementedError(f"{type(obj)} is not supported for write") - cpdef inline handle_unsupported_read(self, Buffer buffer): - cdef c_bool in_band = buffer.read_bool() - if in_band: - if self.unpickler is None: - self.unpickler.buffer = Unpickler(buffer) - return self.unpickler.load() - else: - assert self._unsupported_objects is not None - return next(self._unsupported_objects) + cpdef handle_unsupported_read(self, buffer): + assert self._unsupported_objects is not None + return next(self._unsupported_objects) cpdef inline write_ref_pyobject( self, Buffer buffer, value, TypeInfo typeinfo=None): @@ -1016,7 +985,9 @@ cdef class Fory: return ref_resolver.get_read_object() # indicates that the object is first read. cdef TypeInfo typeinfo = self.type_resolver.read_typeinfo(buffer) + self.inc_depth() o = typeinfo.serializer.read(buffer) + self.depth -= 1 ref_resolver.set_read_object(ref_id, o) return o @@ -1024,8 +995,6 @@ cdef class Fory: self.ref_resolver.reset_write() self.type_resolver.reset_write() self.metastring_resolver.reset_write() - self.serialization_context.reset() - self.pickler.clear_memo() self._unsupported_callback = None cpdef inline reset_read(self): @@ -1033,9 +1002,7 @@ cdef class Fory: self.ref_resolver.reset_read() self.type_resolver.reset_read() self.metastring_resolver.reset_read() - self.serialization_context.reset() self._buffers = None - self.unpickler = None self._unsupported_objects = None cpdef inline reset(self): @@ -1095,29 +1062,6 @@ cpdef inline read_nullable_pystr(Buffer buffer): return None -@cython.final -cdef class SerializationContext: - cdef dict objects - - def __init__(self): - self.objects = dict() - - def add(self, key, obj): - self.objects[id(key)] = obj - - def __contains__(self, key): - return id(key) in self.objects - - def __getitem__(self, key): - return self.objects[id(key)] - - def get(self, key): - return self.objects.get(id(key)) - - def reset(self): - if len(self.objects) > 0: - self.objects.clear() - cdef class Serializer: cdef readonly Fory fory cdef readonly object type_ @@ -1650,10 +1594,12 @@ cdef class TupleSerializer(CollectionSerializer): else: self._read_same_type_ref(buffer, len_, tuple_, typeinfo) else: + self.fory.inc_depth() for i in range(len_): elem = get_next_element(buffer, ref_resolver, type_resolver, is_py) Py_INCREF(elem) PyTuple_SET_ITEM(tuple_, i, elem) + self.fory.dec_depth() return tuple_ cpdef inline _add_element(self, object collection_, int64_t index, object element): diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py index 77691c6a89..1a0666ed87 100644 --- a/python/pyfory/_serializer.py +++ b/python/pyfory/_serializer.py @@ -19,7 +19,7 @@ import logging import platform import time -from abc import ABC, abstractmethod +from abc import ABC from typing import Dict from pyfory._fory import NOT_NULL_INT64_FLAG @@ -74,13 +74,11 @@ def write(self, buffer, value): def read(self, buffer): raise NotImplementedError - @abstractmethod def xwrite(self, buffer, value): - pass + raise NotImplementedError - @abstractmethod def xread(self, buffer): - pass + raise NotImplementedError @classmethod def support_subclass(cls) -> bool: diff --git a/python/pyfory/_struct.py b/python/pyfory/_struct.py index 6c455424e7..1b6f32f18f 100644 --- a/python/pyfory/_struct.py +++ b/python/pyfory/_struct.py @@ -90,14 +90,11 @@ def visit_customized(self, field_name, type_, types_path=None): return None def visit_other(self, field_name, type_, types_path=None): - from pyfory.serializer import PickleSerializer # Local import - if is_subclass(type_, enum.Enum): return self.fory.type_resolver.get_serializer(type_) if type_ not in basic_types and not is_py_array_type(type_): return None serializer = self.fory.type_resolver.get_serializer(type_) - assert not isinstance(serializer, (PickleSerializer,)) return serializer @@ -199,14 +196,11 @@ def visit_customized(self, field_name, type_, types_path=None): self._hash = self._compute_field_hash(self._hash, hash_value) def visit_other(self, field_name, type_, types_path=None): - from pyfory.serializer import PickleSerializer # Local import - typeinfo = self.fory.type_resolver.get_typeinfo(type_, create=False) if typeinfo is None: id_ = 0 else: serializer = typeinfo.serializer - assert not isinstance(serializer, (PickleSerializer,)) id_ = typeinfo.type_id assert id_ is not None, serializer id_ = abs(id_) diff --git a/python/pyfory/format/__init__.py b/python/pyfory/format/__init__.py index 3bc70502fc..f6fd1d8f5a 100644 --- a/python/pyfory/format/__init__.py +++ b/python/pyfory/format/__init__.py @@ -41,8 +41,7 @@ ) except (ImportError, AttributeError) as e: warnings.warn( - f"Fory format initialization failed, please ensure pyarrow is installed " - f"with version which fory is compiled with: {e}", + f"Fory format initialization failed, please ensure pyarrow is installed with version which fory is compiled with: {e}", RuntimeWarning, stacklevel=2, ) diff --git a/python/pyfory/format/tests/test_encoder.py b/python/pyfory/format/tests/test_encoder.py index ac9dbdfdce..b4b01fc1c9 100644 --- a/python/pyfory/format/tests/test_encoder.py +++ b/python/pyfory/format/tests/test_encoder.py @@ -63,9 +63,7 @@ def test_encoder_with_schema(): @require_pyarrow def test_dict(): dict_ = {"f1": 1, "f2": "str"} - encoder = pyfory.create_row_encoder( - pa.schema([("f1", pa.int32()), ("f2", pa.utf8())]) - ) + encoder = pyfory.create_row_encoder(pa.schema([("f1", pa.int32()), ("f2", pa.utf8())])) row = encoder.to_row(dict_) new_obj = encoder.from_row(row) assert new_obj.f1 == dict_["f1"] @@ -74,9 +72,7 @@ def test_dict(): @require_pyarrow def test_ints(): - cls = pyfory.record_class_factory( - "TestNumeric", ["f" + str(i) for i in range(1, 9)] - ) + cls = pyfory.record_class_factory("TestNumeric", ["f" + str(i) for i in range(1, 9)]) schema = pa.schema( [ ("f1", pa.int64()), diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index aba3809c4f..d478eb9ecd 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -17,17 +17,16 @@ import array import builtins +import importlib +import inspect import itertools import marshal import logging import os -import pickle import types import typing import warnings -from weakref import WeakValueDictionary -import pyfory.lib.mmh3 from pyfory.buffer import Buffer from pyfory.codegen import ( gen_write_nullable_basic_stmts, @@ -35,9 +34,9 @@ compile_function, ) from pyfory.error import TypeNotCompatibleError -from pyfory.lib.collection import WeakIdentityKeyDictionary from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG from pyfory import Language +from typing import List try: import numpy as np @@ -137,82 +136,31 @@ def read(self, buffer): return None -class _PickleStub: - pass +class TypeSerializer(Serializer): + """Serializer for Python type objects (classes).""" - -class PickleStrongCacheStub: - pass - - -class PickleCacheStub: - pass - - -class PickleStrongCacheSerializer(Serializer): - """If we can't create weak ref to object, use this cache serializer instead. - clear cache by threshold to avoid memory leak.""" - - __slots__ = "_cached", "_clear_threshold", "_counter" - - def __init__(self, fory, clear_threshold: int = 1000): - super().__init__(fory, PickleStrongCacheStub) - self._cached = {} - self._clear_threshold = clear_threshold + def __init__(self, fory, cls): + super().__init__(fory, cls) + self.cls = cls def write(self, buffer, value): - serialized = self._cached.get(value) - if serialized is None: - serialized = pickle.dumps(value) - self._cached[value] = serialized - buffer.write_bytes_and_size(serialized) - if len(self._cached) == self._clear_threshold: - self._cached.clear() + # Serialize the type by its module and name + module_name = getattr(value, "__module__", "") + type_name = getattr(value, "__name__", "") + buffer.write_string(module_name) + buffer.write_string(type_name) def read(self, buffer): - return pickle.loads(buffer.read_bytes_and_size()) + module_name = buffer.read_string() + type_name = buffer.read_string() - def xwrite(self, buffer, value): - raise NotImplementedError - - def xread(self, buffer): - raise NotImplementedError - - -class PickleCacheSerializer(Serializer): - __slots__ = "_cached", "_reverse_cached" - - def __init__(self, fory): - super().__init__(fory, PickleCacheStub) - self._cached = WeakIdentityKeyDictionary() - self._reverse_cached = WeakValueDictionary() - - def write(self, buffer, value): - cache = self._cached.get(value) - if cache is None: - serialized = pickle.dumps(value) - value_hash = pyfory.lib.mmh3.hash_buffer(serialized)[0] - cache = value_hash, serialized - self._cached[value] = cache - buffer.write_int64(cache[0]) - buffer.write_bytes_and_size(cache[1]) - - def read(self, buffer): - value_hash = buffer.read_int64() - value = self._reverse_cached.get(value_hash) - if value is None: - value = pickle.loads(buffer.read_bytes_and_size()) - self._reverse_cached[value_hash] = value + # Import the module and get the type + if module_name and module_name != "builtins": + module = __import__(module_name, fromlist=[type_name]) + return getattr(module, type_name) else: - size = buffer.read_int32() - buffer.skip(size) - return value - - def xwrite(self, buffer, value): - raise NotImplementedError - - def xread(self, buffer): - raise NotImplementedError + # Handle built-in types + return getattr(builtins, type_name, type) class PandasRangeIndexSerializer(Serializer): @@ -290,27 +238,26 @@ def xread(self, buffer): "1", ) -# Moved from L32 to here, after all Serializer base classes and specific serializers -# like ListSerializer, MapSerializer, PickleSerializer are defined or imported -# and before DataClassSerializer which uses ComplexTypeVisitor from _struct. + from pyfory._struct import _get_hash, _sort_fields, ComplexTypeVisitor class DataClassSerializer(Serializer): - def __init__(self, fory, clz: type, xlang: bool = False): + def __init__(self, fory, clz: type, xlang: bool = False, field_names: List[str] = None, serializers: List[Serializer] = None): super().__init__(fory, clz) self._xlang = xlang # This will get superclass type hints too. self._type_hints = typing.get_type_hints(clz) - self._field_names = self._get_field_names(clz) + self._field_names = field_names or self._get_field_names(clz) self._has_slots = hasattr(clz, "__slots__") if self._xlang: - self._serializers = [None] * len(self._field_names) - visitor = ComplexTypeVisitor(fory) - for index, key in enumerate(self._field_names): - serializer = infer_field(key, self._type_hints[key], visitor, types_path=[]) - self._serializers[index] = serializer + self._serializers = serializers or [None] * len(self._field_names) + if serializers is None: + visitor = ComplexTypeVisitor(fory) + for index, key in enumerate(self._field_names): + serializer = infer_field(key, self._type_hints[key], visitor, types_path=[]) + self._serializers[index] = serializer self._serializers, self._field_names = _sort_fields(fory.type_resolver, self._field_names, self._serializers) self._hash = 0 # Will be computed on first xwrite/xread self._generated_xwrite_method = self._gen_xwrite_method() @@ -443,13 +390,13 @@ def _gen_xwrite_method(self): context["_field_names"] = self._field_names context["_type_hints"] = self._type_hints context["_serializers"] = self._serializers - # Compute hash at generation time since we're in xlang mode - if self._hash == 0: - self._hash = _get_hash(self.fory, self._field_names, self._type_hints) stmts = [ f'"""xwrite method for {self.type_}"""', - f"{buffer}.write_int32({self._hash})", ] + # Compute hash at generation time since we're in xlang mode + if self._hash == 0: + self._hash = _get_hash(self.fory, self._field_names, self._type_hints) + stmts.append(f"{buffer}.write_int32({self._hash})") if not self._has_slots: stmts.append(f"{value_dict} = {value}.__dict__") for index, field_name in enumerate(self._field_names): @@ -487,18 +434,27 @@ def _gen_xread_method(self): context["_field_names"] = self._field_names context["_type_hints"] = self._type_hints context["_serializers"] = self._serializers + stmts = [ + f'"""xread method for {self.type_}"""', + ] # Compute hash at generation time since we're in xlang mode if self._hash == 0: self._hash = _get_hash(self.fory, self._field_names, self._type_hints) - stmts = [ - f'"""xread method for {self.type_}"""', - f"read_hash = {buffer}.read_int32()", - f"if read_hash != {self._hash}:", - f""" raise TypeNotCompatibleError( + stmts.extend( + [ + f"read_hash = {buffer}.read_int32()", + f"if read_hash != {self._hash}:", + f""" raise TypeNotCompatibleError( f"Hash {{read_hash}} is not consistent with {self._hash} for type {self.type_}")""", - f"{obj} = {obj_class}.__new__({obj_class})", - f"{ref_resolver}.reference({obj})", - ] + ] + ) + stmts.extend( + [ + f"{obj} = {obj_class}.__new__({obj_class})", + f"{ref_resolver}.reference({obj})", + ] + ) + if not self._has_slots: stmts.append(f"{obj_dict} = {obj}.__dict__") @@ -666,12 +622,18 @@ def write(self, buffer, value: array.array): def read(self, buffer): typecode = buffer.read_string() data = buffer.read_bytes_and_size() - arr = array.array(typecode, []) + arr = array.array(typecode[0], []) # Take first character arr.frombytes(data) return arr class DynamicPyArraySerializer(Serializer): + """Serializer for dynamic Python arrays that handles any typecode.""" + + def __init__(self, fory, cls): + super().__init__(fory, cls) + self._serializer = ReduceSerializer(fory, cls) + def xwrite(self, buffer, value): itemsize, ftype, type_id = typecode_dict[value.typecode] view = memoryview(value) @@ -692,11 +654,10 @@ def xread(self, buffer): return arr def write(self, buffer, value): - buffer.write_varuint32(PickleSerializer.PICKLE_TYPE_ID) - self.fory.handle_unsupported_write(buffer, value) + self._serializer.write(buffer, value) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + return self._serializer.read(buffer) if np: @@ -731,6 +692,7 @@ def __init__(self, fory, ftype, dtype): super().__init__(fory, ftype) self.dtype = dtype self.itemsize, self.format, self.typecode, self.type_id = _np_dtypes_dict[self.dtype] + self._serializer = ReduceSerializer(fory, np.ndarray) def xwrite(self, buffer, value): assert value.itemsize == self.itemsize @@ -752,11 +714,10 @@ def xread(self, buffer): return np.frombuffer(data, dtype=self.dtype) def write(self, buffer, value): - buffer.write_int8(PickleSerializer.PICKLE_TYPE_ID) - self.fory.handle_unsupported_write(buffer, value) + self._serializer.write(buffer, value) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + return self._serializer.read(buffer) class NDArraySerializer(Serializer): @@ -775,11 +736,32 @@ def xread(self, buffer): raise NotImplementedError("Multi-dimensional array not supported currently") def write(self, buffer, value): - buffer.write_int8(PickleSerializer.PICKLE_TYPE_ID) - self.fory.handle_unsupported_write(buffer, value) + # Serialize numpy ND array using native format + dtype = value.dtype + fory = self.fory + fory.serialize_ref(buffer, dtype) + buffer.write_varuint32(len(value.shape)) + for dim in value.shape: + buffer.write_varuint32(dim) + if dtype.kind == "O": + buffer.write_varint32(len(value)) + for item in value: + fory.serialize_ref(buffer, item) + else: + data = value.tobytes() + buffer.write_bytes_and_size(data) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + fory = self.fory + dtype = fory.deserialize_ref(buffer) + ndim = buffer.read_varuint32() + shape = tuple(buffer.read_varuint32() for _ in range(ndim)) + if dtype.kind == "O": + length = buffer.read_varint32() + items = [fory.deserialize_ref(buffer) for _ in range(length)] + return np.array(items, dtype=object) + data = buffer.read_bytes_and_size() + return np.frombuffer(data, dtype=dtype).reshape(shape) class BytesSerializer(CrossLanguageCompatibleSerializer): @@ -886,39 +868,58 @@ def write(self, buffer, value): # Handle different __reduce__ return formats if isinstance(reduce_result, str): # Case 1: Just a global name (simple case) - self.fory.serialize_ref(buffer, ("global", reduce_result, None, None, None)) + reduce_data = ("global", reduce_result) elif isinstance(reduce_result, tuple): if len(reduce_result) == 2: # Case 2: (callable, args) callable_obj, args = reduce_result - self.fory.serialize_ref(buffer, ("callable", callable_obj, args, None, None)) + reduce_data = ("callable", callable_obj, args) elif len(reduce_result) == 3: # Case 3: (callable, args, state) callable_obj, args, state = reduce_result - self.fory.serialize_ref(buffer, ("callable", callable_obj, args, state, None)) + reduce_data = ("callable", callable_obj, args, state) elif len(reduce_result) == 4: # Case 4: (callable, args, state, listitems) callable_obj, args, state, listitems = reduce_result - self.fory.serialize_ref(buffer, ("callable", callable_obj, args, state, listitems)) + reduce_data = ("callable", callable_obj, args, state, listitems) elif len(reduce_result) == 5: # Case 5: (callable, args, state, listitems, dictitems) callable_obj, args, state, listitems, dictitems = reduce_result - self.fory.serialize_ref(buffer, ("callable", callable_obj, args, state, listitems, dictitems)) + reduce_data = ("callable", callable_obj, args, state, listitems, dictitems) else: raise ValueError(f"Invalid __reduce__ result length: {len(reduce_result)}") else: raise ValueError(f"Invalid __reduce__ result type: {type(reduce_result)}") + buffer.write_varuint32(len(reduce_data)) + fory = self.fory + for item in reduce_data: + fory.serialize_ref(buffer, item) def read(self, buffer): - reduce_data = self.fory.deserialize_ref(buffer) + reduce_data_num_items = buffer.read_varuint32() + assert reduce_data_num_items <= 6, buffer + reduce_data = [None] * 6 + fory = self.fory + for i in range(reduce_data_num_items): + reduce_data[i] = fory.deserialize_ref(buffer) if reduce_data[0] == "global": # Case 1: Global name global_name = reduce_data[1] # Import and return the global object - module_name, obj_name = global_name.rsplit(".", 1) - module = __import__(module_name, fromlist=[obj_name]) - return getattr(module, obj_name) + if "." in global_name: + module_name, obj_name = global_name.rsplit(".", 1) + module = __import__(module_name, fromlist=[obj_name]) + return getattr(module, obj_name) + else: + # Handle case where global_name doesn't contain a dot + # This might be a built-in type or a simple name + try: + import builtins + + return getattr(builtins, global_name) + except AttributeError: + raise ValueError(f"Cannot resolve global name: {global_name}") elif reduce_data[0] == "callable": # Case 2-5: Callable with args and optional state/items callable_obj = reduce_data[1] @@ -977,33 +978,41 @@ class FunctionSerializer(CrossLanguageCompatibleSerializer): def _serialize_function(self, buffer, func): """Serialize a function by capturing all its components.""" # Get function metadata - is_method = hasattr(func, "__self__") - if is_method: + instance = getattr(func, "__self__", None) + if instance is not None and not inspect.ismodule(instance): # Handle bound methods - self_obj = func.__self__ + self_obj = instance func_name = func.__name__ # Serialize as a tuple (is_method, self_obj, method_name) - buffer.write_bool(True) # is a method + buffer.write_int8(0) # is a method # For the 'self' object, we need to use fory's serialization self.fory.serialize_ref(buffer, self_obj) buffer.write_string(func_name) return + import types # Regular function or lambda code = func.__code__ name = func.__name__ - defaults = func.__defaults__ - closure = func.__closure__ - globals_dict = func.__globals__ module = func.__module__ qualname = func.__qualname__ + if "" not in qualname and module != "__main__": + buffer.write_int8(1) # Not a method + buffer.write_string(name) + buffer.write_string(module) + return + # Serialize function metadata - buffer.write_bool(False) # Not a method + buffer.write_int8(2) # Not a method buffer.write_string(name) buffer.write_string(module) buffer.write_string(qualname) + defaults = func.__defaults__ + closure = func.__closure__ + globals_dict = func.__globals__ + # Instead of trying to serialize the code object in parts, use marshal # which is specifically designed for code objects marshalled_code = marshal.dumps(code) @@ -1071,16 +1080,21 @@ def _serialize_function(self, buffer, func): def _deserialize_function(self, buffer): """Deserialize a function from its components.""" - import sys # Check if it's a method - is_method = buffer.read_bool() - if is_method: + func_type_id = buffer.read_int8() + if func_type_id == 0: # Handle bound methods self_obj = self.fory.deserialize_ref(buffer) method_name = buffer.read_string() return getattr(self_obj, method_name) + if func_type_id == 1: + name = buffer.read_string() + module = buffer.read_string() + mod = importlib.import_module(module) + return getattr(mod, name) + # Regular function or lambda name = buffer.read_string() module = buffer.read_string() @@ -1128,7 +1142,7 @@ def _deserialize_function(self, buffer): # Create a globals dictionary with module's globals as the base func_globals = {} try: - mod = sys.modules.get(module) + mod = importlib.import_module(module) if mod: func_globals.update(mod.__dict__) except (KeyError, AttributeError): @@ -1156,12 +1170,10 @@ def _deserialize_function(self, buffer): return func def xwrite(self, buffer, value): - """Serialize a function for cross-language compatibility.""" - self._serialize_function(buffer, value) + raise NotImplementedError() def xread(self, buffer): - """Deserialize a function for cross-language compatibility.""" - return self._deserialize_function(buffer) + raise NotImplementedError() def write(self, buffer, value): """Serialize a function for Python-only mode.""" @@ -1172,20 +1184,56 @@ def read(self, buffer): return self._deserialize_function(buffer) -class PickleSerializer(Serializer): - PICKLE_TYPE_ID = 96 +class NativeFuncMethodSerializer(Serializer): + def write(self, buffer, func): + name = func.__name__ + buffer.write_string(name) + obj = getattr(func, "__self__", None) + if obj is None or inspect.ismodule(obj): + buffer.write_bool(True) + module = func.__module__ + buffer.write_string(module) + else: + buffer.write_bool(False) + self.fory.serialize_ref(buffer, obj) - def xwrite(self, buffer, value): - raise NotImplementedError + def read(self, buffer): + name = buffer.read_string() + if buffer.read_bool(): + module = buffer.read_string() + mod = importlib.import_module(module) + return getattr(mod, name) + else: + obj = self.fory.deserialize_ref(buffer) + return getattr(obj, name) - def xread(self, buffer): - raise NotImplementedError + +class MethodSerializer(Serializer): + """Serializer for bound method objects.""" + + def __init__(self, fory, cls): + super().__init__(fory, cls) + self.cls = cls def write(self, buffer, value): - self.fory.handle_unsupported_write(buffer, value) + # Serialize bound method as (instance, method_name) + instance = value.__self__ + method_name = value.__func__.__name__ + + self.fory.serialize_ref(buffer, instance) + buffer.write_string(method_name) def read(self, buffer): - return self.fory.handle_unsupported_read(buffer) + instance = self.fory.deserialize_ref(buffer) + method_name = buffer.read_string() + + return getattr(instance, method_name) + + def xwrite(self, buffer, value): + return self.write(buffer, value) + + def xread(self, buffer): + return self.read(buffer) class ObjectSerializer(Serializer): @@ -1245,3 +1293,17 @@ def __new__(cls, fory, clz): stacklevel=2, ) return DataClassSerializer(fory, clz, xlang=True) + + +class UnsupportedSerializer(Serializer): + def write(self, buffer, value): + self.fory.handle_unsupported_write(value) + + def read(self, buffer): + return self.fory.handle_unsupported_read(buffer) + + def xwrite(self, buffer, value): + raise NotImplementedError(f"{self.type_} is not supported for xwrite") + + def xread(self, buffer): + raise NotImplementedError(f"{self.type_} is not supported for xread") diff --git a/python/pyfory/tests/benchmark.py b/python/pyfory/tests/benchmark.py index ab6abe1b70..75c883296e 100644 --- a/python/pyfory/tests/benchmark.py +++ b/python/pyfory/tests/benchmark.py @@ -33,13 +33,9 @@ def test_encode(): assert foo == encoder.from_row(row) t1 = timeit.timeit(lambda: encoder.to_row(foo), number=iter_nums) - print( - "encoder take {0} for {1} times, avg: {2}".format(t1, iter_nums, t1 / iter_nums) - ) + print("encoder take {0} for {1} times, avg: {2}".format(t1, iter_nums, t1 / iter_nums)) t2 = timeit.timeit(lambda: pickle.dumps(foo), number=iter_nums) - print( - "pickle take {0} for {1} times, avg: {2}".format(t2, iter_nums, t2 / iter_nums) - ) + print("pickle take {0} for {1} times, avg: {2}".format(t2, iter_nums, t2 / iter_nums)) @pytest.mark.skip(reason="take too long") @@ -51,18 +47,10 @@ def test_decode(): row = encoder.to_row(foo) assert foo == encoder.from_row(row) t1 = timeit.timeit(lambda: encoder.from_row(row), number=iter_nums) - print( - "encoder take {0} for {1} times, avg: {2}, size {3}".format( - t1, iter_nums, t1 / iter_nums, row.size_bytes() - ) - ) + print("encoder take {0} for {1} times, avg: {2}, size {3}".format(t1, iter_nums, t1 / iter_nums, row.size_bytes())) pickled_data = pickle.dumps(foo) t2 = timeit.timeit(lambda: pickle.loads(pickled_data), number=iter_nums) - print( - "pickle take {0} for {1} times, avg: {2}, size {3}".format( - t2, iter_nums, t2 / iter_nums, len(pickled_data) - ) - ) + print("pickle take {0} for {1} times, avg: {2}, size {3}".format(t2, iter_nums, t2 / iter_nums, len(pickled_data))) if __name__ == "__main__": diff --git a/python/pyfory/tests/record.py b/python/pyfory/tests/record.py index 2f56a9ad81..31ebd66a81 100644 --- a/python/pyfory/tests/record.py +++ b/python/pyfory/tests/record.py @@ -117,9 +117,7 @@ def foo_schema(): ("f4", pa.map_(pa.string(), pa.int32())), ("f5", pa.list_(pa.int32())), ("f6", pa.int32()), - pa.field( - "f7", bar_struct, metadata={"cls": fory.get_qualified_classname(Bar)} - ), + pa.field("f7", bar_struct, metadata={"cls": fory.get_qualified_classname(Bar)}), ], metadata={"cls": fory.get_qualified_classname(Foo)}, ) diff --git a/python/pyfory/tests/test_buffer.py b/python/pyfory/tests/test_buffer.py index 3ba9c388ed..cefd6abf5a 100644 --- a/python/pyfory/tests/test_buffer.py +++ b/python/pyfory/tests/test_buffer.py @@ -217,10 +217,7 @@ def check_varuint64(buf: Buffer, value: int, bytes_written: int): assert buf.writer_index == buf.reader_index assert value == varint # test slow read branch in `read_varint64` - assert ( - buf.slice(reader_index, buf.reader_index - reader_index).read_varuint64() - == value - ) + assert buf.slice(reader_index, buf.reader_index - reader_index).read_varuint64() == value def test_write_buffer(): diff --git a/python/pyfory/tests/test_codegen.py b/python/pyfory/tests/test_codegen.py index 3b2243b29a..b73d2465e2 100644 --- a/python/pyfory/tests/test_codegen.py +++ b/python/pyfory/tests/test_codegen.py @@ -43,8 +43,6 @@ def _debug_compiled(x): def test_compile_function(): - code, func = codegen.compile_function( - "test_compile_function", ["x"], ["print(1)", "print(2)", "return x"], {} - ) + code, func = codegen.compile_function("test_compile_function", ["x"], ["print(1)", "print(2)", "return x"], {}) print(code) assert func(100) == 100 diff --git a/python/pyfory/tests/test_metastring.py b/python/pyfory/tests/test_metastring.py index f21e09585c..d470de2994 100644 --- a/python/pyfory/tests/test_metastring.py +++ b/python/pyfory/tests/test_metastring.py @@ -196,7 +196,5 @@ def test_non_ascii_encoding_and_non_utf8(): non_ascii_string = "こんにちは" # Non-ASCII string - with pytest.raises( - ValueError, match="Unsupported character for LOWER_SPECIAL encoding: こ" - ): + with pytest.raises(ValueError, match="Unsupported character for LOWER_SPECIAL encoding: こ"): encoder.encode_with_encoding(non_ascii_string, Encoding.LOWER_SPECIAL) diff --git a/python/pyfory/tests/test_serializer.py b/python/pyfory/tests/test_serializer.py index 6df8875b4c..6e5f1cfe2f 100644 --- a/python/pyfory/tests/test_serializer.py +++ b/python/pyfory/tests/test_serializer.py @@ -18,7 +18,6 @@ import array import datetime import gc -import io import os import pickle import weakref @@ -308,37 +307,6 @@ def ser_de(fory, obj): return fory.deserialize(binary) -def test_pickle(): - buf = Buffer.allocate(32) - pickler = pickle.Pickler(buf) - pickler.dump(b"abc") - buf.write_int32(-1) - pickler.dump("abcd") - assert buf.writer_index - 4 == len(pickle.dumps(b"abc")) + len(pickle.dumps("abcd")) - print(f"writer_index {buf.writer_index}") - - bytes_io_ = io.BytesIO(buf) - unpickler = pickle.Unpickler(bytes_io_) - assert unpickler.load() == b"abc" - bytes_io_.seek(bytes_io_.tell() + 4) - assert unpickler.load() == "abcd" - print(f"reader_index {buf.reader_index} {bytes_io_.tell()}") - - if pa: - pa_buf = pa.BufferReader(buf) - unpickler = pickle.Unpickler(pa_buf) - assert unpickler.load() == b"abc" - pa_buf.seek(pa_buf.tell() + 4) - assert unpickler.load() == "abcd" - print(f"reader_index {buf.reader_index} {pa_buf.tell()} {buf.reader_index}") - - unpickler = pickle.Unpickler(buf) - assert unpickler.load() == b"abc" - buf.reader_index = buf.reader_index + 4 - assert unpickler.load() == "abcd" - print(f"reader_index {buf.reader_index}") - - @require_pyarrow def test_serialize_arrow(): record_batch = create_record_batch(10000) @@ -454,13 +422,16 @@ def xread(self, buffer): assert isinstance(fory.deserialize(fory.serialize(A.B.C())), A.B.C) -def test_pickle_fallback(): +def test_np_types(): fory = Fory(language=Language.PYTHON, ref_tracking=True, require_type_registration=False) o1 = [1, True, np.dtype(np.int32)] data1 = fory.serialize(o1) new_o1 = fory.deserialize(data1) assert o1 == new_o1 + +def test_pandas_dataframe(): + fory = Fory(language=Language.PYTHON, ref_tracking=True, require_type_registration=False) df = pd.DataFrame({"a": list(range(10))}) df2 = fory.deserialize(fory.serialize(df)) assert df2.equals(df) @@ -545,19 +516,6 @@ def test_duplicate_serialize(): assert ser_de(fory, EnumClass.E4) == EnumClass.E4 -@dataclass(unsafe_hash=True) -class CacheClass1: - f1: int - - -def test_cache_serializer(): - fory = Fory(language=Language.PYTHON, ref_tracking=True) - fory.register_type(CacheClass1, serializer=pyfory.PickleStrongCacheSerializer(fory)) - assert ser_de(fory, CacheClass1(1)) == CacheClass1(1) - fory.register_type(CacheClass1, serializer=pyfory.PickleCacheSerializer(fory)) - assert ser_de(fory, CacheClass1(1)) == CacheClass1(1) - - def test_pandas_range_index(): fory = Fory(language=Language.PYTHON, ref_tracking=True, require_type_registration=False) fory.register_type(pd.RangeIndex, serializer=pyfory.PandasRangeIndexSerializer(fory)) diff --git a/python/pyfory/type.py b/python/pyfory/type.py index add96c783f..2c2daae64f 100644 --- a/python/pyfory/type.py +++ b/python/pyfory/type.py @@ -386,30 +386,18 @@ def visit_other(self, field_name, type_, types_path=None): def infer_field(field_name, type_, visitor: TypeVisitor, types_path=None): types_path = list(types_path or []) types_path.append(type_) - origin = ( - typing.get_origin(type_) - if hasattr(typing, "get_origin") - else getattr(type_, "__origin__", type_) - ) + origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else getattr(type_, "__origin__", type_) origin = origin or type_ - args = ( - typing.get_args(type_) - if hasattr(typing, "get_args") - else getattr(type_, "__args__", ()) - ) + args = typing.get_args(type_) if hasattr(typing, "get_args") else getattr(type_, "__args__", ()) if args: if origin is list or origin == typing.List: elem_type = args[0] return visitor.visit_list(field_name, elem_type, types_path=types_path) elif origin is dict or origin == typing.Dict: key_type, value_type = args - return visitor.visit_dict( - field_name, key_type, value_type, types_path=types_path - ) + return visitor.visit_dict(field_name, key_type, value_type, types_path=types_path) else: - raise TypeError( - f"Collection types should be {list, dict} instead of {type_}" - ) + raise TypeError(f"Collection types should be {list, dict} instead of {type_}") else: if is_function(origin) or not hasattr(origin, "__annotations__"): return visitor.visit_other(field_name, type_, types_path=types_path) diff --git a/python/pyproject.toml b/python/pyproject.toml index 58b81674e3..81f9253dba 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -50,7 +50,6 @@ classifiers = [ ] keywords = ["fory", "serialization", "multi-language", "fast", "row-format", "jit", "codegen", "polymorphic", "zero-copy"] dependencies = [ - "cloudpickle", ] [project.optional-dependencies]