-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-27297] [SQL] Add higher order functions to scala API #24232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
nvander1
wants to merge
42
commits into
apache:master
from
nvander1:feature/add_higher_order_functions_to_scala_api
Closed
Changes from 27 commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
9cf1ebf
Adds higher order functions to scala API
nvander1 efc6ba4
Add (Scala-specifc) note to higher order functions
nvander1 b9dceec
Follow style guide more closely
nvander1 1fb46a3
Fix scalastyle issues
nvander1 03d602f
Add java-specific version of higher order function api
nvander1 6bf07d8
Do not prematurely bind lambda variables
nvander1 b03399a
Resolve conflict between Java Function and Scala Function
HyukjinKwon 79d6f84
Adds higher order functions to scala API
nvander1 7adaf9c
Add (Scala-specifc) note to higher order functions
nvander1 ac5c1c2
Follow style guide more closely
nvander1 40ac418
Fix scalastyle issues
nvander1 fb5f8ef
Add java-specific version of higher order function api
nvander1 85979d4
Do not prematurely bind lambda variables
nvander1 5d389d2
Merge branch 'fix-24232' of git://github.com/HyukjinKwon/spark into H…
nvander1 5d77d6b
Merge branch 'master' into feature/add_higher_order_functions_to_scal…
nvander1 a8c7ecd
Add forall to org.apache.spark.sql.functions
nvander1 96fb0ad
Add "@since 3.0.0" to new functions
nvander1 5fa3e71
Add tests for Java transform function
nvander1 0bfa483
Add tests for Java map_filter function
nvander1 815e9f6
Add tests for Java filter function
nvander1 47b100b
Add tests for Java exists function
nvander1 4baf084
Add test for Java API forall
nvander1 9c0f70e
Merge branch 'master' into feature/add_higher_order_functions_to_scal…
nvander1 06b4c82
Add test for Java API: aggregate
nvander1 412ece5
Add test for Java API: map_zip_with
nvander1 c49e7d3
Add java tests for transform_keys, transform_values
nvander1 182a08b
Add tests for java zip_with function
nvander1 ef6b6bb
Remove JavaFunction overloads and add Java transform test
nvander1 a543c90
Merge branch 'tmp' into feature/add_higher_order_functions_to_scala_api
nvander1 527c0cb
Remove (Scala-specifc) from higher order functions
nvander1 013187f
Remove java tests from DataFrameFunctionsSuite
nvander1 554a992
Add simple java test for filter
nvander1 0433756
Add simple java test for exists
nvander1 f371413
Add simple java test for forall
nvander1 c3e320c
Add java test for aggregate
nvander1 84ccf55
Add java aggregate test with finish
nvander1 e43033b
Add java test for zip_with
nvander1 c1c76a9
Add java test for transformKeys
nvander1 10a5f2e
Add java test for transform_values
nvander1 722f0e6
Add java test for map_filter and map_zip_with
nvander1 1bf2654
Fix style nits
nvander1 64c0f87
Fix linter errors in imports
nvander1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import scala.util.Try | |
| import scala.util.control.NonFatal | ||
|
|
||
| import org.apache.spark.annotation.Stable | ||
| import org.apache.spark.api.java.function.{Function => JavaFunction, Function2 => JavaFunction2, Function3 => JavaFunction3} | ||
| import org.apache.spark.sql.api.java._ | ||
| import org.apache.spark.sql.catalyst.ScalaReflection | ||
| import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} | ||
|
|
@@ -3385,6 +3386,320 @@ object functions { | |
| ArrayExcept(col1.expr, col2.expr) | ||
| } | ||
|
|
||
| private def createLambda(f: Column => Column) = { | ||
| val x = UnresolvedNamedLambdaVariable(Seq("x")) | ||
| val function = f(Column(x)).expr | ||
| LambdaFunction(function, Seq(x)) | ||
| } | ||
|
|
||
| private def createLambda(f: (Column, Column) => Column) = { | ||
| val x = UnresolvedNamedLambdaVariable(Seq("x")) | ||
| val y = UnresolvedNamedLambdaVariable(Seq("y")) | ||
| val function = f(Column(x), Column(y)).expr | ||
| LambdaFunction(function, Seq(x, y)) | ||
| } | ||
|
|
||
| private def createLambda(f: (Column, Column, Column) => Column) = { | ||
| val x = UnresolvedNamedLambdaVariable(Seq("x")) | ||
| val y = UnresolvedNamedLambdaVariable(Seq("y")) | ||
| val z = UnresolvedNamedLambdaVariable(Seq("z")) | ||
| val function = f(Column(x), Column(y), Column(z)).expr | ||
| LambdaFunction(function, Seq(x, y, z)) | ||
| } | ||
|
|
||
| private def createLambda(f: JavaFunction[Column, Column]) = { | ||
| val x = UnresolvedNamedLambdaVariable(Seq("x")) | ||
| val function = f.call(Column(x)).expr | ||
| LambdaFunction(function, Seq(x)) | ||
| } | ||
|
|
||
| private def createLambda(f: JavaFunction2[Column, Column, Column]) = { | ||
| val x = UnresolvedNamedLambdaVariable(Seq("x")) | ||
| val y = UnresolvedNamedLambdaVariable(Seq("y")) | ||
| val function = f.call(Column(x), Column(y)).expr | ||
| LambdaFunction(function, Seq(x, y)) | ||
| } | ||
|
|
||
| private def createLambda(f: JavaFunction3[Column, Column, Column, Column]) = { | ||
| val x = UnresolvedNamedLambdaVariable(Seq("x")) | ||
| val y = UnresolvedNamedLambdaVariable(Seq("y")) | ||
| val z = UnresolvedNamedLambdaVariable(Seq("z")) | ||
| val function = f.call(Column(x), Column(y), Column(z)).expr | ||
| LambdaFunction(function, Seq(x, y, z)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Returns an array of elements after applying a tranformation to each element | ||
| * in the input array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def transform(column: Column, f: Column => Column): Column = withExpr { | ||
| ArrayTransform(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Returns an array of elements after applying a tranformation to each element | ||
| * in the input array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def transform(column: Column, f: (Column, Column) => Column): Column = withExpr { | ||
| ArrayTransform(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Returns whether a predicate holds for one or more elements in the array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def exists(column: Column, f: Column => Column): Column = withExpr { | ||
| ArrayExists(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Returns whether a predicate holds for every element in the array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def forall(column: Column, f: Column => Column): Column = withExpr { | ||
| ArrayForAll(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Returns an array of elements for which a predicate holds in a given array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def filter(column: Column, f: Column => Column): Column = withExpr { | ||
| ArrayFilter(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, | ||
| * and reduces this to a single state. The final state is converted into the final result | ||
| * by applying a finish function. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column, | ||
| finish: Column => Column): Column = withExpr { | ||
|
||
| ArrayAggregate( | ||
| expr.expr, | ||
| zero.expr, | ||
| createLambda(merge), | ||
| createLambda(finish) | ||
| ) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Applies a binary operator to an initial state and all elements in the array, | ||
| * and reduces this to a single state. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column): Column = | ||
| aggregate(expr, zero, merge, c => c) | ||
|
|
||
| /** | ||
| * (Scala-specific) Merge two given arrays, element-wise, into a signle array using a function. | ||
| * If one array is shorter, nulls are appended at the end to match the length of the longer | ||
| * array, before applying the function. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = withExpr { | ||
| ZipWith(left.expr, right.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Applies a function to every key-value pair in a map and returns | ||
| * a map with the results of those applications as the new keys for the pairs. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def transform_keys(expr: Column, f: (Column, Column) => Column): Column = withExpr { | ||
| TransformKeys(expr.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Applies a function to every key-value pair in a map and returns | ||
| * a map with the results of those applications as the new values for the pairs. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def transform_values(expr: Column, f: (Column, Column) => Column): Column = withExpr { | ||
| TransformValues(expr.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Returns a map whose key-value pairs satisfy a predicate. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def map_filter(expr: Column, f: (Column, Column) => Column): Column = withExpr { | ||
| MapFilter(expr.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) Merge two given maps, key-wise into a single map using a function. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def map_zip_with(left: Column, right: Column, | ||
| f: (Column, Column, Column) => Column): Column = withExpr { | ||
|
||
| MapZipWith(left.expr, right.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Returns an array of elements after applying a tranformation to each element | ||
| * in the input array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def transform(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { | ||
| ArrayTransform(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Returns an array of elements after applying a tranformation to each element | ||
| * in the input array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def transform(column: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { | ||
| ArrayTransform(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Returns whether a predicate holds for one or more elements in the array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def exists(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { | ||
| ArrayExists(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Returns whether a predicate holds for every element in the array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def forall(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { | ||
| ArrayForAll(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Returns an array of elements for which a predicate holds in a given array. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def filter(column: Column, f: JavaFunction[Column, Column]): Column = withExpr { | ||
| ArrayFilter(column.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Applies a binary operator to an initial state and all elements in the array, | ||
| * and reduces this to a single state. The final state is converted into the final result | ||
| * by applying a finish function. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def aggregate(expr: Column, zero: Column, merge: JavaFunction2[Column, Column, Column], | ||
| finish: JavaFunction[Column, Column]): Column = withExpr { | ||
| ArrayAggregate( | ||
| expr.expr, | ||
| zero.expr, | ||
| createLambda(merge), | ||
| createLambda(finish) | ||
| ) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Applies a binary operator to an initial state and all elements in the array, | ||
| * and reduces this to a single state. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def aggregate(expr: Column, zero: Column, merge: JavaFunction2[Column, Column, Column]): Column = | ||
| aggregate( | ||
| expr, zero, merge, new JavaFunction[Column, Column] { def call(c: Column): Column = c }) | ||
|
|
||
| /** | ||
| * (Java-specific) Merge two given arrays, element-wise, into a signle array using a function. | ||
| * If one array is shorter, nulls are appended at the end to match the length of the longer | ||
| * array, before applying the function. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def zip_with(left: Column, right: Column, f: JavaFunction2[Column, Column, Column]): Column = | ||
| withExpr { | ||
| ZipWith(left.expr, right.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Applies a function to every key-value pair in a map and returns | ||
| * a map with the results of those applications as the new keys for the pairs. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def transform_keys(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { | ||
| TransformKeys(expr.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Applies a function to every key-value pair in a map and returns | ||
| * a map with the results of those applications as the new values for the pairs. | ||
| * | ||
| * @group collection_funcs | ||
nvander1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| * @since 3.0.0 | ||
| */ | ||
| def transform_values(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { | ||
| TransformValues(expr.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Returns a map whose key-value pairs satisfy a predicate. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def map_filter(expr: Column, f: JavaFunction2[Column, Column, Column]): Column = withExpr { | ||
| MapFilter(expr.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) Merge two given maps, key-wise into a single map using a function. | ||
| * | ||
| * @group collection_funcs | ||
| * @since 3.0.0 | ||
| */ | ||
| def map_zip_with(left: Column, right: Column, | ||
| f: JavaFunction3[Column, Column, Column, Column]): Column = withExpr { | ||
| MapZipWith(left.expr, right.expr, createLambda(f)) | ||
| } | ||
|
|
||
| /** | ||
| * Creates a new row for each element in the given array or map column. | ||
| * Uses the default column name `col` for elements in the array and | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But how do we support this in Java?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we change the signatures to accept
scala.runtime.AbstractFunctions instead to avoid using the Function traits?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add
(Scala-specific)at least for each doc. BTW, please take a look for style guide at https://github.com/databricks/scala-style-guideUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually a better idea would probably be to use java functional interfaces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And of course we would use the existing functional interfaces first from
java.util.function, but I don't think there are any that accept three parameters likes some of the functions here require.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears these interfaces already exist in the source tree: https://github.com/apache/spark/blob/v2.4.0/core/src/main/java/org/apache/spark/api/java/function/Function3.java
I'll come back later to add java-specific apis that utilizes these.