diff --git a/.travis.yml b/.travis.yml index b95bf7b4..518513b3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,17 +11,39 @@ before_cache: - find $HOME/.sbt -name "*.lock" -delete matrix: include: + # ---- Spark 2.0.x ---------------------------------------------------------------------------- + # Spark 2.0.0, Scala 2.11, and Avro 1.7.x + - jdk: openjdk7 + scala: 2.11.7 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.0.0, Scala 2.11, and Avro 1.8.x + - jdk: openjdk7 + scala: 2.11.7 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" + # Spark 2.0.0, Scala 2.10, and Avro 1.7.x + - jdk: openjdk7 + scala: 2.10.4 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.0.0, Scala 2.10, and Avro 1.8.x + - jdk: openjdk7 + scala: 2.10.4 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" + # ---- Spark 2.1.x ---------------------------------------------------------------------------- # Spark 2.1.0, Scala 2.11, and Avro 1.7.x - jdk: openjdk7 - scala: 2.11.8 + scala: 2.11.7 env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.1.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.1.0, Scala 2.11, and Avro 1.8.x + - jdk: openjdk7 + scala: 2.11.7 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.1.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" # Spark 2.1.0, Scala 2.10, and Avro 1.7.x - jdk: openjdk7 - scala: 2.10.6 + scala: 2.10.4 env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.1.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" # Spark 2.1.0, Scala 2.10, and Avro 1.8.x - jdk: openjdk7 - scala: 2.10.6 + scala: 2.10.4 env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.1.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" # Spark 2.2.0, Scala 2.11, and Avro 1.7.x - jdk: openjdk8 diff --git a/build.sbt b/build.sbt index a7faee8c..246f8000 100644 --- a/build.sbt +++ b/build.sbt @@ -1,14 +1,17 @@ -name := "spark-avro" -organization := "com.databricks" +lazy val commonSettings = Seq( + organization := "com.databricks", + scalaVersion := "2.11.7", + crossScalaVersions := Seq("2.10.5", "2.11.7") +) -scalaVersion := "2.11.8" +commonSettings -crossScalaVersions := Seq("2.10.6", "2.11.8") +name := "spark-avro" spName := "databricks/spark-avro" -sparkVersion := "2.1.0" +sparkVersion := "2.0.0" val testSparkVersion = settingKey[String]("The version of Spark to test against.") @@ -107,7 +110,7 @@ pomExtra := bintrayReleaseOnPublish in ThisBuild := false -import ReleaseTransformations._ +import sbtrelease.ReleasePlugin.autoImport.ReleaseTransformations._ // Add publishing to spark packages as another step. releaseProcess := Seq[ReleaseStep]( @@ -123,3 +126,26 @@ releaseProcess := Seq[ReleaseStep]( pushChanges, releaseStepTask(spPublish) ) + + +lazy val spark21xProj = project.in(file("spark-2.1.x")).settings( + commonSettings, + libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.1.0" % "provided" +).disablePlugins(SparkPackagePlugin) + + +lazy val spark20xProj = project.in(file("spark-2.0.x")).settings( + commonSettings, + libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.0.0" % "provided" +).disablePlugins(SparkPackagePlugin) + + +unmanagedClasspath in Test ++= { + (exportedProducts in (spark20xProj, Runtime)).value ++ + (exportedProducts in (spark21xProj, Runtime)).value +} + +products in (Compile, packageBin) ++= Seq( + (classDirectory in (spark20xProj, Compile)).value, + (classDirectory in (spark21xProj, Compile)).value +) \ No newline at end of file diff --git a/spark-2.0.x/src/main/scala/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala b/spark-2.0.x/src/main/scala/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala new file mode 100644 index 00000000..72855b43 --- /dev/null +++ b/spark-2.0.x/src/main/scala/com/databricks/spark/avro/Spark20AvroOutputWriterFactory.scala @@ -0,0 +1,52 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.avro + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{TaskAttemptContext, TaskAttemptID} +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.StructType + +private[avro] class Spark20AvroOutputWriterFactory( + schema: StructType, + recordName: String, + recordNamespace: String) extends OutputWriterFactory { + + def doGetDefaultWorkFile(path: String, context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId: TaskAttemptID = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + + def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + + val ot = Class.forName("com.databricks.spark.avro.AvroOutputWriter") + val meth = ot.getDeclaredConstructor( + classOf[String], classOf[TaskAttemptContext], classOf[StructType], + classOf[String], classOf[String], + classOf[Function3[String, TaskAttemptContext, String, Path]] + ) + meth.setAccessible(true) + meth.newInstance(path, context, schema, recordName, recordNamespace, doGetDefaultWorkFile _) + .asInstanceOf[OutputWriter] + } +} diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala b/spark-2.1.x/src/main/scala/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala similarity index 55% rename from src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala rename to spark-2.1.x/src/main/scala/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala index 3f3cbf07..6bfb5eb0 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala +++ b/spark-2.1.x/src/main/scala/com/databricks/spark/avro/Spark21AvroOutputWriterFactory.scala @@ -16,24 +16,35 @@ package com.databricks.spark.avro +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext - import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType -private[avro] class AvroOutputWriterFactory( +private[avro] class Spark21AvroOutputWriterFactory( schema: StructType, recordName: String, recordNamespace: String) extends OutputWriterFactory { - override def getFileExtension(context: TaskAttemptContext): String = { - ".avro" + def doGetDefaultWorkFile(path: String, context: TaskAttemptContext, extension: String): Path = { + new Path(path) } - override def newInstance( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - new AvroOutputWriter(path, context, schema, recordName, recordNamespace) + def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + + val ot = Class.forName("com.databricks.spark.avro.AvroOutputWriter") + val meth = ot.getDeclaredConstructor( + classOf[String], classOf[TaskAttemptContext], classOf[StructType], + classOf[String], classOf[String], + classOf[Function3[String, TaskAttemptContext, String, Path]] + ) + meth.setAccessible(true) + meth.newInstance(path, context, schema, recordName, recordNamespace, doGetDefaultWorkFile _) + .asInstanceOf[OutputWriter] } + + override def getFileExtension(context: TaskAttemptContext): String = ".avro" } diff --git a/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala new file mode 100644 index 00000000..41acfa2d --- /dev/null +++ b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala @@ -0,0 +1,801 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.avro + +/** + * A Spark-SQL Encoder for Avro objects + */ + +import java.io._ +import java.util.{Map => JMap} + +import com.databricks.spark.avro.SchemaConverters.{IncompatibleSchemaException, SchemaType, resolveUnionType, toSqlType} +import org.apache.avro.Schema +import org.apache.avro.Schema.Parser +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic.{GenericData, IndexedRecord} +import org.apache.avro.reflect.ReflectData +import org.apache.avro.specific.SpecificRecord +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable => _, _} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +/** + * A Spark-SQL Encoder for Avro objects + */ +object AvroEncoder { + /** + * Provides an Encoder for Avro objects of the given class + * + * @param avroClass the class of the Avro object for which to generate the Encoder + * @tparam T the type of the Avro class, must implement SpecificRecord + * @return an Encoder for the given Avro class + */ + def of[T <: SpecificRecord](avroClass: Class[T]): Encoder[T] = { + AvroExpressionEncoder.of(avroClass) + } + + /** + * Provides an Encoder for Avro objects implementing the given schema + * + * @param avroSchema the Schema of the Avro object for which to generate the Encoder + * @tparam T the type of the Avro class that implements the Schema, must implement IndexedRecord + * @return an Encoder for the given Avro Schema + */ + def of[T <: IndexedRecord](avroSchema: Schema): Encoder[T] = { + AvroExpressionEncoder.of(avroSchema) + } +} + +private[avro] object ObjectType { + val ot = Class.forName("org.apache.spark.sql.types.ObjectType") + val meth = ot.getDeclaredConstructor(classOf[Class[_]]) + meth.setAccessible(true) + + val cls = ot.getMethod("cls") + cls.setAccessible(true) + + def apply(cls: Class[_]): DataType = { + meth.newInstance(cls).asInstanceOf[DataType] + } + + def _isInstanceOf(obj: AnyRef): Boolean = { + ot.isInstance(obj) + } + + def unapply(arg: DataType): Option[Class[_]] = { + arg match { + case arg if ot.isInstance(arg) => { + Some(cls.invoke(arg).asInstanceOf[Class[_]]) + } + case _ => None + } + } +} + +private[avro] case class LambdaVariable( + value: String, + isNull: String, + dataType: DataType, + nullable: Boolean = true) extends LeafExpression + with Unevaluable with NonSQLExpression { + + override def genCode(ctx: CodegenContext): ExprCode = { + ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + } +} + +private[avro] object ExternalMapToCatalyst { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + def apply( + inputMap: Expression, + keyType: DataType, + keyConverter: Expression => Expression, + valueType: DataType, + valueConverter: Expression => Expression, + valueNullable: Boolean): ExternalMapToCatalyst = { + val id = curId.getAndIncrement() + val keyName = "ExternalMapToCatalyst_key" + id + val valueName = "ExternalMapToCatalyst_value" + id + val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id + + ExternalMapToCatalyst( + keyName, + keyType, + keyConverter(LambdaVariable(keyName, "false", keyType, false)), + valueName, + valueIsNull, + valueType, + valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)), + inputMap + ) + } +} + +private[avro] case class ExternalMapToCatalyst private( + key: String, + keyType: DataType, + keyConverter: Expression, + value: String, + valueIsNull: String, + valueType: DataType, + valueConverter: Expression, + child: Expression) + extends UnaryExpression with NonSQLExpression { + + override def foldable: Boolean = false + + override def dataType: MapType = MapType( + keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val inputMap = child.genCode(ctx) + val genKeyConverter = keyConverter.genCode(ctx) + val genValueConverter = valueConverter.genCode(ctx) + val length = ctx.freshName("length") + val index = ctx.freshName("index") + val convertedKeys = ctx.freshName("convertedKeys") + val convertedValues = ctx.freshName("convertedValues") + val entry = ctx.freshName("entry") + val entries = ctx.freshName("entries") + + val (defineEntries, defineKeyValue) = child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + val javaIteratorCls = classOf[java.util.Iterator[_]].getName + val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName + + val defineEntries = + s"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();" + + val defineKeyValue = + s""" + final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); + ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey(); + ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + """ + + defineEntries -> defineKeyValue + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + val scalaIteratorCls = classOf[Iterator[_]].getName + val scalaMapEntryCls = classOf[Tuple2[_, _]].getName + + val defineEntries = s"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();" + + val defineKeyValue = + s""" + final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); + ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1(); + ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2(); + """ + + defineEntries -> defineKeyValue + } + + val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { + s"boolean $valueIsNull = false;" + } else { + s"boolean $valueIsNull = $value == null;" + } + + val arrayCls = classOf[GenericArrayData].getName + val mapCls = classOf[ArrayBasedMapData].getName + val convertedKeyType = ctx.boxedType(keyConverter.dataType) + val convertedValueType = ctx.boxedType(valueConverter.dataType) + val code = + s""" + ${inputMap.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${inputMap.isNull}) { + final int $length = ${inputMap.value}.size(); + final Object[] $convertedKeys = new Object[$length]; + final Object[] $convertedValues = new Object[$length]; + int $index = 0; + $defineEntries + while($entries.hasNext()) { + $defineKeyValue + $valueNullCheck + ${genKeyConverter.code} + if (${genKeyConverter.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $convertedKeys[$index] = ($convertedKeyType) ${genKeyConverter.value}; + } + ${genValueConverter.code} + if (${genValueConverter.isNull}) { + $convertedValues[$index] = null; + } else { + $convertedValues[$index] = ($convertedValueType) ${genValueConverter.value}; + } + $index++; + } + ${ev.value} = new $mapCls(new $arrayCls($convertedKeys), new $arrayCls($convertedValues)); + } + """ + ev.copy(code = code, isNull = inputMap.isNull) + } +} + + +class SerializableSchema(@transient var value: Schema) extends Externalizable { + def this() = this(null) + override def readExternal(in: ObjectInput): Unit = { + value = new Parser().parse(in.readObject().asInstanceOf[String]) + } + override def writeExternal(out: ObjectOutput): Unit = out.writeObject(value.toString) + def resolveUnion(datum: Any): Int = GenericData.get.resolveUnion(value, datum) +} + +object AvroExpressionEncoder { + def of[T <: SpecificRecord](avroClass: Class[T]): ExpressionEncoder[T] = { + val schema = avroClass.getMethod("getClassSchema").invoke(null).asInstanceOf[Schema] + assert(toSqlType(schema).dataType.isInstanceOf[StructType]) + + val serializer = AvroTypeInference.serializerFor(avroClass, schema) + val deserializer = AvroTypeInference.deserializerFor(schema) + + new ExpressionEncoder[T]( + toSqlType(schema).dataType.asInstanceOf[StructType], + flat = false, + serializer.flatten, + deserializer = deserializer, + ClassTag[T](avroClass)) + } + + def of[T <: IndexedRecord](schema: Schema): ExpressionEncoder[T] = { + assert(toSqlType(schema).dataType.isInstanceOf[StructType]) + + val avroClass = Option(ReflectData.get.getClass(schema)) + .map(_.asSubclass(classOf[SpecificRecord])) + .getOrElse(classOf[GenericData.Record]) + val serializer = AvroTypeInference.serializerFor(avroClass, schema) + val deserializer = AvroTypeInference.deserializerFor(schema) + + new ExpressionEncoder[T]( + toSqlType(schema).dataType.asInstanceOf[StructType], + flat = false, + serializer.flatten, + deserializer, + ClassTag[T](avroClass)) + } +} + +/** + * Utilities for providing Avro object serializers and deserializers + */ +private object AvroTypeInference { + /** + * Translates an Avro Schema type to a proper SQL DataType. The Java Objects that back data in + * generated Generic and Specific records sometimes do not align with those suggested by Avro + * ReflectData, so we infer the proper SQL DataType to serialize and deserialize based on + * nullability and the wrapping Schema type. + */ + private def inferExternalType(avroSchema: Schema): DataType = { + toSqlType(avroSchema) match { + // the non-nullable primitive types + case SchemaType(BooleanType, false) => BooleanType + case SchemaType(IntegerType, false) => IntegerType + case SchemaType(LongType, false) => + if (avroSchema.getType == UNION) { + ObjectType(classOf[java.lang.Number]) + } else { + LongType + } + case SchemaType(FloatType, false) => FloatType + case SchemaType(DoubleType, false) => + if (avroSchema.getType == UNION) { + ObjectType(classOf[java.lang.Number]) + } else { + DoubleType + } + // the nullable primitive types + case SchemaType(BooleanType, true) => ObjectType(classOf[java.lang.Boolean]) + case SchemaType(IntegerType, true) => ObjectType(classOf[java.lang.Integer]) + case SchemaType(LongType, true) => ObjectType(classOf[java.lang.Long]) + case SchemaType(FloatType, true) => ObjectType(classOf[java.lang.Float]) + case SchemaType(DoubleType, true) => ObjectType(classOf[java.lang.Double]) + // the binary types + case SchemaType(BinaryType, _) => + if (avroSchema.getType == FIXED) { + Option(ReflectData.get.getClass(avroSchema)) + .map(ObjectType(_)) + .getOrElse(ObjectType(classOf[GenericData.Fixed])) + } else { + ObjectType(classOf[java.nio.ByteBuffer]) + } + // the referenced types + case SchemaType(ArrayType(_, _), _) => + ObjectType(classOf[java.util.List[Object]]) + case SchemaType(StringType, _) => + avroSchema.getType match { + case ENUM => + Option(ReflectData.get.getClass(avroSchema)) + .map(ObjectType(_)) + .getOrElse(ObjectType(classOf[GenericData.EnumSymbol])) + case _ => + ObjectType(classOf[CharSequence]) + } + case SchemaType(StructType(_), _) => + Option(ReflectData.get.getClass(avroSchema)) + .map(ObjectType(_)) + .getOrElse(ObjectType(classOf[GenericData.Record])) + case SchemaType(MapType(_, _, _), _) => + ObjectType(classOf[java.util.Map[Object, Object]]) + } + } + + /** + * Returns an expression that can be used to deserialize an InternalRow to an Avro object of + * type `T` that implements IndexedRecord and is compatible with the given Schema + */ + def deserializerFor[T <: IndexedRecord] (avroSchema: Schema): Expression = { + deserializerFor(avroSchema, None) + } + + private def deserializerFor(avroSchema: Schema, path: Option[Expression]): Expression = { + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + def getPath: Expression = path.getOrElse( + GetColumnByOrdinal(0, inferExternalType(avroSchema))) + + avroSchema.getType match { + case BOOLEAN => + NewInstance( + classOf[java.lang.Boolean], + getPath :: Nil, + ObjectType(classOf[java.lang.Boolean])) + case INT => + NewInstance( + classOf[java.lang.Integer], + getPath :: Nil, + ObjectType(classOf[java.lang.Integer])) + case LONG => + NewInstance( + classOf[java.lang.Long], + getPath :: Nil, + ObjectType(classOf[java.lang.Long])) + case FLOAT => + NewInstance( + classOf[java.lang.Float], + getPath :: Nil, + ObjectType(classOf[java.lang.Float])) + case DOUBLE => + NewInstance( + classOf[java.lang.Double], + getPath :: Nil, + ObjectType(classOf[java.lang.Double])) + + case BYTES => + StaticInvoke( + classOf[java.nio.ByteBuffer], + ObjectType(classOf[java.nio.ByteBuffer]), + "wrap", + getPath :: Nil) + case FIXED => + val fixedClass = Option(ReflectData.get.getClass(avroSchema)) + .getOrElse(classOf[GenericData.Fixed]) + if (fixedClass == classOf[GenericData.Fixed]) { + NewInstance( + fixedClass, + Invoke( + Literal.fromObject( + new SerializableSchema(avroSchema), + ObjectType(classOf[SerializableSchema])), + "value", + ObjectType(classOf[Schema]), + Nil) :: + getPath :: + Nil, + ObjectType(fixedClass)) + } else { + NewInstance( + fixedClass, + getPath :: Nil, + ObjectType(fixedClass)) + } + + case STRING => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case ENUM => + val enumClass = Option(ReflectData.get.getClass(avroSchema)) + .getOrElse(classOf[GenericData.EnumSymbol]) + if (enumClass == classOf[GenericData.EnumSymbol]) { + NewInstance( + enumClass, + Invoke( + Literal.fromObject( + new SerializableSchema(avroSchema), + ObjectType(classOf[SerializableSchema])), + "value", + ObjectType(classOf[Schema]), + Nil) :: + Invoke(getPath, "toString", ObjectType(classOf[String])) :: + Nil, + ObjectType(enumClass)) + } else { + StaticInvoke( + enumClass, + ObjectType(enumClass), + "valueOf", + Invoke(getPath, "toString", ObjectType(classOf[String])) :: Nil) + } + + case ARRAY => + val elementSchema = avroSchema.getElementType + val elementType = toSqlType(elementSchema).dataType + val array = Invoke( + MapObjects(element => + deserializerFor(elementSchema, Some(element)), + getPath, + elementType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + classOf[java.util.Arrays], + ObjectType(classOf[java.util.List[Object]]), + "asList", + array :: Nil) + + case MAP => + val valueSchema = avroSchema.getValueType + val valueType = inferExternalType(valueSchema) match { + case t if t == ObjectType(classOf[java.lang.CharSequence]) => + StringType + case other => other + } + + val keyData = Invoke( + MapObjects( + p => deserializerFor(Schema.create(STRING), Some(p)), + Invoke(getPath, "keyArray", ArrayType(StringType)), + StringType), + "array", + ObjectType(classOf[Array[Any]])) + val valueData = Invoke( + MapObjects( + p => deserializerFor(valueSchema, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueType)), + valueType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData.getClass, + ObjectType(classOf[JMap[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil) + + case UNION => + val (resolvedSchema, _) = resolveUnionType(avroSchema) + if (resolvedSchema.getType == RECORD && + avroSchema.getTypes.asScala.filterNot(_.getType == NULL).length > 1) { + // A Union resolved to a record that originally had more than 1 type when filtered + // of its nulls must be complex + val bottom = Literal.create(null, ObjectType(classOf[Object])).asInstanceOf[Expression] + + resolvedSchema.getFields.foldLeft(bottom) { (tree: Expression, field: Schema.Field) => + val fieldValue = ObjectCast( + deserializerFor(field.schema, Some(addToPath(field.name))), + ObjectType(classOf[Object])) + + If(IsNull(fieldValue), tree, fieldValue) + } + } else { + deserializerFor(resolvedSchema, path) + } + + case RECORD => + val args = avroSchema.getFields.map { field => + val position = Literal(field.pos) + val argument = deserializerFor(field.schema, Some(addToPath(field.name))) + (position, argument) + }.toList + + val recordClass = Option(ReflectData.get.getClass(avroSchema)) + .getOrElse(classOf[GenericData.Record]) + val newInstance = if (recordClass == classOf[GenericData.Record]) { + NewInstance( + recordClass, + Invoke( + Literal.fromObject( + new SerializableSchema(avroSchema), + ObjectType(classOf[SerializableSchema])), + "value", + ObjectType(classOf[Schema]), + Nil) :: Nil, + ObjectType(recordClass)) + } else { + NewInstance( + recordClass, + Nil, + ObjectType(recordClass)) + } + + val result = InitializeAvroObject(newInstance, args) + + if (path.nonEmpty) { + If(IsNull(getPath), + Literal.create(null, ObjectType(recordClass)), + result) + } else { + result + } + + case NULL => + /* + * Encountering NULL at this level implies it was the type of a Field, which should never + * be the case + */ + throw new IncompatibleSchemaException("Null type should only be used in Union types") + } + } + + /** + * Returns an expression that can be used to serialize an Avro object with a class of type `T` + * that is compatible with the given Schema to an InternalRow + */ + def serializerFor[T <: IndexedRecord](avroClass: Class[T], avroSchema: Schema): + CreateNamedStruct = { + val inputObject = BoundReference(0, ObjectType(avroClass), nullable = true) + serializerFor(inputObject, avroSchema, topLevel = true).asInstanceOf[CreateNamedStruct] + } + + def serializerFor( + inputObject: Expression, + avroSchema: Schema, + topLevel: Boolean = false): Expression = { + + def toCatalystArray(inputObject: Expression, schema: Schema): Expression = { + val elementType = inferExternalType(schema) + + if (ObjectType._isInstanceOf(elementType)) { + MapObjects(element => + serializerFor(element, schema), + Invoke( + inputObject, + "toArray", + ObjectType(classOf[Array[Object]])), + elementType) + } else { + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = ArrayType(elementType, containsNull = false)) + } + } + + def toCatalystMap(inputObject: Expression, schema: Schema): Expression = { + val valueSchema = schema.getValueType + val valueType = inferExternalType(valueSchema) + + ExternalMapToCatalyst( + inputObject, + ObjectType(classOf[org.apache.avro.util.Utf8]), + serializerFor(_, Schema.create(STRING)), + valueType, + serializerFor(_, valueSchema), + true) + } + + if (!ObjectType._isInstanceOf(inputObject.dataType)) { + inputObject + } else { + avroSchema.getType match { + case BOOLEAN => + Invoke(inputObject, "booleanValue", BooleanType) + case INT => + Invoke(inputObject, "intValue", IntegerType) + case LONG => + Invoke(inputObject, "longValue", LongType) + case FLOAT => + Invoke(inputObject, "floatValue", FloatType) + case DOUBLE => + Invoke(inputObject, "doubleValue", DoubleType) + + case BYTES => + Invoke(inputObject, "array", BinaryType) + case FIXED => + Invoke(inputObject, "bytes", BinaryType) + + case STRING => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + Invoke(inputObject, "toString", ObjectType(classOf[java.lang.String])) :: Nil) + + case ENUM => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + Invoke(inputObject, "toString", ObjectType(classOf[java.lang.String])) :: Nil) + + case ARRAY => + val elementSchema = avroSchema.getElementType + toCatalystArray(inputObject, elementSchema) + + case MAP => + toCatalystMap(inputObject, avroSchema) + + case UNION => + val unionWithoutNulls = Schema.createUnion( + avroSchema.getTypes.asScala.filterNot(_.getType == NULL)) + val (resolvedSchema, nullable) = resolveUnionType(avroSchema) + if (resolvedSchema.getType == RECORD && unionWithoutNulls.getTypes.length > 1) { + // A Union resolved to a record that originally had more than 1 type when filtered + // of its nulls must be complex + val complexStruct = CreateNamedStruct( + resolvedSchema.getFields.zipWithIndex.flatMap { case (field, index) => + val unionIndex = Invoke( + Literal.fromObject( + new SerializableSchema(unionWithoutNulls), + ObjectType(classOf[SerializableSchema])), + "resolveUnion", + IntegerType, + inputObject :: Nil) + + val fieldValue = If(EqualTo(Literal(index), unionIndex), +// val fieldValue = If(EqualTo(Literal(index), Literal.fromObject(1, IntegerType)), + serializerFor( + ObjectCast( + inputObject, + inferExternalType(field.schema())), + field.schema), + Literal.create(null, toSqlType(field.schema()).dataType)) + + Literal(field.name) :: serializerFor(fieldValue, field.schema) :: Nil}) + + complexStruct + + } else { + if (nullable) { + serializerFor(inputObject, resolvedSchema) + } else { + serializerFor( + AssertNotNull(inputObject, Seq(avroSchema.getTypes.toString)), + resolvedSchema) + } + } + + case RECORD => + val createStruct = CreateNamedStruct( + avroSchema.getFields.flatMap { field => + val fieldValue = Invoke( + inputObject, + "get", + inferExternalType(field.schema), + Literal(field.pos) :: Nil) + Literal(field.name) :: serializerFor(fieldValue, field.schema) :: Nil}) + if (topLevel) { + createStruct + } else { + If(IsNull(inputObject), + Literal.create(null, createStruct.dataType), + createStruct) + } + + case NULL => + /* + * Encountering NULL at this level implies it was the type of a Field, which should never + * be the case + */ + throw new IncompatibleSchemaException("Null type should only be used in Union types") + } + } + } + + /** + * Initializes an Avro Record instance (that implements the IndexedRecord interface) by calling + * the `put` method on a the Record instance with the provided position and value arguments + * + * @param objectInstance an expression that will evaluate to the Record instance + * @param args a sequence of expression pairs that will respectively evaluate to the index of + * the record in which to insert, and the argument value to insert + */ + private case class InitializeAvroObject( + objectInstance: Expression, + args: List[(Expression, Expression)]) extends Expression with NonSQLExpression { + + override def nullable: Boolean = objectInstance.nullable + override def children: Seq[Expression] = objectInstance +: args.map { case (_, v) => v } + override def dataType: DataType = objectInstance.dataType + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val instanceGen = objectInstance.genCode(ctx) + + val avroInstance = ctx.freshName("avroObject") + val avroInstanceJavaType = ctx.javaType(objectInstance.dataType) + ctx.addMutableState(avroInstanceJavaType, avroInstance, "") + + val initialize = args.map { + case (posExpr, argExpr) => + val posGen = posExpr.genCode(ctx) + val argGen = argExpr.genCode(ctx) + s""" + ${posGen.code} + ${argGen.code} + $avroInstance.put(${posGen.value}, ${argGen.value}); + """ + } + + val initExpressions = ctx.splitExpressions(ctx.INPUT_ROW, initialize) + val code = + s""" + ${instanceGen.code} + $avroInstance = ${instanceGen.value}; + if (!${instanceGen.isNull}) { + $initExpressions + } + """ + ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) + } + } + + /** + * Casts an expression to another object. + * + * @param value The value to cast + * @param resultType The type the value should be cast to. + */ + private case class ObjectCast( + value : Expression, + resultType: DataType) extends Expression with NonSQLExpression { + + override def nullable: Boolean = value.nullable + override def dataType: DataType = resultType + override def children: Seq[Expression] = value :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + + val javaType = ctx.javaType(resultType) + val obj = value.genCode(ctx) + + val code = s""" + ${obj.code} + final $javaType ${ev.value} = ($javaType) ${obj.value}; + """ + + ev.copy(code = code, isNull = obj.isNull) + } + } +} diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index 809bd395..53f514e8 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -24,32 +24,43 @@ import java.util.HashMap import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import scala.collection.immutable.Map +import scala.collection.immutable.Map import org.apache.avro.generic.GenericData.Record import org.apache.avro.generic.GenericRecord -import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.mapred.AvroKey import org.apache.avro.mapreduce.AvroKeyOutputFormat +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, TaskAttemptID} - +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ +import scala.collection.immutable.Map + // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[avro] class AvroOutputWriter( path: String, context: TaskAttemptContext, schema: StructType, recordName: String, - recordNamespace: String) extends OutputWriter { + recordNamespace: String, + workPathFunc: (String, TaskAttemptContext, String) => Path) extends OutputWriter { private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) + // copy of the old conversion logic after api change in SPARK-19085 - private lazy val internalRowConverter = - CatalystTypeConverters.createToScalaConverter(schema).asInstanceOf[InternalRow => Row] + // + // Need to use reflection for Spark versions < 2.2.0 + // + private lazy val internalRowConverter = { + val clazz = Class.forName("org.apache.spark.sql.catalyst.CatalystTypeConverters$") + val m = clazz.getDeclaredMethod("createToScalaConverter", classOf[DataType]) + val obj = clazz.getField("MODULE$").get(null) + m.invoke(obj, schema).asInstanceOf[InternalRow => Row] + } /** * Overrides the couple of methods responsible for generating the output streams / files so @@ -59,7 +70,7 @@ private[avro] class AvroOutputWriter( new AvroKeyOutputFormat[GenericRecord]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(path) + workPathFunc(path, context, extension) } @throws(classOf[IOException]) diff --git a/src/main/scala/com/databricks/spark/avro/DefaultSource.scala b/src/main/scala/com/databricks/spark/avro/DefaultSource.scala index bfbadd7c..6d90e392 100644 --- a/src/main/scala/com/databricks/spark/avro/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/avro/DefaultSource.scala @@ -20,21 +20,17 @@ import java.io._ import java.net.URI import java.util.zip.Deflater -import scala.util.control.NonFatal - import com.databricks.spark.avro.DefaultSource.{AvroSchema, IgnoreFilesWithoutExtensionProperty, SerializableConfiguration} -import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.avro.{Schema, SchemaBuilder} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import org.apache.avro.file.{DataFileConstants, DataFileReader} import org.apache.avro.generic.{GenericDatumReader, GenericRecord} import org.apache.avro.mapred.{AvroOutputFormat, FsInput} import org.apache.avro.mapreduce.AvroJob +import org.apache.avro.{Schema, SchemaBuilder} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.Job -import org.slf4j.LoggerFactory - import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -43,6 +39,9 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.sources.{DataSourceRegister, Filter} import org.apache.spark.sql.types.StructType +import org.slf4j.LoggerFactory + +import scala.util.control.NonFatal private[avro] class DefaultSource extends FileFormat with DataSourceRegister { private val log = LoggerFactory.getLogger(getClass) @@ -142,7 +141,18 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister { log.error(s"unsupported compression codec $unknown") } - new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace) + val clz = spark.version match { + case v if v.startsWith("2.0.") => { + Class.forName("com.databricks.spark.avro.Spark20AvroOutputWriterFactory") + } + case v => { + Class.forName("com.databricks.spark.avro.Spark21AvroOutputWriterFactory") + } + } + + val m = clz.getDeclaredConstructor(classOf[StructType], classOf[String], classOf[String]) + m.setAccessible(true) + m.newInstance(dataSchema, recordName, recordNamespace).asInstanceOf[OutputWriterFactory] } override def buildReader( diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index cbaed7fa..41f7e303 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -18,13 +18,11 @@ package com.databricks.spark.avro import java.nio.ByteBuffer import scala.collection.JavaConverters._ - import org.apache.avro.generic.GenericData.Fixed import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.SchemaBuilder._ import org.apache.avro.Schema.Type._ - import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._ @@ -74,38 +72,50 @@ object SchemaConverters { nullable = false) case UNION => - if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { - // In case of a union with null, eliminate it and make a recursive call - val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) - if (remainingUnionTypes.size == 1) { - toSqlType(remainingUnionTypes.head).copy(nullable = true) - } else { - toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) - } - } else avroSchema.getTypes.asScala.map(_.getType) match { - case Seq(t1) => - toSqlType(avroSchema.getTypes.get(0)) - case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => - SchemaType(LongType, nullable = false) - case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => - SchemaType(DoubleType, nullable = false) - case _ => - // Convert complex unions to struct types where field names are member0, member1, etc. - // This is consistent with the behavior when converting between Avro and Parquet. - val fields = avroSchema.getTypes.asScala.zipWithIndex.map { - case (s, i) => - val schemaType = toSqlType(s) - // All fields are nullable because only one of them is set at a time - StructField(s"member$i", schemaType.dataType, nullable = true) - } - - SchemaType(StructType(fields), nullable = false) + resolveUnionType(avroSchema) match { + case (schema, nullable) => toSqlType(schema).copy(nullable = nullable) } case other => throw new IncompatibleSchemaException(s"Unsupported type $other") } } + /** + * Resolves an avro UNION type to an SQL-compatible avro type. Converts complex unions to records + * if necessary. + */ + def resolveUnionType(avroSchema: Schema, nullable: Boolean = false): (Schema, Boolean) = { + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it, and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + (remainingUnionTypes.head, true) + } else { + resolveUnionType(Schema.createUnion(remainingUnionTypes.asJava), nullable = true) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => + (avroSchema.getTypes.get(0), true) + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + (Schema.create(LONG), false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + (Schema.create(DOUBLE), false) + case _ => + // Convert complex unions to records where field names are member0, member1, etc. + // This is consistent with the behavior when converting between Avro and Parquet. + val record = SchemaBuilder.record(avroSchema.getName).fields() + avroSchema.getTypes.asScala.zipWithIndex.foreach { + case (s, i) => + // All fields are nullable because only one of them is set at a time + record.name(s"member$i").`type`(SchemaBuilder.unionOf() + .`type`(Schema.create(NULL)).and + .`type`(s).endUnion()) + .withDefault(null) + } + (record.endRecord(), false) + } + } + /** * This function converts sparkSQL StructType into avro schema. This method uses two other * converter methods in order to do the conversion. diff --git a/src/test/java/com/databricks/spark/avro/ByteArray.java b/src/test/java/com/databricks/spark/avro/ByteArray.java new file mode 100644 index 00000000..d29b38d7 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/ByteArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ByteArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ByteArray\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"bytes\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List value; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use {@link \#newBuilder()}. + */ + public ByteArray() {} + + /** + * All-args constructor. + */ + public ByteArray(java.util.List value) { + this.value = value; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return value; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: value = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'value' field. + */ + public java.util.List getValue() { + return value; + } + + /** + * Sets the value of the 'value' field. + * @param value the value to set. + */ + public void setValue(java.util.List value) { + this.value = value; + } + + /** Creates a new ByteArray RecordBuilder */ + public static ByteArray.Builder newBuilder() { + return new ByteArray.Builder(); + } + + /** Creates a new ByteArray RecordBuilder by copying an existing Builder */ + public static ByteArray.Builder newBuilder(ByteArray.Builder other) { + return new ByteArray.Builder(other); + } + + /** Creates a new ByteArray RecordBuilder by copying an existing ByteArray instance */ + public static ByteArray.Builder newBuilder(ByteArray other) { + return new ByteArray.Builder(other); + } + + /** + * RecordBuilder for ByteArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List value; + + /** Creates a new Builder */ + private Builder() { + super(ByteArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(ByteArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing ByteArray instance */ + private Builder(ByteArray other) { + super(ByteArray.SCHEMA$); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'value' field */ + public java.util.List getValue() { + return value; + } + + /** Sets the value of the 'value' field */ + public ByteArray.Builder setValue(java.util.List value) { + validate(fields()[0], value); + this.value = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'value' field has been set */ + public boolean hasValue() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'value' field */ + public ByteArray.Builder clearValue() { + value = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public ByteArray build() { + try { + ByteArray record = new ByteArray(); + record.value = fieldSetFlags()[0] ? this.value : (java.util.List) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/DoubleArray.java b/src/test/java/com/databricks/spark/avro/DoubleArray.java new file mode 100644 index 00000000..470fdc70 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/DoubleArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class DoubleArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"DoubleArray\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"double\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List value; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use {@link \#newBuilder()}. + */ + public DoubleArray() {} + + /** + * All-args constructor. + */ + public DoubleArray(java.util.List value) { + this.value = value; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return value; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: value = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'value' field. + */ + public java.util.List getValue() { + return value; + } + + /** + * Sets the value of the 'value' field. + * @param value the value to set. + */ + public void setValue(java.util.List value) { + this.value = value; + } + + /** Creates a new DoubleArray RecordBuilder */ + public static DoubleArray.Builder newBuilder() { + return new DoubleArray.Builder(); + } + + /** Creates a new DoubleArray RecordBuilder by copying an existing Builder */ + public static DoubleArray.Builder newBuilder(DoubleArray.Builder other) { + return new DoubleArray.Builder(other); + } + + /** Creates a new DoubleArray RecordBuilder by copying an existing DoubleArray instance */ + public static DoubleArray.Builder newBuilder(DoubleArray other) { + return new DoubleArray.Builder(other); + } + + /** + * RecordBuilder for DoubleArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List value; + + /** Creates a new Builder */ + private Builder() { + super(DoubleArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(DoubleArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing DoubleArray instance */ + private Builder(DoubleArray other) { + super(DoubleArray.SCHEMA$); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'value' field */ + public java.util.List getValue() { + return value; + } + + /** Sets the value of the 'value' field */ + public DoubleArray.Builder setValue(java.util.List value) { + validate(fields()[0], value); + this.value = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'value' field has been set */ + public boolean hasValue() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'value' field */ + public DoubleArray.Builder clearValue() { + value = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public DoubleArray build() { + try { + DoubleArray record = new DoubleArray(); + record.value = fieldSetFlags()[0] ? this.value : (java.util.List) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/Feature.java b/src/test/java/com/databricks/spark/avro/Feature.java new file mode 100644 index 00000000..c421a183 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/Feature.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class Feature extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Feature\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"key\",\"type\":\"string\"},{\"name\":\"value\",\"type\":[{\"type\":\"record\",\"name\":\"DoubleArray\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"double\"}}]},{\"type\":\"record\",\"name\":\"StringArray\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"string\"}}]},{\"type\":\"record\",\"name\":\"ByteArray\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"bytes\"}}]}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.lang.CharSequence key; + @Deprecated public java.lang.Object value; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use {@link \#newBuilder()}. + */ + public Feature() {} + + /** + * All-args constructor. + */ + public Feature(java.lang.CharSequence key, java.lang.Object value) { + this.key = key; + this.value = value; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return key; + case 1: return value; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: key = (java.lang.CharSequence)value$; break; + case 1: value = (java.lang.Object)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'key' field. + */ + public java.lang.CharSequence getKey() { + return key; + } + + /** + * Sets the value of the 'key' field. + * @param value the value to set. + */ + public void setKey(java.lang.CharSequence value) { + this.key = value; + } + + /** + * Gets the value of the 'value' field. + */ + public java.lang.Object getValue() { + return value; + } + + /** + * Sets the value of the 'value' field. + * @param value the value to set. + */ + public void setValue(java.lang.Object value) { + this.value = value; + } + + /** Creates a new Feature RecordBuilder */ + public static Feature.Builder newBuilder() { + return new Feature.Builder(); + } + + /** Creates a new Feature RecordBuilder by copying an existing Builder */ + public static Feature.Builder newBuilder(Feature.Builder other) { + return new Feature.Builder(other); + } + + /** Creates a new Feature RecordBuilder by copying an existing Feature instance */ + public static Feature.Builder newBuilder(Feature other) { + return new Feature.Builder(other); + } + + /** + * RecordBuilder for Feature instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.lang.CharSequence key; + private java.lang.Object value; + + /** Creates a new Builder */ + private Builder() { + super(Feature.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(Feature.Builder other) { + super(other); + if (isValidValue(fields()[0], other.key)) { + this.key = data().deepCopy(fields()[0].schema(), other.key); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.value)) { + this.value = data().deepCopy(fields()[1].schema(), other.value); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing Feature instance */ + private Builder(Feature other) { + super(Feature.SCHEMA$); + if (isValidValue(fields()[0], other.key)) { + this.key = data().deepCopy(fields()[0].schema(), other.key); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.value)) { + this.value = data().deepCopy(fields()[1].schema(), other.value); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'key' field */ + public java.lang.CharSequence getKey() { + return key; + } + + /** Sets the value of the 'key' field */ + public Feature.Builder setKey(java.lang.CharSequence value) { + validate(fields()[0], value); + this.key = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'key' field has been set */ + public boolean hasKey() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'key' field */ + public Feature.Builder clearKey() { + key = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'value' field */ + public java.lang.Object getValue() { + return value; + } + + /** Sets the value of the 'value' field */ + public Feature.Builder setValue(java.lang.Object value) { + validate(fields()[1], value); + this.value = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'value' field has been set */ + public boolean hasValue() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'value' field */ + public Feature.Builder clearValue() { + value = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public Feature build() { + try { + Feature record = new Feature(); + record.key = fieldSetFlags()[0] ? this.key : (java.lang.CharSequence) defaultValue(fields()[0]); + record.value = fieldSetFlags()[1] ? this.value : (java.lang.Object) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/SimpleEnums.java b/src/test/java/com/databricks/spark/avro/SimpleEnums.java new file mode 100644 index 00000000..5989c620 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/SimpleEnums.java @@ -0,0 +1,13 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public enum SimpleEnums { + SPADES, HEARTS, DIAMONDS, CLUBS ; + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"SimpleEnums\",\"namespace\":\"com.databricks.spark.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } +} diff --git a/src/test/java/com/databricks/spark/avro/SimpleFixed.java b/src/test/java/com/databricks/spark/avro/SimpleFixed.java new file mode 100644 index 00000000..184b51f5 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/SimpleFixed.java @@ -0,0 +1,42 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; + +import org.apache.avro.Schema; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + +@SuppressWarnings("all") +@org.apache.avro.specific.FixedSize(16) +@org.apache.avro.specific.AvroGenerated +public class SimpleFixed extends org.apache.avro.specific.SpecificFixed { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"fixed\",\"name\":\"SimpleFixed\",\"namespace\":\"com.databricks.spark.avro\",\"size\":16}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + + /** Creates a new SimpleFixed */ + public SimpleFixed() { + super(); + } + + /** Creates a new SimpleFixed with the given bytes */ + public SimpleFixed(byte[] bytes) { + super(bytes); + } + + public void writeExternal(ObjectOutput out) throws IOException { + // + } + + public void readExternal(ObjectInput in) throws IOException { + // + } + + public Schema getSchema() { + return getClassSchema(); + } +} diff --git a/src/test/java/com/databricks/spark/avro/SimpleRecord.java b/src/test/java/com/databricks/spark/avro/SimpleRecord.java new file mode 100644 index 00000000..a36161ed --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/SimpleRecord.java @@ -0,0 +1,195 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class SimpleRecord extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"SimpleRecord\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"nested1\",\"type\":\"int\",\"default\":0},{\"name\":\"nested2\",\"type\":\"string\",\"default\":\"string\"}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public int nested1; + @Deprecated public java.lang.CharSequence nested2; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public SimpleRecord() {} + + /** + * All-args constructor. + */ + public SimpleRecord(java.lang.Integer nested1, java.lang.CharSequence nested2) { + this.nested1 = nested1; + this.nested2 = nested2; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return nested1; + case 1: return nested2; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: nested1 = (java.lang.Integer)value$; break; + case 1: nested2 = (java.lang.CharSequence)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'nested1' field. + */ + public java.lang.Integer getNested1() { + return nested1; + } + + /** + * Sets the value of the 'nested1' field. + * @param value the value to set. + */ + public void setNested1(java.lang.Integer value) { + this.nested1 = value; + } + + /** + * Gets the value of the 'nested2' field. + */ + public java.lang.CharSequence getNested2() { + return nested2; + } + + /** + * Sets the value of the 'nested2' field. + * @param value the value to set. + */ + public void setNested2(java.lang.CharSequence value) { + this.nested2 = value; + } + + /** Creates a new SimpleRecord RecordBuilder */ + public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder() { + return new com.databricks.spark.avro.SimpleRecord.Builder(); + } + + /** Creates a new SimpleRecord RecordBuilder by copying an existing Builder */ + public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder(com.databricks.spark.avro.SimpleRecord.Builder other) { + return new com.databricks.spark.avro.SimpleRecord.Builder(other); + } + + /** Creates a new SimpleRecord RecordBuilder by copying an existing SimpleRecord instance */ + public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder(com.databricks.spark.avro.SimpleRecord other) { + return new com.databricks.spark.avro.SimpleRecord.Builder(other); + } + + /** + * RecordBuilder for SimpleRecord instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private int nested1; + private java.lang.CharSequence nested2; + + /** Creates a new Builder */ + private Builder() { + super(com.databricks.spark.avro.SimpleRecord.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(com.databricks.spark.avro.SimpleRecord.Builder other) { + super(other); + if (isValidValue(fields()[0], other.nested1)) { + this.nested1 = data().deepCopy(fields()[0].schema(), other.nested1); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested2)) { + this.nested2 = data().deepCopy(fields()[1].schema(), other.nested2); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing SimpleRecord instance */ + private Builder(com.databricks.spark.avro.SimpleRecord other) { + super(com.databricks.spark.avro.SimpleRecord.SCHEMA$); + if (isValidValue(fields()[0], other.nested1)) { + this.nested1 = data().deepCopy(fields()[0].schema(), other.nested1); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.nested2)) { + this.nested2 = data().deepCopy(fields()[1].schema(), other.nested2); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'nested1' field */ + public java.lang.Integer getNested1() { + return nested1; + } + + /** Sets the value of the 'nested1' field */ + public com.databricks.spark.avro.SimpleRecord.Builder setNested1(int value) { + validate(fields()[0], value); + this.nested1 = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'nested1' field has been set */ + public boolean hasNested1() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'nested1' field */ + public com.databricks.spark.avro.SimpleRecord.Builder clearNested1() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'nested2' field */ + public java.lang.CharSequence getNested2() { + return nested2; + } + + /** Sets the value of the 'nested2' field */ + public com.databricks.spark.avro.SimpleRecord.Builder setNested2(java.lang.CharSequence value) { + validate(fields()[1], value); + this.nested2 = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'nested2' field has been set */ + public boolean hasNested2() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'nested2' field */ + public com.databricks.spark.avro.SimpleRecord.Builder clearNested2() { + nested2 = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public SimpleRecord build() { + try { + SimpleRecord record = new SimpleRecord(); + record.nested1 = fieldSetFlags()[0] ? this.nested1 : (java.lang.Integer) defaultValue(fields()[0]); + record.nested2 = fieldSetFlags()[1] ? this.nested2 : (java.lang.CharSequence) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/StringArray.java b/src/test/java/com/databricks/spark/avro/StringArray.java new file mode 100644 index 00000000..ce980d12 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/StringArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class StringArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"StringArray\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"value\",\"type\":{\"type\":\"array\",\"items\":\"string\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List value; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use {@link \#newBuilder()}. + */ + public StringArray() {} + + /** + * All-args constructor. + */ + public StringArray(java.util.List value) { + this.value = value; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return value; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: value = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'value' field. + */ + public java.util.List getValue() { + return value; + } + + /** + * Sets the value of the 'value' field. + * @param value the value to set. + */ + public void setValue(java.util.List value) { + this.value = value; + } + + /** Creates a new StringArray RecordBuilder */ + public static StringArray.Builder newBuilder() { + return new StringArray.Builder(); + } + + /** Creates a new StringArray RecordBuilder by copying an existing Builder */ + public static StringArray.Builder newBuilder(StringArray.Builder other) { + return new StringArray.Builder(other); + } + + /** Creates a new StringArray RecordBuilder by copying an existing StringArray instance */ + public static StringArray.Builder newBuilder(StringArray other) { + return new StringArray.Builder(other); + } + + /** + * RecordBuilder for StringArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List value; + + /** Creates a new Builder */ + private Builder() { + super(StringArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(StringArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing StringArray instance */ + private Builder(StringArray other) { + super(StringArray.SCHEMA$); + if (isValidValue(fields()[0], other.value)) { + this.value = data().deepCopy(fields()[0].schema(), other.value); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'value' field */ + public java.util.List getValue() { + return value; + } + + /** Sets the value of the 'value' field */ + public StringArray.Builder setValue(java.util.List value) { + validate(fields()[0], value); + this.value = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'value' field has been set */ + public boolean hasValue() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'value' field */ + public StringArray.Builder clearValue() { + value = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public StringArray build() { + try { + StringArray record = new StringArray(); + record.value = fieldSetFlags()[0] ? this.value : (java.util.List) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/java/com/databricks/spark/avro/TestRecord.java b/src/test/java/com/databricks/spark/avro/TestRecord.java new file mode 100644 index 00000000..dd323bb7 --- /dev/null +++ b/src/test/java/com/databricks/spark/avro/TestRecord.java @@ -0,0 +1,893 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package com.databricks.spark.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class TestRecord extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"TestRecord\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"boolean\",\"type\":\"boolean\",\"default\":true},{\"name\":\"int\",\"type\":\"int\",\"default\":0},{\"name\":\"long\",\"type\":\"long\",\"default\":0},{\"name\":\"float\",\"type\":\"float\",\"default\":0.0},{\"name\":\"double\",\"type\":\"double\",\"default\":0.0},{\"name\":\"string\",\"type\":\"string\",\"default\":\"value\"},{\"name\":\"bytes\",\"type\":\"bytes\",\"default\":\"ΓΏ\"},{\"name\":\"nested\",\"type\":{\"type\":\"record\",\"name\":\"SimpleRecord\",\"fields\":[{\"name\":\"nested1\",\"type\":\"int\",\"default\":0},{\"name\":\"nested2\",\"type\":\"string\",\"default\":\"string\"}]},\"default\":{\"nested1\":0,\"nested2\":\"string\"}},{\"name\":\"enum\",\"type\":{\"type\":\"enum\",\"name\":\"SimpleEnums\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},\"default\":\"SPADES\"},{\"name\":\"fixed\",\"type\":{\"type\":\"fixed\",\"name\":\"SimpleFixed\",\"size\":16},\"default\":\"string_length_16\"},{\"name\":\"intArray\",\"type\":{\"type\":\"array\",\"items\":\"int\"},\"default\":[1,2,3]},{\"name\":\"stringArray\",\"type\":{\"type\":\"array\",\"items\":\"string\"},\"default\":[\"a\",\"b\",\"c\"]},{\"name\":\"recordArray\",\"type\":{\"type\":\"array\",\"items\":\"SimpleRecord\"},\"default\":[{\"nested1\":0,\"nested2\":\"value\"},{\"nested1\":0,\"nested2\":\"value\"}]},{\"name\":\"enumArray\",\"type\":{\"type\":\"array\",\"items\":\"SimpleEnums\"},\"default\":[\"SPADES\",\"HEARTS\",\"SPADES\"]},{\"name\":\"fixedArray\",\"type\":{\"type\":\"array\",\"items\":\"SimpleFixed\"},\"default\":[\"foo\",\"bar\",\"baz\"]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public boolean boolean$; + @Deprecated public int int$; + @Deprecated public long long$; + @Deprecated public float float$; + @Deprecated public double double$; + @Deprecated public java.lang.CharSequence string; + @Deprecated public java.nio.ByteBuffer bytes; + @Deprecated public com.databricks.spark.avro.SimpleRecord nested; + @Deprecated public com.databricks.spark.avro.SimpleEnums enum$; + @Deprecated public com.databricks.spark.avro.SimpleFixed fixed; + @Deprecated public java.util.List intArray; + @Deprecated public java.util.List stringArray; + @Deprecated public java.util.List recordArray; + @Deprecated public java.util.List enumArray; + @Deprecated public java.util.List fixedArray; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public TestRecord() {} + + /** + * All-args constructor. + */ + public TestRecord(java.lang.Boolean boolean$, java.lang.Integer int$, java.lang.Long long$, java.lang.Float float$, java.lang.Double double$, java.lang.CharSequence string, java.nio.ByteBuffer bytes, com.databricks.spark.avro.SimpleRecord nested, com.databricks.spark.avro.SimpleEnums enum$, com.databricks.spark.avro.SimpleFixed fixed, java.util.List intArray, java.util.List stringArray, java.util.List recordArray, java.util.List enumArray, java.util.List fixedArray) { + this.boolean$ = boolean$; + this.int$ = int$; + this.long$ = long$; + this.float$ = float$; + this.double$ = double$; + this.string = string; + this.bytes = bytes; + this.nested = nested; + this.enum$ = enum$; + this.fixed = fixed; + this.intArray = intArray; + this.stringArray = stringArray; + this.recordArray = recordArray; + this.enumArray = enumArray; + this.fixedArray = fixedArray; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return boolean$; + case 1: return int$; + case 2: return long$; + case 3: return float$; + case 4: return double$; + case 5: return string; + case 6: return bytes; + case 7: return nested; + case 8: return enum$; + case 9: return fixed; + case 10: return intArray; + case 11: return stringArray; + case 12: return recordArray; + case 13: return enumArray; + case 14: return fixedArray; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: boolean$ = (java.lang.Boolean)value$; break; + case 1: int$ = (java.lang.Integer)value$; break; + case 2: long$ = (java.lang.Long)value$; break; + case 3: float$ = (java.lang.Float)value$; break; + case 4: double$ = (java.lang.Double)value$; break; + case 5: string = (java.lang.CharSequence)value$; break; + case 6: bytes = (java.nio.ByteBuffer)value$; break; + case 7: nested = (com.databricks.spark.avro.SimpleRecord)value$; break; + case 8: enum$ = (com.databricks.spark.avro.SimpleEnums)value$; break; + case 9: fixed = (com.databricks.spark.avro.SimpleFixed)value$; break; + case 10: intArray = (java.util.List)value$; break; + case 11: stringArray = (java.util.List)value$; break; + case 12: recordArray = (java.util.List)value$; break; + case 13: enumArray = (java.util.List)value$; break; + case 14: fixedArray = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'boolean$' field. + */ + public java.lang.Boolean getBoolean$() { + return boolean$; + } + + /** + * Sets the value of the 'boolean$' field. + * @param value the value to set. + */ + public void setBoolean$(java.lang.Boolean value) { + this.boolean$ = value; + } + + /** + * Gets the value of the 'int$' field. + */ + public java.lang.Integer getInt$() { + return int$; + } + + /** + * Sets the value of the 'int$' field. + * @param value the value to set. + */ + public void setInt$(java.lang.Integer value) { + this.int$ = value; + } + + /** + * Gets the value of the 'long$' field. + */ + public java.lang.Long getLong$() { + return long$; + } + + /** + * Sets the value of the 'long$' field. + * @param value the value to set. + */ + public void setLong$(java.lang.Long value) { + this.long$ = value; + } + + /** + * Gets the value of the 'float$' field. + */ + public java.lang.Float getFloat$() { + return float$; + } + + /** + * Sets the value of the 'float$' field. + * @param value the value to set. + */ + public void setFloat$(java.lang.Float value) { + this.float$ = value; + } + + /** + * Gets the value of the 'double$' field. + */ + public java.lang.Double getDouble$() { + return double$; + } + + /** + * Sets the value of the 'double$' field. + * @param value the value to set. + */ + public void setDouble$(java.lang.Double value) { + this.double$ = value; + } + + /** + * Gets the value of the 'string' field. + */ + public java.lang.CharSequence getString() { + return string; + } + + /** + * Sets the value of the 'string' field. + * @param value the value to set. + */ + public void setString(java.lang.CharSequence value) { + this.string = value; + } + + /** + * Gets the value of the 'bytes' field. + */ + public java.nio.ByteBuffer getBytes() { + return bytes; + } + + /** + * Sets the value of the 'bytes' field. + * @param value the value to set. + */ + public void setBytes(java.nio.ByteBuffer value) { + this.bytes = value; + } + + /** + * Gets the value of the 'nested' field. + */ + public com.databricks.spark.avro.SimpleRecord getNested() { + return nested; + } + + /** + * Sets the value of the 'nested' field. + * @param value the value to set. + */ + public void setNested(com.databricks.spark.avro.SimpleRecord value) { + this.nested = value; + } + + /** + * Gets the value of the 'enum$' field. + */ + public com.databricks.spark.avro.SimpleEnums getEnum$() { + return enum$; + } + + /** + * Sets the value of the 'enum$' field. + * @param value the value to set. + */ + public void setEnum$(com.databricks.spark.avro.SimpleEnums value) { + this.enum$ = value; + } + + /** + * Gets the value of the 'fixed' field. + */ + public com.databricks.spark.avro.SimpleFixed getFixed() { + return fixed; + } + + /** + * Sets the value of the 'fixed' field. + * @param value the value to set. + */ + public void setFixed(com.databricks.spark.avro.SimpleFixed value) { + this.fixed = value; + } + + /** + * Gets the value of the 'intArray' field. + */ + public java.util.List getIntArray() { + return intArray; + } + + /** + * Sets the value of the 'intArray' field. + * @param value the value to set. + */ + public void setIntArray(java.util.List value) { + this.intArray = value; + } + + /** + * Gets the value of the 'stringArray' field. + */ + public java.util.List getStringArray() { + return stringArray; + } + + /** + * Sets the value of the 'stringArray' field. + * @param value the value to set. + */ + public void setStringArray(java.util.List value) { + this.stringArray = value; + } + + /** + * Gets the value of the 'recordArray' field. + */ + public java.util.List getRecordArray() { + return recordArray; + } + + /** + * Sets the value of the 'recordArray' field. + * @param value the value to set. + */ + public void setRecordArray(java.util.List value) { + this.recordArray = value; + } + + /** + * Gets the value of the 'enumArray' field. + */ + public java.util.List getEnumArray() { + return enumArray; + } + + /** + * Sets the value of the 'enumArray' field. + * @param value the value to set. + */ + public void setEnumArray(java.util.List value) { + this.enumArray = value; + } + + /** + * Gets the value of the 'fixedArray' field. + */ + public java.util.List getFixedArray() { + return fixedArray; + } + + /** + * Sets the value of the 'fixedArray' field. + * @param value the value to set. + */ + public void setFixedArray(java.util.List value) { + this.fixedArray = value; + } + + /** Creates a new TestRecord RecordBuilder */ + public static com.databricks.spark.avro.TestRecord.Builder newBuilder() { + return new com.databricks.spark.avro.TestRecord.Builder(); + } + + /** Creates a new TestRecord RecordBuilder by copying an existing Builder */ + public static com.databricks.spark.avro.TestRecord.Builder newBuilder(com.databricks.spark.avro.TestRecord.Builder other) { + return new com.databricks.spark.avro.TestRecord.Builder(other); + } + + /** Creates a new TestRecord RecordBuilder by copying an existing TestRecord instance */ + public static com.databricks.spark.avro.TestRecord.Builder newBuilder(com.databricks.spark.avro.TestRecord other) { + return new com.databricks.spark.avro.TestRecord.Builder(other); + } + + /** + * RecordBuilder for TestRecord instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private boolean boolean$; + private int int$; + private long long$; + private float float$; + private double double$; + private java.lang.CharSequence string; + private java.nio.ByteBuffer bytes; + private com.databricks.spark.avro.SimpleRecord nested; + private com.databricks.spark.avro.SimpleEnums enum$; + private com.databricks.spark.avro.SimpleFixed fixed; + private java.util.List intArray; + private java.util.List stringArray; + private java.util.List recordArray; + private java.util.List enumArray; + private java.util.List fixedArray; + + /** Creates a new Builder */ + private Builder() { + super(com.databricks.spark.avro.TestRecord.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(com.databricks.spark.avro.TestRecord.Builder other) { + super(other); + if (isValidValue(fields()[0], other.boolean$)) { + this.boolean$ = data().deepCopy(fields()[0].schema(), other.boolean$); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int$)) { + this.int$ = data().deepCopy(fields()[1].schema(), other.int$); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long$)) { + this.long$ = data().deepCopy(fields()[2].schema(), other.long$); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float$)) { + this.float$ = data().deepCopy(fields()[3].schema(), other.float$); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double$)) { + this.double$ = data().deepCopy(fields()[4].schema(), other.double$); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.string)) { + this.string = data().deepCopy(fields()[5].schema(), other.string); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.bytes)) { + this.bytes = data().deepCopy(fields()[6].schema(), other.bytes); + fieldSetFlags()[6] = true; + } + if (isValidValue(fields()[7], other.nested)) { + this.nested = data().deepCopy(fields()[7].schema(), other.nested); + fieldSetFlags()[7] = true; + } + if (isValidValue(fields()[8], other.enum$)) { + this.enum$ = data().deepCopy(fields()[8].schema(), other.enum$); + fieldSetFlags()[8] = true; + } + if (isValidValue(fields()[9], other.fixed)) { + this.fixed = data().deepCopy(fields()[9].schema(), other.fixed); + fieldSetFlags()[9] = true; + } + if (isValidValue(fields()[10], other.intArray)) { + this.intArray = data().deepCopy(fields()[10].schema(), other.intArray); + fieldSetFlags()[10] = true; + } + if (isValidValue(fields()[11], other.stringArray)) { + this.stringArray = data().deepCopy(fields()[11].schema(), other.stringArray); + fieldSetFlags()[11] = true; + } + if (isValidValue(fields()[12], other.recordArray)) { + this.recordArray = data().deepCopy(fields()[12].schema(), other.recordArray); + fieldSetFlags()[12] = true; + } + if (isValidValue(fields()[13], other.enumArray)) { + this.enumArray = data().deepCopy(fields()[13].schema(), other.enumArray); + fieldSetFlags()[13] = true; + } + if (isValidValue(fields()[14], other.fixedArray)) { + this.fixedArray = data().deepCopy(fields()[14].schema(), other.fixedArray); + fieldSetFlags()[14] = true; + } + } + + /** Creates a Builder by copying an existing TestRecord instance */ + private Builder(com.databricks.spark.avro.TestRecord other) { + super(com.databricks.spark.avro.TestRecord.SCHEMA$); + if (isValidValue(fields()[0], other.boolean$)) { + this.boolean$ = data().deepCopy(fields()[0].schema(), other.boolean$); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int$)) { + this.int$ = data().deepCopy(fields()[1].schema(), other.int$); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long$)) { + this.long$ = data().deepCopy(fields()[2].schema(), other.long$); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float$)) { + this.float$ = data().deepCopy(fields()[3].schema(), other.float$); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double$)) { + this.double$ = data().deepCopy(fields()[4].schema(), other.double$); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.string)) { + this.string = data().deepCopy(fields()[5].schema(), other.string); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.bytes)) { + this.bytes = data().deepCopy(fields()[6].schema(), other.bytes); + fieldSetFlags()[6] = true; + } + if (isValidValue(fields()[7], other.nested)) { + this.nested = data().deepCopy(fields()[7].schema(), other.nested); + fieldSetFlags()[7] = true; + } + if (isValidValue(fields()[8], other.enum$)) { + this.enum$ = data().deepCopy(fields()[8].schema(), other.enum$); + fieldSetFlags()[8] = true; + } + if (isValidValue(fields()[9], other.fixed)) { + this.fixed = data().deepCopy(fields()[9].schema(), other.fixed); + fieldSetFlags()[9] = true; + } + if (isValidValue(fields()[10], other.intArray)) { + this.intArray = data().deepCopy(fields()[10].schema(), other.intArray); + fieldSetFlags()[10] = true; + } + if (isValidValue(fields()[11], other.stringArray)) { + this.stringArray = data().deepCopy(fields()[11].schema(), other.stringArray); + fieldSetFlags()[11] = true; + } + if (isValidValue(fields()[12], other.recordArray)) { + this.recordArray = data().deepCopy(fields()[12].schema(), other.recordArray); + fieldSetFlags()[12] = true; + } + if (isValidValue(fields()[13], other.enumArray)) { + this.enumArray = data().deepCopy(fields()[13].schema(), other.enumArray); + fieldSetFlags()[13] = true; + } + if (isValidValue(fields()[14], other.fixedArray)) { + this.fixedArray = data().deepCopy(fields()[14].schema(), other.fixedArray); + fieldSetFlags()[14] = true; + } + } + + /** Gets the value of the 'boolean$' field */ + public java.lang.Boolean getBoolean$() { + return boolean$; + } + + /** Sets the value of the 'boolean$' field */ + public com.databricks.spark.avro.TestRecord.Builder setBoolean$(boolean value) { + validate(fields()[0], value); + this.boolean$ = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'boolean$' field has been set */ + public boolean hasBoolean$() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'boolean$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearBoolean$() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'int$' field */ + public java.lang.Integer getInt$() { + return int$; + } + + /** Sets the value of the 'int$' field */ + public com.databricks.spark.avro.TestRecord.Builder setInt$(int value) { + validate(fields()[1], value); + this.int$ = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'int$' field has been set */ + public boolean hasInt$() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'int$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearInt$() { + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'long$' field */ + public java.lang.Long getLong$() { + return long$; + } + + /** Sets the value of the 'long$' field */ + public com.databricks.spark.avro.TestRecord.Builder setLong$(long value) { + validate(fields()[2], value); + this.long$ = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'long$' field has been set */ + public boolean hasLong$() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'long$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearLong$() { + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'float$' field */ + public java.lang.Float getFloat$() { + return float$; + } + + /** Sets the value of the 'float$' field */ + public com.databricks.spark.avro.TestRecord.Builder setFloat$(float value) { + validate(fields()[3], value); + this.float$ = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'float$' field has been set */ + public boolean hasFloat$() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'float$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearFloat$() { + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'double$' field */ + public java.lang.Double getDouble$() { + return double$; + } + + /** Sets the value of the 'double$' field */ + public com.databricks.spark.avro.TestRecord.Builder setDouble$(double value) { + validate(fields()[4], value); + this.double$ = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'double$' field has been set */ + public boolean hasDouble$() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'double$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearDouble$() { + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'string' field */ + public java.lang.CharSequence getString() { + return string; + } + + /** Sets the value of the 'string' field */ + public com.databricks.spark.avro.TestRecord.Builder setString(java.lang.CharSequence value) { + validate(fields()[5], value); + this.string = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'string' field has been set */ + public boolean hasString() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'string' field */ + public com.databricks.spark.avro.TestRecord.Builder clearString() { + string = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'bytes' field */ + public java.nio.ByteBuffer getBytes() { + return bytes; + } + + /** Sets the value of the 'bytes' field */ + public com.databricks.spark.avro.TestRecord.Builder setBytes(java.nio.ByteBuffer value) { + validate(fields()[6], value); + this.bytes = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'bytes' field has been set */ + public boolean hasBytes() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'bytes' field */ + public com.databricks.spark.avro.TestRecord.Builder clearBytes() { + bytes = null; + fieldSetFlags()[6] = false; + return this; + } + + /** Gets the value of the 'nested' field */ + public com.databricks.spark.avro.SimpleRecord getNested() { + return nested; + } + + /** Sets the value of the 'nested' field */ + public com.databricks.spark.avro.TestRecord.Builder setNested(com.databricks.spark.avro.SimpleRecord value) { + validate(fields()[7], value); + this.nested = value; + fieldSetFlags()[7] = true; + return this; + } + + /** Checks whether the 'nested' field has been set */ + public boolean hasNested() { + return fieldSetFlags()[7]; + } + + /** Clears the value of the 'nested' field */ + public com.databricks.spark.avro.TestRecord.Builder clearNested() { + nested = null; + fieldSetFlags()[7] = false; + return this; + } + + /** Gets the value of the 'enum$' field */ + public com.databricks.spark.avro.SimpleEnums getEnum$() { + return enum$; + } + + /** Sets the value of the 'enum$' field */ + public com.databricks.spark.avro.TestRecord.Builder setEnum$(com.databricks.spark.avro.SimpleEnums value) { + validate(fields()[8], value); + this.enum$ = value; + fieldSetFlags()[8] = true; + return this; + } + + /** Checks whether the 'enum$' field has been set */ + public boolean hasEnum$() { + return fieldSetFlags()[8]; + } + + /** Clears the value of the 'enum$' field */ + public com.databricks.spark.avro.TestRecord.Builder clearEnum$() { + enum$ = null; + fieldSetFlags()[8] = false; + return this; + } + + /** Gets the value of the 'fixed' field */ + public com.databricks.spark.avro.SimpleFixed getFixed() { + return fixed; + } + + /** Sets the value of the 'fixed' field */ + public com.databricks.spark.avro.TestRecord.Builder setFixed(com.databricks.spark.avro.SimpleFixed value) { + validate(fields()[9], value); + this.fixed = value; + fieldSetFlags()[9] = true; + return this; + } + + /** Checks whether the 'fixed' field has been set */ + public boolean hasFixed() { + return fieldSetFlags()[9]; + } + + /** Clears the value of the 'fixed' field */ + public com.databricks.spark.avro.TestRecord.Builder clearFixed() { + fixed = null; + fieldSetFlags()[9] = false; + return this; + } + + /** Gets the value of the 'intArray' field */ + public java.util.List getIntArray() { + return intArray; + } + + /** Sets the value of the 'intArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setIntArray(java.util.List value) { + validate(fields()[10], value); + this.intArray = value; + fieldSetFlags()[10] = true; + return this; + } + + /** Checks whether the 'intArray' field has been set */ + public boolean hasIntArray() { + return fieldSetFlags()[10]; + } + + /** Clears the value of the 'intArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearIntArray() { + intArray = null; + fieldSetFlags()[10] = false; + return this; + } + + /** Gets the value of the 'stringArray' field */ + public java.util.List getStringArray() { + return stringArray; + } + + /** Sets the value of the 'stringArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setStringArray(java.util.List value) { + validate(fields()[11], value); + this.stringArray = value; + fieldSetFlags()[11] = true; + return this; + } + + /** Checks whether the 'stringArray' field has been set */ + public boolean hasStringArray() { + return fieldSetFlags()[11]; + } + + /** Clears the value of the 'stringArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearStringArray() { + stringArray = null; + fieldSetFlags()[11] = false; + return this; + } + + /** Gets the value of the 'recordArray' field */ + public java.util.List getRecordArray() { + return recordArray; + } + + /** Sets the value of the 'recordArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setRecordArray(java.util.List value) { + validate(fields()[12], value); + this.recordArray = value; + fieldSetFlags()[12] = true; + return this; + } + + /** Checks whether the 'recordArray' field has been set */ + public boolean hasRecordArray() { + return fieldSetFlags()[12]; + } + + /** Clears the value of the 'recordArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearRecordArray() { + recordArray = null; + fieldSetFlags()[12] = false; + return this; + } + + /** Gets the value of the 'enumArray' field */ + public java.util.List getEnumArray() { + return enumArray; + } + + /** Sets the value of the 'enumArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setEnumArray(java.util.List value) { + validate(fields()[13], value); + this.enumArray = value; + fieldSetFlags()[13] = true; + return this; + } + + /** Checks whether the 'enumArray' field has been set */ + public boolean hasEnumArray() { + return fieldSetFlags()[13]; + } + + /** Clears the value of the 'enumArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearEnumArray() { + enumArray = null; + fieldSetFlags()[13] = false; + return this; + } + + /** Gets the value of the 'fixedArray' field */ + public java.util.List getFixedArray() { + return fixedArray; + } + + /** Sets the value of the 'fixedArray' field */ + public com.databricks.spark.avro.TestRecord.Builder setFixedArray(java.util.List value) { + validate(fields()[14], value); + this.fixedArray = value; + fieldSetFlags()[14] = true; + return this; + } + + /** Checks whether the 'fixedArray' field has been set */ + public boolean hasFixedArray() { + return fieldSetFlags()[14]; + } + + /** Clears the value of the 'fixedArray' field */ + public com.databricks.spark.avro.TestRecord.Builder clearFixedArray() { + fixedArray = null; + fieldSetFlags()[14] = false; + return this; + } + + @Override + public TestRecord build() { + try { + TestRecord record = new TestRecord(); + record.boolean$ = fieldSetFlags()[0] ? this.boolean$ : (java.lang.Boolean) defaultValue(fields()[0]); + record.int$ = fieldSetFlags()[1] ? this.int$ : (java.lang.Integer) defaultValue(fields()[1]); + record.long$ = fieldSetFlags()[2] ? this.long$ : (java.lang.Long) defaultValue(fields()[2]); + record.float$ = fieldSetFlags()[3] ? this.float$ : (java.lang.Float) defaultValue(fields()[3]); + record.double$ = fieldSetFlags()[4] ? this.double$ : (java.lang.Double) defaultValue(fields()[4]); + record.string = fieldSetFlags()[5] ? this.string : (java.lang.CharSequence) defaultValue(fields()[5]); + record.bytes = fieldSetFlags()[6] ? this.bytes : (java.nio.ByteBuffer) defaultValue(fields()[6]); + record.nested = fieldSetFlags()[7] ? this.nested : (com.databricks.spark.avro.SimpleRecord) defaultValue(fields()[7]); + record.enum$ = fieldSetFlags()[8] ? this.enum$ : (com.databricks.spark.avro.SimpleEnums) defaultValue(fields()[8]); + record.fixed = fieldSetFlags()[9] ? this.fixed : (com.databricks.spark.avro.SimpleFixed) defaultValue(fields()[9]); + record.intArray = fieldSetFlags()[10] ? this.intArray : (java.util.List) defaultValue(fields()[10]); + record.stringArray = fieldSetFlags()[11] ? this.stringArray : (java.util.List) defaultValue(fields()[11]); + record.recordArray = fieldSetFlags()[12] ? this.recordArray : (java.util.List) defaultValue(fields()[12]); + record.enumArray = fieldSetFlags()[13] ? this.enumArray : (java.util.List) defaultValue(fields()[13]); + record.fixedArray = fieldSetFlags()[14] ? this.fixedArray : (java.util.List) defaultValue(fields()[14]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/src/test/resources/specific.avsc b/src/test/resources/specific.avsc new file mode 100644 index 00000000..dbbc1da6 --- /dev/null +++ b/src/test/resources/specific.avsc @@ -0,0 +1,40 @@ +{ + "namespace": "com.databricks.spark.avro", + "type": "record", + "name": "TestRecord", + "fields": [ + {"name": "boolean", "type": "boolean", "default": true}, + {"name": "int", "type": "int", "default": 0}, + {"name": "long", "type": "long", "default": 0}, + {"name": "float", "type": "float", "default": 0.0}, + {"name": "double", "type": "double", "default": 0.0}, + {"name": "string", "type": "string", "default": "value"}, + {"name": "bytes", "type": "bytes", "default": "\u00ff"}, + {"name": "nested", "type": { + "type": "record", "name": "SimpleRecord", "fields": [ + {"name": "nested1", "type": "int", "default": 0}, + {"name": "nested2", "type": "string", "default": "string"}]}, + "default": {"nested1": 0, "nested2": "string"}}, + {"name": "enum", "type": { + "name": "SimpleEnums", "type": "enum", "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}, + "default": "SPADES"}, + {"name": "fixed", "type": { + "name": "SimpleFixed", "type": "fixed", "size": 16}, + "default": "string_length_16"}, + {"name": "intArray", + "type": {"type": "array", "items": "int"}, + "default": [1, 2, 3]}, + {"name": "stringArray", + "type": {"type": "array", "items": "string"}, + "default": ["a", "b", "c"]}, + {"name": "recordArray", + "type": {"type": "array", "items": "com.databricks.spark.avro.SimpleRecord"}, + "default": [{"nested1": 0, "nested2": "value"}, {"nested1": 0, "nested2": "value"}]}, + {"name": "enumArray", + "type": {"type": "array", "items": "com.databricks.spark.avro.SimpleEnums"}, + "default": ["SPADES", "HEARTS", "SPADES"]}, + {"name": "fixedArray", + "type": {"type": "array", "items": "com.databricks.spark.avro.SimpleFixed"}, + "default": ["foo", "bar", "baz"]} + ] +} \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index e9509613..80ff4726 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -17,21 +17,28 @@ package com.databricks.spark.avro import java.io._ +import java.nio.ByteBuffer import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.{TimeZone, UUID} -import scala.collection.JavaConversions._ -import org.apache.avro.Schema +import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.file.DataFileWriter -import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord, GenericRecordBuilder} import org.apache.commons.io.FileUtils -import org.apache.spark.sql._ +import org.apache.hadoop.fs +import org.apache.hadoop.fs.Path +import org.apache.spark.SparkConf +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.scalatest.{BeforeAndAfterAll, FunSuite} -import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException + +import scala.collection.JavaConversions._ class AvroSuite extends FunSuite with BeforeAndAfterAll { val episodesFile = "src/test/resources/episodes.avro" @@ -41,10 +48,16 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { override protected def beforeAll(): Unit = { super.beforeAll() + + val sc = new SparkConf() + sc.registerAvroSchemas(Feature.getClassSchema) + spark = SparkSession.builder() .master("local[2]") .appName("AvroSuite") .config("spark.sql.files.maxPartitionBytes", 1024) + .config(sc) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .getOrCreate() } @@ -74,7 +87,7 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { test("request no fields") { val df = spark.read.avro(episodesFile) - df.registerTempTable("avro_table") + df.createOrReplaceTempView("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } @@ -425,7 +438,7 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { test("sql test") { spark.sql( s""" - |CREATE TEMPORARY TABLE avroTable + |CREATE TEMPORARY VIEW avroTable |USING com.databricks.spark.avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) @@ -577,18 +590,19 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { test("SQL test insert overwrite") { TestUtils.withTempDir { tempDir => - val tempEmptyDir = s"$tempDir/sqlOverwrite" + val tempEmptyDir = new Path(s"$tempDir/sqlOverwrite") // Create a temp directory for table that will be overwritten - new File(tempEmptyDir).mkdirs() + val local = fs.FileSystem.getLocal(spark.sparkContext.hadoopConfiguration) + local.mkdirs(tempEmptyDir) spark.sql( s""" - |CREATE TEMPORARY TABLE episodes + |CREATE TEMPORARY VIEW episodes |USING com.databricks.spark.avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" - |CREATE TEMPORARY TABLE episodesEmpty + |CREATE TEMPORARY VIEW episodesEmpty |(name string, air_date string, doctor int) |USING com.databricks.spark.avro |OPTIONS (path "$tempEmptyDir") @@ -692,6 +706,259 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { } } + test("generic record converts to row and back") { + val nested = + SchemaBuilder.record("simple_record").fields() + .name("nested1").`type`("int").withDefault(0) + .name("nested2").`type`("string").withDefault("string").endRecord() + + val schema = SchemaBuilder.record("record").fields() + .name("boolean").`type`("boolean").withDefault(false) + .name("int").`type`("int").withDefault(0) + .name("long").`type`("long").withDefault(0L) + .name("float").`type`("float").withDefault(0.0F) + .name("double").`type`("double").withDefault(0.0) + .name("string").`type`("string").withDefault("string") + .name("bytes").`type`("bytes").withDefault(java.nio.ByteBuffer.wrap("bytes".getBytes)) + .name("nested").`type`(nested).withDefault(new GenericRecordBuilder(nested).build) + .name("enum").`type`( + SchemaBuilder.enumeration("simple_enums") + .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS")) + .withDefault("SPADES") + .name("int_array").`type`( + SchemaBuilder.array().items().`type`("int")) + .withDefault(java.util.Arrays.asList(1, 2, 3)) + .name("string_array").`type`( + SchemaBuilder.array().items().`type`("string")) + .withDefault(java.util.Arrays.asList("a", "b", "c")) + .name("record_array").`type`( + SchemaBuilder.array.items.`type`(nested)) + .withDefault(java.util.Arrays.asList( + new GenericRecordBuilder(nested).build, + new GenericRecordBuilder(nested).build)) + .name("enum_array").`type`( + SchemaBuilder.array.items.`type`( + SchemaBuilder.enumeration("simple_enums") + .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS"))) + .withDefault(java.util.Arrays.asList("SPADES", "HEARTS", "SPADES")) + .name("fixed_array").`type`( + SchemaBuilder.array.items().`type`( + SchemaBuilder.fixed("simple_fixed").size(3))) + .withDefault(java.util.Arrays.asList("foo", "bar", "baz")) + .name("fixed").`type`(SchemaBuilder.fixed("simple_fixed").size(16)) + .withDefault("string_length_16") + .endRecord() + + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + val row = expressionEncoder.toRow(record) + val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(record == recordFromRow) + } + + test("specific record converts to row and back") { + val schemaPath = "src/test/resources/specific.avsc" + val schema = new Schema.Parser().parse(new File(schemaPath)) + val record = TestRecord.newBuilder().build() + + val classEncoder = AvroEncoder.of[TestRecord](classOf[TestRecord]) + val classExpressionEncoder = classEncoder.asInstanceOf[ExpressionEncoder[TestRecord]] + val classRow = classExpressionEncoder.toRow(record) + val classRecordFromRow = classExpressionEncoder.resolveAndBind().fromRow(classRow) + + assert(record == classRecordFromRow) + + val schemaEncoder = AvroEncoder.of[TestRecord](schema) + val schemaExpressionEncoder = schemaEncoder.asInstanceOf[ExpressionEncoder[TestRecord]] + val schemaRow = schemaExpressionEncoder.toRow(record) + val schemaRecordFromRow = schemaExpressionEncoder.resolveAndBind().fromRow(schemaRow) + + assert(record == schemaRecordFromRow) + } + + test("encoder resolves union types to rows") { + val schema = SchemaBuilder.record("record").fields() + .name("int_null_union").`type`( + SchemaBuilder.unionOf.`type`("null").and.`type`("int").endUnion) + .withDefault(null) + .name("string_null_union").`type`( + SchemaBuilder.unionOf.`type`("null").and.`type`("string").endUnion) + .withDefault(null) + .name("int_long_union").`type`( + SchemaBuilder.unionOf.`type`("int").and.`type`("long").endUnion) + .withDefault(0) + .name("float_double_union").`type`( + SchemaBuilder.unionOf.`type`("float").and.`type`("double").endUnion) + .withDefault(0.0) + .endRecord + + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + val row = expressionEncoder.toRow(record) + val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(record.get(0) == recordFromRow.get(0)) + assert(record.get(1) == recordFromRow.get(1)) + assert(record.get(2) == recordFromRow.get(2)) + assert(record.get(3) == recordFromRow.get(3)) + + record.put(0, 0) + record.put(1, "value") + + val updatedRow = expressionEncoder.toRow(record) + val updatedRecordFromRow = expressionEncoder.resolveAndBind().fromRow(updatedRow) + + assert(record.get(0) == updatedRecordFromRow.get(0)) + assert(record.get(1) == updatedRecordFromRow.get(1)) + } + + test("encoder resolves map types to rows") { + val intMap = new java.util.HashMap[java.lang.String, java.lang.Integer] + intMap.put("foo", 1) + intMap.put("bar", 2) + intMap.put("baz", 3) + + val stringMap = new java.util.HashMap[java.lang.String, java.lang.String] + stringMap.put("foo", "a") + stringMap.put("bar", "b") + stringMap.put("baz", "c") + + val schema = SchemaBuilder.record("record").fields() + .name("int_map").`type`( + SchemaBuilder.map.values.`type`("int")).withDefault(intMap) + .name("string_map").`type`( + SchemaBuilder.map.values.`type`("string")).withDefault(stringMap) + .endRecord() + + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + val row = expressionEncoder.toRow(record) + val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + val rowIntMap = recordFromRow.get(0) + assert(intMap == rowIntMap) + + val rowStringMap = recordFromRow.get(1) + assert(stringMap == rowStringMap) + } + + test("encoder resolves complex unions to rows") { + val nested = + SchemaBuilder.record("simple_record").fields() + .name("nested1").`type`("int").withDefault(0) + .name("nested2").`type`("string").withDefault("foo").endRecord() + val schema = SchemaBuilder.record("record").fields() + .name("int_float_string_record").`type`( + SchemaBuilder.unionOf() + .`type`("null").and() + .`type`("int").and() + .`type`("float").and() + .`type`("string").and() + .`type`(nested).endUnion() + ).withDefault(null).endRecord() + + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + var row = expressionEncoder.toRow(record) + var recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, 1) + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, 1F) + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, "bar") + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, new GenericRecordBuilder(nested).build()) + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(record == recordFromRow) + } + + test("create Dataset from SpecificRecords with unions") { + val sparkSession = spark + import sparkSession.implicits._ + + implicit val enc = AvroEncoder.of(classOf[Feature]) + + val rdd = sparkSession.sparkContext + .parallelize(Seq(1)).mapPartitions { iter => + iter.map { _ => + val ls = StringArray.newBuilder().setValue(List("foo", "bar", "baz")).build() + + Feature.newBuilder().setKey("FOOBAR").setValue(ls).build() + } + } + + val ds = rdd.toDS() + assert(ds.count() == 1) + } + + test("create Dataset from GenericRecord") { + val sparkSession = spark + import sparkSession.implicits._ + + val schema: Schema = + SchemaBuilder + .record("GenericRecordTest") + .namespace("com.databricks.spark.avro") + .fields() + .requiredString("field1") + .name("enumVal").`type`().enumeration("letters").symbols("a", "b", "c").enumDefault("a") + .name("fixedVal").`type`().fixed("MD5").size(16).fixedDefault(ByteBuffer.allocate(16)) + .endRecord() + + implicit val enc = AvroEncoder.of[GenericData.Record](schema) + + val genericRecords = (1 to 10) map { i => + new GenericRecordBuilder(schema) + .set("field1", "field-" + i) + .build() + } + + val rdd: RDD[GenericData.Record] = sparkSession.sparkContext + .parallelize(genericRecords) + + val ds = rdd.toDS() + + assert(ds.count() == genericRecords.size) + } + case class NestedBottom(id: Int, data: String) case class NestedMiddle(id: Int, data: NestedBottom)