@@ -26,6 +26,7 @@ import org.apache.avro.Conversions.DecimalConversion
2626import org .apache .avro .LogicalTypes .{TimestampMicros , TimestampMillis }
2727import org .apache .avro .Schema
2828import org .apache .avro .Schema .Type
29+ import org .apache .avro .Schema .Type ._
2930import org .apache .avro .generic .GenericData .{EnumSymbol , Fixed , Record }
3031import org .apache .avro .generic .GenericData .Record
3132import org .apache .avro .util .Utf8
@@ -72,62 +73,70 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
7273 private lazy val decimalConversions = new DecimalConversion ()
7374
7475 private def newConverter (catalystType : DataType , avroType : Schema ): Converter = {
75- catalystType match {
76- case NullType =>
76+ ( catalystType, avroType.getType) match {
77+ case ( NullType , NULL ) =>
7778 (getter, ordinal) => null
78- case BooleanType =>
79+ case ( BooleanType , BOOLEAN ) =>
7980 (getter, ordinal) => getter.getBoolean(ordinal)
80- case ByteType =>
81+ case ( ByteType , INT ) =>
8182 (getter, ordinal) => getter.getByte(ordinal).toInt
82- case ShortType =>
83+ case ( ShortType , INT ) =>
8384 (getter, ordinal) => getter.getShort(ordinal).toInt
84- case IntegerType =>
85+ case ( IntegerType , INT ) =>
8586 (getter, ordinal) => getter.getInt(ordinal)
86- case LongType =>
87+ case ( LongType , LONG ) =>
8788 (getter, ordinal) => getter.getLong(ordinal)
88- case FloatType =>
89+ case ( FloatType , FLOAT ) =>
8990 (getter, ordinal) => getter.getFloat(ordinal)
90- case DoubleType =>
91+ case ( DoubleType , DOUBLE ) =>
9192 (getter, ordinal) => getter.getDouble(ordinal)
92- case d : DecimalType =>
93+ case (d : DecimalType , FIXED )
94+ if avroType.getLogicalType == LogicalTypes .decimal(d.precision, d.scale) =>
9395 (getter, ordinal) =>
9496 val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
9597 decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
9698 LogicalTypes .decimal(d.precision, d.scale))
9799
98- case StringType => avroType.getType match {
99- case Type .ENUM =>
100- import scala .collection .JavaConverters ._
101- val enumSymbols : Set [String ] = avroType.getEnumSymbols.asScala.toSet
102- (getter, ordinal) =>
103- val data = getter.getUTF8String(ordinal).toString
104- if (! enumSymbols.contains(data)) {
105- throw new IncompatibleSchemaException (
106- " Cannot write \" " + data + " \" since it's not defined in enum \" " +
107- enumSymbols.mkString(" \" , \" " ) + " \" " )
108- }
109- new EnumSymbol (avroType, data)
110- case _ =>
111- (getter, ordinal) => new Utf8 (getter.getUTF8String(ordinal).getBytes)
112- }
113- case BinaryType => avroType.getType match {
114- case Type .FIXED =>
115- val size = avroType.getFixedSize()
116- (getter, ordinal) =>
117- val data : Array [Byte ] = getter.getBinary(ordinal)
118- if (data.length != size) {
119- throw new IncompatibleSchemaException (
120- s " Cannot write ${data.length} ${if (data.length > 1 ) " bytes" else " byte" } of " +
121- " binary data into FIXED Type with size of " +
122- s " $size ${if (size > 1 ) " bytes" else " byte" }" )
123- }
124- new Fixed (avroType, data)
125- case _ =>
126- (getter, ordinal) => ByteBuffer .wrap(getter.getBinary(ordinal))
127- }
128- case DateType =>
100+ case (d : DecimalType , BYTES )
101+ if avroType.getLogicalType == LogicalTypes .decimal(d.precision, d.scale) =>
102+ (getter, ordinal) =>
103+ val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
104+ decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType,
105+ LogicalTypes .decimal(d.precision, d.scale))
106+
107+ case (StringType , ENUM ) =>
108+ val enumSymbols : Set [String ] = avroType.getEnumSymbols.asScala.toSet
109+ (getter, ordinal) =>
110+ val data = getter.getUTF8String(ordinal).toString
111+ if (! enumSymbols.contains(data)) {
112+ throw new IncompatibleSchemaException (
113+ " Cannot write \" " + data + " \" since it's not defined in enum \" " +
114+ enumSymbols.mkString(" \" , \" " ) + " \" " )
115+ }
116+ new EnumSymbol (avroType, data)
117+
118+ case (StringType , STRING ) =>
119+ (getter, ordinal) => new Utf8 (getter.getUTF8String(ordinal).getBytes)
120+
121+ case (BinaryType , FIXED ) =>
122+ val size = avroType.getFixedSize()
123+ (getter, ordinal) =>
124+ val data : Array [Byte ] = getter.getBinary(ordinal)
125+ if (data.length != size) {
126+ throw new IncompatibleSchemaException (
127+ s " Cannot write ${data.length} ${if (data.length > 1 ) " bytes" else " byte" } of " +
128+ " binary data into FIXED Type with size of " +
129+ s " $size ${if (size > 1 ) " bytes" else " byte" }" )
130+ }
131+ new Fixed (avroType, data)
132+
133+ case (BinaryType , BYTES ) =>
134+ (getter, ordinal) => ByteBuffer .wrap(getter.getBinary(ordinal))
135+
136+ case (DateType , INT ) =>
129137 (getter, ordinal) => getter.getInt(ordinal)
130- case TimestampType => avroType.getLogicalType match {
138+
139+ case (TimestampType , LONG ) => avroType.getLogicalType match {
131140 case _ : TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000
132141 case _ : TimestampMicros => (getter, ordinal) => getter.getLong(ordinal)
133142 // For backward compatibility, if the Avro type is Long and it is not logical type,
@@ -137,7 +146,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
137146 s " Cannot convert Catalyst Timestamp type to Avro logical type ${other}" )
138147 }
139148
140- case ArrayType (et, containsNull) =>
149+ case ( ArrayType (et, containsNull), ARRAY ) =>
141150 val elementConverter = newConverter(
142151 et, resolveNullableType(avroType.getElementType, containsNull))
143152 (getter, ordinal) => {
@@ -158,12 +167,12 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
158167 java.util.Arrays .asList(result : _* )
159168 }
160169
161- case st : StructType =>
170+ case ( st : StructType , RECORD ) =>
162171 val structConverter = newStructConverter(st, avroType)
163172 val numFields = st.length
164173 (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
165174
166- case MapType (kt, vt, valueContainsNull) if kt == StringType =>
175+ case ( MapType (kt, vt, valueContainsNull), MAP ) if kt == StringType =>
167176 val valueConverter = newConverter(
168177 vt, resolveNullableType(avroType.getValueType, valueContainsNull))
169178 (getter, ordinal) =>
@@ -185,12 +194,17 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
185194 result
186195
187196 case other =>
188- throw new IncompatibleSchemaException (s " Unexpected type: $other" )
197+ throw new IncompatibleSchemaException (s " Cannot convert Catalyst type $catalystType to " +
198+ s " Avro type $avroType. " )
189199 }
190200 }
191201
192202 private def newStructConverter (
193203 catalystStruct : StructType , avroStruct : Schema ): InternalRow => Record = {
204+ if (avroStruct.getType != RECORD ) {
205+ throw new IncompatibleSchemaException (s " Cannot convert Catalyst type $catalystStruct to " +
206+ s " Avro type $avroStruct. " )
207+ }
194208 val avroFields = avroStruct.getFields
195209 assert(avroFields.size() == catalystStruct.length)
196210 val fieldConverters = catalystStruct.zip(avroFields.asScala).map {
@@ -212,7 +226,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
212226 }
213227
214228 private def resolveNullableType (avroType : Schema , nullable : Boolean ): Schema = {
215- if (nullable) {
229+ if (nullable && avroType.getType != NULL ) {
216230 // avro uses union to represent nullable type.
217231 val fields = avroType.getTypes.asScala
218232 assert(fields.length == 2 )
0 commit comments