1717
1818package org .apache .spark .sql .execution .python
1919
20+ import scala .collection .JavaConverters ._
21+
2022import org .apache .spark .TaskContext
2123import org .apache .spark .api .python .{ChainedPythonFunctions , PythonEvalType }
2224import org .apache .spark .sql .catalyst .InternalRow
2325import org .apache .spark .sql .catalyst .expressions ._
2426import org .apache .spark .sql .execution .SparkPlan
25- import org .apache .spark .sql .execution .arrow .{ArrowConverters , ArrowPayload }
2627import org .apache .spark .sql .types .StructType
2728
2829/**
@@ -39,25 +40,36 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
3940 iter : Iterator [InternalRow ],
4041 schema : StructType ,
4142 context : TaskContext ): Iterator [InternalRow ] = {
42- val inputIterator = ArrowConverters .toPayloadIterator(
43- iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable)
44-
45- // Output iterator for results from Python.
46- val outputIterator = new PythonUDFRunner (
47- funcs, bufferSize, reuseWorker, PythonEvalType .SQL_PANDAS_UDF , argOffsets)
48- .compute(inputIterator, context.partitionId(), context)
49-
50- val outputRowIterator = ArrowConverters .fromPayloadIterator(
51- outputIterator.map(new ArrowPayload (_)), context)
52-
53- // Verify that the output schema is correct
54- if (outputRowIterator.hasNext) {
55- val schemaOut = StructType .fromAttributes(output.drop(child.output.length).zipWithIndex
56- .map { case (attr, i) => attr.withName(s " _ $i" ) })
57- assert(schemaOut.equals(outputRowIterator.schema),
58- s " Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}" )
59- }
6043
61- outputRowIterator
44+ val schemaOut = StructType .fromAttributes(output.drop(child.output.length).zipWithIndex
45+ .map { case (attr, i) => attr.withName(s " _ $i" ) })
46+
47+ val columnarBatchIter = new ArrowPythonRunner (
48+ funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker,
49+ PythonEvalType .SQL_PANDAS_UDF , argOffsets, schema)
50+ .compute(iter, context.partitionId(), context)
51+
52+ new Iterator [InternalRow ] {
53+
54+ var currentIter = if (columnarBatchIter.hasNext) {
55+ val batch = columnarBatchIter.next()
56+ assert(schemaOut.equals(batch.schema),
57+ s " Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}" )
58+ batch.rowIterator.asScala
59+ } else {
60+ Iterator .empty
61+ }
62+
63+ override def hasNext : Boolean = currentIter.hasNext || {
64+ if (columnarBatchIter.hasNext) {
65+ currentIter = columnarBatchIter.next().rowIterator.asScala
66+ hasNext
67+ } else {
68+ false
69+ }
70+ }
71+
72+ override def next (): InternalRow = currentIter.next()
73+ }
6274 }
6375}
0 commit comments