diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala new file mode 100644 index 000000000000..b9db657952b5 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +private[hive] sealed trait FetchIterator[A] extends Iterator[A] { + /** + * Begin a fetch block, forward from the current position. + * Resets the fetch start offset. + */ + def fetchNext(): Unit + + /** + * Begin a fetch block, moving the iterator back by offset from the start of the previous fetch + * block start. + * Resets the fetch start offset. + * + * @param offset the amount to move a fetch start position toward the prior direction. + */ + def fetchPrior(offset: Long): Unit = fetchAbsolute(getFetchStart - offset) + + /** + * Begin a fetch block, moving the iterator to the given position. + * Resets the fetch start offset. + * + * @param pos index to move a position of iterator. + */ + def fetchAbsolute(pos: Long): Unit + + def getFetchStart: Long + + def getPosition: Long +} + +private[hive] class ArrayFetchIterator[A](src: Array[A]) extends FetchIterator[A] { + private var fetchStart: Long = 0 + + private var position: Long = 0 + + override def fetchNext(): Unit = fetchStart = position + + override def fetchAbsolute(pos: Long): Unit = { + position = (pos max 0) min src.length + fetchStart = position + } + + override def getFetchStart: Long = fetchStart + + override def getPosition: Long = position + + override def hasNext: Boolean = position < src.length + + override def next(): A = { + position += 1 + src(position.toInt - 1) + } +} + +private[hive] class IterableFetchIterator[A](iterable: Iterable[A]) extends FetchIterator[A] { + private var iter: Iterator[A] = iterable.iterator + + private var fetchStart: Long = 0 + + private var position: Long = 0 + + override def fetchNext(): Unit = fetchStart = position + + override def fetchAbsolute(pos: Long): Unit = { + val newPos = pos max 0 + if (newPos < position) resetPosition() + while (position < newPos && hasNext) next() + fetchStart = position + } + + override def getFetchStart: Long = fetchStart + + override def getPosition: Long = position + + override def hasNext: Boolean = iter.hasNext + + override def next(): A = { + position += 1 + iter.next() + } + + private def resetPosition(): Unit = { + if (position != 0) { + iter = iterable.iterator + position = 0 + fetchStart = 0 + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index f7a4be959181..c4ae035e1f83 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -69,13 +69,7 @@ private[hive] class SparkExecuteStatementOperation( private var result: DataFrame = _ - // We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST. - // This is only used when `spark.sql.thriftServer.incrementalCollect` is set to `false`. - // In case of `true`, this will be `None` and FETCH_FIRST will trigger re-execution. - private var resultList: Option[Array[SparkRow]] = _ - private var previousFetchEndOffset: Long = 0 - private var previousFetchStartOffset: Long = 0 - private var iter: Iterator[SparkRow] = _ + private var iter: FetchIterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ private lazy val resultSchema: TableSchema = { @@ -148,43 +142,14 @@ private[hive] class SparkExecuteStatementOperation( setHasResultSet(true) val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false) - // Reset iter when FETCH_FIRST or FETCH_PRIOR - if ((order.equals(FetchOrientation.FETCH_FIRST) || - order.equals(FetchOrientation.FETCH_PRIOR)) && previousFetchEndOffset != 0) { - // Reset the iterator to the beginning of the query. - iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { - resultList = None - result.toLocalIterator.asScala - } else { - if (resultList.isEmpty) { - resultList = Some(result.collect()) - } - resultList.get.iterator - } - } - - var resultOffset = { - if (order.equals(FetchOrientation.FETCH_FIRST)) { - logInfo(s"FETCH_FIRST request with $statementId. Resetting to resultOffset=0") - 0 - } else if (order.equals(FetchOrientation.FETCH_PRIOR)) { - // TODO: FETCH_PRIOR should be handled more efficiently than rewinding to beginning and - // reiterating. - val targetOffset = math.max(previousFetchStartOffset - maxRowsL, 0) - logInfo(s"FETCH_PRIOR request with $statementId. Resetting to resultOffset=$targetOffset") - var off = 0 - while (off < targetOffset && iter.hasNext) { - iter.next() - off += 1 - } - off - } else { // FETCH_NEXT - previousFetchEndOffset - } + if (order.equals(FetchOrientation.FETCH_FIRST)) { + iter.fetchAbsolute(0) + } else if (order.equals(FetchOrientation.FETCH_PRIOR)) { + iter.fetchPrior(maxRowsL) + } else { + iter.fetchNext() } - - resultRowSet.setStartOffset(resultOffset) - previousFetchStartOffset = resultOffset + resultRowSet.setStartOffset(iter.getPosition) if (!iter.hasNext) { resultRowSet } else { @@ -206,11 +171,9 @@ private[hive] class SparkExecuteStatementOperation( } resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) curRow += 1 - resultOffset += 1 } - previousFetchEndOffset = resultOffset log.info(s"Returning result set with ${curRow} rows from offsets " + - s"[$previousFetchStartOffset, $previousFetchEndOffset) with $statementId") + s"[${iter.getFetchStart}, ${iter.getPosition}) with $statementId") resultRowSet } } @@ -326,14 +289,12 @@ private[hive] class SparkExecuteStatementOperation( logDebug(result.queryExecution.toString()) HiveThriftServer2.eventManager.onStatementParsed(statementId, result.queryExecution.toString()) - iter = { - if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { - resultList = None - result.toLocalIterator.asScala - } else { - resultList = Some(result.collect()) - resultList.get.iterator - } + iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { + new IterableFetchIterator[SparkRow](new Iterable[SparkRow] { + override def iterator: Iterator[SparkRow] = result.toLocalIterator.asScala + }) + } else { + new ArrayFetchIterator[SparkRow](result.collect()) } dataTypes = result.schema.fields.map(_.dataType) } catch { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala new file mode 100644 index 000000000000..0fbdb8a9050c --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.spark.SparkFunSuite + +class FetchIteratorSuite extends SparkFunSuite { + + private def getRows(fetchIter: FetchIterator[Int], maxRowCount: Int): Seq[Int] = { + for (_ <- 0 until maxRowCount if fetchIter.hasNext) yield fetchIter.next() + } + + test("SPARK-33655: Test fetchNext and fetchPrior") { + val testData = 0 until 10 + + def iteratorTest(fetchIter: FetchIterator[Int]): Unit = { + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 2)(getRows(fetchIter, 2)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 2) + + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 2) + assertResult(2 until 3)(getRows(fetchIter, 1)) + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 3) + + fetchIter.fetchPrior(2) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 3)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 3) + + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 3) + assert(fetchIter.getPosition == 3) + assertResult(3 until 8)(getRows(fetchIter, 5)) + assert(fetchIter.getFetchStart == 3) + assert(fetchIter.getPosition == 8) + + fetchIter.fetchPrior(2) + assert(fetchIter.getFetchStart == 1) + assert(fetchIter.getPosition == 1) + assertResult(1 until 4)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 1) + assert(fetchIter.getPosition == 4) + + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 4) + assert(fetchIter.getPosition == 4) + assertResult(4 until 10)(getRows(fetchIter, 10)) + assert(fetchIter.getFetchStart == 4) + assert(fetchIter.getPosition == 10) + + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(fetchIter, 10)) + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + + fetchIter.fetchPrior(20) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 3)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 3) + } + iteratorTest(new ArrayFetchIterator[Int](testData.toArray)) + iteratorTest(new IterableFetchIterator[Int](testData)) + } + + test("SPARK-33655: Test fetchAbsolute") { + val testData = 0 until 10 + + def iteratorTest(fetchIter: FetchIterator[Int]): Unit = { + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 5)(getRows(fetchIter, 5)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 5) + + fetchIter.fetchAbsolute(2) + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 2) + assertResult(2 until 5)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 5) + + fetchIter.fetchAbsolute(7) + assert(fetchIter.getFetchStart == 7) + assert(fetchIter.getPosition == 7) + assertResult(7 until 8)(getRows(fetchIter, 1)) + assert(fetchIter.getFetchStart == 7) + assert(fetchIter.getPosition == 8) + + fetchIter.fetchAbsolute(20) + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(fetchIter, 1)) + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + + fetchIter.fetchAbsolute(0) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 3)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 3) + } + iteratorTest(new ArrayFetchIterator[Int](testData.toArray)) + iteratorTest(new IterableFetchIterator[Int](testData)) + } +}