Skip to content
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e5ebdad
[SPARK-23922][SQL] Add arrays_overlap function
mgaido91 Apr 10, 2018
682bc73
fix python style
mgaido91 Apr 10, 2018
876cd93
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 Apr 17, 2018
88e09b3
review comments
mgaido91 Apr 20, 2018
c895707
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 Apr 20, 2018
65b7d6d
introduce BinaryArrayExpressionWithImplicitCast
mgaido91 Apr 27, 2018
f9a1ecf
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 Apr 27, 2018
1dbcd0c
fix type check
mgaido91 Apr 27, 2018
076fc69
fix scalastyle
mgaido91 Apr 27, 2018
eafca0f
fix build error
mgaido91 Apr 27, 2018
5925104
fix
mgaido91 Apr 27, 2018
2a1121c
address comments
mgaido91 May 3, 2018
bf81e4a
use sets instead of nested loops
mgaido91 May 3, 2018
4a18ba8
address review comments
mgaido91 May 4, 2018
566946a
address review comments
mgaido91 May 4, 2018
710433e
add test case for null
mgaido91 May 4, 2018
3cf410a
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 May 7, 2018
9d086f9
address comments
mgaido91 May 7, 2018
964f7af
use findTightestCommonType for type inference
mgaido91 May 8, 2018
41ef6c6
support binary and complex data types
mgaido91 May 9, 2018
3dd724b
review comments
mgaido91 May 11, 2018
f7089f5
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 May 11, 2018
e36a5d7
fix compilation error
mgaido91 May 11, 2018
49d9372
address comments
mgaido91 May 14, 2018
227437b
address comment
mgaido91 May 15, 2018
2e9e024
fix null handling with complex types
mgaido91 May 16, 2018
92730a1
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 May 17, 2018
56c59ae
fix build
mgaido91 May 17, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,20 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(2.4)
def arrays_overlap(a1, a2):
"""
Collection function: returns true if the arrays contain any common non-null element; if not,
returns null if any of the arrays contains a null element and false otherwise.

>>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y'])
>>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
[Row(overlap=True), Row(overlap=False)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))


@since(2.4)
def slice(x, start, length):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArraysOverlap]("arrays_overlap"),
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,45 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.Comparator

import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
* casting.
*/
trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ImplicitCastInputTypes trait is able to work with any number of children. Would it be possible to implement this trait to behave in the same way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible indeed. Though, as far as I know there is no use case for a function with a different number of children, so I am not sure if it makes sense to generalize it. @cloud-fan @kiszk @ueshin WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @ueshin pointed out here, concat is also a use case that has a different number of children. Am I wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiszk you are not wrong, but Concat is a very specific case, since it supports also Strings and Binarys, so it would anyway require a specific implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I would like to hear other opinions

with ImplicitCastInputTypes {

protected lazy val elementType: DataType = inputTypes.head.asInstanceOf[ArrayType].elementType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be a def

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it better a def than a lazy val?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lazy val will be serialized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your explanation!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(or @transient lazy val too optionally)


override def inputTypes: Seq[AbstractDataType] = {
TypeCoercion.findWiderTypeForTwo(left.dataType, right.dataType) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does presto allow implicitly casting to string for these collection functions? e.g. can ArraysOverlap work for array of int and array of string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does not for int and string, but it does for decimal and int. What shall we do?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we probably need to call TypeCoercion.findTightestCommonType here, and fix findTightestCommonType for array/struct/map types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about findWiderTypeWithoutStringPromotionForTwo?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now the question is, shall we allow precision lose for array functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. Checking the way we are doing it I would say no. Since we are bounding in a quite strange way at the moment (causing loss of int digits instead of decimals) I would say no, since this could lead to have many NULLs. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then findTightestCommonType is a better choice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I'll go for it. Thanks.

case Some(arrayType) => Seq(arrayType, arrayType)
case _ => Seq.empty
}
}

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) =>
TypeCheckResult.TypeCheckSuccess
case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " +
s"been two ${ArrayType.simpleString}s with same element type, but it's " +
s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]")
}
}
}


/**
* Given an array or map, returns its size. Returns -1 if null.
*/
Expand Down Expand Up @@ -529,6 +559,157 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}

/**
* Checks if the two arrays contain at least one common element.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least an element present also in a2. If the arrays have no common element and either of them contains a null element null is returned, false otherwise.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5));
true
""", since = "2.4.0")
// scalastyle:off line.size.limit
case class ArraysOverlap(left: Expression, right: Expression)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you override prettyName to a value following the conventions?

override def prettyName: String = "arrays_overlap"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks!

extends BinaryArrayExpressionWithImplicitCast {

override def dataType: DataType = BooleanType

override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
right.dataType.asInstanceOf[ArrayType].containsNull
}

override def nullSafeEval(a1: Any, a2: Any): Any = {
var hasNull = false
val arr1 = a1.asInstanceOf[ArrayData]
val arr2 = a2.asInstanceOf[ArrayData]
val (bigger, smaller, biggerDt) = if (arr1.numElements() > arr2.numElements()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the biggerDt is not used

(arr1, arr2, left.dataType.asInstanceOf[ArrayType])
} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can skip if the right is containsNull == false?

(arr2, arr1, right.dataType.asInstanceOf[ArrayType])
}
if (smaller.numElements() > 0) {
val smallestSet = new mutable.HashSet[Any]
smaller.foreach(elementType, (_, v) =>
if (v == null) {
hasNull = true
} else {
smallestSet += v
})
bigger.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else if (smallestSet.contains(v1)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't work with BinaryType(the data is byte[]). We may need to wrap values with ByteBuffer first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it was not working also with ArrayType, so I addressed the problem in a more general way which supports both these cases. Thanks.

return true
}
)
} else if (containsNull(bigger, biggerDt)) {
hasNull = true
}
if (hasNull) {
null
} else {
false
}
}

def containsNull(arr: ArrayData, dt: ArrayType): Boolean = {
if (dt.containsNull) {
var i = 0
var hasNull = false
while (i < arr.numElements && !hasNull) {
hasNull = arr.isNullAt(i)
i += 1
}
hasNull
} else {
false
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (a1, a2) => {
val i = ctx.freshName("i")
val smaller = ctx.freshName("smallerArray")
val bigger = ctx.freshName("biggerArray")
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i)
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
val smallerEmptyCode = if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) {
s"""
|else {
| for (int $i = 0; $i < $bigger.numElements(); $i ++) {
| if ($bigger.isNullAt($i)) {
| ${ev.isNull} = true;
| break;
| }
| }
|}
""".stripMargin
} else {
""
}
val javaElementClass = CodeGenerator.boxedType(elementType)
val javaSet = classOf[java.util.HashSet[_]].getName
val set2 = ctx.freshName("set")
val addToSetFromSmallerCode = nullSafeElementCodegen(
smaller, i, s"$set2.add($getFromSmaller);", s"${ev.isNull} = true;")
val elementIsInSetCode = nullSafeElementCodegen(
bigger,
i,
s"""
|if ($set2.contains($getFromBigger)) {
| ${ev.isNull} = false;
| ${ev.value} = true;
| break;
|}
|""".stripMargin,
s"${ev.isNull} = true;")
s"""
|ArrayData $smaller;
|ArrayData $bigger;
|if ($a1.numElements() > $a2.numElements()) {
| $bigger = $a1;
| $smaller = $a2;
|} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

| $smaller = $a1;
| $bigger = $a2;
|}
|if ($smaller.numElements() > 0) {
| $javaSet<$javaElementClass> $set2 = new $javaSet<$javaElementClass>();
| for (int $i = 0; $i < $smaller.numElements(); $i ++) {
| $addToSetFromSmallerCode
| }
| for (int $i = 0; $i < $bigger.numElements(); $i ++) {
| $elementIsInSetCode
| }
|} $smallerEmptyCode
|""".stripMargin
})
}

def nullSafeElementCodegen(
arrayVar: String,
index: String,
code: String,
isNullCode: String): String = {
if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this depend on whether the input array arrayVar contains null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately we don't know which one we have here (the left or the rigth) as arrayVar, since we don't know which one is the smaller/bigger and this can change record to record. So we can skip the null check only if both them don't contain null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, makes sense!

s"""
|if ($arrayVar.isNullAt($index)) {
| $isNullCode
|} else {
| $code
|}
|""".stripMargin
} else {
code
}
}

override def prettyName: String = "arrays_overlap"
}

/**
* Slices an array according to the requested start index and length
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,37 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("ArraysOverlap") {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType))
val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType))
val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType))
val a4 = Literal.create(Seq.empty[Int], ArrayType(IntegerType))

val a5 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a6 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType))
val a7 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType))

checkEvaluation(ArraysOverlap(a0, a1), true)
checkEvaluation(ArraysOverlap(a0, a2), null)
checkEvaluation(ArraysOverlap(a1, a2), true)
checkEvaluation(ArraysOverlap(a1, a3), false)
checkEvaluation(ArraysOverlap(a0, a4), false)
checkEvaluation(ArraysOverlap(a2, a4), null)
checkEvaluation(ArraysOverlap(a4, a2), null)

checkEvaluation(ArraysOverlap(a5, a6), true)
checkEvaluation(ArraysOverlap(a5, a7), null)
checkEvaluation(ArraysOverlap(a6, a7), false)

// null handling
checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null)
checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null)
checkEvaluation(ArraysOverlap(
Literal.create(Seq(null), ArrayType(IntegerType)),
Literal.create(Seq(null), ArrayType(IntegerType))), null)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if arrays_overlap(array(), array(null))?
Seems like Presto returns false for the case. TestArrayOperators.java#L1041
Also can you add the test case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am returning null for it. This is interesting. I checked Presto's implementation and it returns false if any of the input arrays is empty. I am copying Presto's behavior but this is quite against what the docs say:

Returns null if there are no non-null elements in common but either array contains null.

I will add a sentence to clarify the behavior in our docs. Thanks for this nice catch!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a test case for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is covered by https://github.com/apache/spark/pull/21028/files#diff-d31eca9f1c4c33104dc2cb8950486910R163 for instance. Anyway, I am adding another on which is exactly this one.

}

test("Slice") {
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3039,6 +3039,16 @@ object functions {
ArrayContains(column.expr, Literal(value))
}

/**
* Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and
* any of the arrays contains a `null`, it returns `null`. It returns `false` otherwise.
* @group collection_funcs
* @since 2.4.0
*/
def arrays_overlap(a1: Column, a2: Column): Column = withExpr {
ArraysOverlap(a1.expr, a2.expr)
}

/**
* Returns an array containing all the elements in `x` from index `start` (or starting from the
* end if `start` is negative) with the specified `length`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("arrays_overlap function") {
val df = Seq(
(Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))),
(Seq.empty[Option[Int]], Seq[Option[Int]](Some(-1), None)),
(Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2)))
).toDF("a", "b")

val answer = Seq(Row(false), Row(null), Row(true))

checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer)
checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer)

checkAnswer(sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))"), Row(false))

Copy link
Member

@kiszk kiszk May 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test like this for null inputs?

val df = Seq((null, null)).toDF("a", "b")
val ans = ...
checkAnswer(df.select(array_overlap($"a", $"b")), ans)
checkAnswer(df.selectExpr("array_overlap(a, b)"), ans)

Do we expect the result is null or an exception is thrown?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@kiszk kiszk May 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, I think no. It is good to have test cases with null in primitive.
On the other hand, my comment is talking about null handling in Dataframe API.
Other operations aslo perform tests with null in DataFrame API and primitive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, now I see what you mean. I can add it, but it seems useless to me. This function accepts only Arrays so any other type (NullType included) throws an AnalysisException. In array_contains is different since we the second argument can be anything and so makes sense to check the behavior of NullType which is handled differently from the others. Do you agree?

Copy link
Member

@kiszk kiszk May 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that to add this makes sense to explicitly ensure so any other type (NullType included) throws an AnalysisException. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I don't see its utility but I see also no harm in introducing it, so if you think it is a added value, I think it is fine to add it. So I just added it, thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you.

intercept[AnalysisException] {
sql("select arrays_overlap(array(array(1)), array('a'))")
}

intercept[AnalysisException] {
sql("select arrays_overlap(null, null)")
}
}

test("slice function") {
val df = Seq(
Seq(1, 2, 3),
Expand Down