Skip to content

Commit 1479bde

Browse files
author
Patrick Woody
committed
pr feedback
1 parent 426374b commit 1479bde

File tree

2 files changed

+61
-57
lines changed

2 files changed

+61
-57
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.columnar
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering}
22-
import org.apache.spark.sql.catalyst.util.{ArrayData, TypeUtils}
22+
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
23+
import org.apache.spark.sql.catalyst.util.TypeUtils
2324
import org.apache.spark.sql.types._
2425
import org.apache.spark.unsafe.types.UTF8String
2526

@@ -357,20 +358,21 @@ private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends C
357358
}
358359

359360
private[columnar] final class ArrayColumnStats(dataType: DataType)
360-
extends OrderableSafeColumnStats[ArrayData](dataType) {
361-
override def getValue(row: InternalRow, ordinal: Int): ArrayData = row.getArray(ordinal)
361+
extends OrderableSafeColumnStats[UnsafeArrayData](dataType) {
362+
override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData =
363+
row.getArray(ordinal).asInstanceOf[UnsafeArrayData]
362364

363-
override def copy(value: ArrayData): ArrayData = value.copy()
365+
override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy()
364366
}
365367

366368
private[columnar] final class StructColumnStats(dataType: DataType)
367-
extends OrderableSafeColumnStats[InternalRow](dataType) {
369+
extends OrderableSafeColumnStats[UnsafeRow](dataType) {
368370
private val numFields = dataType.asInstanceOf[StructType].fields.length
369371

370-
override def getValue(row: InternalRow, ordinal: Int): InternalRow =
371-
row.getStruct(ordinal, numFields)
372+
override def getValue(row: InternalRow, ordinal: Int): UnsafeRow =
373+
row.getStruct(ordinal, numFields).asInstanceOf[UnsafeRow]
372374

373-
override def copy(value: InternalRow): InternalRow = value.copy()
375+
override def copy(value: UnsafeRow): UnsafeRow = value.copy()
374376
}
375377

376378
private[columnar] final class MapColumnStats(dataType: DataType) extends ColumnStats {

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.columnar
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.sql.catalyst.expressions.RowOrdering
21+
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
2222
import org.apache.spark.sql.catalyst.util.TypeUtils
2323
import org.apache.spark.sql.types._
2424

@@ -35,9 +35,30 @@ class ColumnStatsSuite extends SparkFunSuite {
3535
)
3636
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0, 0, 0))
3737
testDecimalColumnStats(Array(null, null, 0, 0, 0))
38-
testArrayColumnStats(ArrayType(IntegerType), orderable = true, Array(null, null, 0, 0, 0))
39-
testStructColumnStats(
40-
StructType(Array(StructField("test", DataTypes.StringType))),
38+
39+
private val orderableArrayDataType = ArrayType(IntegerType)
40+
testOrderableColumnStats(
41+
orderableArrayDataType,
42+
() => new ArrayColumnStats(orderableArrayDataType),
43+
ARRAY(orderableArrayDataType),
44+
orderable = true,
45+
Array(null, null, 0, 0, 0)
46+
)
47+
48+
private val unorderableArrayDataType = ArrayType(MapType(IntegerType, StringType))
49+
testOrderableColumnStats(
50+
unorderableArrayDataType,
51+
() => new ArrayColumnStats(unorderableArrayDataType),
52+
ARRAY(unorderableArrayDataType),
53+
orderable = false,
54+
Array(null, null, 0, 0, 0)
55+
)
56+
57+
private val structDataType = StructType(Array(StructField("test", DataTypes.StringType)))
58+
testOrderableColumnStats(
59+
structDataType,
60+
() => new StructColumnStats(structDataType),
61+
STRUCT(structDataType),
4162
orderable = true,
4263
Array(null, null, 0, 0, 0)
4364
)
@@ -120,58 +141,23 @@ class ColumnStatsSuite extends SparkFunSuite {
120141
}
121142
}
122143

123-
def testArrayColumnStats(
124-
dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = {
125-
val columnType = ColumnType(dataType)
126-
127-
test(s"${dataType.typeName}: empty") {
128-
val objectStats = new ArrayColumnStats(dataType)
129-
objectStats.collectedStatistics.zip(initialStatistics).foreach {
130-
case (actual, expected) => assert(actual === expected)
131-
}
132-
}
133-
134-
test(s"${dataType.typeName}: non-empty") {
135-
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
136-
val objectStats = new ArrayColumnStats(dataType)
137-
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
138-
rows.foreach(objectStats.gatherStats(_, 0))
139-
140-
val stats = objectStats.collectedStatistics
141-
if (orderable) {
142-
val values = rows.take(10).map(_.get(0, columnType.dataType))
143-
val ordering = TypeUtils.getInterpretedOrdering(dataType)
144-
145-
assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
146-
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
147-
} else {
148-
assertResult(null, "Wrong lower bound")(stats(0))
149-
assertResult(null, "Wrong upper bound")(stats(1))
150-
}
151-
assertResult(10, "Wrong null count")(stats(2))
152-
assertResult(20, "Wrong row count")(stats(3))
153-
assertResult(stats(4), "Wrong size in bytes") {
154-
rows.map { row =>
155-
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
156-
}.sum
157-
}
158-
}
159-
}
160-
161-
def testStructColumnStats(
162-
dataType: DataType, orderable: Boolean, initialStatistics: Array[Any]): Unit = {
163-
val columnType = ColumnType(dataType)
144+
def testOrderableColumnStats[T](
145+
dataType: DataType,
146+
statsSupplier: () => OrderableSafeColumnStats[T],
147+
columnType: ColumnType[T],
148+
orderable: Boolean,
149+
initialStatistics: Array[Any]): Unit = {
164150

165-
test(s"${dataType.typeName}: empty") {
166-
val objectStats = new StructColumnStats(dataType)
151+
test(s"${dataType.typeName}, $orderable: empty") {
152+
val objectStats = statsSupplier()
167153
objectStats.collectedStatistics.zip(initialStatistics).foreach {
168154
case (actual, expected) => assert(actual === expected)
169155
}
170156
}
171157

172-
test(s"${dataType.typeName}: non-empty") {
158+
test(s"${dataType.typeName}, $orderable: non-empty") {
173159
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
174-
val objectStats = new StructColumnStats(dataType)
160+
val objectStats = statsSupplier()
175161
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
176162
rows.foreach(objectStats.gatherStats(_, 0))
177163

@@ -224,4 +210,20 @@ class ColumnStatsSuite extends SparkFunSuite {
224210
}
225211
}
226212
}
213+
214+
test("Reuse UnsafeArrayData for stats") {
215+
val stats = new ArrayColumnStats(ArrayType(IntegerType))
216+
val unsafeData = UnsafeArrayData.fromPrimitiveArray(Array(1))
217+
(1 to 10).foreach { value =>
218+
val row = new GenericInternalRow(Array[Any](unsafeData))
219+
unsafeData.setInt(0, value)
220+
stats.gatherStats(row, 0)
221+
}
222+
val collected = stats.collectedStatistics
223+
assertResult(UnsafeArrayData.fromPrimitiveArray(Array(1)))(collected(0))
224+
assertResult(UnsafeArrayData.fromPrimitiveArray(Array(10)))(collected(1))
225+
assertResult(0)(collected(2))
226+
assertResult(10)(collected(3))
227+
assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4))
228+
}
227229
}

0 commit comments

Comments
 (0)