-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar Pandas UDF #21650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
3c2fe9a
b3435b6
490dc09
3015257
cbf310e
78f2ebf
4c9c007
83635da
2bc906d
6b22fea
b25936d
8e995e8
f3a45a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self): | |
| 'Result vector from pandas_udf was not the required length'): | ||
| df.select(raise_exception(col('id'))).collect() | ||
|
|
||
| def test_vectorized_udf_mix_udf(self): | ||
| from pyspark.sql.functions import pandas_udf, udf, col | ||
| df = self.spark.range(10) | ||
| row_by_row_udf = udf(lambda x: x, LongType()) | ||
| pd_udf = pandas_udf(lambda x: x, LongType()) | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp( | ||
| Exception, | ||
| 'Can not mix vectorized and non-vectorized UDFs'): | ||
| df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() | ||
|
|
||
| def test_vectorized_udf_chained(self): | ||
| from pyspark.sql.functions import pandas_udf, col | ||
| df = self.spark.range(10) | ||
|
|
@@ -5060,6 +5049,147 @@ def test_type_annotation(self): | |
| df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) | ||
| self.assertEqual(df.first()[0], 0) | ||
|
|
||
| def test_mixed_udf(self): | ||
| import pandas as pd | ||
| from pyspark.sql.functions import col, udf, pandas_udf | ||
|
|
||
| df = self.spark.range(0, 1).toDF('v') | ||
|
|
||
| # Test mixture of multiple UDFs and Pandas UDFs | ||
|
|
||
| @udf('int') | ||
| def f1(x): | ||
| assert type(x) == int | ||
| return x + 1 | ||
|
|
||
| @pandas_udf('int') | ||
| def f2(x): | ||
| assert type(x) == pd.Series | ||
| return x + 10 | ||
|
|
||
| @udf('int') | ||
| def f3(x): | ||
| assert type(x) == int | ||
| return x + 100 | ||
|
|
||
| @pandas_udf('int') | ||
| def f4(x): | ||
| assert type(x) == pd.Series | ||
| return x + 1000 | ||
|
|
||
| # Test mixed udfs in a single projection | ||
| df1 = df \ | ||
| .withColumn('f1', f1(col('v'))) \ | ||
| .withColumn('f2', f2(col('v'))) \ | ||
| .withColumn('f3', f3(col('v'))) \ | ||
| .withColumn('f4', f4(col('v'))) \ | ||
| .withColumn('f2_f1', f2(col('f1'))) \ | ||
| .withColumn('f3_f1', f3(col('f1'))) \ | ||
| .withColumn('f4_f1', f4(col('f1'))) \ | ||
| .withColumn('f3_f2', f3(col('f2'))) \ | ||
| .withColumn('f4_f2', f4(col('f2'))) \ | ||
| .withColumn('f4_f3', f4(col('f3'))) \ | ||
| .withColumn('f3_f2_f1', f3(col('f2_f1'))) \ | ||
| .withColumn('f4_f2_f1', f4(col('f2_f1'))) \ | ||
| .withColumn('f4_f3_f1', f4(col('f3_f1'))) \ | ||
| .withColumn('f4_f3_f2', f4(col('f3_f2'))) \ | ||
| .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) | ||
|
|
||
| # Test mixed udfs in a single expression | ||
| df2 = df \ | ||
| .withColumn('f1', f1(col('v'))) \ | ||
| .withColumn('f2', f2(col('v'))) \ | ||
| .withColumn('f3', f3(col('v'))) \ | ||
| .withColumn('f4', f4(col('v'))) \ | ||
| .withColumn('f2_f1', f2(f1(col('v')))) \ | ||
| .withColumn('f3_f1', f3(f1(col('v')))) \ | ||
| .withColumn('f4_f1', f4(f1(col('v')))) \ | ||
| .withColumn('f3_f2', f3(f2(col('v')))) \ | ||
| .withColumn('f4_f2', f4(f2(col('v')))) \ | ||
| .withColumn('f4_f3', f4(f3(col('v')))) \ | ||
| .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ | ||
| .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ | ||
| .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ | ||
| .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ | ||
| .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) | ||
|
|
||
| # expected result | ||
| df3 = df \ | ||
| .withColumn('f1', df['v'] + 1) \ | ||
| .withColumn('f2', df['v'] + 10) \ | ||
| .withColumn('f3', df['v'] + 100) \ | ||
| .withColumn('f4', df['v'] + 1000) \ | ||
| .withColumn('f2_f1', df['v'] + 11) \ | ||
| .withColumn('f3_f1', df['v'] + 101) \ | ||
| .withColumn('f4_f1', df['v'] + 1001) \ | ||
| .withColumn('f3_f2', df['v'] + 110) \ | ||
| .withColumn('f4_f2', df['v'] + 1010) \ | ||
| .withColumn('f4_f3', df['v'] + 1100) \ | ||
| .withColumn('f3_f2_f1', df['v'] + 111) \ | ||
| .withColumn('f4_f2_f1', df['v'] + 1011) \ | ||
| .withColumn('f4_f3_f1', df['v'] + 1101) \ | ||
| .withColumn('f4_f3_f2', df['v'] + 1110) \ | ||
| .withColumn('f4_f3_f2_f1', df['v'] + 1111) | ||
|
|
||
| self.assertEquals(df3.collect(), df1.collect()) | ||
| self.assertEquals(df3.collect(), df2.collect()) | ||
|
|
||
| def test_mixed_udf_and_sql(self): | ||
| import pandas as pd | ||
| from pyspark.sql.functions import udf, pandas_udf | ||
|
|
||
| df = self.spark.range(0, 1).toDF('v') | ||
|
|
||
| # Test mixture of UDFs, Pandas UDFs and SQL expression. | ||
|
|
||
| @udf('int') | ||
| def f1(x): | ||
| assert type(x) == int | ||
| return x + 1 | ||
|
|
||
| def f2(x): | ||
|
||
| return x + 10 | ||
|
|
||
| @pandas_udf('int') | ||
| def f3(x): | ||
| assert type(x) == pd.Series | ||
| return x + 100 | ||
|
|
||
| df1 = df.withColumn('f1', f1(df['v'])) \ | ||
| .withColumn('f2', f2(df['v'])) \ | ||
| .withColumn('f3', f3(df['v'])) \ | ||
| .withColumn('f1_f2', f1(f2(df['v']))) \ | ||
| .withColumn('f1_f3', f1(f3(df['v']))) \ | ||
| .withColumn('f2_f1', f2(f1(df['v']))) \ | ||
| .withColumn('f2_f3', f2(f3(df['v']))) \ | ||
| .withColumn('f3_f1', f3(f1(df['v']))) \ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks combination between f1 and f3 duplicating few tests in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the way the test is written is that I am trying to test many combinations so there are some dup cases. Do you prefer that I remove these?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea.. I know it's still minor since the elapsed time will be virtually the same but recently the build / test time was an issue, and I wonder if there's better way then avoding duplicated tests for now..
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was discussed here #21845
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I don't think it's necessary (we are only likely to remove a few cases and like you said, the test time is virtually the same) and helps the readability of the tests (so it doesn't look like some test cases are missed). But if that's the preferred practice I can remove duplicate cases in the next commit.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay to leave it too here since it's clear they are virtually the same but let's remove duplicated tests or orthogonal tests next time.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha. I will keep that in mind next time. |
||
| .withColumn('f3_f2', f3(f2(df['v']))) \ | ||
| .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ | ||
| .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ | ||
| .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ | ||
| .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ | ||
| .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ | ||
| .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) | ||
|
|
||
| # expected result | ||
| df2 = df.withColumn('f1', df['v'] + 1) \ | ||
| .withColumn('f2', df['v'] + 10) \ | ||
| .withColumn('f3', df['v'] + 100) \ | ||
| .withColumn('f1_f2', df['v'] + 11) \ | ||
| .withColumn('f1_f3', df['v'] + 101) \ | ||
| .withColumn('f2_f1', df['v'] + 11) \ | ||
| .withColumn('f2_f3', df['v'] + 110) \ | ||
| .withColumn('f3_f1', df['v'] + 101) \ | ||
| .withColumn('f3_f2', df['v'] + 110) \ | ||
| .withColumn('f1_f2_f3', df['v'] + 111) \ | ||
| .withColumn('f1_f3_f2', df['v'] + 111) \ | ||
| .withColumn('f2_f1_f3', df['v'] + 111) \ | ||
| .withColumn('f2_f3_f1', df['v'] + 111) \ | ||
| .withColumn('f3_f1_f2', df['v'] + 111) \ | ||
| .withColumn('f3_f2_f1', df['v'] + 111) | ||
|
|
||
| self.assertEquals(df2.collect(), df1.collect()) | ||
|
||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| not _have_pandas or not _have_pyarrow, | ||
|
|
@@ -5487,6 +5617,22 @@ def dummy_pandas_udf(df): | |
| F.col('temp0.key') == F.col('temp1.key')) | ||
| self.assertEquals(res.count(), 5) | ||
|
|
||
| def test_mixed_scalar_udfs_followed_by_grouby_apply(self): | ||
| # Test Pandas UDF and scalar Python UDF followed by groupby apply | ||
| from pyspark.sql.functions import udf, pandas_udf, PandasUDFType | ||
| import pandas as pd | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not a big deal at all really .. but I would swap the import order (thridparty, pyspark) |
||
|
|
||
| df = self.spark.range(0, 10).toDF('v1') | ||
| df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ | ||
| .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) | ||
|
|
||
| result = df.groupby() \ | ||
| .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), | ||
| 'sum int', | ||
| PandasUDFType.GROUPED_MAP)) | ||
|
|
||
| self.assertEquals(result.collect()[0]['sum'], 165) | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| not _have_pandas or not _have_pyarrow, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ import scala.collection.mutable | |
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.api.python.PythonEvalType | ||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} | ||
|
|
@@ -94,36 +95,59 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { | |
| */ | ||
| object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | ||
|
|
||
| private def hasPythonUDF(e: Expression): Boolean = { | ||
| private def hasScalarPythonUDF(e: Expression): Boolean = { | ||
| e.find(PythonUDF.isScalarPythonUDF).isDefined | ||
| } | ||
|
|
||
| private def canEvaluateInPython(e: PythonUDF): Boolean = { | ||
| e.children match { | ||
| // single PythonUDF child could be chained and evaluated in Python | ||
| case Seq(u: PythonUDF) => canEvaluateInPython(u) | ||
| // Python UDF can't be evaluated directly in JVM | ||
| case children => !children.exists(hasPythonUDF) | ||
| private def canEvaluateInPython(e: PythonUDF, evalType: Int): Boolean = { | ||
| if (e.evalType != evalType) { | ||
|
||
| false | ||
| } else { | ||
| e.children match { | ||
| // single PythonUDF child could be chained and evaluated in Python | ||
| case Seq(u: PythonUDF) => canEvaluateInPython(u, evalType) | ||
| // Python UDF can't be evaluated directly in JVM | ||
| case children => !children.exists(hasScalarPythonUDF) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { | ||
| case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) | ||
| case e => e.children.flatMap(collectEvaluatableUDF) | ||
| private def collectEvaluableUDF(expr: Expression, evalType: Int): Seq[PythonUDF] = expr match { | ||
|
||
| case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) => | ||
| Seq(udf) | ||
| case e => e.children.flatMap(collectEvaluableUDF(_, evalType)) | ||
| } | ||
|
|
||
| /** | ||
| * Collect evaluable UDFs from the current node. | ||
| * | ||
| * This function collects Python UDFs or Scalar Python UDFs from expressions of the input node, | ||
| * and returns a list of UDFs of the same eval type. | ||
|
||
| * | ||
| * If expressions contain both UDFs eval types, this function will only return Python UDFs. | ||
| * | ||
| * The caller should call this function multiple times until all evaluable UDFs are collected. | ||
|
||
| */ | ||
| private def collectEvaluableUDFs(plan: SparkPlan): Seq[PythonUDF] = { | ||
| val pythonUDFs = | ||
| plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_BATCHED_UDF)) | ||
|
|
||
| if (pythonUDFs.isEmpty) { | ||
| plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_SCALAR_PANDAS_UDF)) | ||
| } else { | ||
| pythonUDFs | ||
|
||
| } | ||
| } | ||
|
|
||
| def apply(plan: SparkPlan): SparkPlan = plan transformUp { | ||
| // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker | ||
| // Therefore we don't need to extract the UDFs | ||
| case plan: FlatMapGroupsInPandasExec => plan | ||
|
||
| case plan: SparkPlan => extract(plan) | ||
| } | ||
|
|
||
| /** | ||
| * Extract all the PythonUDFs from the current operator and evaluate them before the operator. | ||
| */ | ||
| private def extract(plan: SparkPlan): SparkPlan = { | ||
| val udfs = plan.expressions.flatMap(collectEvaluatableUDF) | ||
| val udfs = collectEvaluableUDFs(plan) | ||
| // ignore the PythonUDF that come from second/third aggregate, which is not used | ||
| .filter(udf => udf.references.subsetOf(plan.inputSet)) | ||
| if (udfs.isEmpty) { | ||
|
|
@@ -167,7 +191,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | |
| case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => | ||
| BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) | ||
| case _ => | ||
| throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") | ||
| throw new AnalysisException( | ||
|
||
| "Expected either Scalar Pandas UDFs or Batched UDFs but got both") | ||
| } | ||
|
|
||
| attributeMap ++= validUdfs.zip(resultAttrs) | ||
|
|
@@ -205,7 +230,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | |
| case filter: FilterExec => | ||
| val (candidates, nonDeterministic) = | ||
| splitConjunctivePredicates(filter.condition).partition(_.deterministic) | ||
| val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) | ||
| val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) | ||
| if (pushDown.nonEmpty) { | ||
| val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) | ||
| FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.python | ||
|
|
||
| import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} | ||
| import org.apache.spark.sql.functions.col | ||
| import org.apache.spark.sql.test.SharedSQLContext | ||
|
|
||
| class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext { | ||
| import testImplicits.newProductEncoder | ||
| import testImplicits.localSeqToDatasetHolder | ||
|
|
||
| val batchedPythonUDF = new MyDummyPythonUDF | ||
| val scalarPandasUDF = new MyDummyScalarPandasUDF | ||
|
|
||
| private def collectBatchExec(plan: SparkPlan): Seq[BatchEvalPythonExec] = plan.collect { | ||
| case b: BatchEvalPythonExec => b | ||
| } | ||
|
|
||
| private def collectArrowExec(plan: SparkPlan): Seq[ArrowEvalPythonExec] = plan.collect { | ||
| case b: ArrowEvalPythonExec => b | ||
| } | ||
|
|
||
| test("Chained Batched Python UDFs should be combined to a single physical node") { | ||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) | ||
| .withColumn("d", batchedPythonUDF(col("c"))) | ||
| val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) | ||
| assert(pythonEvalNodes.size == 1) | ||
| } | ||
|
|
||
| test("Chained Scalar Pandas UDFs should be combined to a single physical node") { | ||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| val df2 = df.withColumn("c", scalarPandasUDF(col("a"))) | ||
| .withColumn("d", scalarPandasUDF(col("c"))) | ||
| val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) | ||
| assert(arrowEvalNodes.size == 1) | ||
| } | ||
|
|
||
| test("Mixed Batched Python UDFs and Pandas UDF should be separate physical node") { | ||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) | ||
| .withColumn("d", scalarPandasUDF(col("b"))) | ||
|
|
||
| val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) | ||
| val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) | ||
| assert(pythonEvalNodes.size == 1) | ||
| assert(arrowEvalNodes.size == 1) | ||
| } | ||
|
|
||
| test("Independent Batched Python UDFs and Scalar Pandas UDFs should be combined separately") { | ||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) | ||
| .withColumn("c2", batchedPythonUDF(col("c1"))) | ||
| .withColumn("d1", scalarPandasUDF(col("a"))) | ||
| .withColumn("d2", scalarPandasUDF(col("d1"))) | ||
|
|
||
| val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) | ||
| val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) | ||
| assert(pythonEvalNodes.size == 1) | ||
| assert(arrowEvalNodes.size == 1) | ||
| } | ||
|
|
||
| test("Dependent Batched Python UDFs and Scalar Pandas UDFs should not be combined") { | ||
| val df = Seq(("Hello", 4)).toDF("a", "b") | ||
| val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) | ||
| .withColumn("d1", scalarPandasUDF(col("c1"))) | ||
| .withColumn("c2", batchedPythonUDF(col("d1"))) | ||
| .withColumn("d2", scalarPandasUDF(col("c2"))) | ||
|
|
||
| val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) | ||
| val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) | ||
| assert(pythonEvalNodes.size == 2) | ||
| assert(arrowEvalNodes.size == 2) | ||
| } | ||
| } | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks testing udf + udf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the way the test is written is that I am trying to test many combinations so some combinations might not be mixed UDF. Do you prefer that I remove these cases?