Skip to content

Commit 6b24926

Browse files
huaxingaodongjoon-hyun
authored andcommitted
[SPARK-39857][SQL] V2ExpressionBuilder uses the wrong LiteralValue data type for In predicate
### What changes were proposed in this pull request? When building V2 `In` Predicate in `V2ExpressionBuilder`, `InSet.dataType` (which is `BooleanType`) is used to build the `LiteralValue`, `InSet.child.dataType` should be used instead. ### Why are the changes needed? bug fix ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new test Closes #37271 from huaxingao/inset. Authored-by: huaxingao <huaxin_gao@apple.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent ae1f6a2 commit 6b24926

2 files changed

Lines changed: 228 additions & 5 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
5252
} else {
5353
Some(ref)
5454
}
55-
case in @ InSet(child, hset) =>
55+
case InSet(child, hset) =>
5656
generateExpression(child).map { v =>
5757
val children =
58-
(v +: hset.toSeq.map(elem => LiteralValue(elem, in.dataType))).toArray[V2Expression]
58+
(v +: hset.toSeq.map(elem => LiteralValue(elem, child.dataType))).toArray[V2Expression]
5959
new V2Predicate("IN", children)
6060
}
6161
// Because we only convert In to InSet in Optimizer when there are more than certain

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala

Lines changed: 226 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.plans.PlanTest
2323
import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue}
24-
import org.apache.spark.sql.connector.expressions.filter.Predicate
24+
import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate}
2525
import org.apache.spark.sql.test.SharedSparkSession
2626
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructField, StructType}
27+
import org.apache.spark.unsafe.types.UTF8String
2728

2829
class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
2930
val attrInts = Seq(
@@ -55,8 +56,37 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
5556
"a.b.cint" // three level nested field
5657
))
5758

58-
test("SPARK-39784: translate binary expression") { attrInts
59-
.foreach { case (attrInt, intColName) =>
59+
val attrStrs = Seq(
60+
$"cstr".string,
61+
$"c.str".string,
62+
GetStructField($"a".struct(StructType(
63+
StructField("cint", IntegerType, nullable = true) ::
64+
StructField("cstr", StringType, nullable = true) :: Nil)), 1, None),
65+
GetStructField($"a".struct(StructType(
66+
StructField("c.str", StringType, nullable = true) ::
67+
StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None),
68+
GetStructField($"a.b".struct(StructType(
69+
StructField("cint1", IntegerType, nullable = true) ::
70+
StructField("cint2", IntegerType, nullable = true) ::
71+
StructField("cstr", StringType, nullable = true) :: Nil)), 2, None),
72+
GetStructField($"a.b".struct(StructType(
73+
StructField("c.str", StringType, nullable = true) :: Nil)), 0, None),
74+
GetStructField(GetStructField($"a".struct(StructType(
75+
StructField("cint1", IntegerType, nullable = true) ::
76+
StructField("b", StructType(StructField("cstr", StringType, nullable = true) ::
77+
StructField("cint2", IntegerType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None)
78+
).zip(Seq(
79+
"cstr",
80+
"`c.str`", // single level field that contains `dot` in name
81+
"a.cstr", // two level nested field
82+
"a.`c.str`", // two level nested field, and nested level contains `dot`
83+
"`a.b`.cstr", // two level nested field, and top level contains `dot`
84+
"`a.b`.`c.str`", // two level nested field, and both levels contain `dot`
85+
"a.b.cstr" // three level nested field
86+
))
87+
88+
test("translate simple expression") { attrInts.zip(attrStrs)
89+
.foreach { case ((attrInt, intColName), (attrStr, strColName)) =>
6090
testTranslateFilter(EqualTo(attrInt, 1),
6191
Some(new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
6292
testTranslateFilter(EqualTo(1, attrInt),
@@ -86,6 +116,199 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
86116
Some(new Predicate("<=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
87117
testTranslateFilter(LessThanOrEqual(1, attrInt),
88118
Some(new Predicate(">=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
119+
120+
testTranslateFilter(IsNull(attrInt),
121+
Some(new Predicate("IS_NULL", Array(FieldReference(intColName)))))
122+
testTranslateFilter(IsNotNull(attrInt),
123+
Some(new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
124+
125+
testTranslateFilter(InSet(attrInt, Set(1, 2, 3)),
126+
Some(new Predicate("IN", Array(FieldReference(intColName),
127+
LiteralValue(1, IntegerType), LiteralValue(2, IntegerType),
128+
LiteralValue(3, IntegerType)))))
129+
130+
testTranslateFilter(In(attrInt, Seq(1, 2, 3)),
131+
Some(new Predicate("IN", Array(FieldReference(intColName),
132+
LiteralValue(1, IntegerType), LiteralValue(2, IntegerType),
133+
LiteralValue(3, IntegerType)))))
134+
135+
// cint > 1 AND cint < 10
136+
testTranslateFilter(And(
137+
GreaterThan(attrInt, 1),
138+
LessThan(attrInt, 10)),
139+
Some(new V2And(
140+
new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
141+
new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType))))))
142+
143+
// cint >= 8 OR cint <= 2
144+
testTranslateFilter(Or(
145+
GreaterThanOrEqual(attrInt, 8),
146+
LessThanOrEqual(attrInt, 2)),
147+
Some(new V2Or(
148+
new Predicate(">=", Array(FieldReference(intColName), LiteralValue(8, IntegerType))),
149+
new Predicate("<=", Array(FieldReference(intColName), LiteralValue(2, IntegerType))))))
150+
151+
testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)),
152+
Some(new V2Not(new Predicate(">=", Array(FieldReference(intColName),
153+
LiteralValue(8, IntegerType))))))
154+
155+
testTranslateFilter(StartsWith(attrStr, "a"),
156+
Some(new Predicate("STARTS_WITH", Array(FieldReference(strColName),
157+
LiteralValue(UTF8String.fromString("a"), StringType)))))
158+
159+
testTranslateFilter(EndsWith(attrStr, "a"),
160+
Some(new Predicate("ENDS_WITH", Array(FieldReference(strColName),
161+
LiteralValue(UTF8String.fromString("a"), StringType)))))
162+
163+
testTranslateFilter(Contains(attrStr, "a"),
164+
Some(new Predicate("CONTAINS", Array(FieldReference(strColName),
165+
LiteralValue(UTF8String.fromString("a"), StringType)))))
166+
}
167+
}
168+
169+
test("translate complex expression") {
170+
attrInts.foreach { case (attrInt, intColName) =>
171+
172+
// ABS(cint) - 2 <= 1
173+
testTranslateFilter(LessThanOrEqual(
174+
// Expressions are not supported
175+
// Functions such as 'Abs' are not supported
176+
Subtract(Abs(attrInt), 2), 1), None)
177+
178+
// (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100)
179+
testTranslateFilter(Or(
180+
And(
181+
GreaterThan(attrInt, 1),
182+
LessThan(attrInt, 10)
183+
),
184+
And(
185+
GreaterThan(attrInt, 50),
186+
LessThan(attrInt, 100))),
187+
Some(new V2Or(
188+
new V2And(
189+
new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
190+
new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))),
191+
new V2And(
192+
new Predicate(">", Array(FieldReference(intColName), LiteralValue(50, IntegerType))),
193+
new Predicate("<", Array(FieldReference(intColName),
194+
LiteralValue(100, IntegerType)))))
195+
)
196+
)
197+
198+
// (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100)
199+
testTranslateFilter(Or(
200+
And(
201+
GreaterThan(attrInt, 1),
202+
// Functions such as 'Abs' are not supported
203+
LessThan(Abs(attrInt), 10)
204+
),
205+
And(
206+
GreaterThan(attrInt, 50),
207+
LessThan(attrInt, 100))), None)
208+
209+
// NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100))
210+
testTranslateFilter(Not(And(
211+
Or(
212+
LessThanOrEqual(attrInt, 1),
213+
// Functions such as 'Abs' are not supported
214+
GreaterThanOrEqual(Abs(attrInt), 10)
215+
),
216+
Or(
217+
LessThanOrEqual(attrInt, 50),
218+
GreaterThanOrEqual(attrInt, 100)))), None)
219+
220+
// (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10)
221+
testTranslateFilter(Or(
222+
Or(
223+
EqualTo(attrInt, 1),
224+
EqualTo(attrInt, 10)
225+
),
226+
Or(
227+
GreaterThan(attrInt, 0),
228+
LessThan(attrInt, -10))),
229+
Some(new V2Or(
230+
new V2Or(
231+
new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
232+
new Predicate("=", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))),
233+
new V2Or(
234+
new Predicate(">", Array(FieldReference(intColName), LiteralValue(0, IntegerType))),
235+
new Predicate("<", Array(FieldReference(intColName), LiteralValue(-10, IntegerType)))))
236+
)
237+
)
238+
239+
// (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10)
240+
testTranslateFilter(Or(
241+
Or(
242+
EqualTo(attrInt, 1),
243+
// Functions such as 'Abs' are not supported
244+
EqualTo(Abs(attrInt), 10)
245+
),
246+
Or(
247+
GreaterThan(attrInt, 0),
248+
LessThan(attrInt, -10))), None)
249+
250+
// In end-to-end testing, conjunctive predicate should has been split
251+
// before reaching DataSourceStrategy.translateFilter.
252+
// This is for UT purpose to test each [[case]].
253+
// (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL)
254+
testTranslateFilter(And(
255+
And(
256+
GreaterThan(attrInt, 1),
257+
LessThan(attrInt, 10)
258+
),
259+
And(
260+
EqualTo(attrInt, 6),
261+
IsNotNull(attrInt))),
262+
Some(new V2And(
263+
new V2And(
264+
new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
265+
new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))),
266+
new V2And(
267+
new Predicate("=", Array(FieldReference(intColName), LiteralValue(6, IntegerType))),
268+
new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
269+
)
270+
)
271+
272+
// (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL)
273+
testTranslateFilter(And(
274+
And(
275+
GreaterThan(attrInt, 1),
276+
LessThan(attrInt, 10)
277+
),
278+
And(
279+
// Functions such as 'Abs' are not supported
280+
EqualTo(Abs(attrInt), 6),
281+
IsNotNull(attrInt))), None)
282+
283+
// (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
284+
testTranslateFilter(And(
285+
Or(
286+
GreaterThan(attrInt, 1),
287+
LessThan(attrInt, 10)
288+
),
289+
Or(
290+
EqualTo(attrInt, 6),
291+
IsNotNull(attrInt))),
292+
Some(new V2And(
293+
new V2Or(
294+
new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
295+
new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))),
296+
new V2Or(
297+
new Predicate("=", Array(FieldReference(intColName), LiteralValue(6, IntegerType))),
298+
new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
299+
)
300+
)
301+
302+
// (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
303+
testTranslateFilter(And(
304+
Or(
305+
GreaterThan(attrInt, 1),
306+
LessThan(attrInt, 10)
307+
),
308+
Or(
309+
// Functions such as 'Abs' are not supported
310+
EqualTo(Abs(attrInt), 6),
311+
IsNotNull(attrInt))), None)
89312
}
90313
}
91314

0 commit comments

Comments
 (0)