Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ private[spark] class PythonRDD(
}
}

private[spark] case class PythonFunction(
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])

/**
* A helper class to run Python UDFs in Spark.
Expand Down
75 changes: 75 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ def map(self, f):
"""
return self.rdd.map(f)

@ignore_unicode_prefix
@since(2.0)
def mapPartitions2(self, func):
""" TODO """
return PipelinedDataFrame(self, func)

@ignore_unicode_prefix
@since(1.3)
def flatMap(self, f):
Expand Down Expand Up @@ -1354,6 +1360,75 @@ def toPandas(self):
drop_duplicates = dropDuplicates


class PipelinedDataFrame(DataFrame):

""" TODO """

def __init__(self, prev, func, output_schema=None):
self.output_schema = output_schema
self._schema = None
self.is_cached = False
self.sql_ctx = prev.sql_ctx
self._sc = self.sql_ctx and self.sql_ctx._sc
self._jdf_val = None
self._lazy_rdd = None

if output_schema is not None:
# This transformation is applying schema, just copy member variables from prev.
self.func = func
self._prev_jdf = prev._prev_jdf
elif not isinstance(prev, PipelinedDataFrame) or not prev.is_cached:
# This transformation is the first in its stage:
self.func = func
self._prev_jdf = prev._jdf
else:
self.func = _pipeline_func(prev.func, func)
self._prev_jdf = prev._prev_jdf # maintain the pipeline

def applySchema(self, schema):
return PipelinedDataFrame(self, self.func, schema)

@property
def _jdf(self):
if self._jdf_val is None:
if self.output_schema is None:
schema = StructType().add("binary", BinaryType(), False, {"pickled": True})
final_func = self.func
elif isinstance(self.output_schema, StructType):
schema = self.output_schema
to_row = lambda iterator: map(schema.toInternal, iterator)
final_func = _pipeline_func(self.func, to_row)
else:
data_type = self.output_schema
schema = StructType().add("value", data_type)
converter = lambda obj: (data_type.toInternal(obj), )
to_row = lambda iterator: map(converter, iterator)
final_func = _pipeline_func(self.func, to_row)

wrapped_func = self._wrap_function(final_func)
self._jdf_val = self._prev_jdf.pythonMapPartitions(wrapped_func, schema.json())

return self._jdf_val

def _wrap_function(self, f):
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer

ser = AutoBatchedSerializer(PickleSerializer())
command = (lambda _, iterator: f(iterator), None, ser, ser)
sc = self._sc
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)


def _pipeline_func(prev_func, next_func):
if prev_func is None:
return next_func
else:
return lambda iterator: next_func(prev_func(iterator))


def _to_scala_map(sc, jm):
"""
Convert a dict into a JVM Map.
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,30 @@ def test_functions_broadcast(self):
# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()

def test_dataset(self):
ds = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))

func = lambda row: {"key": row.key + 1, "value": row.value} # convert row to python dict
ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator))
schema = StructType().add("key", IntegerType()).add("value", StringType())
ds3 = ds2.applySchema(schema)
result = ds3.select("key").collect()
self.assertEqual(result[0][0], 2)
self.assertEqual(result[1][0], 3)

schema = StructType().add("value", StringType()) # use a different but compatible schema
ds3 = ds2.applySchema(schema)
result = ds3.collect()
self.assertEqual(result[0][0], "1")
self.assertEqual(result[1][0], "2")

func = lambda row: row.key * 3
ds2 = ds.mapPartitions2(lambda iterator: map(func, iterator))
ds3 = ds2.applySchema(IntegerType()) # use a flat schema
result = ds3.collect()
self.assertEqual(result[0][0], 3)
self.assertEqual(result[1][0], 6)


class HiveContextSQLTests(ReusedPySparkTestCase):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -91,6 +92,13 @@ case class MapPartitions(
override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}

case class PythonMapPartitions(
func: PythonFunction,
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def expressions: Seq[Expression] = Nil
}

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
Expand Down Expand Up @@ -208,8 +216,6 @@ case class CoGroup(
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode with ObjectOperator {

override def producedAttributes: AttributeSet = outputSet

override def deserializers: Seq[(Expression, Seq[Attribute])] =
// The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
// the `keyDeserializer` based on either of them, here we pick the left one.
Expand Down
9 changes: 8 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import com.fasterxml.jackson.core.JsonFactory

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.api.python.{PythonFunction, PythonRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
Expand Down Expand Up @@ -1761,6 +1761,13 @@ class DataFrame private[sql](
}
}

protected[sql] def pythonMapPartitions(
func: PythonFunction,
schemaJson: String): DataFrame = withPlan {
val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType]
PythonMapPartitions(func, schema.toAttributes, logicalPlan)
}

/**
* Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with
* an execution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ python.EvaluatePython(udf, child, _) =>
python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case logical.PythonMapPartitions(func, output, child) =>
execution.PythonMapPartitions(func, output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@

package org.apache.spark.sql.execution

import scala.collection.JavaConverters._

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.ObjectType
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.types.{BinaryType, ObjectType, StructType}

/**
* Helper functions for physical operators that work with user defined objects.
Expand Down Expand Up @@ -67,6 +74,72 @@ case class MapPartitions(
}
}

case class PythonMapPartitions(
func: PythonFunction,
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {

override def expressions: Seq[Expression] = Nil

private def isPickled(schema: StructType): Boolean = {
schema.length == 1 && schema.head.dataType == BinaryType &&
schema.head.metadata.contains("pickled")
}

override protected def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val childIsPickled = isPickled(child.schema)
val outputIsPickled = isPickled(schema)

inputRDD.mapPartitions { iter =>
val inputIterator = if (childIsPickled) {
iter.map(_.getBinary(0))
} else {
EvaluatePython.registerPicklers() // register pickler for Row

val pickle = new Pickler

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { row =>
EvaluatePython.toJava(row, child.schema)
}.toArray
pickle.dumps(toBePickled)
}
}

val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator =
new PythonRunner(
func.command,
func.envVars,
func.pythonIncludes,
func.pythonExec,
func.pythonVer,
func.broadcastVars,
func.accumulator,
bufferSize,
reuseWorker
).compute(inputIterator, context.partitionId(), context)

if (outputIsPickled) {
outputIterator.map(bytes => InternalRow(bytes))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid copying the bytes, here I create safe rows. However, according to #10511, operators should always produce unsafe rows. Actually python UDF operator(BatchPythonEvaluation) also produce safe rows, which may also have problems. Should we bring back the requireUnsafeRow stuff? In some cases like here, converting to unsafe rows is expensive and may not have much benefit.

cc @davies

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BatchPythonEvaluation will produce UnsafeRow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I missed the unsafe projection at the very last. Then we can probably add an unsafe projection here too.

} else {
val unpickle = new Unpickler
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map(result => EvaluatePython.fromJava(result, schema).asInstanceOf[InternalRow])
}
}
}
}

/**
* Applies the given function to each input row, appending the encoded result at the end of the row.
*/
Expand Down