Skip to content

Commit 6634819

Browse files
gengliangwangcloud-fan
authored andcommitted
[SPARK-25718][SQL] Detect recursive reference in Avro schema and throw exception
## What changes were proposed in this pull request? Avro schema allows recursive reference, e.g. the schema for linked-list in https://avro.apache.org/docs/1.8.2/spec.html#schema_record ``` { "type": "record", "name": "LongList", "aliases": ["LinkedLongs"], // old name for this "fields" : [ {"name": "value", "type": "long"}, // each element has a long {"name": "next", "type": ["null", "LongList"]} // optional next element ] } ``` In current Spark SQL, it is impossible to convert the schema as `StructType` . Run `SchemaConverters.toSqlType(avroSchema)` and we will get stack overflow exception. We should detect the recursive reference and throw exception for it. ## How was this patch tested? New unit test case. Closes #22709 from gengliangwang/avroRecursiveRef. Authored-by: Gengliang Wang <gengliang.wang@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 2eaf058) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 765cbca commit 6634819

2 files changed

Lines changed: 84 additions & 7 deletions

File tree

external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ object SchemaConverters {
4343
* This function takes an avro schema and returns a sql schema.
4444
*/
4545
def toSqlType(avroSchema: Schema): SchemaType = {
46+
toSqlTypeHelper(avroSchema, Set.empty)
47+
}
48+
49+
def toSqlTypeHelper(avroSchema: Schema, existingRecordNames: Set[String]): SchemaType = {
4650
avroSchema.getType match {
4751
case INT => avroSchema.getLogicalType match {
4852
case _: Date => SchemaType(DateType, nullable = false)
@@ -67,21 +71,28 @@ object SchemaConverters {
6771
case ENUM => SchemaType(StringType, nullable = false)
6872

6973
case RECORD =>
74+
if (existingRecordNames.contains(avroSchema.getFullName)) {
75+
throw new IncompatibleSchemaException(s"""
76+
|Found recursive reference in Avro schema, which can not be processed by Spark:
77+
|${avroSchema.toString(true)}
78+
""".stripMargin)
79+
}
80+
val newRecordNames = existingRecordNames + avroSchema.getFullName
7081
val fields = avroSchema.getFields.asScala.map { f =>
71-
val schemaType = toSqlType(f.schema())
82+
val schemaType = toSqlTypeHelper(f.schema(), newRecordNames)
7283
StructField(f.name, schemaType.dataType, schemaType.nullable)
7384
}
7485

7586
SchemaType(StructType(fields), nullable = false)
7687

7788
case ARRAY =>
78-
val schemaType = toSqlType(avroSchema.getElementType)
89+
val schemaType = toSqlTypeHelper(avroSchema.getElementType, existingRecordNames)
7990
SchemaType(
8091
ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
8192
nullable = false)
8293

8394
case MAP =>
84-
val schemaType = toSqlType(avroSchema.getValueType)
95+
val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames)
8596
SchemaType(
8697
MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
8798
nullable = false)
@@ -91,13 +102,14 @@ object SchemaConverters {
91102
// In case of a union with null, eliminate it and make a recursive call
92103
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
93104
if (remainingUnionTypes.size == 1) {
94-
toSqlType(remainingUnionTypes.head).copy(nullable = true)
105+
toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames).copy(nullable = true)
95106
} else {
96-
toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true)
107+
toSqlTypeHelper(Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames)
108+
.copy(nullable = true)
97109
}
98110
} else avroSchema.getTypes.asScala.map(_.getType) match {
99111
case Seq(t1) =>
100-
toSqlType(avroSchema.getTypes.get(0))
112+
toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames)
101113
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
102114
SchemaType(LongType, nullable = false)
103115
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
@@ -107,7 +119,7 @@ object SchemaConverters {
107119
// This is consistent with the behavior when converting between Avro and Parquet.
108120
val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
109121
case (s, i) =>
110-
val schemaType = toSqlType(s)
122+
val schemaType = toSqlTypeHelper(s, existingRecordNames)
111123
// All fields are nullable because only one of them is set at a time
112124
StructField(s"member$i", schemaType.dataType, nullable = true)
113125
}

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,4 +1266,69 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
12661266
checkCodec(df, path, "xz")
12671267
}
12681268
}
1269+
1270+
private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = {
1271+
val message = intercept[IncompatibleSchemaException] {
1272+
SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema))
1273+
}.getMessage
1274+
1275+
assert(message.contains("Found recursive reference in Avro schema"))
1276+
}
1277+
1278+
test("Detect recursive loop") {
1279+
checkSchemaWithRecursiveLoop("""
1280+
|{
1281+
| "type": "record",
1282+
| "name": "LongList",
1283+
| "fields" : [
1284+
| {"name": "value", "type": "long"}, // each element has a long
1285+
| {"name": "next", "type": ["null", "LongList"]} // optional next element
1286+
| ]
1287+
|}
1288+
""".stripMargin)
1289+
1290+
checkSchemaWithRecursiveLoop("""
1291+
|{
1292+
| "type": "record",
1293+
| "name": "LongList",
1294+
| "fields": [
1295+
| {
1296+
| "name": "value",
1297+
| "type": {
1298+
| "type": "record",
1299+
| "name": "foo",
1300+
| "fields": [
1301+
| {
1302+
| "name": "parent",
1303+
| "type": "LongList"
1304+
| }
1305+
| ]
1306+
| }
1307+
| }
1308+
| ]
1309+
|}
1310+
""".stripMargin)
1311+
1312+
checkSchemaWithRecursiveLoop("""
1313+
|{
1314+
| "type": "record",
1315+
| "name": "LongList",
1316+
| "fields" : [
1317+
| {"name": "value", "type": "long"},
1318+
| {"name": "array", "type": {"type": "array", "items": "LongList"}}
1319+
| ]
1320+
|}
1321+
""".stripMargin)
1322+
1323+
checkSchemaWithRecursiveLoop("""
1324+
|{
1325+
| "type": "record",
1326+
| "name": "LongList",
1327+
| "fields" : [
1328+
| {"name": "value", "type": "long"},
1329+
| {"name": "map", "type": {"type": "map", "values": "LongList"}}
1330+
| ]
1331+
|}
1332+
""".stripMargin)
1333+
}
12691334
}

0 commit comments

Comments
 (0)