Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._
import org.apache.avro.generic._
import org.apache.avro.util.Utf8
Expand Down Expand Up @@ -86,8 +87,18 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
case (LONG, LongType) => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long])

case (LONG, TimestampType) => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
case (LONG, TimestampType) => avroType.getLogicalType match {
case _: TimestampMillis => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
case _: TimestampMicros => (updater, ordinal, value) =>
updater.setLong(ordinal, value.asInstanceOf[Long])
case null => (updater, ordinal, value) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, add a default case.

// For backward compatibility, if the Avro type is Long and it is not logical type,
// the value is processed as timestamp type with millisecond precision.
updater.setLong(ordinal, value.asInstanceOf[Long] * 1000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment to say it's for backward compatibility reasons. Also we should only do it when logical type is null. For other logical types, we should fail here.

case other => throw new IncompatibleSchemaException(
s"Cannot convert Avro logical type ${other} to Catalyst Timestamp type.")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add a default case and throw IncompatibleSchemaException, in case avro add more logical types for long type in the future.


case (LONG, DateType) => (updater, ordinal, value) =>
updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ private[avro] class AvroFileFormat extends FileFormat
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf())
val outputAvroSchema = SchemaConverters.toAvroType(
dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace)
val outputAvroSchema = SchemaConverters.toAvroType(dataSchema, nullable = false,
parsedOptions.recordName, parsedOptions.recordNamespace, parsedOptions.outputTimestampType)

AvroJob.setOutputKeySchema(job, outputAvroSchema)
val COMPRESS_KEY = "mapred.output.compress"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType

/**
* Options for Avro Reader and Writer stored in case insensitive manner.
Expand Down Expand Up @@ -79,4 +80,14 @@ class AvroOptions(
val compression: String = {
parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec)
}

/**
* Avro timestamp type used when Spark writes data to Avro files.
* Currently supported types are `TIMESTAMP_MICROS` and `TIMESTAMP_MILLIS`.
* TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of microseconds
* from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with millisecond precision,
* which means Spark has to truncate the microsecond portion of its timestamp value.
* The related configuration is set via SQLConf, and it is not exposed as an option.
*/
val outputTimestampType: AvroOutputTimestampType.Value = SQLConf.get.avroOutputTimestampType
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.ByteBuffer

import scala.collection.JavaConverters._

import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema
import org.apache.avro.Schema.Type.NULL
import org.apache.avro.generic.GenericData.Record
Expand Down Expand Up @@ -92,8 +93,15 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
(getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
case DateType =>
(getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY
case TimestampType =>
(getter, ordinal) => getter.getLong(ordinal) / 1000
case TimestampType => avroType.getLogicalType match {
case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000
case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal)
// For backward compatibility, if the Avro type is Long and it is not logical type,
// output the timestamp value as with millisecond precision.
case null => (getter, ordinal) => getter.getLong(ordinal) / 1000
case other => throw new IncompatibleSchemaException(
s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}")
}

case ArrayType(et, containsNull) =>
val elementConverter = newConverter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.sql.avro

import scala.collection.JavaConverters._

import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._

import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType
import org.apache.spark.sql.types._

/**
Expand All @@ -42,7 +44,10 @@ object SchemaConverters {
case BYTES => SchemaType(BinaryType, nullable = false)
case DOUBLE => SchemaType(DoubleType, nullable = false)
case FLOAT => SchemaType(FloatType, nullable = false)
case LONG => SchemaType(LongType, nullable = false)
case LONG => avroSchema.getLogicalType match {
case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false)
case _ => SchemaType(LongType, nullable = false)
}
case FIXED => SchemaType(BinaryType, nullable = false)
case ENUM => SchemaType(StringType, nullable = false)

Expand Down Expand Up @@ -103,31 +108,49 @@ object SchemaConverters {
catalystType: DataType,
nullable: Boolean = false,
recordName: String = "topLevelRecord",
prevNameSpace: String = ""): Schema = {
prevNameSpace: String = "",
outputTimestampType: AvroOutputTimestampType.Value = AvroOutputTimestampType.TIMESTAMP_MICROS)
: Schema = {
val builder = if (nullable) {
SchemaBuilder.builder().nullable()
} else {
SchemaBuilder.builder()
}

catalystType match {
case BooleanType => builder.booleanType()
case ByteType | ShortType | IntegerType => builder.intType()
case LongType => builder.longType()
case DateType => builder.longType()
case TimestampType => builder.longType()
case TimestampType =>
val timestampType = outputTimestampType match {
case AvroOutputTimestampType.TIMESTAMP_MILLIS => LogicalTypes.timestampMillis()
case AvroOutputTimestampType.TIMESTAMP_MICROS => LogicalTypes.timestampMicros()
case other =>
throw new IncompatibleSchemaException(s"Unexpected output timestamp type $other.")
}
if (nullable) {
val avroType = timestampType.addToSchema(SchemaBuilder.builder().longType())
builder.`type`(avroType)
} else {
timestampType.addToSchema(builder.longType())
}
case FloatType => builder.floatType()
case DoubleType => builder.doubleType()
case _: DecimalType | StringType => builder.stringType()
case BinaryType => builder.bytesType()
case ArrayType(et, containsNull) =>
builder.array().items(toAvroType(et, containsNull, recordName, prevNameSpace))
builder.array()
.items(toAvroType(et, containsNull, recordName, prevNameSpace, outputTimestampType))
case MapType(StringType, vt, valueContainsNull) =>
builder.map().values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace))
builder.map()
.values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace, outputTimestampType))
case st: StructType =>
val nameSpace = s"$prevNameSpace.$recordName"
val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields()
st.foreach { f =>
val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, nameSpace)
val fieldAvroType =
toAvroType(f.dataType, f.nullable, f.name, nameSpace, outputTimestampType)
fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault()
}
fieldsAssembler.endRecord()
Expand Down
Binary file added external/avro/src/test/resources/timestamp.avro
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.sql.types._
class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val episodesAvro = testFile("episodes.avro")
val testAvro = testFile("test.avro")
val timestampAvro = testFile("timestamp.avro")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at least we should provide how the binary file is generated, or just do roundtrip test: Spark write avro files and then read it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema and data is stated in https://github.com/apache/spark/pull/21935/files#diff-9364b0610f92b3cc35a4bc43a80751bfR397
It should be easy to get from test cases.
The other test file episodesAvro also doesn't provide how it is generated.


override protected def beforeAll(): Unit = {
super.beforeAll()
Expand Down Expand Up @@ -331,6 +332,84 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}

test("Logical type: timestamp_millis") {
val sparkSession = spark
import sparkSession.implicits._

val expected =
Seq(1L, 666L).toDF("timestamp_millis").select('timestamp_millis.cast(TimestampType)).collect()
val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis)

checkAnswer(df, expected)

withTempPath { dir =>
df.write.format("avro").save(dir.toString)
checkAnswer(spark.read.format("avro").load(dir.toString), expected)
}
}

test("Logical type: timestamp_micros") {
val sparkSession = spark
import sparkSession.implicits._

val expected =
Seq(2L, 999L).toDF("timestamp_micros").select('timestamp_micros.cast(TimestampType)).collect()
val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros)

checkAnswer(df, expected)

withTempPath { dir =>
df.write.format("avro").save(dir.toString)
checkAnswer(spark.read.format("avro").load(dir.toString), expected)
}
}

test("Logical type: specify different output timestamp types") {
val sparkSession = spark
import sparkSession.implicits._

val df = spark.read.format("avro").load(timestampAvro)

val expected = Seq((1L, 2L), (666L, 999L))
.toDF("timestamp_millis", "timestamp_micros")
.select('timestamp_millis.cast(TimestampType), 'timestamp_micros.cast(TimestampType))
.collect()

Seq("TIMESTAMP_MILLIS", "TIMESTAMP_MICROS").foreach { timestampType =>
withSQLConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE.key -> timestampType) {
withTempPath { dir =>
df.write.format("avro").save(dir.toString)
checkAnswer(spark.read.format("avro").load(dir.toString), expected)
}
}
}
}

test("Logical type: user specified schema") {
val sparkSession = spark
import sparkSession.implicits._

val expected = Seq((1L, 2L), (666L, 999L))
.toDF("timestamp_millis", "timestamp_micros")
.select('timestamp_millis.cast(TimestampType), 'timestamp_micros.cast(TimestampType))
.collect()

val avroSchema = s"""
{
"namespace": "logical",
"type": "record",
"name": "test",
"fields": [
{"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}},
{"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}
]
}
"""
val df = spark.read.format("avro").option("avroSchema", avroSchema).load(timestampAvro)

checkAnswer(df, expected)
}

test("Array data types") {
withTempPath { dir =>
val testSchema = StructType(Seq(
Expand Down Expand Up @@ -511,7 +590,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {

// TimesStamps are converted to longs
val times = spark.read.format("avro").load(avroDir).select("Time").collect()
assert(times.map(_(0)).toSet == Set(666, 777, 42))
assert(times.map(_(0)).toSet ==
Set(new Timestamp(666), new Timestamp(777), new Timestamp(42)))

// DecimalType should be converted to string
val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,21 @@ object SQLConf {
.intConf
.createWithDefault(20)

object AvroOutputTimestampType extends Enumeration {
val TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value
}

val AVRO_OUTPUT_TIMESTAMP_TYPE = buildConf("spark.sql.avro.outputTimestampType")
.doc("Sets which Avro timestamp type to use when Spark writes data to Avro files. " +
"TIMESTAMP_MICROS is a logical timestamp type in Avro, which stores number of " +
"microseconds from the Unix epoch. TIMESTAMP_MILLIS is also logical, but with " +
"millisecond precision, which means Spark has to truncate the microsecond portion of its " +
"timestamp value.")
.stringConf
.transform(_.toUpperCase(Locale.ROOT))
.checkValues(AvroOutputTimestampType.values.map(_.toString))
.createWithDefault(AvroOutputTimestampType.TIMESTAMP_MICROS.toString)

val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec")
.doc("Compression codec used in writing of AVRO files. Default codec is snappy.")
.stringConf
Expand Down Expand Up @@ -1835,6 +1850,9 @@ class SQLConf extends Serializable with Logging {

def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE)

def avroOutputTimestampType: AvroOutputTimestampType.Value =
AvroOutputTimestampType.withName(getConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE))

def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC)

def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL)
Expand Down