Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,23 @@ private[spark] case class PythonFunction(
*/
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])

/**
* Enumerate the type of command that will be sent to the Python worker
*/
private[spark] object PythonEvalType {
val NON_UDF = 0
val SQL_BATCHED_UDF = 1
val SQL_ARROW_UDF = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new udf parameter is pandas Series, I think it's more accurate to call it SQL_PANDAS_UDF.

}

private[spark] object PythonRunner {
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
Seq(ChainedPythonFunctions(Seq(func))),
bufferSize,
reuse_worker,
PythonEvalType.NON_UDF,
Array(Array(0)))
}
}

Expand All @@ -100,7 +113,7 @@ private[spark] class PythonRunner(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
isUDF: Boolean,
evalType: Int,
argOffsets: Array[Array[Int]])
extends Logging {

Expand Down Expand Up @@ -309,8 +322,8 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(evalType)
if (evalType != PythonEvalType.NON_UDF) {
dataOut.writeInt(funcs.length)
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
Expand All @@ -324,7 +337,6 @@ private[spark] class PythonRunner(
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
Expand Down
43 changes: 41 additions & 2 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class SpecialLengths(object):
NULL = -5


class PythonEvalType(object):
NON_UDF = 0
SQL_BATCHED_UDF = 1
SQL_ARROW_UDF = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto



class Serializer(object):

def dump_stream(self, iterator, stream):
Expand Down Expand Up @@ -187,8 +193,14 @@ class ArrowSerializer(FramedSerializer):
Serializes an Arrow stream.
"""

def dumps(self, obj):
raise NotImplementedError
def dumps(self, batch):
import pyarrow as pa
import io
sink = io.BytesIO()
writer = pa.RecordBatchFileWriter(sink, batch.schema)
writer.write_batch(batch)
writer.close()
return sink.getvalue()

def loads(self, obj):
import pyarrow as pa
Expand All @@ -199,6 +211,33 @@ def __repr__(self):
return "ArrowSerializer"


class ArrowPandasSerializer(ArrowSerializer):

def __init__(self):
super(ArrowPandasSerializer, self).__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that was leftovers.. I'll remove it in a followup.


def dumps(self, series):
"""
Make an ArrowRecordBatch from a Pandas Series and serialize
"""
import pyarrow as pa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we catch ImportError?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would probably be best to handle it the same way as in toPandas().

That got me thinking that it is a little weird to have an SQLConf "spark.sql.execution.arrow.enable" that is set for toPandas() but has no bearing with pandas_udf. It doesn't need to since it is an explicit call but seems a little contradictory, what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, hm .. let me check the previous discussions and think about this a bit more.

Copy link
Member

@HyukjinKwon HyukjinKwon Sep 16, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with leaving it as is here. I think we should catch and throw it with better messages in all cases (probably at entry points) later but let's talk about this in another place later.

if not isinstance(series, (list, tuple)):
series = [series]
arrs = [pa.Array.from_pandas(s) for s in series]
batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use xrange.

return super(ArrowPandasSerializer, self).dumps(batch)

def loads(self, obj):
"""
Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series
"""
table = super(ArrowPandasSerializer, self).loads(obj)
return [c.to_pandas() for c in table.itercolumns()]

def __repr__(self):
return "ArrowPandasSerializer"


class BatchedSerializer(Serializer):

"""
Expand Down
15 changes: 8 additions & 7 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2032,7 +2032,7 @@ class UserDefinedFunction(object):

.. versionadded:: 1.3
"""
def __init__(self, func, returnType, name=None):
def __init__(self, func, returnType, name=None, vectorized=False):
if not callable(func):
raise TypeError(
"Not a function or callable (__call__ is not defined): "
Expand All @@ -2046,6 +2046,7 @@ def __init__(self, func, returnType, name=None):
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self._vectorized = vectorized

@property
def returnType(self):
Expand Down Expand Up @@ -2077,7 +2078,7 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt)
self._name, wrapped_func, jdt, self._vectorized)
return judf

def __call__(self, *cols):
Expand Down Expand Up @@ -2112,7 +2113,7 @@ def wrapper(*args):


@since(1.3)
def udf(f=None, returnType=StringType()):
def udf(f=None, returnType=StringType(), vectorized=False):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felixcheung does this fit your idea for a more generic decorator? Not exclusively labeled as pandas_udf, just enable vectorization with a flag, e.g. @udf(DoubleType(), vectorized=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @pandas_udf(DoubleType()) is better than @udf(DoubleType(), vectorized=True), which is more concise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we discussed in the email, we should also accept data type of string format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and also **kwargs to bring the size information

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the consensus is for pandas_udf and I'm fine with that too. I'll make that change and the others brought up here.

"""Creates a :class:`Column` expression representing a user defined function (UDF).

.. note:: The user-defined functions must be deterministic. Due to optimization,
Expand Down Expand Up @@ -2142,18 +2143,18 @@ def udf(f=None, returnType=StringType()):
| 8| JOHN DOE| 22|
+----------+--------------+------------+
"""
def _udf(f, returnType=StringType()):
udf_obj = UserDefinedFunction(f, returnType)
def _udf(f, returnType=StringType(), vectorized=False):
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
return udf_obj._wrapped()

# decorator @udf, @udf() or @udf(dataType())
if f is None or isinstance(f, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(_udf, returnType=return_type)
return functools.partial(_udf, returnType=return_type, vectorized=vectorized)
else:
return _udf(f=f, returnType=returnType)
return _udf(f=f, returnType=returnType, vectorized=vectorized)


blacklist = ['map', 'since', 'ignore_unicode_prefix']
Expand Down
20 changes: 13 additions & 7 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from pyspark.taskcontext import TaskContext
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowPandasSerializer
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -85,7 +86,7 @@ def read_single_udf(pickleSer, infile):
return arg_offsets, wrap_udf(row_func, return_type)


def read_udfs(pickleSer, infile):
def read_udfs(pickleSer, infile, eval_type):
num_udfs = read_int(infile)
udfs = {}
call_udf = []
Expand All @@ -102,7 +103,12 @@ def read_udfs(pickleSer, infile):
mapper = eval(mapper_str, udfs)

func = lambda _, it: map(mapper, it)
ser = BatchedSerializer(PickleSerializer(), 100)

if eval_type == PythonEvalType.SQL_ARROW_UDF:
ser = ArrowPandasSerializer()
else:
ser = BatchedSerializer(PickleSerializer(), 100)

# profiling is not supported for UDF
return func, None, ser, ser

Expand Down Expand Up @@ -159,11 +165,11 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
is_sql_udf = read_int(infile)
if is_sql_udf:
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
else:
eval_type = read_int(infile)
if eval_type == PythonEvalType.NON_UDF:
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
else:
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)

init_time = time.time()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ private[sql] object ArrowConverters {

val root = VectorSchemaRoot.create(arrowSchema, allocator)
val arrowWriter = ArrowWriter.create(root)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Looks unrelated change)

var closed = false

context.addTaskCompletionListener { _ =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.python

import java.io.File

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils


/**
* A physical plan that evaluates a [[PythonUDF]],
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {

def children: Seq[SparkPlan] = child :: Nil

override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))

private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
Copy link
Member

@viirya viirya Sep 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

producedAttributes and collectFunctions looks duplicate between ArrowEvalPythonExec and BatchEvalPythonExec. We can de-duplicate them, maybe in later PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these functions are duplicated, as well as some code in doExecute(). I could add a common base class like EvalPythonExec to clean this up, and maybe move to the same file?

udf.children match {
case Seq(u: PythonUDF) =>
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}

protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)

inputRDD.mapPartitions { iter =>

// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(),
new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
TaskContext.get().addTaskCompletionListener({ ctx =>
queue.close()
})

val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip

// flatten all the arguments
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
if (allInputs.exists(_.semanticEquals(e))) {
allInputs.indexWhere(_.semanticEquals(e))
} else {
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
}
}.toArray
}.toArray
val projection = newMutableProjection(allInputs, child.output)
val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
StructField(s"_$i", dt)
})

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is wrong now. We don't group input rows here.

val projectedRowIter = iter.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
projection(inputRow)
}

val context = TaskContext.get()

val inputIterator = ArrowConverters.toPayloadIterator(
projectedRowIter, schema, conf.arrowMaxRecordsPerBatch, context).
map(_.asPythonSerializable)

val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex.
map { case (attr, i) => attr.withName(s"_$i") })

// Output iterator for results from Python.
val outputIterator = new PythonRunner(
pyFuncs, bufferSize, reuseWorker, PythonEvalType.SQL_ARROW_UDF, argOffsets).
compute(inputIterator, context.partitionId(), context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we usually write it in a style as:

val outputIterator = new PythonRunner(
    pyFuncs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets)
  .compute(inputIterator, context.partitionId(), context)

There are similar styles above, e.g. map { case (attr, i) => ... -> .map { case (attr, i) => ....


val joined = new JoinedRow
val resultProj = UnsafeProjection.create(output, output)

val outputRowIterator = ArrowConverters.fromPayloadIterator(
outputIterator.map(new ArrowPayload(_)), context)

assert(schemaOut.equals(outputRowIterator.schema))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felixcheung , I think you had also brought up checking the return type matches what was defined in the UDF. This is done here.


outputRowIterator.map { outputRow =>
resultProj(joined(queue.remove(), outputRow))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -140,7 +140,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
val outputIterator = new PythonRunner(
pyFuncs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,16 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
AttributeReference(s"pythonUDF$i", u.dataType)()
}
val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child)

val evaluation = validUdfs.partition(_.vectorized) match {
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
case _ =>
throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs")
}

attributeMap ++= validUdfs.zip(resultAttrs)
evaluation
} else {
Expand Down
Loading