@@ -21,9 +21,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2121import org .apache .spark .sql .catalyst .expressions ._
2222import org .apache .spark .sql .catalyst .plans .PlanTest
2323import org .apache .spark .sql .connector .expressions .{FieldReference , LiteralValue }
24- import org .apache .spark .sql .connector .expressions .filter .Predicate
24+ import org .apache .spark .sql .connector .expressions .filter .{ And => V2And , Not => V2Not , Or => V2Or , Predicate }
2525import org .apache .spark .sql .test .SharedSparkSession
2626import org .apache .spark .sql .types .{BooleanType , IntegerType , StringType , StructField , StructType }
27+ import org .apache .spark .unsafe .types .UTF8String
2728
2829class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
2930 val attrInts = Seq (
@@ -55,8 +56,37 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
5556 " a.b.cint" // three level nested field
5657 ))
5758
58- test(" SPARK-39784: translate binary expression" ) { attrInts
59- .foreach { case (attrInt, intColName) =>
59+ val attrStrs = Seq (
60+ $" cstr" .string,
61+ $" c.str" .string,
62+ GetStructField ($" a" .struct(StructType (
63+ StructField (" cint" , IntegerType , nullable = true ) ::
64+ StructField (" cstr" , StringType , nullable = true ) :: Nil )), 1 , None ),
65+ GetStructField ($" a" .struct(StructType (
66+ StructField (" c.str" , StringType , nullable = true ) ::
67+ StructField (" cint" , IntegerType , nullable = true ) :: Nil )), 0 , None ),
68+ GetStructField ($" a.b" .struct(StructType (
69+ StructField (" cint1" , IntegerType , nullable = true ) ::
70+ StructField (" cint2" , IntegerType , nullable = true ) ::
71+ StructField (" cstr" , StringType , nullable = true ) :: Nil )), 2 , None ),
72+ GetStructField ($" a.b" .struct(StructType (
73+ StructField (" c.str" , StringType , nullable = true ) :: Nil )), 0 , None ),
74+ GetStructField (GetStructField ($" a" .struct(StructType (
75+ StructField (" cint1" , IntegerType , nullable = true ) ::
76+ StructField (" b" , StructType (StructField (" cstr" , StringType , nullable = true ) ::
77+ StructField (" cint2" , IntegerType , nullable = true ) :: Nil )) :: Nil )), 1 , None ), 0 , None )
78+ ).zip(Seq (
79+ " cstr" ,
80+ " `c.str`" , // single level field that contains `dot` in name
81+ " a.cstr" , // two level nested field
82+ " a.`c.str`" , // two level nested field, and nested level contains `dot`
83+ " `a.b`.cstr" , // two level nested field, and top level contains `dot`
84+ " `a.b`.`c.str`" , // two level nested field, and both levels contain `dot`
85+ " a.b.cstr" // three level nested field
86+ ))
87+
88+ test(" translate simple expression" ) { attrInts.zip(attrStrs)
89+ .foreach { case ((attrInt, intColName), (attrStr, strColName)) =>
6090 testTranslateFilter(EqualTo (attrInt, 1 ),
6191 Some (new Predicate (" =" , Array (FieldReference (intColName), LiteralValue (1 , IntegerType )))))
6292 testTranslateFilter(EqualTo (1 , attrInt),
@@ -86,6 +116,199 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
86116 Some (new Predicate (" <=" , Array (FieldReference (intColName), LiteralValue (1 , IntegerType )))))
87117 testTranslateFilter(LessThanOrEqual (1 , attrInt),
88118 Some (new Predicate (" >=" , Array (FieldReference (intColName), LiteralValue (1 , IntegerType )))))
119+
120+ testTranslateFilter(IsNull (attrInt),
121+ Some (new Predicate (" IS_NULL" , Array (FieldReference (intColName)))))
122+ testTranslateFilter(IsNotNull (attrInt),
123+ Some (new Predicate (" IS_NOT_NULL" , Array (FieldReference (intColName)))))
124+
125+ testTranslateFilter(InSet (attrInt, Set (1 , 2 , 3 )),
126+ Some (new Predicate (" IN" , Array (FieldReference (intColName),
127+ LiteralValue (1 , IntegerType ), LiteralValue (2 , IntegerType ),
128+ LiteralValue (3 , IntegerType )))))
129+
130+ testTranslateFilter(In (attrInt, Seq (1 , 2 , 3 )),
131+ Some (new Predicate (" IN" , Array (FieldReference (intColName),
132+ LiteralValue (1 , IntegerType ), LiteralValue (2 , IntegerType ),
133+ LiteralValue (3 , IntegerType )))))
134+
135+ // cint > 1 AND cint < 10
136+ testTranslateFilter(And (
137+ GreaterThan (attrInt, 1 ),
138+ LessThan (attrInt, 10 )),
139+ Some (new V2And (
140+ new Predicate (" >" , Array (FieldReference (intColName), LiteralValue (1 , IntegerType ))),
141+ new Predicate (" <" , Array (FieldReference (intColName), LiteralValue (10 , IntegerType ))))))
142+
143+ // cint >= 8 OR cint <= 2
144+ testTranslateFilter(Or (
145+ GreaterThanOrEqual (attrInt, 8 ),
146+ LessThanOrEqual (attrInt, 2 )),
147+ Some (new V2Or (
148+ new Predicate (" >=" , Array (FieldReference (intColName), LiteralValue (8 , IntegerType ))),
149+ new Predicate (" <=" , Array (FieldReference (intColName), LiteralValue (2 , IntegerType ))))))
150+
151+ testTranslateFilter(Not (GreaterThanOrEqual (attrInt, 8 )),
152+ Some (new V2Not (new Predicate (" >=" , Array (FieldReference (intColName),
153+ LiteralValue (8 , IntegerType ))))))
154+
155+ testTranslateFilter(StartsWith (attrStr, " a" ),
156+ Some (new Predicate (" STARTS_WITH" , Array (FieldReference (strColName),
157+ LiteralValue (UTF8String .fromString(" a" ), StringType )))))
158+
159+ testTranslateFilter(EndsWith (attrStr, " a" ),
160+ Some (new Predicate (" ENDS_WITH" , Array (FieldReference (strColName),
161+ LiteralValue (UTF8String .fromString(" a" ), StringType )))))
162+
163+ testTranslateFilter(Contains (attrStr, " a" ),
164+ Some (new Predicate (" CONTAINS" , Array (FieldReference (strColName),
165+ LiteralValue (UTF8String .fromString(" a" ), StringType )))))
166+ }
167+ }
168+
169+ test(" translate complex expression" ) {
170+ attrInts.foreach { case (attrInt, intColName) =>
171+
172+ // ABS(cint) - 2 <= 1
173+ testTranslateFilter(LessThanOrEqual (
174+ // Expressions are not supported
175+ // Functions such as 'Abs' are not supported
176+ Subtract (Abs (attrInt), 2 ), 1 ), None )
177+
178+ // (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100)
179+ testTranslateFilter(Or (
180+ And (
181+ GreaterThan (attrInt, 1 ),
182+ LessThan (attrInt, 10 )
183+ ),
184+ And (
185+ GreaterThan (attrInt, 50 ),
186+ LessThan (attrInt, 100 ))),
187+ Some (new V2Or (
188+ new V2And (
189+ new Predicate (" >" , Array (FieldReference (intColName), LiteralValue (1 , IntegerType ))),
190+ new Predicate (" <" , Array (FieldReference (intColName), LiteralValue (10 , IntegerType )))),
191+ new V2And (
192+ new Predicate (" >" , Array (FieldReference (intColName), LiteralValue (50 , IntegerType ))),
193+ new Predicate (" <" , Array (FieldReference (intColName),
194+ LiteralValue (100 , IntegerType )))))
195+ )
196+ )
197+
198+ // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100)
199+ testTranslateFilter(Or (
200+ And (
201+ GreaterThan (attrInt, 1 ),
202+ // Functions such as 'Abs' are not supported
203+ LessThan (Abs (attrInt), 10 )
204+ ),
205+ And (
206+ GreaterThan (attrInt, 50 ),
207+ LessThan (attrInt, 100 ))), None )
208+
209+ // NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100))
210+ testTranslateFilter(Not (And (
211+ Or (
212+ LessThanOrEqual (attrInt, 1 ),
213+ // Functions such as 'Abs' are not supported
214+ GreaterThanOrEqual (Abs (attrInt), 10 )
215+ ),
216+ Or (
217+ LessThanOrEqual (attrInt, 50 ),
218+ GreaterThanOrEqual (attrInt, 100 )))), None )
219+
220+ // (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10)
221+ testTranslateFilter(Or (
222+ Or (
223+ EqualTo (attrInt, 1 ),
224+ EqualTo (attrInt, 10 )
225+ ),
226+ Or (
227+ GreaterThan (attrInt, 0 ),
228+ LessThan (attrInt, - 10 ))),
229+ Some (new V2Or (
230+ new V2Or (
231+ new Predicate (" =" , Array (FieldReference (intColName), LiteralValue (1 , IntegerType ))),
232+ new Predicate (" =" , Array (FieldReference (intColName), LiteralValue (10 , IntegerType )))),
233+ new V2Or (
234+ new Predicate (" >" , Array (FieldReference (intColName), LiteralValue (0 , IntegerType ))),
235+ new Predicate (" <" , Array (FieldReference (intColName), LiteralValue (- 10 , IntegerType )))))
236+ )
237+ )
238+
239+ // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10)
240+ testTranslateFilter(Or (
241+ Or (
242+ EqualTo (attrInt, 1 ),
243+ // Functions such as 'Abs' are not supported
244+ EqualTo (Abs (attrInt), 10 )
245+ ),
246+ Or (
247+ GreaterThan (attrInt, 0 ),
248+ LessThan (attrInt, - 10 ))), None )
249+
250+ // In end-to-end testing, conjunctive predicate should has been split
251+ // before reaching DataSourceStrategy.translateFilter.
252+ // This is for UT purpose to test each [[case]].
253+ // (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL)
254+ testTranslateFilter(And (
255+ And (
256+ GreaterThan (attrInt, 1 ),
257+ LessThan (attrInt, 10 )
258+ ),
259+ And (
260+ EqualTo (attrInt, 6 ),
261+ IsNotNull (attrInt))),
262+ Some (new V2And (
263+ new V2And (
264+ new Predicate (" >" , Array (FieldReference (intColName), LiteralValue (1 , IntegerType ))),
265+ new Predicate (" <" , Array (FieldReference (intColName), LiteralValue (10 , IntegerType )))),
266+ new V2And (
267+ new Predicate (" =" , Array (FieldReference (intColName), LiteralValue (6 , IntegerType ))),
268+ new Predicate (" IS_NOT_NULL" , Array (FieldReference (intColName)))))
269+ )
270+ )
271+
272+ // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL)
273+ testTranslateFilter(And (
274+ And (
275+ GreaterThan (attrInt, 1 ),
276+ LessThan (attrInt, 10 )
277+ ),
278+ And (
279+ // Functions such as 'Abs' are not supported
280+ EqualTo (Abs (attrInt), 6 ),
281+ IsNotNull (attrInt))), None )
282+
283+ // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
284+ testTranslateFilter(And (
285+ Or (
286+ GreaterThan (attrInt, 1 ),
287+ LessThan (attrInt, 10 )
288+ ),
289+ Or (
290+ EqualTo (attrInt, 6 ),
291+ IsNotNull (attrInt))),
292+ Some (new V2And (
293+ new V2Or (
294+ new Predicate (" >" , Array (FieldReference (intColName), LiteralValue (1 , IntegerType ))),
295+ new Predicate (" <" , Array (FieldReference (intColName), LiteralValue (10 , IntegerType )))),
296+ new V2Or (
297+ new Predicate (" =" , Array (FieldReference (intColName), LiteralValue (6 , IntegerType ))),
298+ new Predicate (" IS_NOT_NULL" , Array (FieldReference (intColName)))))
299+ )
300+ )
301+
302+ // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
303+ testTranslateFilter(And (
304+ Or (
305+ GreaterThan (attrInt, 1 ),
306+ LessThan (attrInt, 10 )
307+ ),
308+ Or (
309+ // Functions such as 'Abs' are not supported
310+ EqualTo (Abs (attrInt), 6 ),
311+ IsNotNull (attrInt))), None )
89312 }
90313 }
91314
0 commit comments