diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala new file mode 100644 index 0000000000000..ae6fb45ba8d53 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URI + +import org.apache.avro.file.{DataFileReader} +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.mapred.FsInput +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkConf +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.sql.v2.avro.AvroScan + +class AvroRowReaderSuite + extends QueryTest + with SharedSparkSession { + + import testImplicits._ + + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "") // need this for BatchScanExec + + test("SPARK-33314: hasNextRow and nextRow properly handle consecutive calls") { + withTempPath { dir => + Seq((1), (2), (3)) + .toDF("value") + .coalesce(1) + .write + .format("avro") + .save(dir.getCanonicalPath) + + val df = spark.read.format("avro").load(dir.getCanonicalPath) + val fileScan = df.queryExecution.executedPlan collectFirst { + case BatchScanExec(_, f: AvroScan) => f + } + val filePath = fileScan.get.fileIndex.inputFiles(0) + val fileSize = new File(new URI(filePath)).length + val in = new FsInput(new Path(new URI(filePath)), new Configuration()) + val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) + + val it = new Iterator[InternalRow] with AvroUtils.RowReader { + override val fileReader = reader + override val deserializer = new AvroDeserializer( + reader.getSchema, + StructType(new StructField("value", IntegerType, true) :: Nil), + CORRECTED, + new NoopFilters) + override val stopPosition = fileSize + + override def hasNext: Boolean = hasNextRow + + override def next: InternalRow = nextRow + } + assert(it.hasNext == true) + assert(it.next.getInt(0) == 1) + // test no intervening next + assert(it.hasNext == true) + assert(it.hasNext == true) + // test no intervening hasNext + assert(it.next.getInt(0) == 2) + assert(it.next.getInt(0) == 3) + assert(it.hasNext == false) + assertThrows[NoSuchElementException] { + it.next + } + } + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 6c04417289292..552fa6e3685c3 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.avro import java.io._ -import java.net.{URI, URL} +import java.net.URL import java.nio.file.{Files, Paths, StandardCopyOption} import java.sql.{Date, Timestamp} import java.util.{Locale, UUID} @@ -31,16 +31,12 @@ import org.apache.avro.Schema.Type._ import org.apache.avro.file.{DataFileReader, DataFileWriter} import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} -import org.apache.avro.mapred.FsInput import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException} import org.apache.spark.TestUtils.assertExceptionMsg import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.IntervalData -import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, LA, UTC} @@ -2247,61 +2243,3 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { } } } - -class AvroRowReaderSuite - extends QueryTest - with SharedSparkSession { - - import testImplicits._ - - override protected def sparkConf: SparkConf = - super - .sparkConf - .set(SQLConf.USE_V1_SOURCE_LIST, "") // need this for BatchScanExec - - test("SPARK-33314: hasNextRow and nextRow properly handle consecutive calls") { - withTempPath { dir => - Seq((1), (2), (3)) - .toDF("value") - .coalesce(1) - .write - .format("avro") - .save(dir.getCanonicalPath) - - val df = spark.read.format("avro").load(dir.getCanonicalPath) - val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan) => f - } - val filePath = fileScan.get.fileIndex.inputFiles(0) - val fileSize = new File(new URI(filePath)).length - val in = new FsInput(new Path(new URI(filePath)), new Configuration()) - val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) - - val it = new Iterator[InternalRow] with AvroUtils.RowReader { - override val fileReader = reader - override val deserializer = new AvroDeserializer( - reader.getSchema, - StructType(new StructField("value", IntegerType, true) :: Nil), - CORRECTED, - new NoopFilters) - override val stopPosition = fileSize - - override def hasNext: Boolean = hasNextRow - - override def next: InternalRow = nextRow - } - assert(it.hasNext == true) - assert(it.next.getInt(0) == 1) - // test no intervening next - assert(it.hasNext == true) - assert(it.hasNext == true) - // test no intervening hasNext - assert(it.next.getInt(0) == 2) - assert(it.next.getInt(0) == 3) - assert(it.hasNext == false) - assertThrows[NoSuchElementException] { - it.next - } - } - } -}