@@ -20,7 +20,8 @@ package org.apache.hudi
2020import org .apache .hudi .ColumnStatsIndexSupport .composeIndexSchema
2121import org .apache .hudi .testutils .HoodieClientTestBase
2222import org .apache .spark .sql .HoodieCatalystExpressionUtils .resolveExpr
23- import org .apache .spark .sql .catalyst .expressions .{Expression , Not }
23+ import org .apache .spark .sql .catalyst .analysis .UnresolvedAttribute
24+ import org .apache .spark .sql .catalyst .expressions .{Expression , InSet , Not }
2425import org .apache .spark .sql .functions .{col , lower }
2526import org .apache .spark .sql .hudi .DataSkippingUtils
2627import org .apache .spark .sql .internal .SQLConf .SESSION_LOCAL_TIMEZONE
@@ -34,6 +35,7 @@ import org.junit.jupiter.params.provider.{Arguments, MethodSource}
3435
3536import java .sql .Timestamp
3637import scala .collection .JavaConverters ._
38+ import scala .collection .immutable .HashSet
3739
3840// NOTE: Only A, B columns are indexed
3941case class IndexRow (fileName : String ,
@@ -80,31 +82,38 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor
8082 val indexSchema : StructType = composeIndexSchema(indexedCols, sourceTableSchema)
8183
8284 @ ParameterizedTest
83- @ MethodSource (
84- Array (
85- " testBasicLookupFilterExpressionsSource" ,
86- " testAdvancedLookupFilterExpressionsSource" ,
87- " testCompositeFilterExpressionsSource"
88- ))
89- def testLookupFilterExpressions (sourceExpr : String , input : Seq [IndexRow ], output : Seq [String ]): Unit = {
85+ @ MethodSource (Array (
86+ " testBasicLookupFilterExpressionsSource" ,
87+ " testAdvancedLookupFilterExpressionsSource" ,
88+ " testCompositeFilterExpressionsSource"
89+ ))
90+ def testLookupFilterExpressions (sourceFilterExprStr : String , input : Seq [IndexRow ], expectedOutput : Seq [String ]): Unit = {
9091 // We have to fix the timezone to make sure all date-bound utilities output
9192 // is consistent with the fixtures
9293 spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE .key, " UTC" )
9394
94- val resolvedExpr : Expression = resolveExpr(spark, sourceExpr , sourceTableSchema)
95- val lookupFilter = DataSkippingUtils .translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema )
95+ val resolvedFilterExpr : Expression = resolveExpr(spark, sourceFilterExprStr , sourceTableSchema)
96+ val rows : Seq [ String ] = applyFilterExpr(resolvedFilterExpr, input )
9697
97- val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema)
98+ assertEquals(expectedOutput, rows)
99+ }
98100
99- val rows = indexDf.where(new Column (lookupFilter))
100- .select(" fileName" )
101- .collect()
102- .map(_.getString(0 ))
103- .toSeq
101+ @ ParameterizedTest
102+ @ MethodSource (Array (
103+ " testMiscLookupFilterExpressionsSource"
104+ ))
105+ def testMiscLookupFilterExpressions (filterExpr : Expression , input : Seq [IndexRow ], expectedOutput : Seq [String ]): Unit = {
106+ // We have to fix the timezone to make sure all date-bound utilities output
107+ // is consistent with the fixtures
108+ spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE .key, " UTC" )
104109
105- assertEquals(output, rows)
110+ val resolvedFilterExpr : Expression = resolveExpr(spark, filterExpr, sourceTableSchema)
111+ val rows : Seq [String ] = applyFilterExpr(resolvedFilterExpr, input)
112+
113+ assertEquals(expectedOutput, rows)
106114 }
107115
116+
108117 @ ParameterizedTest
109118 @ MethodSource (Array (" testStringsLookupFilterExpressionsSource" ))
110119 def testStringsLookupFilterExpressions (sourceExpr : Expression , input : Seq [IndexRow ], output : Seq [String ]): Unit = {
@@ -124,6 +133,18 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor
124133
125134 assertEquals(output, rows)
126135 }
136+
137+ private def applyFilterExpr (resolvedExpr : Expression , input : Seq [IndexRow ]): Seq [String ] = {
138+ val lookupFilter = DataSkippingUtils .translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema)
139+
140+ val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema)
141+
142+ indexDf.where(new Column (lookupFilter))
143+ .select(" fileName" )
144+ .collect()
145+ .map(_.getString(0 ))
146+ .toSeq
147+ }
127148}
128149
129150object TestDataSkippingUtils {
@@ -159,6 +180,23 @@ object TestDataSkippingUtils {
159180 )
160181 }
161182
183+ def testMiscLookupFilterExpressionsSource (): java.util.stream.Stream [Arguments ] = {
184+ // NOTE: Have to use [[Arrays.stream]], as Scala can't resolve properly 2 overloads for [[Stream.of]]
185+ // (for single element)
186+ java.util.Arrays .stream(
187+ Array (
188+ arguments(
189+ InSet (UnresolvedAttribute (" A" ), HashSet (0 , 1 )),
190+ Seq (
191+ IndexRow (" file_1" , valueCount = 1 , 1 , 2 , 0 ),
192+ IndexRow (" file_2" , valueCount = 1 , - 1 , 1 , 0 ),
193+ IndexRow (" file_3" , valueCount = 1 , - 2 , - 1 , 0 )
194+ ),
195+ Seq (" file_1" , " file_2" ))
196+ )
197+ )
198+ }
199+
162200 def testBasicLookupFilterExpressionsSource (): java.util.stream.Stream [Arguments ] = {
163201 java.util.stream.Stream .of(
164202 // TODO cases
0 commit comments