diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala index 6ed330d92f5e6..7732c83aaf7c4 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.avro import java.io.ByteArrayOutputStream +import org.apache.avro.Schema import org.apache.avro.generic.GenericDatumWriter import org.apache.avro.io.{BinaryEncoder, EncoderFactory} @@ -26,12 +27,16 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{BinaryType, DataType} -case class CatalystDataToAvro(child: Expression) extends UnaryExpression { +case class CatalystDataToAvro( + child: Expression, + jsonFormatSchema: Option[String]) extends UnaryExpression { override def dataType: DataType = BinaryType @transient private lazy val avroType = - SchemaConverters.toAvroType(child.dataType, child.nullable) + jsonFormatSchema + .map(new Schema.Parser().parse) + .getOrElse(SchemaConverters.toAvroType(child.dataType, child.nullable)) @transient private lazy val serializer = new AvroSerializer(child.dataType, avroType, child.nullable) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala index 5ed7828510d54..a6ae3906c6d80 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala @@ -72,6 +72,19 @@ object functions { */ @Experimental def to_avro(data: Column): Column = { - new Column(CatalystDataToAvro(data.expr)) + new Column(CatalystDataToAvro(data.expr, None)) + } + + /** + * Converts a column into binary of avro format. + * + * @param data the data column. + * @param jsonFormatSchema user-specified output avro schema in JSON string format. + * + * @since 3.0.0 + */ + @Experimental + def to_avro(data: Column, jsonFormatSchema: String): Column = { + new Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema))) } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 80dd4c535ad9c..27915562fded0 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.avro import org.apache.avro.Schema -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} @@ -38,12 +38,12 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite private def checkResult(data: Literal, schema: String, expected: Any): Unit = { checkEvaluation( - AvroDataToCatalyst(CatalystDataToAvro(data), schema, Map.empty), + AvroDataToCatalyst(CatalystDataToAvro(data, None), schema, Map.empty), prepareExpectedResult(expected)) } protected def checkUnsupportedRead(data: Literal, schema: String): Unit = { - val binary = CatalystDataToAvro(data) + val binary = CatalystDataToAvro(data, None) intercept[Exception] { AvroDataToCatalyst(binary, schema, Map("mode" -> "FAILFAST")).eval() } @@ -209,4 +209,41 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite checkUnsupportedRead(input, avroSchema) } } + + test("user-specified output schema") { + val data = Literal("SPADES") + val jsonFormatSchema = + """ + |{ "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + |} + """.stripMargin + + val message = intercept[SparkException] { + AvroDataToCatalyst( + CatalystDataToAvro( + data, + None), + jsonFormatSchema, + options = Map.empty).eval() + }.getMessage + assert(message.contains("Malformed records are detected in record parsing.")) + + checkEvaluation( + AvroDataToCatalyst( + CatalystDataToAvro( + data, + Some(jsonFormatSchema)), + jsonFormatSchema, + options = Map.empty), + data.eval()) + } + + test("invalid user-specified output schema") { + val message = intercept[IncompatibleSchemaException] { + CatalystDataToAvro(Literal("SPADES"), Some("\"long\"")).eval() + }.getMessage + assert(message == "Cannot convert Catalyst type StringType to Avro type \"long\".") + } } diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index e07b625284175..711de6532e28c 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -69,7 +69,7 @@ def from_avro(data, jsonFormatSchema, options={}): @ignore_unicode_prefix @since(3.0) -def to_avro(data): +def to_avro(data, jsonFormatSchema=""): """ Converts a column into binary of avro format. @@ -77,18 +77,27 @@ def to_avro(data): application as per the deployment section of "Apache Avro Data Source Guide". :param data: the data column. + :param jsonFormatSchema: user-specified output avro schema in JSON string format. >>> from pyspark.sql import Row >>> from pyspark.sql.avro.functions import to_avro - >>> data = [(1, Row(name='Alice', age=2))] - >>> df = spark.createDataFrame(data, ("key", "value")) - >>> df.select(to_avro(df.value).alias("avro")).collect() - [Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))] + >>> data = ['SPADES'] + >>> df = spark.createDataFrame(data, "string") + >>> df.select(to_avro(df.value).alias("suite")).collect() + [Row(suite=bytearray(b'\\x00\\x0cSPADES'))] + >>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value", + ... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]''' + >>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect() + [Row(suite=bytearray(b'\\x02\\x00'))] """ sc = SparkContext._active_spark_context try: - jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(_to_java_column(data)) + if jsonFormatSchema == "": + jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(_to_java_column(data)) + else: + jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro( + _to_java_column(data), jsonFormatSchema) except TypeError as e: if str(e) == "'JavaPackage' object is not callable": _print_missing_jar("Avro", "avro", "avro", sc.version)