-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23920][SQL]add array_remove to remove all elements that equal element from array #21069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
f92e18c
f6a629b
1c24720
89b4f48
9281ae2
074ed88
52d2308
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1882,3 +1882,98 @@ case class ArrayRepeat(left: Expression, right: Expression) | |
| } | ||
|
|
||
| } | ||
|
|
||
| /** | ||
| * Remove all elements that equal to element from the given array | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); | ||
| [1,2,null] | ||
| """, since = "2.4.0") | ||
| case class ArrayRemove(left: Expression, right: Expression) | ||
| extends BinaryExpression with ImplicitCastInputTypes { | ||
|
|
||
| override def dataType: DataType = left.dataType | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = | ||
| Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) | ||
|
|
||
| lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType | ||
|
|
||
| override def nullSafeEval(arr: Any, value: Any): Any = { | ||
| val elementType = left.dataType.asInstanceOf[ArrayType].elementType | ||
| val data = arr.asInstanceOf[ArrayData].toArray[AnyRef](elementType).filter(_ != value) | ||
|
||
| new GenericArrayData(data.asInstanceOf[Array[Any]]) | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, (arr, value) => { | ||
| val numsToRemove = ctx.freshName("numsToRemove") | ||
| val newArraySize = ctx.freshName("newArraySize") | ||
| val i = ctx.freshName("i") | ||
| val getValue = CodeGenerator.getValue(arr, elementType, i) | ||
| val isEqual = ctx.genEqual(elementType, value, getValue) | ||
| s""" | ||
| |int $numsToRemove = 0; | ||
| |for (int $i = 0; $i < $arr.numElements(); $i ++) { | ||
| | if (!$arr.isNullAt($i) && $isEqual) { | ||
| | $numsToRemove = $numsToRemove + 1; | ||
| | } | ||
| |} | ||
| |int $newArraySize = $arr.numElements() - $numsToRemove; | ||
| |${genCodeForResult(ctx, ev, arr, value, newArraySize)} | ||
| """.stripMargin | ||
| }) | ||
| } | ||
|
|
||
| def genCodeForResult( | ||
| ctx: CodegenContext, | ||
| ev: ExprCode, | ||
| inputArray: String, | ||
| value: String, | ||
| newArraySize: String): String = { | ||
| val values = ctx.freshName("values") | ||
| val i = ctx.freshName("i") | ||
| val pos = ctx.freshName("pos") | ||
| val getValue = CodeGenerator.getValue(inputArray, elementType, i) | ||
| val isEqual = ctx.genEqual(elementType, value, getValue) | ||
| if (!CodeGenerator.isPrimitiveType(elementType)) { | ||
| val arrayClass = classOf[GenericArrayData].getName | ||
| s""" | ||
| |int $pos = 0; | ||
| |Object[] $values = new Object[$newArraySize]; | ||
| |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { | ||
| | if (!($isEqual)) { | ||
|
||
| | $values[$pos] = $getValue; | ||
| | $pos = $pos + 1; | ||
| | } | ||
| |} | ||
| |${ev.value} = new $arrayClass($values); | ||
| """.stripMargin | ||
| } else { | ||
| val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) | ||
| s""" | ||
| |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} | ||
| |int $pos = 0; | ||
| |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { | ||
| | if ($inputArray.isNullAt($i)) { | ||
| | $values.setNullAt($pos); | ||
| | $pos = $pos + 1; | ||
| | } | ||
| | else { | ||
| | if (!($isEqual)) { | ||
| | $values.set$primitiveValueTypeName($pos, $getValue); | ||
| | $pos = $pos + 1; | ||
| | } | ||
| | } | ||
| |} | ||
| |${ev.value} = $values; | ||
| """.stripMargin | ||
| } | ||
| } | ||
|
|
||
| override def prettyName: String = "array_remove" | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -552,4 +552,36 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) | ||
| checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) | ||
| } | ||
|
|
||
| test("Array remove") { | ||
| val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) | ||
| val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) | ||
| val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType)) | ||
| val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) | ||
| val a4 = Literal.create(null, ArrayType(StringType)) | ||
| val a5 = Literal.create(Seq(1, null, 8, 9, null), ArrayType(IntegerType)) | ||
| val a6 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) | ||
|
|
||
| checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5)) | ||
| checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5)) | ||
| checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5)) | ||
| checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5)) | ||
| checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2)) | ||
|
||
| checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null) | ||
|
|
||
| checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b")) | ||
| checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b")) | ||
| checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c")) | ||
| checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) | ||
|
|
||
| checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) | ||
| checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null) | ||
|
|
||
| checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer]) | ||
|
|
||
| checkEvaluation(ArrayRemove(a4, Literal("a")), null) | ||
|
|
||
| checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null)) | ||
| checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true)) | ||
|
||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will cause
ClassCastException. See #21401.Also could you add tests similar to tests added in #21401?