-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21190][PYSPARK] Python Vectorized UDFs #18659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
be81ef6
8569736
9236e99
cc7ed5a
cf764b0
4f6c950
91dead2
4a2fec2
518126e
3b4465c
25e3a71
dc237e7
4a0691b
d49a3db
69112a5
f451d65
44a20f6
53926cc
b8ffa50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,12 @@ class SpecialLengths(object): | |
| NULL = -5 | ||
|
|
||
|
|
||
| class PythonEvalType(object): | ||
| NON_UDF = 0 | ||
| SQL_BATCHED_UDF = 1 | ||
| SQL_ARROW_UDF = 2 | ||
|
||
|
|
||
|
|
||
| class Serializer(object): | ||
|
|
||
| def dump_stream(self, iterator, stream): | ||
|
|
@@ -187,8 +193,14 @@ class ArrowSerializer(FramedSerializer): | |
| Serializes an Arrow stream. | ||
| """ | ||
|
|
||
| def dumps(self, obj): | ||
| raise NotImplementedError | ||
| def dumps(self, batch): | ||
| import pyarrow as pa | ||
| import io | ||
| sink = io.BytesIO() | ||
| writer = pa.RecordBatchFileWriter(sink, batch.schema) | ||
| writer.write_batch(batch) | ||
| writer.close() | ||
| return sink.getvalue() | ||
|
|
||
| def loads(self, obj): | ||
| import pyarrow as pa | ||
|
|
@@ -199,6 +211,33 @@ def __repr__(self): | |
| return "ArrowSerializer" | ||
|
|
||
|
|
||
| class ArrowPandasSerializer(ArrowSerializer): | ||
|
|
||
| def __init__(self): | ||
| super(ArrowPandasSerializer, self).__init__() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, that was leftovers.. I'll remove it in a followup. |
||
|
|
||
| def dumps(self, series): | ||
| """ | ||
| Make an ArrowRecordBatch from a Pandas Series and serialize | ||
| """ | ||
| import pyarrow as pa | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we catch
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it would probably be best to handle it the same way as in That got me thinking that it is a little weird to have an SQLConf "spark.sql.execution.arrow.enable" that is set for
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, hm .. let me check the previous discussions and think about this a bit more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay with leaving it as is here. I think we should catch and throw it with better messages in all cases (probably at entry points) later but let's talk about this in another place later. |
||
| if not isinstance(series, (list, tuple)): | ||
| series = [series] | ||
| arrs = [pa.Array.from_pandas(s) for s in series] | ||
| batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) | ||
|
||
| return super(ArrowPandasSerializer, self).dumps(batch) | ||
|
|
||
| def loads(self, obj): | ||
| """ | ||
| Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series | ||
| """ | ||
| table = super(ArrowPandasSerializer, self).loads(obj) | ||
| return [c.to_pandas() for c in table.itercolumns()] | ||
|
|
||
| def __repr__(self): | ||
| return "ArrowPandasSerializer" | ||
|
|
||
|
|
||
| class BatchedSerializer(Serializer): | ||
|
|
||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2032,7 +2032,7 @@ class UserDefinedFunction(object): | |
|
|
||
| .. versionadded:: 1.3 | ||
| """ | ||
| def __init__(self, func, returnType, name=None): | ||
| def __init__(self, func, returnType, name=None, vectorized=False): | ||
| if not callable(func): | ||
| raise TypeError( | ||
| "Not a function or callable (__call__ is not defined): " | ||
|
|
@@ -2046,6 +2046,7 @@ def __init__(self, func, returnType, name=None): | |
| self._name = name or ( | ||
| func.__name__ if hasattr(func, '__name__') | ||
| else func.__class__.__name__) | ||
| self._vectorized = vectorized | ||
|
|
||
| @property | ||
| def returnType(self): | ||
|
|
@@ -2077,7 +2078,7 @@ def _create_judf(self): | |
| wrapped_func = _wrap_function(sc, self.func, self.returnType) | ||
| jdt = spark._jsparkSession.parseDataType(self.returnType.json()) | ||
| judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( | ||
| self._name, wrapped_func, jdt) | ||
| self._name, wrapped_func, jdt, self._vectorized) | ||
| return judf | ||
|
|
||
| def __call__(self, *cols): | ||
|
|
@@ -2112,7 +2113,7 @@ def wrapper(*args): | |
|
|
||
|
|
||
| @since(1.3) | ||
| def udf(f=None, returnType=StringType()): | ||
| def udf(f=None, returnType=StringType(), vectorized=False): | ||
|
||
| """Creates a :class:`Column` expression representing a user defined function (UDF). | ||
|
|
||
| .. note:: The user-defined functions must be deterministic. Due to optimization, | ||
|
|
@@ -2142,18 +2143,18 @@ def udf(f=None, returnType=StringType()): | |
| | 8| JOHN DOE| 22| | ||
| +----------+--------------+------------+ | ||
| """ | ||
| def _udf(f, returnType=StringType()): | ||
| udf_obj = UserDefinedFunction(f, returnType) | ||
| def _udf(f, returnType=StringType(), vectorized=False): | ||
| udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) | ||
| return udf_obj._wrapped() | ||
|
|
||
| # decorator @udf, @udf() or @udf(dataType()) | ||
| if f is None or isinstance(f, (str, DataType)): | ||
| # If DataType has been passed as a positional argument | ||
| # for decorator use it as a returnType | ||
| return_type = f or returnType | ||
| return functools.partial(_udf, returnType=return_type) | ||
| return functools.partial(_udf, returnType=return_type, vectorized=vectorized) | ||
| else: | ||
| return _udf(f=f, returnType=returnType) | ||
| return _udf(f=f, returnType=returnType, vectorized=vectorized) | ||
|
|
||
|
|
||
| blacklist = ['map', 'since', 'ignore_unicode_prefix'] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,7 +82,6 @@ private[sql] object ArrowConverters { | |
|
|
||
| val root = VectorSchemaRoot.create(arrowSchema, allocator) | ||
| val arrowWriter = ArrowWriter.create(root) | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Looks unrelated change) |
||
| var closed = false | ||
|
|
||
| context.addTaskCompletionListener { _ => | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.python | ||
|
|
||
| import java.io.File | ||
|
|
||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.{SparkEnv, TaskContext} | ||
| import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.execution.SparkPlan | ||
| import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} | ||
| import org.apache.spark.sql.types.{DataType, StructField, StructType} | ||
| import org.apache.spark.util.Utils | ||
|
|
||
|
|
||
| /** | ||
| * A physical plan that evaluates a [[PythonUDF]], | ||
| */ | ||
| case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) | ||
| extends SparkPlan { | ||
|
|
||
| def children: Seq[SparkPlan] = child :: Nil | ||
|
|
||
| override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) | ||
|
|
||
| private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { | ||
|
||
| udf.children match { | ||
| case Seq(u: PythonUDF) => | ||
| val (chained, children) = collectFunctions(u) | ||
| (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) | ||
| case children => | ||
| // There should not be any other UDFs, or the children can't be evaluated directly. | ||
| assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) | ||
| (ChainedPythonFunctions(Seq(udf.func)), udf.children) | ||
| } | ||
| } | ||
|
|
||
| protected override def doExecute(): RDD[InternalRow] = { | ||
| val inputRDD = child.execute().map(_.copy()) | ||
| val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) | ||
| val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) | ||
|
|
||
| inputRDD.mapPartitions { iter => | ||
|
|
||
| // The queue used to buffer input rows so we can drain it to | ||
| // combine input with output from Python. | ||
| val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(), | ||
| new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) | ||
| TaskContext.get().addTaskCompletionListener({ ctx => | ||
| queue.close() | ||
| }) | ||
|
|
||
| val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip | ||
|
|
||
| // flatten all the arguments | ||
| val allInputs = new ArrayBuffer[Expression] | ||
| val dataTypes = new ArrayBuffer[DataType] | ||
| val argOffsets = inputs.map { input => | ||
| input.map { e => | ||
| if (allInputs.exists(_.semanticEquals(e))) { | ||
| allInputs.indexWhere(_.semanticEquals(e)) | ||
| } else { | ||
| allInputs += e | ||
| dataTypes += e.dataType | ||
| allInputs.length - 1 | ||
| } | ||
| }.toArray | ||
| }.toArray | ||
| val projection = newMutableProjection(allInputs, child.output) | ||
| val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => | ||
| StructField(s"_$i", dt) | ||
| }) | ||
|
|
||
| // Input iterator to Python: input rows are grouped so we send them in batches to Python. | ||
| // For each row, add it to the queue. | ||
|
||
| val projectedRowIter = iter.map { inputRow => | ||
| queue.add(inputRow.asInstanceOf[UnsafeRow]) | ||
| projection(inputRow) | ||
| } | ||
|
|
||
| val context = TaskContext.get() | ||
|
|
||
| val inputIterator = ArrowConverters.toPayloadIterator( | ||
| projectedRowIter, schema, conf.arrowMaxRecordsPerBatch, context). | ||
| map(_.asPythonSerializable) | ||
|
|
||
| val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex. | ||
| map { case (attr, i) => attr.withName(s"_$i") }) | ||
|
|
||
| // Output iterator for results from Python. | ||
| val outputIterator = new PythonRunner( | ||
| pyFuncs, bufferSize, reuseWorker, PythonEvalType.SQL_ARROW_UDF, argOffsets). | ||
| compute(inputIterator, context.partitionId(), context) | ||
|
||
|
|
||
| val joined = new JoinedRow | ||
| val resultProj = UnsafeProjection.create(output, output) | ||
|
|
||
| val outputRowIterator = ArrowConverters.fromPayloadIterator( | ||
| outputIterator.map(new ArrowPayload(_)), context) | ||
|
|
||
| assert(schemaOut.equals(outputRowIterator.schema)) | ||
|
||
|
|
||
| outputRowIterator.map { outputRow => | ||
| resultProj(joined(queue.remove(), outputRow)) | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the new udf parameter is pandas
Series, I think it's more accurate to call itSQL_PANDAS_UDF.