Skip to content

Commit 763a1f8

Browse files
committed
address review comments
1 parent f5ebbe8 commit 763a1f8

File tree

3 files changed

+26
-30
lines changed

3 files changed

+26
-30
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elem
477477
return fromPrimitiveArray(null, offset, length, elementSize);
478478
}
479479

480-
public static boolean useGenericArrayData(int elementSize, int length) {
480+
public static boolean canUseGenericArrayData(int elementSize, int length) {
481481
final long headerInBytes = calculateHeaderPortionInBytes(length);
482482
final long valueRegionInBytes = (long)elementSize * length;
483483
final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3262,23 +3262,13 @@ case class ArrayDistinct(child: Expression)
32623262
override def prettyName: String = "array_distinct"
32633263
}
32643264

3265-
object ArraySetLike {
3266-
def throwUnionLengthOverflowException(length: Int): Unit = {
3267-
throw new RuntimeException(s"Unsuccessful try to union arrays with $length " +
3268-
s"elements due to exceeding the array size limit " +
3269-
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
3270-
}
3271-
}
3272-
3273-
3265+
/**
3266+
* Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept.
3267+
*/
32743268
abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
32753269
override def dataType: DataType = {
3276-
val dataTypes = children.map(_.dataType)
3277-
dataTypes.headOption.map {
3278-
case ArrayType(et, _) =>
3279-
ArrayType(et, dataTypes.exists(_.asInstanceOf[ArrayType].containsNull))
3280-
case dt => dt
3281-
}.getOrElse(StringType)
3270+
val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType])
3271+
ArrayType(elementType, dataTypes.exists(_.containsNull))
32823272
}
32833273

32843274
override def checkInputDataTypes(): TypeCheckResult = {
@@ -3301,6 +3291,15 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
33013291
}
33023292
}
33033293

3294+
object ArraySetLike {
3295+
def throwUnionLengthOverflowException(length: Int): Unit = {
3296+
throw new RuntimeException(s"Unsuccessful try to union arrays with $length " +
3297+
s"elements due to exceeding the array size limit " +
3298+
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
3299+
}
3300+
}
3301+
3302+
33043303
/**
33053304
* Returns an array of the elements in the union of x and y, without duplicates
33063305
*/
@@ -3353,7 +3352,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33533352
// store elements into resultArray
33543353
var nullElementSize = 0
33553354
var pos = 0
3356-
Seq(array1, array2).foreach(array => {
3355+
Seq(array1, array2).foreach { array =>
33573356
var i = 0
33583357
while (i < array.numElements()) {
33593358
val size = if (!isLongType) hsInt.size else hsLong.size
@@ -3380,7 +3379,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33803379
}
33813380
i += 1
33823381
}
3383-
})
3382+
}
33843383
pos
33853384
}
33863385

@@ -3396,7 +3395,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
33963395
hsInt = new OpenHashSet[Int]
33973396
val elements = evalIntLongPrimitiveType(array1, array2, null, false)
33983397
hsInt = new OpenHashSet[Int]
3399-
val resultArray = if (UnsafeArrayData.useGenericArrayData(
3398+
val resultArray = if (UnsafeArrayData.canUseGenericArrayData(
34003399
IntegerType.defaultSize, elements)) {
34013400
new GenericArrayData(new Array[Any](elements))
34023401
} else {
@@ -3411,7 +3410,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
34113410
hsLong = new OpenHashSet[Long]
34123411
val elements = evalIntLongPrimitiveType(array1, array2, null, true)
34133412
hsLong = new OpenHashSet[Long]
3414-
val resultArray = if (UnsafeArrayData.useGenericArrayData(
3413+
val resultArray = if (UnsafeArrayData.canUseGenericArrayData(
34153414
LongType.defaultSize, elements)) {
34163415
new GenericArrayData(new Array[Any](elements))
34173416
} else {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
11711171
val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
11721172
val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false))
11731173
val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true))
1174-
val a03 = Literal.create(Seq(-5, 4, -3, 2, -1), ArrayType(IntegerType, containsNull = false))
1174+
val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false))
11751175
val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false))
11761176
val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false))
11771177
val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false))
@@ -1191,17 +1191,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
11911191
val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType))
11921192
val a31 = Literal.create(null, ArrayType(StringType))
11931193

1194-
checkEvaluation(ArrayUnion(a00, a01), UnsafeArrayData.fromPrimitiveArray(Array(1, 2, 3, 4)))
1195-
checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3, -1))
1196-
checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, -1, 1, null, 5))
1194+
checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4))
1195+
checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3))
1196+
checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5))
11971197
checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5))
1198-
checkEvaluation(
1199-
ArrayUnion(a05, a06), UnsafeArrayData.fromPrimitiveArray(Array[Byte](1, 2, 3, 4)))
1200-
checkEvaluation(
1201-
ArrayUnion(a07, a08), UnsafeArrayData.fromPrimitiveArray(Array[Short](1, 2, 3, 4)))
1198+
checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4))
1199+
checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4))
12021200

1203-
checkEvaluation(
1204-
ArrayUnion(a10, a11), UnsafeArrayData.fromPrimitiveArray(Array(1L, 2L, 3L, 4L)))
1201+
checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L))
12051202
checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L))
12061203
checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L))
12071204
checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L))

0 commit comments

Comments
 (0)