Skip to content

Commit 7f6e43f

Browse files
committed
Replace Arrow file format with Arrow stream format instead of having a new conf.
1 parent 4a23c52 commit 7f6e43f

10 files changed

Lines changed: 42 additions & 134 deletions

File tree

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ private[spark] object PythonEvalType {
3636
val NON_UDF = 0
3737
val SQL_BATCHED_UDF = 1
3838
val SQL_PANDAS_UDF = 2
39-
val SQL_PANDAS_UDF_STREAM = 3
4039
}
4140

4241
/**

python/pyspark/serializers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ class PythonEvalType(object):
8686
NON_UDF = 0
8787
SQL_BATCHED_UDF = 1
8888
SQL_PANDAS_UDF = 2
89-
SQL_PANDAS_UDF_STREAM = 3
9089

9190

9291
class Serializer(object):
@@ -235,7 +234,7 @@ def cast_series(s, t):
235234

236235
class ArrowPandasSerializer(ArrowSerializer):
237236
"""
238-
Serializes Pandas.Series as Arrow data.
237+
Serializes Pandas.Series as Arrow data with Arrow file format.
239238
"""
240239

241240
def dumps(self, series):
@@ -259,7 +258,7 @@ def __repr__(self):
259258

260259
class ArrowStreamPandasSerializer(Serializer):
261260
"""
262-
(De)serializes a vectorized(Apache Arrow) stream.
261+
Serializes Pandas.Series as Arrow data with Arrow streaming format.
263262
"""
264263

265264
def load_stream(self, stream):

python/pyspark/sql/tests.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,16 +3376,6 @@ def test_vectorized_udf_empty_partition(self):
33763376
res = df.select(f(col('id')))
33773377
self.assertEquals(df.collect(), res.collect())
33783378

3379-
3380-
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
3381-
class ArrowStreamVectorizedUDFTests(VectorizedUDFTests):
3382-
3383-
@classmethod
3384-
def setUpClass(cls):
3385-
VectorizedUDFTests.setUpClass()
3386-
cls.spark.conf.set("spark.sql.execution.arrow.stream.enable", "true")
3387-
3388-
33893379
if __name__ == "__main__":
33903380
from pyspark.sql.tests import *
33913381
if xmlrunner:

python/pyspark/worker.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pyspark.files import SparkFiles
3232
from pyspark.serializers import write_with_length, write_int, read_long, \
3333
write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
34-
BatchedSerializer, ArrowPandasSerializer, ArrowStreamPandasSerializer
34+
BatchedSerializer, ArrowStreamPandasSerializer
3535
from pyspark.sql.types import toArrowType
3636
from pyspark import shuffle
3737

@@ -98,10 +98,10 @@ def read_single_udf(pickleSer, infile, eval_type):
9898
else:
9999
row_func = chain(row_func, f)
100100
# the last returnType will be the return type of UDF
101-
if eval_type == PythonEvalType.SQL_BATCHED_UDF:
102-
return arg_offsets, wrap_udf(row_func, return_type)
103-
else:
101+
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
104102
return arg_offsets, wrap_pandas_udf(row_func, return_type)
103+
else:
104+
return arg_offsets, wrap_udf(row_func, return_type)
105105

106106

107107
def read_udfs(pickleSer, infile, eval_type):
@@ -123,8 +123,6 @@ def read_udfs(pickleSer, infile, eval_type):
123123
func = lambda _, it: map(mapper, it)
124124

125125
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
126-
ser = ArrowPandasSerializer()
127-
elif eval_type == PythonEvalType.SQL_PANDAS_UDF_STREAM:
128126
ser = ArrowStreamPandasSerializer()
129127
else:
130128
ser = BatchedSerializer(PickleSerializer(), 100)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -925,13 +925,6 @@ object SQLConf {
925925
.intConf
926926
.createWithDefault(10000)
927927

928-
val ARROW_EXECUTION_STREAM_ENABLE =
929-
buildConf("spark.sql.execution.arrow.stream.enable")
930-
.internal()
931-
.doc("When using Apache Arrow, use Arrow stream protocol if possible.")
932-
.booleanConf
933-
.createWithDefault(false)
934-
935928
object Deprecated {
936929
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
937930
}
@@ -1210,8 +1203,6 @@ class SQLConf extends Serializable with Logging {
12101203

12111204
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
12121205

1213-
def arrowStreamEnable: Boolean = getConf(ARROW_EXECUTION_STREAM_ENABLE)
1214-
12151206
/** ********************** SQLConf functionality methods ************ */
12161207

12171208
/** Set Spark SQL configuration properties. */

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
102102

103103
/** A sequence of rules that will be applied in order to the physical plan before execution. */
104104
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
105-
python.ExtractPythonUDFs(sparkSession.sessionState.conf),
105+
python.ExtractPythonUDFs,
106106
PlanSubqueries(sparkSession),
107107
new ReorderJoinPredicates,
108108
EnsureRequirements(sparkSession.sessionState.conf),

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20+
import scala.collection.JavaConverters._
21+
2022
import org.apache.spark.TaskContext
2123
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2224
import org.apache.spark.sql.catalyst.InternalRow
2325
import org.apache.spark.sql.catalyst.expressions._
2426
import org.apache.spark.sql.execution.SparkPlan
25-
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
2627
import org.apache.spark.sql.types.StructType
2728

2829
/**
@@ -39,25 +40,36 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
3940
iter: Iterator[InternalRow],
4041
schema: StructType,
4142
context: TaskContext): Iterator[InternalRow] = {
42-
val inputIterator = ArrowConverters.toPayloadIterator(
43-
iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable)
44-
45-
// Output iterator for results from Python.
46-
val outputIterator = new PythonUDFRunner(
47-
funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets)
48-
.compute(inputIterator, context.partitionId(), context)
49-
50-
val outputRowIterator = ArrowConverters.fromPayloadIterator(
51-
outputIterator.map(new ArrowPayload(_)), context)
52-
53-
// Verify that the output schema is correct
54-
if (outputRowIterator.hasNext) {
55-
val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
56-
.map { case (attr, i) => attr.withName(s"_$i") })
57-
assert(schemaOut.equals(outputRowIterator.schema),
58-
s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
59-
}
6043

61-
outputRowIterator
44+
val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
45+
.map { case (attr, i) => attr.withName(s"_$i") })
46+
47+
val columnarBatchIter = new ArrowPythonRunner(
48+
funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker,
49+
PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema)
50+
.compute(iter, context.partitionId(), context)
51+
52+
new Iterator[InternalRow] {
53+
54+
var currentIter = if (columnarBatchIter.hasNext) {
55+
val batch = columnarBatchIter.next()
56+
assert(schemaOut.equals(batch.schema),
57+
s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}")
58+
batch.rowIterator.asScala
59+
} else {
60+
Iterator.empty
61+
}
62+
63+
override def hasNext: Boolean = currentIter.hasNext || {
64+
if (columnarBatchIter.hasNext) {
65+
currentIter = columnarBatchIter.next().rowIterator.asScala
66+
hasNext
67+
} else {
68+
false
69+
}
70+
}
71+
72+
override def next(): InternalRow = currentIter.next()
73+
}
6274
}
6375
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala renamed to sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.util.Utils
3737
/**
3838
* Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream.
3939
*/
40-
class ArrowStreamPythonUDFRunner(
40+
class ArrowPythonRunner(
4141
funcs: Seq[ChainedPythonFunctions],
4242
batchSize: Int,
4343
bufferSize: Int,

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala

Lines changed: 0 additions & 76 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj
2626
import org.apache.spark.sql.catalyst.rules.Rule
2727
import org.apache.spark.sql.execution
2828
import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
29-
import org.apache.spark.sql.internal.SQLConf
3029

3130

3231
/**
@@ -91,7 +90,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
9190
* This has the limitation that the input to the Python UDF is not allowed include attributes from
9291
* multiple child operators.
9392
*/
94-
case class ExtractPythonUDFs(conf: SQLConf) extends Rule[SparkPlan] with PredicateHelper {
93+
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
9594

9695
private def hasPythonUDF(e: Expression): Boolean = {
9796
e.find(_.isInstanceOf[PythonUDF]).isDefined
@@ -142,11 +141,7 @@ case class ExtractPythonUDFs(conf: SQLConf) extends Rule[SparkPlan] with Predica
142141

143142
val evaluation = validUdfs.partition(_.vectorized) match {
144143
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
145-
if (conf.arrowStreamEnable) {
146-
ArrowStreamEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
147-
} else {
148-
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
149-
}
144+
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
150145
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
151146
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
152147
case _ =>

0 commit comments

Comments
 (0)