Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2064,22 +2064,65 @@ public void resetRead() {}

public void resetWrite() {}

private static final GenericType OBJECT_GENERIC_TYPE = GenericType.build(Object.class);

@CodegenInvoke
public GenericType getGenericTypeInStruct(Class<?> cls, String genericTypeStr) {
Map<String, GenericType> map =
extRegistry.classGenericTypes.computeIfAbsent(cls, k -> new HashMap<>());
GenericType genericType = map.get(genericTypeStr);
if (genericType == null) {
for (Field field : ReflectionUtils.getFields(cls, true)) {
Type type = field.getGenericType();
TypeRef<Object> typeRef = TypeRef.of(type);
genericType = buildGenericType(typeRef);
map.put(type.getTypeName(), genericType);
}
extRegistry.classGenericTypes.computeIfAbsent(cls, this::buildGenericMap);
return map.getOrDefault(genericTypeStr, OBJECT_GENERIC_TYPE);
}

/**
* Build a map of nested generic type name to generic type for all fields in the class.
*
* @param cls the class to build the map of nested generic type name to generic type for all
* fields in the class
* @return a map of nested generic type name to generic type for all fields in the class
*/
protected Map<String, GenericType> buildGenericMap(Class<?> cls) {
Map<String, GenericType> map = new HashMap<>();
Map<String, GenericType> map2 = new HashMap<>();
for (Field field : ReflectionUtils.getFields(cls, true)) {
Type type = field.getGenericType();
GenericType genericType = buildGenericType(type);
buildGenericMap(map, genericType);
TypeRef<?> typeRef = TypeRef.of(type);
buildGenericMap(map2, typeRef);
}
map.putAll(map2);
return map;
}

private void buildGenericMap(Map<String, GenericType> map, TypeRef<?> typeRef) {
if (map.containsKey(typeRef.getType().getTypeName())) {
return;
}
map.put(typeRef.getType().getTypeName(), buildGenericType(typeRef));
Class<?> rawType = typeRef.getRawType();
if (TypeUtils.isMap(rawType)) {
Tuple2<TypeRef<?>, TypeRef<?>> kvTypes = TypeUtils.getMapKeyValueType(typeRef);
buildGenericMap(map, kvTypes.f0);
buildGenericMap(map, kvTypes.f1);
} else if (TypeUtils.isCollection(rawType)) {
TypeRef<?> elementType = TypeUtils.getElementType(typeRef);
buildGenericMap(map, elementType);
} else if (rawType.isArray()) {
TypeRef<?> arrayComponent = TypeUtils.getArrayComponent(typeRef);
buildGenericMap(map, arrayComponent);
}
return genericType;
}

private void buildGenericMap(Map<String, GenericType> map, GenericType genericType) {
if (map.containsKey(genericType.getType().getTypeName())) {
return;
}
map.put(genericType.getType().getTypeName(), genericType);
for (GenericType t : genericType.getTypeParameters()) {
buildGenericMap(map, t);
}
}

@Override
public GenericType buildGenericType(TypeRef<?> typeRef) {
return GenericType.build(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.fory.resolver;

import java.lang.reflect.Field;
import java.lang.reflect.Type;
import org.apache.fory.Fory;
import org.apache.fory.annotation.Internal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ public TypeRef<?> getTypeRef() {
return typeRef;
}

public Type getType() {
return typeRef.getType();
}

public Class<?> getCls() {
return cls;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,14 @@ public static Class<?> getComponentIfArray(Class<?> type) {
return type;
}

public static TypeRef<?> getArrayComponent(TypeRef<?> type) {
if (type.getType() instanceof GenericArrayType) {
Type componentType = ((GenericArrayType) (type.getType())).getGenericComponentType();
return TypeRef.of(componentType);
}
return TypeRef.of(getArrayComponentInfo(type.getRawType()).f0);
}

public static Class<?> getArrayComponent(Class<?> type) {
return getArrayComponentInfo(type).f0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1048,4 +1048,30 @@ public void testChunkArrayGeneric() {
State state2 = (State) fory2.deserialize(bytes);
Assert.assertEquals(state2.map.get("bar"), new String[] {"bar"});
}

@Data
public static class OuterClass {
private Map<String, InnerClass> f1 = new HashMap<>();
private TestEnum f2;
}

@Data
public static class InnerClass {
int f1;
}

@Test(dataProvider = "compatible")
public void testNestedMapGenericCodegen(boolean compatible) {
Fory fory =
builder()
.withCodegen(true)
.withCompatibleMode(
compatible ? CompatibleMode.COMPATIBLE : CompatibleMode.SCHEMA_CONSISTENT)
.requireClassRegistration(false)
.build();

OuterClass value = new OuterClass();
value.f1.put("aaa", null);
serDeCheck(fory, value);
}
}
61 changes: 4 additions & 57 deletions python/pyfory/_fory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import enum
import logging
import os
import warnings
from abc import ABC, abstractmethod
from typing import Union, Iterable, TypeVar

Expand All @@ -37,9 +36,6 @@
except ImportError:
np = None

from cloudpickle import Pickler

from pickle import Unpickler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,11 +97,8 @@ class Fory:
"ref_tracking",
"ref_resolver",
"type_resolver",
"serialization_context",
"require_type_registration",
"buffer",
"pickler",
"unpickler",
"_buffer_callback",
"_buffers",
"metastring_resolver",
Expand All @@ -115,7 +108,6 @@ class Fory:
"max_depth",
"depth",
)
serialization_context: "SerializationContext"

def __init__(
self,
Expand Down Expand Up @@ -152,19 +144,7 @@ def __init__(
self.metastring_resolver = MetaStringResolver()
self.type_resolver = TypeResolver(self)
self.type_resolver.initialize()
self.serialization_context = SerializationContext()
self.buffer = Buffer.allocate(32)
if not require_type_registration:
warnings.warn(
"Type registration is disabled, unknown types can be deserialized which may be insecure.",
RuntimeWarning,
stacklevel=2,
)
self.pickler = Pickler(self.buffer)
self.unpickler = None
else:
self.pickler = _PicklerStub()
self.unpickler = _UnpicklerStub()
self._buffer_callback = None
self._buffers = None
self._unsupported_callback = None
Expand Down Expand Up @@ -231,9 +211,7 @@ def _serialize(
) -> Union[Buffer, bytes]:
self._buffer_callback = buffer_callback
self._unsupported_callback = unsupported_callback
if buffer is not None:
self.pickler = Pickler(buffer)
else:
if buffer is None:
self.buffer.writer_index = 0
buffer = self.buffer
if self.language == Language.XLANG:
Expand Down Expand Up @@ -463,21 +441,11 @@ def read_buffer_object(self, buffer) -> Buffer:

def handle_unsupported_write(self, buffer, obj):
if self._unsupported_callback is None or self._unsupported_callback(obj):
buffer.write_bool(True)
self.pickler.dump(obj)
else:
buffer.write_bool(False)
raise NotImplementedError(f"{type(obj)} is not supported for write")

def handle_unsupported_read(self, buffer):
in_band = buffer.read_bool()
if in_band:
unpickler = self.unpickler
if unpickler is None:
self.unpickler = unpickler = Unpickler(buffer)
return unpickler.load()
else:
assert self._unsupported_objects is not None
return next(self._unsupported_objects)
assert self._unsupported_objects is not None
return next(self._unsupported_objects)

def write_ref_pyobject(self, buffer, value, typeinfo=None):
if self.ref_resolver.write_ref_or_null(buffer, value):
Expand Down Expand Up @@ -507,19 +475,15 @@ def throw_depth_limit_exceeded_exception(self):
def reset_write(self):
self.ref_resolver.reset_write()
self.type_resolver.reset_write()
self.serialization_context.reset()
self.metastring_resolver.reset_write()
self.pickler.clear_memo()
self._buffer_callback = None
self._unsupported_callback = None

def reset_read(self):
self.depth = 0
self.ref_resolver.reset_read()
self.type_resolver.reset_read()
self.serialization_context.reset()
self.metastring_resolver.reset_write()
self.unpickler = None
self._buffers = None
self._unsupported_objects = None

Expand Down Expand Up @@ -562,20 +526,3 @@ def reset(self):
"1",
"true",
}


class _PicklerStub:
def dump(self, o):
raise ValueError(
f"Type {type(o)} is not registered, "
f"pickle is not allowed when type registration enabled, "
f"Please register the type or pass unsupported_callback"
)

def clear_memo(self):
pass


class _UnpicklerStub:
def load(self):
raise ValueError("pickle is not allowed when type registration enabled, Please register the type or pass unsupported_callback")
Loading
Loading