From 78f846ae751cf2f2438aa7126567284a6be0e673 Mon Sep 17 00:00:00 2001 From: Dooyoung Hwang Date: Wed, 2 Dec 2020 16:28:42 +0900 Subject: [PATCH 1/5] FETCH_PRIOR do not cause reiterating from beginning. --- .../SparkExecuteStatementOperation.scala | 141 +++++++++++------- .../SparkExecuteStatementOperationSuite.scala | 72 +++++++++ 2 files changed, 159 insertions(+), 54 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index f7a4be959181..7180d55821b0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -69,13 +69,7 @@ private[hive] class SparkExecuteStatementOperation( private var result: DataFrame = _ - // We cache the returned rows to get iterators again in case the user wants to use FETCH_FIRST. - // This is only used when `spark.sql.thriftServer.incrementalCollect` is set to `false`. - // In case of `true`, this will be `None` and FETCH_FIRST will trigger re-execution. - private var resultList: Option[Array[SparkRow]] = _ - private var previousFetchEndOffset: Long = 0 - private var previousFetchStartOffset: Long = 0 - private var iter: Iterator[SparkRow] = _ + private var iter: IteratorWithFetch[SparkRow] = _ private var dataTypes: Array[DataType] = _ private lazy val resultSchema: TableSchema = { @@ -149,42 +143,11 @@ private[hive] class SparkExecuteStatementOperation( val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false) // Reset iter when FETCH_FIRST or FETCH_PRIOR - if ((order.equals(FetchOrientation.FETCH_FIRST) || - order.equals(FetchOrientation.FETCH_PRIOR)) && previousFetchEndOffset != 0) { - // Reset the iterator to the beginning of the query. - iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { - resultList = None - result.toLocalIterator.asScala - } else { - if (resultList.isEmpty) { - resultList = Some(result.collect()) - } - resultList.get.iterator - } - } - - var resultOffset = { - if (order.equals(FetchOrientation.FETCH_FIRST)) { - logInfo(s"FETCH_FIRST request with $statementId. Resetting to resultOffset=0") - 0 - } else if (order.equals(FetchOrientation.FETCH_PRIOR)) { - // TODO: FETCH_PRIOR should be handled more efficiently than rewinding to beginning and - // reiterating. - val targetOffset = math.max(previousFetchStartOffset - maxRowsL, 0) - logInfo(s"FETCH_PRIOR request with $statementId. Resetting to resultOffset=$targetOffset") - var off = 0 - while (off < targetOffset && iter.hasNext) { - iter.next() - off += 1 - } - off - } else { // FETCH_NEXT - previousFetchEndOffset - } - } - - resultRowSet.setStartOffset(resultOffset) - previousFetchStartOffset = resultOffset + if (order.equals(FetchOrientation.FETCH_FIRST)) iter.fetchFirst() + else if (order.equals(FetchOrientation.FETCH_PRIOR)) iter.fetchPrior(maxRowsL) + else iter.fetchNext() + resultRowSet.setStartOffset(iter.getPosition) + val fetchStartOffset = iter.getPosition if (!iter.hasNext) { resultRowSet } else { @@ -206,11 +169,10 @@ private[hive] class SparkExecuteStatementOperation( } resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) curRow += 1 - resultOffset += 1 } - previousFetchEndOffset = resultOffset + val fetchEndOffset = iter.getPosition log.info(s"Returning result set with ${curRow} rows from offsets " + - s"[$previousFetchStartOffset, $previousFetchEndOffset) with $statementId") + s"[$fetchStartOffset, $fetchEndOffset) with $statementId") resultRowSet } } @@ -326,14 +288,12 @@ private[hive] class SparkExecuteStatementOperation( logDebug(result.queryExecution.toString()) HiveThriftServer2.eventManager.onStatementParsed(statementId, result.queryExecution.toString()) - iter = { - if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { - resultList = None - result.toLocalIterator.asScala - } else { - resultList = Some(result.collect()) - resultList.get.iterator - } + iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { + new IterableIteratorWithFetch[SparkRow](new Iterable[SparkRow] { + override def iterator: Iterator[SparkRow] = result.toLocalIterator.asScala + }) + } else { + new ArrayIteratorWithFetch[SparkRow](result.collect()) } dataTypes = result.schema.fields.map(_.dataType) } catch { @@ -421,3 +381,76 @@ object SparkExecuteStatementOperation { new TableSchema(schema.asJava) } } + +private[hive] sealed trait IteratorWithFetch[A] extends Iterator[A] { + def fetchNext(): Unit + + def fetchPrior(size: Long): Unit + + def fetchFirst(): Unit + + def getPosition: Long +} + +private[hive] class ArrayIteratorWithFetch[A](src: Array[A]) extends IteratorWithFetch[A] { + private var fetchStart: Long = 0 + + private var position: Long = 0 + + override def fetchNext(): Unit = fetchStart = position + + override def fetchPrior(size: Long): Unit = { + position = (fetchStart - size max 0) min src.length + fetchStart = position + } + + override def fetchFirst(): Unit = { + fetchStart = 0 + position = 0 + } + + override def getPosition: Long = position + + override def hasNext: Boolean = position < src.length + + override def next(): A = { + position += 1 + src(position.toInt - 1) + } +} + +private[hive] class IterableIteratorWithFetch[A]( + iterable: Iterable[A] +) extends IteratorWithFetch[A] { + private var iter: Iterator[A] = iterable.iterator + + private var fetchStart: Long = 0 + + private var position: Long = 0 + + override def fetchNext(): Unit = fetchStart = position + + override def fetchPrior(size: Long): Unit = { + val newPos = fetchStart - size max 0 + if (newPos < position) fetchFirst() + while (position < newPos && hasNext) next() + fetchStart = position + } + + override def fetchFirst(): Unit = { + if (position != 0) { + iter = iterable.iterator + position = 0 + fetchStart = 0 + } + } + + override def getPosition: Long = position + + override def hasNext: Boolean = iter.hasNext + + override def next(): A = { + position += 1 + iter.next() + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala index c8bb6d9ee082..1affd1f9087c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -59,6 +59,78 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite with SharedSpark assert(columns.get(1).getComment() == "") } + test("Test fetchFirst, fetchNext, fetchPrior with IteratorWithFetch") { + val testData = 0 until 10 + + def iteratorTest(iterWithFetch: IteratorWithFetch[Int]): Unit = { + def assertNext(expected: Seq[Int]): Unit = { + expected.foreach(n => assertResult(n)(iterWithFetch.next())) + } + + iterWithFetch.fetchNext() + assert(iterWithFetch.getPosition == 0) + assertNext(0 until 3) + + iterWithFetch.fetchNext() + assert(iterWithFetch.getPosition == 3) + assertNext(3 until 6) + + iterWithFetch.fetchPrior(2) + assert(iterWithFetch.getPosition == 1) + assertNext(1 until 3) + + iterWithFetch.fetchNext() + assert(iterWithFetch.getPosition == 3) + assertNext(3 until 6) + + iterWithFetch.fetchPrior(10) + assert(iterWithFetch.getPosition == 0) + assertNext(0 until 3) + + iterWithFetch.fetchNext() + assert(iterWithFetch.getPosition == 3) + assertNext(3 until 10) + + iterWithFetch.fetchNext() + assert(iterWithFetch.getPosition == 10) + assertNext(Seq.empty[Int]) + + iterWithFetch.fetchPrior(3) + assert(iterWithFetch.getPosition == 7) + assertNext(7 until 10) + + iterWithFetch.fetchFirst() + assert(iterWithFetch.getPosition == 0) + assertNext(0 until 3) + + iterWithFetch.fetchFirst() + assert(iterWithFetch.getPosition == 0) + assertNext(0 until 3) + + iterWithFetch.fetchNext() + assert(iterWithFetch.getPosition == 3) + assertNext(3 until 10) + + iterWithFetch.fetchNext() + assert(iterWithFetch.getPosition == 10) + assertNext(Seq.empty[Int]) + + iterWithFetch.fetchPrior(-3) + assert(iterWithFetch.getPosition == 10) + assertNext(Seq.empty[Int]) + + iterWithFetch.fetchPrior(20) + assert(iterWithFetch.getPosition == 0) + assertNext(0 until 1) + + iterWithFetch.fetchPrior(-3) + assert(iterWithFetch.getPosition == 3) + assertNext(3 until 10) + } + iteratorTest(new ArrayIteratorWithFetch[Int](testData.toArray)) + iteratorTest(new IterableIteratorWithFetch[Int](testData)) + } + Seq( (OperationState.CANCELED, (_: SparkExecuteStatementOperation).cancel()), (OperationState.TIMEDOUT, (_: SparkExecuteStatementOperation).timeoutCancel()), From 7e4907031107114ce1982160ba88ec37dbc08715 Mon Sep 17 00:00:00 2001 From: Dooyoung Hwang Date: Wed, 2 Dec 2020 22:53:47 +0900 Subject: [PATCH 2/5] Add FetchIterator trait which supports setRelativePosition & setAbsolutePosition. --- .../sql/hive/thriftserver/FetchIterator.scala | 92 +++++++++++++++++++ .../SparkExecuteStatementOperation.scala | 83 +---------------- .../thriftserver/FetchIteratorSuite.scala | 88 ++++++++++++++++++ .../SparkExecuteStatementOperationSuite.scala | 72 --------------- 4 files changed, 185 insertions(+), 150 deletions(-) create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala new file mode 100644 index 000000000000..a9bad9fa4569 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +private[hive] sealed trait FetchIterator[A] extends Iterator[A] { + def fetchNext(): Unit + + def setRelativePosition(diff: Long): Unit + + def setAbsolutePosition(pos: Long): Unit + + def getPosition: Long +} + +private[hive] class ArrayFetchIterator[A](src: Array[A]) extends FetchIterator[A] { + private var fetchStart: Long = 0 + + private var position: Long = 0 + + override def fetchNext(): Unit = fetchStart = position + + override def setRelativePosition(diff: Long): Unit = { + setAbsolutePosition(fetchStart + diff) + } + + override def setAbsolutePosition(pos: Long): Unit = { + position = (pos max 0) min src.length + fetchStart = position + } + + override def getPosition: Long = position + + override def hasNext: Boolean = position < src.length + + override def next(): A = { + position += 1 + src(position.toInt - 1) + } +} + +private[hive] class IterableFetchIterator[A](iterable: Iterable[A]) extends FetchIterator[A] { + private var iter: Iterator[A] = iterable.iterator + + private var fetchStart: Long = 0 + + private var position: Long = 0 + + override def fetchNext(): Unit = fetchStart = position + + override def setRelativePosition(diff: Long): Unit = { + setAbsolutePosition(fetchStart + diff) + } + + override def setAbsolutePosition(pos: Long): Unit = { + val newPos = pos max 0 + if (newPos < position) resetPosition() + while (position < newPos && hasNext) next() + fetchStart = position + } + + override def getPosition: Long = position + + override def hasNext: Boolean = iter.hasNext + + override def next(): A = { + position += 1 + iter.next() + } + + private def resetPosition(): Unit = { + if (position != 0) { + iter = iterable.iterator + position = 0 + fetchStart = 0 + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 7180d55821b0..5558bc3cc536 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -69,7 +69,7 @@ private[hive] class SparkExecuteStatementOperation( private var result: DataFrame = _ - private var iter: IteratorWithFetch[SparkRow] = _ + private var iter: FetchIterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ private lazy val resultSchema: TableSchema = { @@ -143,8 +143,8 @@ private[hive] class SparkExecuteStatementOperation( val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false) // Reset iter when FETCH_FIRST or FETCH_PRIOR - if (order.equals(FetchOrientation.FETCH_FIRST)) iter.fetchFirst() - else if (order.equals(FetchOrientation.FETCH_PRIOR)) iter.fetchPrior(maxRowsL) + if (order.equals(FetchOrientation.FETCH_FIRST)) iter.setAbsolutePosition(0) + else if (order.equals(FetchOrientation.FETCH_PRIOR)) iter.setRelativePosition(-maxRowsL) else iter.fetchNext() resultRowSet.setStartOffset(iter.getPosition) val fetchStartOffset = iter.getPosition @@ -289,11 +289,11 @@ private[hive] class SparkExecuteStatementOperation( HiveThriftServer2.eventManager.onStatementParsed(statementId, result.queryExecution.toString()) iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) { - new IterableIteratorWithFetch[SparkRow](new Iterable[SparkRow] { + new IterableFetchIterator[SparkRow](new Iterable[SparkRow] { override def iterator: Iterator[SparkRow] = result.toLocalIterator.asScala }) } else { - new ArrayIteratorWithFetch[SparkRow](result.collect()) + new ArrayFetchIterator[SparkRow](result.collect()) } dataTypes = result.schema.fields.map(_.dataType) } catch { @@ -381,76 +381,3 @@ object SparkExecuteStatementOperation { new TableSchema(schema.asJava) } } - -private[hive] sealed trait IteratorWithFetch[A] extends Iterator[A] { - def fetchNext(): Unit - - def fetchPrior(size: Long): Unit - - def fetchFirst(): Unit - - def getPosition: Long -} - -private[hive] class ArrayIteratorWithFetch[A](src: Array[A]) extends IteratorWithFetch[A] { - private var fetchStart: Long = 0 - - private var position: Long = 0 - - override def fetchNext(): Unit = fetchStart = position - - override def fetchPrior(size: Long): Unit = { - position = (fetchStart - size max 0) min src.length - fetchStart = position - } - - override def fetchFirst(): Unit = { - fetchStart = 0 - position = 0 - } - - override def getPosition: Long = position - - override def hasNext: Boolean = position < src.length - - override def next(): A = { - position += 1 - src(position.toInt - 1) - } -} - -private[hive] class IterableIteratorWithFetch[A]( - iterable: Iterable[A] -) extends IteratorWithFetch[A] { - private var iter: Iterator[A] = iterable.iterator - - private var fetchStart: Long = 0 - - private var position: Long = 0 - - override def fetchNext(): Unit = fetchStart = position - - override def fetchPrior(size: Long): Unit = { - val newPos = fetchStart - size max 0 - if (newPos < position) fetchFirst() - while (position < newPos && hasNext) next() - fetchStart = position - } - - override def fetchFirst(): Unit = { - if (position != 0) { - iter = iterable.iterator - position = 0 - fetchStart = 0 - } - } - - override def getPosition: Long = position - - override def hasNext: Boolean = iter.hasNext - - override def next(): A = { - position += 1 - iter.next() - } -} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala new file mode 100644 index 000000000000..6111e55d168e --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.spark.SparkFunSuite + +class FetchIteratorSuite extends SparkFunSuite { + + test("Test setRelativePosition, setAbsolutePosition, fetchNext FetchIterator") { + val testData = 0 until 10 + + def iteratorTest(fetchIter: FetchIterator[Int]): Unit = { + def getRows(maxRowCount: Int): Seq[Int] = { + for (_ <- 0 until maxRowCount if fetchIter.hasNext) yield fetchIter.next() + } + + fetchIter.fetchNext() + assert(fetchIter.getPosition == 0) + assertResult(0 until 3)(getRows(3)) + + fetchIter.fetchNext() + assert(fetchIter.getPosition == 3) + assertResult(3 until 6)(getRows(3)) + + fetchIter.setRelativePosition(-2) + assert(fetchIter.getPosition == 1) + assertResult(1 until 4)(getRows(3)) + + fetchIter.fetchNext() + assert(fetchIter.getPosition == 4) + assertResult(4 until 10)(getRows(10)) + + fetchIter.fetchNext() + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(1)) + + fetchIter.setRelativePosition(-3) + assert(fetchIter.getPosition == 7) + assertResult(7 until 10)(getRows(3)) + + fetchIter.setAbsolutePosition(0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 1)(getRows(1)) + + fetchIter.setAbsolutePosition(20) + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(1)) + + fetchIter.fetchNext() + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(1)) + + fetchIter.setRelativePosition(3) + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(1)) + + fetchIter.setRelativePosition(-20) + assert(fetchIter.getPosition == 0) + assertResult(0 until 3)(getRows(3)) + + fetchIter.setAbsolutePosition(-20) + assert(fetchIter.getPosition == 0) + assertResult(0 until 3)(getRows(3)) + + fetchIter.fetchNext() + assert(fetchIter.getPosition == 3) + assertResult(3 until 10)(getRows(10)) + } + + iteratorTest(new ArrayFetchIterator[Int](testData.toArray)) + iteratorTest(new IterableFetchIterator[Int](testData)) + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala index 1affd1f9087c..c8bb6d9ee082 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -59,78 +59,6 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite with SharedSpark assert(columns.get(1).getComment() == "") } - test("Test fetchFirst, fetchNext, fetchPrior with IteratorWithFetch") { - val testData = 0 until 10 - - def iteratorTest(iterWithFetch: IteratorWithFetch[Int]): Unit = { - def assertNext(expected: Seq[Int]): Unit = { - expected.foreach(n => assertResult(n)(iterWithFetch.next())) - } - - iterWithFetch.fetchNext() - assert(iterWithFetch.getPosition == 0) - assertNext(0 until 3) - - iterWithFetch.fetchNext() - assert(iterWithFetch.getPosition == 3) - assertNext(3 until 6) - - iterWithFetch.fetchPrior(2) - assert(iterWithFetch.getPosition == 1) - assertNext(1 until 3) - - iterWithFetch.fetchNext() - assert(iterWithFetch.getPosition == 3) - assertNext(3 until 6) - - iterWithFetch.fetchPrior(10) - assert(iterWithFetch.getPosition == 0) - assertNext(0 until 3) - - iterWithFetch.fetchNext() - assert(iterWithFetch.getPosition == 3) - assertNext(3 until 10) - - iterWithFetch.fetchNext() - assert(iterWithFetch.getPosition == 10) - assertNext(Seq.empty[Int]) - - iterWithFetch.fetchPrior(3) - assert(iterWithFetch.getPosition == 7) - assertNext(7 until 10) - - iterWithFetch.fetchFirst() - assert(iterWithFetch.getPosition == 0) - assertNext(0 until 3) - - iterWithFetch.fetchFirst() - assert(iterWithFetch.getPosition == 0) - assertNext(0 until 3) - - iterWithFetch.fetchNext() - assert(iterWithFetch.getPosition == 3) - assertNext(3 until 10) - - iterWithFetch.fetchNext() - assert(iterWithFetch.getPosition == 10) - assertNext(Seq.empty[Int]) - - iterWithFetch.fetchPrior(-3) - assert(iterWithFetch.getPosition == 10) - assertNext(Seq.empty[Int]) - - iterWithFetch.fetchPrior(20) - assert(iterWithFetch.getPosition == 0) - assertNext(0 until 1) - - iterWithFetch.fetchPrior(-3) - assert(iterWithFetch.getPosition == 3) - assertNext(3 until 10) - } - iteratorTest(new ArrayIteratorWithFetch[Int](testData.toArray)) - iteratorTest(new IterableIteratorWithFetch[Int](testData)) - } - Seq( (OperationState.CANCELED, (_: SparkExecuteStatementOperation).cancel()), (OperationState.TIMEDOUT, (_: SparkExecuteStatementOperation).timeoutCancel()), From f451868a8b2b1ac62ee5886102e6786271eef713 Mon Sep 17 00:00:00 2001 From: Dooyoung Hwang Date: Fri, 4 Dec 2020 00:03:51 +0900 Subject: [PATCH 3/5] Refactor FetchIterator - Remove setRelativePosition & setAbsolutePosition. - Add fetchPrior, fetchAbsolute and getFetchStart. --- .../sql/hive/thriftserver/FetchIterator.scala | 41 +++++-- .../SparkExecuteStatementOperation.scala | 8 +- .../thriftserver/FetchIteratorSuite.scala | 112 ++++++++++++------ 3 files changed, 110 insertions(+), 51 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala index a9bad9fa4569..b9db657952b5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/FetchIterator.scala @@ -18,11 +18,30 @@ package org.apache.spark.sql.hive.thriftserver private[hive] sealed trait FetchIterator[A] extends Iterator[A] { + /** + * Begin a fetch block, forward from the current position. + * Resets the fetch start offset. + */ def fetchNext(): Unit - def setRelativePosition(diff: Long): Unit - - def setAbsolutePosition(pos: Long): Unit + /** + * Begin a fetch block, moving the iterator back by offset from the start of the previous fetch + * block start. + * Resets the fetch start offset. + * + * @param offset the amount to move a fetch start position toward the prior direction. + */ + def fetchPrior(offset: Long): Unit = fetchAbsolute(getFetchStart - offset) + + /** + * Begin a fetch block, moving the iterator to the given position. + * Resets the fetch start offset. + * + * @param pos index to move a position of iterator. + */ + def fetchAbsolute(pos: Long): Unit + + def getFetchStart: Long def getPosition: Long } @@ -34,15 +53,13 @@ private[hive] class ArrayFetchIterator[A](src: Array[A]) extends FetchIterator[A override def fetchNext(): Unit = fetchStart = position - override def setRelativePosition(diff: Long): Unit = { - setAbsolutePosition(fetchStart + diff) - } - - override def setAbsolutePosition(pos: Long): Unit = { + override def fetchAbsolute(pos: Long): Unit = { position = (pos max 0) min src.length fetchStart = position } + override def getFetchStart: Long = fetchStart + override def getPosition: Long = position override def hasNext: Boolean = position < src.length @@ -62,17 +79,15 @@ private[hive] class IterableFetchIterator[A](iterable: Iterable[A]) extends Fetc override def fetchNext(): Unit = fetchStart = position - override def setRelativePosition(diff: Long): Unit = { - setAbsolutePosition(fetchStart + diff) - } - - override def setAbsolutePosition(pos: Long): Unit = { + override def fetchAbsolute(pos: Long): Unit = { val newPos = pos max 0 if (newPos < position) resetPosition() while (position < newPos && hasNext) next() fetchStart = position } + override def getFetchStart: Long = fetchStart + override def getPosition: Long = position override def hasNext: Boolean = iter.hasNext diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 5558bc3cc536..7f1ba49e5329 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -143,11 +143,10 @@ private[hive] class SparkExecuteStatementOperation( val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false) // Reset iter when FETCH_FIRST or FETCH_PRIOR - if (order.equals(FetchOrientation.FETCH_FIRST)) iter.setAbsolutePosition(0) - else if (order.equals(FetchOrientation.FETCH_PRIOR)) iter.setRelativePosition(-maxRowsL) + if (order.equals(FetchOrientation.FETCH_FIRST)) iter.fetchAbsolute(0) + else if (order.equals(FetchOrientation.FETCH_PRIOR)) iter.fetchPrior(maxRowsL) else iter.fetchNext() resultRowSet.setStartOffset(iter.getPosition) - val fetchStartOffset = iter.getPosition if (!iter.hasNext) { resultRowSet } else { @@ -170,9 +169,8 @@ private[hive] class SparkExecuteStatementOperation( resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) curRow += 1 } - val fetchEndOffset = iter.getPosition log.info(s"Returning result set with ${curRow} rows from offsets " + - s"[$fetchStartOffset, $fetchEndOffset) with $statementId") + s"[${iter.getFetchStart}, ${iter.getPosition}) with $statementId") resultRowSet } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala index 6111e55d168e..bd5b669dabf1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala @@ -21,67 +21,113 @@ import org.apache.spark.SparkFunSuite class FetchIteratorSuite extends SparkFunSuite { - test("Test setRelativePosition, setAbsolutePosition, fetchNext FetchIterator") { + private def getRows(fetchIter: FetchIterator[Int], maxRowCount: Int): Seq[Int] = { + for (_ <- 0 until maxRowCount if fetchIter.hasNext) yield fetchIter.next() + } + + test("Test fetchNext and fetchPrior") { val testData = 0 until 10 def iteratorTest(fetchIter: FetchIterator[Int]): Unit = { - def getRows(maxRowCount: Int): Seq[Int] = { - for (_ <- 0 until maxRowCount if fetchIter.hasNext) yield fetchIter.next() - } + fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 2)(getRows(fetchIter, 2)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 2) fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 2) + assertResult(2 until 3)(getRows(fetchIter, 1)) + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 3) + + fetchIter.fetchPrior(2) + assert(fetchIter.getFetchStart == 0) assert(fetchIter.getPosition == 0) - assertResult(0 until 3)(getRows(3)) + assertResult(0 until 3)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 3) fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 3) assert(fetchIter.getPosition == 3) - assertResult(3 until 6)(getRows(3)) + assertResult(3 until 8)(getRows(fetchIter, 5)) + assert(fetchIter.getFetchStart == 3) + assert(fetchIter.getPosition == 8) - fetchIter.setRelativePosition(-2) + fetchIter.fetchPrior(2) + assert(fetchIter.getFetchStart == 1) assert(fetchIter.getPosition == 1) - assertResult(1 until 4)(getRows(3)) + assertResult(1 until 4)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 1) + assert(fetchIter.getPosition == 4) fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 4) assert(fetchIter.getPosition == 4) - assertResult(4 until 10)(getRows(10)) + assertResult(4 until 10)(getRows(fetchIter, 10)) + assert(fetchIter.getFetchStart == 4) + assert(fetchIter.getPosition == 10) fetchIter.fetchNext() + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(fetchIter, 10)) + assert(fetchIter.getFetchStart == 10) assert(fetchIter.getPosition == 10) - assertResult(Seq.empty[Int])(getRows(1)) - - fetchIter.setRelativePosition(-3) - assert(fetchIter.getPosition == 7) - assertResult(7 until 10)(getRows(3)) - fetchIter.setAbsolutePosition(0) + fetchIter.fetchPrior(20) + assert(fetchIter.getFetchStart == 0) assert(fetchIter.getPosition == 0) - assertResult(0 until 1)(getRows(1)) + assertResult(0 until 3)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 3) + } + iteratorTest(new ArrayFetchIterator[Int](testData.toArray)) + iteratorTest(new IterableFetchIterator[Int](testData)) + } - fetchIter.setAbsolutePosition(20) - assert(fetchIter.getPosition == 10) - assertResult(Seq.empty[Int])(getRows(1)) + test("Test fetchAbsolute") { + val testData = 0 until 10 + def iteratorTest(fetchIter: FetchIterator[Int]): Unit = { fetchIter.fetchNext() - assert(fetchIter.getPosition == 10) - assertResult(Seq.empty[Int])(getRows(1)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 0) + assertResult(0 until 5)(getRows(fetchIter, 5)) + assert(fetchIter.getFetchStart == 0) + assert(fetchIter.getPosition == 5) + + fetchIter.fetchAbsolute(2) + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 2) + assertResult(2 until 5)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 2) + assert(fetchIter.getPosition == 5) + + fetchIter.fetchAbsolute(7) + assert(fetchIter.getFetchStart == 7) + assert(fetchIter.getPosition == 7) + assertResult(7 until 8)(getRows(fetchIter, 1)) + assert(fetchIter.getFetchStart == 7) + assert(fetchIter.getPosition == 8) - fetchIter.setRelativePosition(3) + fetchIter.fetchAbsolute(20) + assert(fetchIter.getFetchStart == 10) + assert(fetchIter.getPosition == 10) + assertResult(Seq.empty[Int])(getRows(fetchIter, 1)) + assert(fetchIter.getFetchStart == 10) assert(fetchIter.getPosition == 10) - assertResult(Seq.empty[Int])(getRows(1)) - - fetchIter.setRelativePosition(-20) - assert(fetchIter.getPosition == 0) - assertResult(0 until 3)(getRows(3)) - fetchIter.setAbsolutePosition(-20) + fetchIter.fetchAbsolute(0) + assert(fetchIter.getFetchStart == 0) assert(fetchIter.getPosition == 0) - assertResult(0 until 3)(getRows(3)) - - fetchIter.fetchNext() + assertResult(0 until 3)(getRows(fetchIter, 3)) + assert(fetchIter.getFetchStart == 0) assert(fetchIter.getPosition == 3) - assertResult(3 until 10)(getRows(10)) } - iteratorTest(new ArrayFetchIterator[Int](testData.toArray)) iteratorTest(new IterableFetchIterator[Int](testData)) } From 2b1b7008ddd2c05b1f2bdbf864d690b91e34a30f Mon Sep 17 00:00:00 2001 From: Dooyoung Hwang Date: Fri, 4 Dec 2020 11:22:57 +0900 Subject: [PATCH 4/5] Write issue name on test suite. --- .../hive/thriftserver/SparkExecuteStatementOperation.scala | 1 - .../spark/sql/hive/thriftserver/FetchIteratorSuite.scala | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 7f1ba49e5329..733efe3730cd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -142,7 +142,6 @@ private[hive] class SparkExecuteStatementOperation( setHasResultSet(true) val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false) - // Reset iter when FETCH_FIRST or FETCH_PRIOR if (order.equals(FetchOrientation.FETCH_FIRST)) iter.fetchAbsolute(0) else if (order.equals(FetchOrientation.FETCH_PRIOR)) iter.fetchPrior(maxRowsL) else iter.fetchNext() diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala index bd5b669dabf1..0fbdb8a9050c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/FetchIteratorSuite.scala @@ -25,7 +25,7 @@ class FetchIteratorSuite extends SparkFunSuite { for (_ <- 0 until maxRowCount if fetchIter.hasNext) yield fetchIter.next() } - test("Test fetchNext and fetchPrior") { + test("SPARK-33655: Test fetchNext and fetchPrior") { val testData = 0 until 10 def iteratorTest(fetchIter: FetchIterator[Int]): Unit = { @@ -89,7 +89,7 @@ class FetchIteratorSuite extends SparkFunSuite { iteratorTest(new IterableFetchIterator[Int](testData)) } - test("Test fetchAbsolute") { + test("SPARK-33655: Test fetchAbsolute") { val testData = 0 until 10 def iteratorTest(fetchIter: FetchIterator[Int]): Unit = { From fb9356212c7a1db798eadb4be30044d481ab332d Mon Sep 17 00:00:00 2001 From: Dooyoung Hwang Date: Mon, 7 Dec 2020 13:43:34 +0900 Subject: [PATCH 5/5] Fix style error - Add curly braces for if-else --- .../thriftserver/SparkExecuteStatementOperation.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 733efe3730cd..c4ae035e1f83 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -142,9 +142,13 @@ private[hive] class SparkExecuteStatementOperation( setHasResultSet(true) val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false) - if (order.equals(FetchOrientation.FETCH_FIRST)) iter.fetchAbsolute(0) - else if (order.equals(FetchOrientation.FETCH_PRIOR)) iter.fetchPrior(maxRowsL) - else iter.fetchNext() + if (order.equals(FetchOrientation.FETCH_FIRST)) { + iter.fetchAbsolute(0) + } else if (order.equals(FetchOrientation.FETCH_PRIOR)) { + iter.fetchPrior(maxRowsL) + } else { + iter.fetchNext() + } resultRowSet.setStartOffset(iter.getPosition) if (!iter.hasNext) { resultRowSet