Skip to content

Commit ab19730

Browse files
gengliangwangdbtsai
authored andcommitted
[SPARK-25104][SQL] Avro: Validate user specified output schema
## What changes were proposed in this pull request? With code changes in #21847 , Spark can write out to Avro file as per user provided output schema. To make it more robust and user friendly, we should validate the Avro schema before tasks launched. Also we should support output logical decimal type as BYTES (By default we output as FIXED) ## How was this patch tested? Unit test Closes #22094 from gengliangwang/AvroSerializerMatch. Authored-by: Gengliang Wang <gengliang.wang@databricks.com> Signed-off-by: DB Tsai <d_tsai@apple.com>
1 parent c220cc4 commit ab19730

3 files changed

Lines changed: 158 additions & 47 deletions

File tree

external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala

Lines changed: 61 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.avro.Conversions.DecimalConversion
2626
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
2727
import org.apache.avro.Schema
2828
import org.apache.avro.Schema.Type
29+
import org.apache.avro.Schema.Type._
2930
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
3031
import org.apache.avro.generic.GenericData.Record
3132
import org.apache.avro.util.Utf8
@@ -72,62 +73,70 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
7273
private lazy val decimalConversions = new DecimalConversion()
7374

7475
private def newConverter(catalystType: DataType, avroType: Schema): Converter = {
75-
catalystType match {
76-
case NullType =>
76+
(catalystType, avroType.getType) match {
77+
case (NullType, NULL) =>
7778
(getter, ordinal) => null
78-
case BooleanType =>
79+
case (BooleanType, BOOLEAN) =>
7980
(getter, ordinal) => getter.getBoolean(ordinal)
80-
case ByteType =>
81+
case (ByteType, INT) =>
8182
(getter, ordinal) => getter.getByte(ordinal).toInt
82-
case ShortType =>
83+
case (ShortType, INT) =>
8384
(getter, ordinal) => getter.getShort(ordinal).toInt
84-
case IntegerType =>
85+
case (IntegerType, INT) =>
8586
(getter, ordinal) => getter.getInt(ordinal)
86-
case LongType =>
87+
case (LongType, LONG) =>
8788
(getter, ordinal) => getter.getLong(ordinal)
88-
case FloatType =>
89+
case (FloatType, FLOAT) =>
8990
(getter, ordinal) => getter.getFloat(ordinal)
90-
case DoubleType =>
91+
case (DoubleType, DOUBLE) =>
9192
(getter, ordinal) => getter.getDouble(ordinal)
92-
case d: DecimalType =>
93+
case (d: DecimalType, FIXED)
94+
if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
9395
(getter, ordinal) =>
9496
val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
9597
decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
9698
LogicalTypes.decimal(d.precision, d.scale))
9799

98-
case StringType => avroType.getType match {
99-
case Type.ENUM =>
100-
import scala.collection.JavaConverters._
101-
val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
102-
(getter, ordinal) =>
103-
val data = getter.getUTF8String(ordinal).toString
104-
if (!enumSymbols.contains(data)) {
105-
throw new IncompatibleSchemaException(
106-
"Cannot write \"" + data + "\" since it's not defined in enum \"" +
107-
enumSymbols.mkString("\", \"") + "\"")
108-
}
109-
new EnumSymbol(avroType, data)
110-
case _ =>
111-
(getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
112-
}
113-
case BinaryType => avroType.getType match {
114-
case Type.FIXED =>
115-
val size = avroType.getFixedSize()
116-
(getter, ordinal) =>
117-
val data: Array[Byte] = getter.getBinary(ordinal)
118-
if (data.length != size) {
119-
throw new IncompatibleSchemaException(
120-
s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " +
121-
"binary data into FIXED Type with size of " +
122-
s"$size ${if (size > 1) "bytes" else "byte"}")
123-
}
124-
new Fixed(avroType, data)
125-
case _ =>
126-
(getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
127-
}
128-
case DateType =>
100+
case (d: DecimalType, BYTES)
101+
if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) =>
102+
(getter, ordinal) =>
103+
val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
104+
decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType,
105+
LogicalTypes.decimal(d.precision, d.scale))
106+
107+
case (StringType, ENUM) =>
108+
val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
109+
(getter, ordinal) =>
110+
val data = getter.getUTF8String(ordinal).toString
111+
if (!enumSymbols.contains(data)) {
112+
throw new IncompatibleSchemaException(
113+
"Cannot write \"" + data + "\" since it's not defined in enum \"" +
114+
enumSymbols.mkString("\", \"") + "\"")
115+
}
116+
new EnumSymbol(avroType, data)
117+
118+
case (StringType, STRING) =>
119+
(getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
120+
121+
case (BinaryType, FIXED) =>
122+
val size = avroType.getFixedSize()
123+
(getter, ordinal) =>
124+
val data: Array[Byte] = getter.getBinary(ordinal)
125+
if (data.length != size) {
126+
throw new IncompatibleSchemaException(
127+
s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " +
128+
"binary data into FIXED Type with size of " +
129+
s"$size ${if (size > 1) "bytes" else "byte"}")
130+
}
131+
new Fixed(avroType, data)
132+
133+
case (BinaryType, BYTES) =>
134+
(getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
135+
136+
case (DateType, INT) =>
129137
(getter, ordinal) => getter.getInt(ordinal)
130-
case TimestampType => avroType.getLogicalType match {
138+
139+
case (TimestampType, LONG) => avroType.getLogicalType match {
131140
case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000
132141
case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal)
133142
// For backward compatibility, if the Avro type is Long and it is not logical type,
@@ -137,7 +146,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
137146
s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}")
138147
}
139148

140-
case ArrayType(et, containsNull) =>
149+
case (ArrayType(et, containsNull), ARRAY) =>
141150
val elementConverter = newConverter(
142151
et, resolveNullableType(avroType.getElementType, containsNull))
143152
(getter, ordinal) => {
@@ -158,12 +167,12 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
158167
java.util.Arrays.asList(result: _*)
159168
}
160169

161-
case st: StructType =>
170+
case (st: StructType, RECORD) =>
162171
val structConverter = newStructConverter(st, avroType)
163172
val numFields = st.length
164173
(getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
165174

166-
case MapType(kt, vt, valueContainsNull) if kt == StringType =>
175+
case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
167176
val valueConverter = newConverter(
168177
vt, resolveNullableType(avroType.getValueType, valueContainsNull))
169178
(getter, ordinal) =>
@@ -185,12 +194,17 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
185194
result
186195

187196
case other =>
188-
throw new IncompatibleSchemaException(s"Unexpected type: $other")
197+
throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystType to " +
198+
s"Avro type $avroType.")
189199
}
190200
}
191201

192202
private def newStructConverter(
193203
catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = {
204+
if (avroStruct.getType != RECORD) {
205+
throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " +
206+
s"Avro type $avroStruct.")
207+
}
194208
val avroFields = avroStruct.getFields
195209
assert(avroFields.size() == catalystStruct.length)
196210
val fieldConverters = catalystStruct.zip(avroFields.asScala).map {
@@ -212,7 +226,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
212226
}
213227

214228
private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = {
215-
if (nullable) {
229+
if (nullable && avroType.getType != NULL) {
216230
// avro uses union to represent nullable type.
217231
val fields = avroType.getTypes.asScala
218232
assert(fields.length == 2)

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,46 @@ class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestU
267267
}
268268
}
269269

270+
test("Logical type: write Decimal with BYTES type") {
271+
val specifiedSchema = """
272+
{
273+
"type" : "record",
274+
"name" : "topLevelRecord",
275+
"namespace" : "topLevelRecord",
276+
"fields" : [ {
277+
"name" : "bytes",
278+
"type" : [ {
279+
"type" : "bytes",
280+
"namespace" : "topLevelRecord.bytes",
281+
"logicalType" : "decimal",
282+
"precision" : 4,
283+
"scale" : 2
284+
}, "null" ]
285+
}, {
286+
"name" : "fixed",
287+
"type" : [ {
288+
"type" : "bytes",
289+
"logicalType" : "decimal",
290+
"precision" : 4,
291+
"scale" : 2
292+
}, "null" ]
293+
} ]
294+
}
295+
"""
296+
withTempDir { dir =>
297+
val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath)
298+
assert(specifiedSchema != avroSchema)
299+
val expected =
300+
decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) }
301+
val df = spark.read.format("avro").load(avroFile)
302+
303+
withTempPath { path =>
304+
df.write.format("avro").option("avroSchema", specifiedSchema).save(path.toString)
305+
checkAnswer(spark.read.format("avro").load(path.toString), expected)
306+
}
307+
}
308+
}
309+
270310
test("Logical type: Decimal with too large precision") {
271311
withTempDir { dir =>
272312
val schema = new Schema.Parser().parse("""{

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import scala.collection.JavaConverters._
2727

2828
import org.apache.avro.Schema
2929
import org.apache.avro.Schema.{Field, Type}
30+
import org.apache.avro.Schema.Type._
3031
import org.apache.avro.file.{DataFileReader, DataFileWriter}
3132
import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord}
3233
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
@@ -850,6 +851,62 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
850851
}
851852
}
852853

854+
test("throw exception if unable to write with user provided Avro schema") {
855+
val input: Seq[(DataType, Schema.Type)] = Seq(
856+
(NullType, NULL),
857+
(BooleanType, BOOLEAN),
858+
(ByteType, INT),
859+
(ShortType, INT),
860+
(IntegerType, INT),
861+
(LongType, LONG),
862+
(FloatType, FLOAT),
863+
(DoubleType, DOUBLE),
864+
(BinaryType, BYTES),
865+
(DateType, INT),
866+
(TimestampType, LONG),
867+
(DecimalType(4, 2), BYTES)
868+
)
869+
def assertException(f: () => AvroSerializer) {
870+
val message = intercept[org.apache.spark.sql.avro.IncompatibleSchemaException] {
871+
f()
872+
}.getMessage
873+
assert(message.contains("Cannot convert Catalyst type"))
874+
}
875+
876+
def resolveNullable(schema: Schema, nullable: Boolean): Schema = {
877+
if (nullable && schema.getType != NULL) {
878+
Schema.createUnion(schema, Schema.create(NULL))
879+
} else {
880+
schema
881+
}
882+
}
883+
for {
884+
i <- input
885+
j <- input
886+
nullable <- Seq(true, false)
887+
} if (i._2 != j._2) {
888+
val avroType = resolveNullable(Schema.create(j._2), nullable)
889+
val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable)
890+
val avroMapType = resolveNullable(Schema.createMap(avroType), nullable)
891+
val name = "foo"
892+
val avroField = new Field(name, avroType, "", null)
893+
val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava)
894+
val avroRecordType = resolveNullable(recordSchema, nullable)
895+
896+
val catalystType = i._1
897+
val catalystArrayType = ArrayType(catalystType, nullable)
898+
val catalystMapType = MapType(StringType, catalystType, nullable)
899+
val catalystStructType = StructType(Seq(StructField(name, catalystType, nullable)))
900+
901+
for {
902+
avro <- Seq(avroType, avroArrayType, avroMapType, avroRecordType)
903+
catalyst <- Seq(catalystType, catalystArrayType, catalystMapType, catalystStructType)
904+
} {
905+
assertException(() => new AvroSerializer(catalyst, avro, nullable))
906+
}
907+
}
908+
}
909+
853910
test("reading from invalid path throws exception") {
854911

855912
// Directory given has no avro files

0 commit comments

Comments
 (0)