Skip to content

Commit 6e31b7c

Browse files
author
Alexey Kudinkin
authored
[HUDI-4851] Fixing CSI not handling InSet operator properly (apache#6685)
1 parent 22d6019 commit 6e31b7c

2 files changed

Lines changed: 70 additions & 22 deletions

File tree

hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.hudi.common.util.ValidationUtils.checkState
2323
import org.apache.spark.internal.Logging
2424
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2525
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
26-
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression}
26+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, InSet, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression}
2727
import org.apache.spark.sql.functions.col
2828
import org.apache.spark.sql.hudi.ColumnStatsExpressionUtils._
2929
import org.apache.spark.sql.types.StructType
@@ -61,7 +61,7 @@ object DataSkippingUtils extends Logging {
6161
}
6262
}
6363

64-
private def tryComposeIndexFilterExpr(sourceExpr: Expression, indexSchema: StructType): Option[Expression] = {
64+
private def tryComposeIndexFilterExpr(sourceFilterExpr: Expression, indexSchema: StructType): Option[Expression] = {
6565
//
6666
// For translation of the Filter Expression for the Data Table into Filter Expression for Column Stats Index, we're
6767
// assuming that
@@ -91,7 +91,7 @@ object DataSkippingUtils extends Logging {
9191
// colA_minValue = min(colA) => transform_expr(colA_minValue) = min(transform_expr(colA))
9292
// colA_maxValue = max(colA) => transform_expr(colA_maxValue) = max(transform_expr(colA))
9393
//
94-
sourceExpr match {
94+
sourceFilterExpr match {
9595
// If Expression is not resolved, we can't perform the analysis accurately, bailing
9696
case expr if !expr.resolved => None
9797

@@ -227,6 +227,16 @@ object DataSkippingUtils extends Logging {
227227
list.map(lit => genColumnValuesEqualToExpression(colName, lit, targetExprBuilder)).reduce(Or)
228228
}
229229

230+
// Filter "expr(colA) in (B1, B2, ...)"
231+
// NOTE: [[InSet]] is an optimized version of the [[In]] expression, where every sub-expression w/in the
232+
// set is a static literal
233+
case InSet(sourceExpr @ AllowedTransformationExpression(attrRef), hset: Set[Any]) =>
234+
getTargetIndexedColumnName(attrRef, indexSchema)
235+
.map { colName =>
236+
val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
237+
hset.map(value => genColumnValuesEqualToExpression(colName, Literal(value), targetExprBuilder)).reduce(Or)
238+
}
239+
230240
// Filter "expr(colA) not in (B1, B2, ...)"
231241
// Translates to "NOT((colA_minValue = B1 AND colA_maxValue = B1) OR (colA_minValue = B2 AND colA_maxValue = B2))" for index lookup
232242
// NOTE: This is NOT an inversion of `in (B1, B2, ...)` expr, this is equivalent to "colA != B1 AND colA != B2 AND ..."
@@ -331,8 +341,8 @@ private object ColumnStatsExpressionUtils {
331341
@inline def genColValueCountExpr: Expression = col(getValueCountColumnNameFor).expr
332342

333343
@inline def genColumnValuesEqualToExpression(colName: String,
334-
value: Expression,
335-
targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = {
344+
value: Expression,
345+
targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = {
336346
val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName))
337347
val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName))
338348
// Only case when column C contains value V is when min(C) <= V <= max(c)

hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.hudi
2020
import org.apache.hudi.ColumnStatsIndexSupport.composeIndexSchema
2121
import org.apache.hudi.testutils.HoodieClientTestBase
2222
import org.apache.spark.sql.HoodieCatalystExpressionUtils.resolveExpr
23-
import org.apache.spark.sql.catalyst.expressions.{Expression, Not}
23+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
24+
import org.apache.spark.sql.catalyst.expressions.{Expression, InSet, Not}
2425
import org.apache.spark.sql.functions.{col, lower}
2526
import org.apache.spark.sql.hudi.DataSkippingUtils
2627
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
@@ -34,6 +35,7 @@ import org.junit.jupiter.params.provider.{Arguments, MethodSource}
3435

3536
import java.sql.Timestamp
3637
import scala.collection.JavaConverters._
38+
import scala.collection.immutable.HashSet
3739

3840
// NOTE: Only A, B columns are indexed
3941
case class IndexRow(fileName: String,
@@ -80,31 +82,38 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor
8082
val indexSchema: StructType = composeIndexSchema(indexedCols, sourceTableSchema)
8183

8284
@ParameterizedTest
83-
@MethodSource(
84-
Array(
85-
"testBasicLookupFilterExpressionsSource",
86-
"testAdvancedLookupFilterExpressionsSource",
87-
"testCompositeFilterExpressionsSource"
88-
))
89-
def testLookupFilterExpressions(sourceExpr: String, input: Seq[IndexRow], output: Seq[String]): Unit = {
85+
@MethodSource(Array(
86+
"testBasicLookupFilterExpressionsSource",
87+
"testAdvancedLookupFilterExpressionsSource",
88+
"testCompositeFilterExpressionsSource"
89+
))
90+
def testLookupFilterExpressions(sourceFilterExprStr: String, input: Seq[IndexRow], expectedOutput: Seq[String]): Unit = {
9091
// We have to fix the timezone to make sure all date-bound utilities output
9192
// is consistent with the fixtures
9293
spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE.key, "UTC")
9394

94-
val resolvedExpr: Expression = resolveExpr(spark, sourceExpr, sourceTableSchema)
95-
val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema)
95+
val resolvedFilterExpr: Expression = resolveExpr(spark, sourceFilterExprStr, sourceTableSchema)
96+
val rows: Seq[String] = applyFilterExpr(resolvedFilterExpr, input)
9697

97-
val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema)
98+
assertEquals(expectedOutput, rows)
99+
}
98100

99-
val rows = indexDf.where(new Column(lookupFilter))
100-
.select("fileName")
101-
.collect()
102-
.map(_.getString(0))
103-
.toSeq
101+
@ParameterizedTest
102+
@MethodSource(Array(
103+
"testMiscLookupFilterExpressionsSource"
104+
))
105+
def testMiscLookupFilterExpressions(filterExpr: Expression, input: Seq[IndexRow], expectedOutput: Seq[String]): Unit = {
106+
// We have to fix the timezone to make sure all date-bound utilities output
107+
// is consistent with the fixtures
108+
spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE.key, "UTC")
104109

105-
assertEquals(output, rows)
110+
val resolvedFilterExpr: Expression = resolveExpr(spark, filterExpr, sourceTableSchema)
111+
val rows: Seq[String] = applyFilterExpr(resolvedFilterExpr, input)
112+
113+
assertEquals(expectedOutput, rows)
106114
}
107115

116+
108117
@ParameterizedTest
109118
@MethodSource(Array("testStringsLookupFilterExpressionsSource"))
110119
def testStringsLookupFilterExpressions(sourceExpr: Expression, input: Seq[IndexRow], output: Seq[String]): Unit = {
@@ -124,6 +133,18 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor
124133

125134
assertEquals(output, rows)
126135
}
136+
137+
private def applyFilterExpr(resolvedExpr: Expression, input: Seq[IndexRow]): Seq[String] = {
138+
val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema)
139+
140+
val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema)
141+
142+
indexDf.where(new Column(lookupFilter))
143+
.select("fileName")
144+
.collect()
145+
.map(_.getString(0))
146+
.toSeq
147+
}
127148
}
128149

129150
object TestDataSkippingUtils {
@@ -159,6 +180,23 @@ object TestDataSkippingUtils {
159180
)
160181
}
161182

183+
def testMiscLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = {
184+
// NOTE: Have to use [[Arrays.stream]], as Scala can't resolve properly 2 overloads for [[Stream.of]]
185+
// (for single element)
186+
java.util.Arrays.stream(
187+
Array(
188+
arguments(
189+
InSet(UnresolvedAttribute("A"), HashSet(0, 1)),
190+
Seq(
191+
IndexRow("file_1", valueCount = 1, 1, 2, 0),
192+
IndexRow("file_2", valueCount = 1, -1, 1, 0),
193+
IndexRow("file_3", valueCount = 1, -2, -1, 0)
194+
),
195+
Seq("file_1", "file_2"))
196+
)
197+
)
198+
}
199+
162200
def testBasicLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = {
163201
java.util.stream.Stream.of(
164202
// TODO cases

0 commit comments

Comments
 (0)