Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9cf1ebf
Adds higher order functions to scala API
nvander1 Mar 28, 2019
efc6ba4
Add (Scala-specifc) note to higher order functions
nvander1 Mar 28, 2019
b9dceec
Follow style guide more closely
nvander1 Mar 28, 2019
1fb46a3
Fix scalastyle issues
nvander1 Mar 28, 2019
03d602f
Add java-specific version of higher order function api
nvander1 Mar 28, 2019
6bf07d8
Do not prematurely bind lambda variables
nvander1 Jun 14, 2019
b03399a
Resolve conflict between Java Function and Scala Function
HyukjinKwon Jul 25, 2019
79d6f84
Adds higher order functions to scala API
nvander1 Mar 28, 2019
7adaf9c
Add (Scala-specifc) note to higher order functions
nvander1 Mar 28, 2019
ac5c1c2
Follow style guide more closely
nvander1 Mar 28, 2019
40ac418
Fix scalastyle issues
nvander1 Mar 28, 2019
fb5f8ef
Add java-specific version of higher order function api
nvander1 Mar 28, 2019
85979d4
Do not prematurely bind lambda variables
nvander1 Jun 14, 2019
5d389d2
Merge branch 'fix-24232' of git://github.com/HyukjinKwon/spark into H…
nvander1 Aug 2, 2019
5d77d6b
Merge branch 'master' into feature/add_higher_order_functions_to_scal…
nvander1 Aug 6, 2019
a8c7ecd
Add forall to org.apache.spark.sql.functions
nvander1 Aug 6, 2019
96fb0ad
Add "@since 3.0.0" to new functions
nvander1 Aug 10, 2019
5fa3e71
Add tests for Java transform function
nvander1 Aug 10, 2019
0bfa483
Add tests for Java map_filter function
nvander1 Aug 10, 2019
815e9f6
Add tests for Java filter function
nvander1 Aug 10, 2019
47b100b
Add tests for Java exists function
nvander1 Aug 10, 2019
4baf084
Add test for Java API forall
nvander1 Aug 19, 2019
9c0f70e
Merge branch 'master' into feature/add_higher_order_functions_to_scal…
nvander1 Aug 19, 2019
06b4c82
Add test for Java API: aggregate
nvander1 Aug 19, 2019
412ece5
Add test for Java API: map_zip_with
nvander1 Aug 19, 2019
c49e7d3
Add java tests for transform_keys, transform_values
nvander1 Aug 21, 2019
182a08b
Add tests for java zip_with function
nvander1 Aug 21, 2019
ef6b6bb
Remove JavaFunction overloads and add Java transform test
nvander1 Aug 21, 2019
a543c90
Merge branch 'tmp' into feature/add_higher_order_functions_to_scala_api
nvander1 Aug 21, 2019
527c0cb
Remove (Scala-specifc) from higher order functions
nvander1 Aug 21, 2019
013187f
Remove java tests from DataFrameFunctionsSuite
nvander1 Sep 17, 2019
554a992
Add simple java test for filter
nvander1 Sep 18, 2019
0433756
Add simple java test for exists
nvander1 Sep 18, 2019
f371413
Add simple java test for forall
nvander1 Sep 18, 2019
c3e320c
Add java test for aggregate
nvander1 Sep 19, 2019
84ccf55
Add java aggregate test with finish
nvander1 Sep 19, 2019
e43033b
Add java test for zip_with
nvander1 Sep 19, 2019
c1c76a9
Add java test for transformKeys
nvander1 Sep 20, 2019
10a5f2e
Add java test for transform_values
nvander1 Sep 20, 2019
722f0e6
Add java test for map_filter and map_zip_with
nvander1 Sep 20, 2019
1bf2654
Fix style nits
nvander1 Oct 2, 2019
64c0f87
Fix linter errors in imports
nvander1 Oct 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,139 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods

/**
* Helper methods for constructing higher order functions.
*/
object HigherOrderUtils {
def createLambda(
dt: DataType,
nullable: Boolean,
f: Expression => Expression): Expression = {
val lv = NamedLambdaVariable("arg", dt, nullable)
val function = f(lv)
LambdaFunction(function, Seq(lv))
}

def createLambda(
dt1: DataType,
nullable1: Boolean,
dt2: DataType,
nullable2: Boolean,
f: (Expression, Expression) => Expression): Expression = {
val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
val function = f(lv1, lv2)
LambdaFunction(function, Seq(lv1, lv2))
}

def createLambda(
dt1: DataType,
nullable1: Boolean,
dt2: DataType,
nullable2: Boolean,
dt3: DataType,
nullable3: Boolean,
f: (Expression, Expression, Expression) => Expression): Expression = {
val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
val lv3 = NamedLambdaVariable("arg3", dt3, nullable3)
val function = f(lv1, lv2, lv3)
LambdaFunction(function, Seq(lv1, lv2, lv3))
}

def validateBinding(
e: Expression,
argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
case f: LambdaFunction =>
assert(f.arguments.size == argInfo.size)
f.arguments.zip(argInfo).foreach {
case (arg, (dataType, nullable)) =>
assert(arg.dataType == dataType)
assert(arg.nullable == nullable)
}
f
}

// Array-based helpers
def filter(expr: Expression, f: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding)
}

def exists(expr: Expression, f: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding)
}

def transform(expr: Expression, f: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding)
}

def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding)
}

def aggregate(
expr: Expression,
zero: Expression,
merge: (Expression, Expression) => Expression,
finish: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
val zeroType = zero.dataType
ArrayAggregate(
expr,
zero,
createLambda(zeroType, true, et, cn, merge),
createLambda(zeroType, true, finish))
.bind(validateBinding)
}

def aggregate(
expr: Expression,
zero: Expression,
merge: (Expression, Expression) => Expression): Expression = {
aggregate(expr, zero, merge, identity)
}

def zip_with(
left: Expression,
right: Expression,
f: (Expression, Expression) => Expression): Expression = {
val ArrayType(leftT, _) = left.dataType
val ArrayType(rightT, _) = right.dataType
ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding)
}

// Map-based helpers

def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
}

def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
}

def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
}

def map_zip_with(
left: Expression,
right: Expression,
f: (Expression, Expression, Expression) => Expression): Expression = {
val MapType(kt, vt1, _) = left.dataType
val MapType(_, vt2, _) = right.dataType
MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f))
.bind(validateBinding)
}
}

/**
* A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,102 +24,7 @@ import org.apache.spark.sql.types._

class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
import org.apache.spark.sql.catalyst.dsl.expressions._

private def createLambda(
dt: DataType,
nullable: Boolean,
f: Expression => Expression): Expression = {
val lv = NamedLambdaVariable("arg", dt, nullable)
val function = f(lv)
LambdaFunction(function, Seq(lv))
}

private def createLambda(
dt1: DataType,
nullable1: Boolean,
dt2: DataType,
nullable2: Boolean,
f: (Expression, Expression) => Expression): Expression = {
val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
val function = f(lv1, lv2)
LambdaFunction(function, Seq(lv1, lv2))
}

private def createLambda(
dt1: DataType,
nullable1: Boolean,
dt2: DataType,
nullable2: Boolean,
dt3: DataType,
nullable3: Boolean,
f: (Expression, Expression, Expression) => Expression): Expression = {
val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
val lv3 = NamedLambdaVariable("arg3", dt3, nullable3)
val function = f(lv1, lv2, lv3)
LambdaFunction(function, Seq(lv1, lv2, lv3))
}

private def validateBinding(
e: Expression,
argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
case f: LambdaFunction =>
assert(f.arguments.size === argInfo.size)
f.arguments.zip(argInfo).foreach {
case (arg, (dataType, nullable)) =>
assert(arg.dataType === dataType)
assert(arg.nullable === nullable)
}
f
}

def transform(expr: Expression, f: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding)
}

def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding)
}

def filter(expr: Expression, f: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding)
}

def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
}

def aggregate(
expr: Expression,
zero: Expression,
merge: (Expression, Expression) => Expression,
finish: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
val zeroType = zero.dataType
ArrayAggregate(
expr,
zero,
createLambda(zeroType, true, et, cn, merge),
createLambda(zeroType, true, finish))
.bind(validateBinding)
}

def aggregate(
expr: Expression,
zero: Expression,
merge: (Expression, Expression) => Expression): Expression = {
aggregate(expr, zero, merge, identity)
}

def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
}
import org.apache.spark.sql.catalyst.expressions.HigherOrderUtils._

test("ArrayTransform") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
Expand Down Expand Up @@ -163,10 +68,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
}

test("MapFilter") {
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
}
val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null),
Expand Down Expand Up @@ -244,11 +145,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
}

test("ArrayExists") {
def exists(expr: Expression, f: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding)
}

val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false))
Expand Down Expand Up @@ -457,16 +353,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
}

test("MapZipWith") {
def map_zip_with(
left: Expression,
right: Expression,
f: (Expression, Expression, Expression) => Expression): Expression = {
val MapType(kt, vt1, _) = left.dataType
val MapType(_, vt2, _) = right.dataType
MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f))
.bind(validateBinding)
}

val mii0 = Literal.create(create_map(1 -> 10, 2 -> 20, 3 -> 30),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii1 = Literal.create(create_map(1 -> -1, 2 -> -2, 4 -> -4),
Expand Down Expand Up @@ -549,15 +435,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
}

test("ZipWith") {
def zip_with(
left: Expression,
right: Expression,
f: (Expression, Expression) => Expression): Expression = {
val ArrayType(leftT, _) = left.dataType
val ArrayType(rightT, _) = right.dataType
ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding)
}

val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false))
val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down
Loading