-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-28445][SQL][PYTHON] Fix error when PythonUDF is used in both group by and aggregate expression #25215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This comment has been minimized.
This comment has been minimized.
SparkR CRAN feasibility check (SPARK-24152) fails again..Emailed to CRAN for help. |
|
retest this please |
This comment has been minimized.
This comment has been minimized.
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
Outdated
Show resolved
Hide resolved
sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Outdated
Show resolved
Hide resolved
sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Outdated
Show resolved
Hide resolved
sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Outdated
Show resolved
Hide resolved
sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Outdated
Show resolved
Hide resolved
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
|
retest this please |
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
Show resolved
Hide resolved
|
Test build #108001 has finished for PR 25215 at commit
|
|
Test build #108114 has finished for PR 25215 at commit
|
|
cc @cloud-fan and @mgaido91 too |
| * before aggregate. | ||
| * This must be executed after `ExtractPythonUDFFromAggregate` rule and before `ExtractPythonUDFs`. | ||
| */ | ||
| object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between this rule and ExtractPythonUDFFromAggregate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExtractPythonUDFFromAggregate pulls out Python UDFs which have aggregate expression or grouping key as input, like udf(sum(c)), and Python UDFs which have no input. Those UDFs pulled out are evaluated after aggregate.
This rule, ExtractGroupingPythonUDFFromAggregate, pulls out Python UDFs which are used in grouping keys, like SELECT count(*) FROM table GROUP BY udf(id). This kind of Python UDF is evaluated before aggregate.
| override lazy val canonicalized: Expression = { | ||
| val canonicalizedChildren = children.map(_.canonicalized) | ||
| // `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result. | ||
| Canonicalize.execute(this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need to run Canonicalize.execute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be saved, yes.
| // in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too. | ||
| case p: PythonUDF if p.udfDeterministic => | ||
| val canonicalized = p.canonicalized.asInstanceOf[PythonUDF] | ||
| attributeMap.get(canonicalized).map(_.toAttribute).getOrElse(p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can put attributes in attributeMap, instead of Alias.
| .agg(scalaTestUDF(base("a") + 1), scalaTestUDF(count(base("b")))) | ||
| val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) | ||
| .agg(pythonTestUDF(base("a") + 1), pythonTestUDF(count(base("b")))) | ||
| checkAnswer(df, df2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we create a test case for each of these checks? We can move the scalaTestUDF, pythonTestUDF and base to the class body.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
| "in grouping expression") | ||
| val alias = Alias(p, "groupingPythonUDF")() | ||
| projList += alias | ||
| attributeMap += ((p.canonicalized.asInstanceOf[PythonUDF], alias.toAttribute)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if a python udf is handled before, this will replace it with a new attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. let's replace it. since they are deterministic, it should be fine.
| // in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too. | ||
| case p: PythonUDF if p.udfDeterministic => | ||
| val canonicalized = p.canonicalized.asInstanceOf[PythonUDF] | ||
| attributeMap.getOrElse(canonicalized, p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if we can't replace python udf here, we can't run the query. Maybe it's better to do attributeMap.get(...).getOrElse(fail)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we can't replace python udf here, it still can run. Like:
val df = base.groupBy(pythonTestUDF(base("a")))
.agg(sum(pythonTestUDF(base("a") + 1)))ExtractPythonUDFs will extract such udfs.
|
Test build #108546 has finished for PR 25215 at commit
|
|
Test build #108551 has finished for PR 25215 at commit
|
|
retest this please |
| // CheckAnalysis guarantees the arguments are deterministic. | ||
| // 2. PythonUDF in grouping key. Grouping key must be deterministic. | ||
| // 3. PythonUDF not in grouping key. It is either no arguments or with grouping key | ||
| // in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tiny nit: spacing.
| (None, Some(1)), (Some(3), None), (None, None)).toDF("a", "b") | ||
|
|
||
| test("SPARK-28445: PythonUDF as grouping key and aggregate expressions") { | ||
| val df1 = base.groupBy(scalaTestUDF(base("a") + 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, thanks for changing this into DSL. It was rather a nit that needs some efforts.
|
Test build #108554 has finished for PR 25215 at commit
|
|
Merged to master. |
… by clause in 'udf-group-analytics.sql' ## What changes were proposed in this pull request? This PR is a followup of a fix as described in here: apache#25215 (comment) <details><summary>Diff comparing to 'group-analytics.sql'</summary> <p> ```diff diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-analytics.sql.out index 3439a05..de297ab 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-analytics.sql.out -13,9 +13,9 struct<> -- !query 1 -SELECT a + b, b, SUM(a - b) FROM testData GROUP BY a + b, b WITH CUBE +SELECT udf(a + b), b, udf(SUM(a - b)) FROM testData GROUP BY udf(a + b), b WITH CUBE -- !query 1 schema -struct<(a + b):int,b:int,sum((a - b)):bigint> +struct<CAST(udf(cast((a + b) as string)) AS INT):int,b:int,CAST(udf(cast(sum(cast((a - b) as bigint)) as string)) AS BIGINT):bigint> -- !query 1 output 2 1 0 2 NULL 0 -33,9 +33,9 NULL NULL 3 -- !query 2 -SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH CUBE +SELECT udf(a), udf(b), SUM(b) FROM testData GROUP BY udf(a), b WITH CUBE -- !query 2 schema -struct<a:int,b:int,sum(b):bigint> +struct<CAST(udf(cast(a as string)) AS INT):int,CAST(udf(cast(b as string)) AS INT):int,sum(b):bigint> -- !query 2 output 1 1 1 1 2 2 -52,9 +52,9 NULL NULL 9 -- !query 3 -SELECT a + b, b, SUM(a - b) FROM testData GROUP BY a + b, b WITH ROLLUP +SELECT udf(a + b), b, SUM(a - b) FROM testData GROUP BY a + b, b WITH ROLLUP -- !query 3 schema -struct<(a + b):int,b:int,sum((a - b)):bigint> +struct<CAST(udf(cast((a + b) as string)) AS INT):int,b:int,sum((a - b)):bigint> -- !query 3 output 2 1 0 2 NULL 0 -70,9 +70,9 NULL NULL 3 -- !query 4 -SELECT a, b, SUM(b) FROM testData GROUP BY a, b WITH ROLLUP +SELECT udf(a), b, udf(SUM(b)) FROM testData GROUP BY udf(a), b WITH ROLLUP -- !query 4 schema -struct<a:int,b:int,sum(b):bigint> +struct<CAST(udf(cast(a as string)) AS INT):int,b:int,CAST(udf(cast(sum(cast(b as bigint)) as string)) AS BIGINT):bigint> -- !query 4 output 1 1 1 1 2 2 -97,7 +97,7 struct<> -- !query 6 -SELECT course, year, SUM(earnings) FROM courseSales GROUP BY ROLLUP(course, year) ORDER BY course, year +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY ROLLUP(course, year) ORDER BY udf(course), year -- !query 6 schema struct<course:string,year:int,sum(earnings):bigint> -- !query 6 output -111,7 +111,7 dotNET 2013 48000 -- !query 7 -SELECT course, year, SUM(earnings) FROM courseSales GROUP BY CUBE(course, year) ORDER BY course, year +SELECT course, year, SUM(earnings) FROM courseSales GROUP BY CUBE(course, year) ORDER BY course, udf(year) -- !query 7 schema struct<course:string,year:int,sum(earnings):bigint> -- !query 7 output -127,9 +127,9 dotNET 2013 48000 -- !query 8 -SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course, year) +SELECT course, udf(year), SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course, year) -- !query 8 schema -struct<course:string,year:int,sum(earnings):bigint> +struct<course:string,CAST(udf(cast(year as string)) AS INT):int,sum(earnings):bigint> -- !query 8 output Java NULL 50000 NULL 2012 35000 -138,26 +138,26 dotNET NULL 63000 -- !query 9 -SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(course) +SELECT course, year, udf(SUM(earnings)) FROM courseSales GROUP BY course, year GROUPING SETS(course) -- !query 9 schema -struct<course:string,year:int,sum(earnings):bigint> +struct<course:string,year:int,CAST(udf(cast(sum(cast(earnings as bigint)) as string)) AS BIGINT):bigint> -- !query 9 output Java NULL 50000 dotNET NULL 63000 -- !query 10 -SELECT course, year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(year) +SELECT udf(course), year, SUM(earnings) FROM courseSales GROUP BY course, year GROUPING SETS(year) -- !query 10 schema -struct<course:string,year:int,sum(earnings):bigint> +struct<CAST(udf(cast(course as string)) AS STRING):string,year:int,sum(earnings):bigint> -- !query 10 output NULL 2012 35000 NULL 2013 78000 -- !query 11 -SELECT course, SUM(earnings) AS sum FROM courseSales -GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum +SELECT course, udf(SUM(earnings)) AS sum FROM courseSales +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, udf(sum) -- !query 11 schema struct<course:string,sum:bigint> -- !query 11 output -173,7 +173,7 dotNET 63000 -- !query 12 SELECT course, SUM(earnings) AS sum, GROUPING_ID(course, earnings) FROM courseSales -GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY course, sum +GROUP BY course, earnings GROUPING SETS((), (course), (course, earnings)) ORDER BY udf(course), sum -- !query 12 schema struct<course:string,sum:bigint,grouping_id(course, earnings):int> -- !query 12 output -188,10 +188,10 dotNET 63000 1 -- !query 13 -SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales +SELECT udf(course), udf(year), GROUPING(course), GROUPING(year), GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) -- !query 13 schema -struct<course:string,year:int,grouping(course):tinyint,grouping(year):tinyint,grouping_id(course, year):int> +struct<CAST(udf(cast(course as string)) AS STRING):string,CAST(udf(cast(year as string)) AS INT):int,grouping(course):tinyint,grouping(year):tinyint,grouping_id(course, year):int> -- !query 13 output Java 2012 0 0 0 Java 2013 0 0 0 -205,7 +205,7 dotNET NULL 0 1 1 -- !query 14 -SELECT course, year, GROUPING(course) FROM courseSales GROUP BY course, year +SELECT course, udf(year), GROUPING(course) FROM courseSales GROUP BY course, udf(year) -- !query 14 schema struct<> -- !query 14 output -214,7 +214,7 grouping() can only be used with GroupingSets/Cube/Rollup; -- !query 15 -SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY course, year +SELECT course, udf(year), GROUPING_ID(course, year) FROM courseSales GROUP BY udf(course), year -- !query 15 schema struct<> -- !query 15 output -223,7 +223,7 grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 16 -SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, udf(year) -- !query 16 schema struct<course:string,year:int,grouping__id:int> -- !query 16 output -240,7 +240,7 NULL NULL 3 -- !query 17 SELECT course, year FROM courseSales GROUP BY CUBE(course, year) -HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, year +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, udf(year) -- !query 17 schema struct<course:string,year:int> -- !query 17 output -250,7 +250,7 dotNET NULL -- !query 18 -SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING(course) > 0 +SELECT course, udf(year) FROM courseSales GROUP BY udf(course), year HAVING GROUPING(course) > 0 -- !query 18 schema struct<> -- !query 18 output -259,7 +259,7 grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 19 -SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0 +SELECT course, udf(udf(year)) FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0 -- !query 19 schema struct<> -- !query 19 output -268,9 +268,9 grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 20 -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0 +SELECT udf(course), year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0 -- !query 20 schema -struct<course:string,year:int> +struct<CAST(udf(cast(course as string)) AS STRING):string,year:int> -- !query 20 output Java NULL NULL 2012 -281,7 +281,7 dotNET NULL -- !query 21 SELECT course, year, GROUPING(course), GROUPING(year) FROM courseSales GROUP BY CUBE(course, year) -ORDER BY GROUPING(course), GROUPING(year), course, year +ORDER BY GROUPING(course), GROUPING(year), course, udf(year) -- !query 21 schema struct<course:string,year:int,grouping(course):tinyint,grouping(year):tinyint> -- !query 21 output -298,7 +298,7 NULL NULL 1 1 -- !query 22 SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(course, year) -ORDER BY GROUPING(course), GROUPING(year), course, year +ORDER BY GROUPING(course), GROUPING(year), course, udf(year) -- !query 22 schema struct<course:string,year:int,grouping_id(course, year):int> -- !query 22 output -314,7 +314,7 NULL NULL 3 -- !query 23 -SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course) +SELECT course, udf(year) FROM courseSales GROUP BY course, udf(year) ORDER BY GROUPING(course) -- !query 23 schema struct<> -- !query 23 output -323,7 +323,7 grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 24 -SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course) +SELECT course, udf(year) FROM courseSales GROUP BY course, udf(year) ORDER BY GROUPING_ID(course) -- !query 24 schema struct<> -- !query 24 output -332,7 +332,7 grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 25 -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, udf(course), year -- !query 25 schema struct<course:string,year:int> -- !query 25 output -348,7 +348,7 NULL NULL -- !query 26 -SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2) +SELECT udf(a + b) AS k1, udf(b) AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2) -- !query 26 schema struct<k1:int,k2:int,sum((a - b)):bigint> -- !query 26 output -368,7 +368,7 NULL NULL 3 -- !query 27 -SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b) +SELECT udf(udf(a + b)) AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b) -- !query 27 schema struct<k:int,b:int,sum((a - b)):bigint> -- !query 27 output -386,9 +386,9 NULL NULL 3 -- !query 28 -SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) +SELECT udf(a + b), udf(udf(b)) AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) -- !query 28 schema -struct<(a + b):int,k:int,sum((a - b)):bigint> +struct<CAST(udf(cast((a + b) as string)) AS INT):int,k:int,sum((a - b)):bigint> -- !query 28 output NULL 1 3 NULL 2 0 ``` </p> </details> ## How was this patch tested? Tested as instructed in SPARK-27921. Closes apache#25362 from skonto/group-analytics-followup. Authored-by: Stavros Kontopoulos <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
@HyukjinKwon |
|
@shivusondur @HyukjinKwon The analysis exception by adding udf to group by, is caused by SPARK-28386, SPARK-26741. |
|
Thanks, @viirya. @shivusondur Can you comment the test out with those JIRA numbers? |
What changes were proposed in this pull request?
When PythonUDF is used in group by, and it is also in aggregate expression, like
It causes analysis exception in
CheckAnalysis, likeFirst,
CheckAnalysiscan't check semantic equality between PythonUDFs.Second, even we make it possible, runtime exception will be thrown
The cause is,
ExtractPythonUDFsextracts both PythonUDFs in group by and aggregate expression. The PythonUDFs are two different aliases now in the logical aggregate. In runtime, we can't bind the resulting expression in aggregate to its grouping and aggregate attributes.This patch proposes a rule
ExtractGroupingPythonUDFFromAggregateto extract PythonUDFs in group by and evaluate them before aggregate. We replace the group by PythonUDF in aggregate expression with aliased result.The query plan of query
SELECT pyUDF(a + 1), pyUDF(COUNT(b)) FROM testData GROUP BY pyUDF(a + 1), likeHow was this patch tested?
Added tests.