-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23922][SQL] Add arrays_overlap function #21028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
e5ebdad
682bc73
876cd93
88e09b3
c895707
65b7d6d
f9a1ecf
1dbcd0c
076fc69
eafca0f
5925104
2a1121c
bf81e4a
4a18ba8
566946a
710433e
3cf410a
9d086f9
964f7af
41ef6c6
3dd724b
f7089f5
e36a5d7
49d9372
227437b
2e9e024
92730a1
56c59ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -288,6 +288,114 @@ case class ArrayContains(left: Expression, right: Expression) | |
| override def prettyName: String = "array_contains" | ||
| } | ||
|
|
||
| /** | ||
| * Checks if the two arrays contain at least one common element. | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least an element present also in a2.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); | ||
| true | ||
| """, since = "2.4.0") | ||
| case class ArraysOverlap(left: Expression, right: Expression) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't you override
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, thanks! |
||
| extends BinaryExpression with ImplicitCastInputTypes { | ||
|
|
||
| private lazy val elementType = inputTypes.head.asInstanceOf[ArrayType].elementType | ||
|
|
||
| override def dataType: DataType = BooleanType | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = left.dataType match { | ||
|
||
| case la: ArrayType if la.sameType(right.dataType) => | ||
| Seq(la, la) | ||
| case _ => Seq.empty | ||
| } | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| if (!left.dataType.isInstanceOf[ArrayType] || !right.dataType.isInstanceOf[ArrayType] || | ||
| !left.dataType.sameType(right.dataType)) { | ||
| TypeCheckResult.TypeCheckFailure("Arguments must be arrays with the same element type.") | ||
| } else { | ||
| TypeCheckResult.TypeCheckSuccess | ||
| } | ||
|
||
| } | ||
|
|
||
| override def nullable: Boolean = { | ||
| left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || | ||
| right.dataType.asInstanceOf[ArrayType].containsNull | ||
| } | ||
|
|
||
| override def nullSafeEval(a1: Any, a2: Any): Any = { | ||
| var hasNull = false | ||
| val arr1 = a1.asInstanceOf[ArrayData] | ||
| val arr2 = a2.asInstanceOf[ArrayData] | ||
| if (arr1.numElements() > 0) { | ||
|
||
| arr1.foreach(elementType, (_, v1) => | ||
| if (v1 == null) { | ||
| hasNull = true | ||
| } else { | ||
| arr2.foreach(elementType, (_, v2) => | ||
| if (v2 == null) { | ||
| hasNull = true | ||
| } else if (v1 == v2) { | ||
| return true | ||
| } | ||
| ) | ||
| } | ||
| ) | ||
| } else { | ||
|
||
| arr2.foreach(elementType, (_, v) => | ||
| if (v == null) { | ||
| return null | ||
| } | ||
| ) | ||
| } | ||
| if (hasNull) { | ||
| null | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, (a1, a2) => { | ||
| val i1 = ctx.freshName("i") | ||
| val i2 = ctx.freshName("i") | ||
| val getValue1 = CodeGenerator.getValue(a1, elementType, i1) | ||
| val getValue2 = CodeGenerator.getValue(a2, elementType, i2) | ||
| s""" | ||
| |if ($a1.numElements() > 0) { | ||
| | for (int $i1 = 0; $i1 < $a1.numElements(); $i1 ++) { | ||
| | if ($a1.isNullAt($i1)) { | ||
| | ${ev.isNull} = true; | ||
| | } else { | ||
| | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { | ||
| | if ($a2.isNullAt($i2)) { | ||
| | ${ev.isNull} = true; | ||
| | } else if (${ctx.genEqual(elementType, getValue1, getValue2)}) { | ||
| | ${ev.isNull} = false; | ||
| | ${ev.value} = true; | ||
| | break; | ||
| | } | ||
| | } | ||
| | if (${ev.value}) { | ||
| | break; | ||
| | } | ||
| | } | ||
| | } | ||
| |} else { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
| | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { | ||
| | if ($a2.isNullAt($i2)) { | ||
| | ${ev.isNull} = true; | ||
| | break; | ||
| | } | ||
| | } | ||
| |} | ||
| |""".stripMargin | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns the minimum value in the array. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,6 +106,30 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) | ||
| } | ||
|
|
||
| test("ArraysOverlap") { | ||
| val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) | ||
| val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) | ||
| val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) | ||
| val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) | ||
| val a4 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) | ||
|
|
||
| val a5 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) | ||
| val a6 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) | ||
| val a7 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) | ||
|
|
||
| checkEvaluation(ArraysOverlap(a0, a1), true) | ||
| checkEvaluation(ArraysOverlap(a0, a2), null) | ||
| checkEvaluation(ArraysOverlap(a1, a2), true) | ||
| checkEvaluation(ArraysOverlap(a1, a3), false) | ||
| checkEvaluation(ArraysOverlap(a0, a4), false) | ||
| checkEvaluation(ArraysOverlap(a2, a4), null) | ||
| checkEvaluation(ArraysOverlap(a4, a2), null) | ||
|
|
||
| checkEvaluation(ArraysOverlap(a5, a6), true) | ||
| checkEvaluation(ArraysOverlap(a5, a7), null) | ||
| checkEvaluation(ArraysOverlap(a6, a7), false) | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add cases for one of the two arguments is |
||
|
|
||
| test("Array Min") { | ||
| checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) | ||
| checkEvaluation( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a note for null handling?