Skip to content

Commit a1f91cd

Browse files
committed
refactor Unbounded
1 parent 5c9a992 commit a1f91cd

File tree

9 files changed

+117
-44
lines changed

9 files changed

+117
-44
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ case class WindowSpecDefinition(
7373
s"The data type '${orderSpec.head.dataType}' used in the order specification does " +
7474
s"not match the data type '${f.valueBoundary.head.dataType}' which is used in the " +
7575
"range frame.")
76+
case f: SpecifiedWindowFrame if !isValidFrameBoundary(f.lower, f.upper) =>
77+
TypeCheckFailure(s"The upper bound of the window frame is '${f.upper.sql}', which is " +
78+
s"smaller than the lower bound '${f.lower.sql}'.")
7679
case _ => TypeCheckSuccess
7780
}
7881
}
@@ -90,6 +93,15 @@ case class WindowSpecDefinition(
9093
}
9194

9295
private def isValidFrameType(ft: DataType): Boolean = orderSpec.head.dataType == ft
96+
97+
private def isValidFrameBoundary(lower: Expression, upper: Expression): Boolean = {
98+
(lower, upper) match {
99+
case (UnboundedFollowing, _) => false
100+
case (_, UnboundedPreceding) => false
101+
case (l: Expression, u: SpecialFrameBoundary) => !u.notFollows(l)
102+
case _ => true
103+
}
104+
}
93105
}
94106

95107
/**
@@ -141,13 +153,37 @@ sealed trait SpecialFrameBoundary extends Expression with Unevaluable {
141153
override def dataType: DataType = NullType
142154
override def foldable: Boolean = false
143155
override def nullable: Boolean = false
156+
157+
def notFollows(other: Expression): Boolean
144158
}
145159

146160
/** UNBOUNDED boundary. */
147-
case object Unbounded extends SpecialFrameBoundary
161+
case object UnboundedPreceding extends SpecialFrameBoundary {
162+
override def sql: String = "UNBOUNDED PRECEDING"
163+
164+
override def notFollows(other: Expression): Boolean = true
165+
}
166+
167+
case object UnboundedFollowing extends SpecialFrameBoundary {
168+
override def sql: String = "UNBOUNDED FOLLOWING"
169+
170+
override def notFollows(other: Expression): Boolean = other match {
171+
case UnboundedFollowing => true
172+
case _ => false
173+
}
174+
}
148175

149176
/** CURRENT ROW boundary. */
150-
case object CurrentRow extends SpecialFrameBoundary
177+
case object CurrentRow extends SpecialFrameBoundary {
178+
override def sql: String = "CURRENT ROW"
179+
180+
override def notFollows(other: Expression): Boolean = other match {
181+
case UnboundedPreceding => false
182+
case CurrentRow => false
183+
case e: Expression if e.foldable => GreaterThan(e, Literal(0)).eval().asInstanceOf[Boolean]
184+
case _ => true
185+
}
186+
}
151187

152188
/**
153189
* Represents a window frame.
@@ -206,12 +242,12 @@ case class SpecifiedWindowFrame(
206242
}
207243

208244
override def sql: String = {
209-
val lowerSql = boundarySql(lower, "PRECEDING")
210-
val upperSql = boundarySql(upper, "FOLLOWING")
245+
val lowerSql = boundarySql(lower)
246+
val upperSql = boundarySql(upper)
211247
s"${frameType.sql} BETWEEN $lowerSql AND $upperSql"
212248
}
213249

214-
def isUnbounded: Boolean = lower == Unbounded && upper == Unbounded
250+
def isUnbounded: Boolean = lower == UnboundedPreceding && upper == UnboundedFollowing
215251

216252
def isValueBound: Boolean = valueBoundary.nonEmpty
217253

@@ -220,9 +256,8 @@ case class SpecifiedWindowFrame(
220256
case _ => false
221257
}
222258

223-
private def boundarySql(expr: Expression, defaultDirection: String): String = expr match {
224-
case CurrentRow => "CURRENT ROW"
225-
case Unbounded => "UNBOUNDED " + defaultDirection
259+
private def boundarySql(expr: Expression): String = expr match {
260+
case e: SpecialFrameBoundary => e.sql
226261
case UnaryMinus(n) => n.sql + " PRECEDING"
227262
case e: Expression => e.sql + " FOLLOWING"
228263
}
@@ -257,11 +292,11 @@ object SpecifiedWindowFrame {
257292
if (hasOrderSpecification && acceptWindowFrame) {
258293
// If order spec is defined and the window function supports user specified window frames,
259294
// the default frame is RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW.
260-
SpecifiedWindowFrame(RangeFrame, Unbounded, CurrentRow)
295+
SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow)
261296
} else {
262297
// Otherwise, the default frame is
263298
// ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING.
264-
SpecifiedWindowFrame(RowFrame, Unbounded, Unbounded)
299+
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
265300
}
266301
}
267302
}
@@ -429,7 +464,7 @@ case class Lag(input: Expression, offset: Expression, default: Expression)
429464

430465
abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction {
431466
self: Product =>
432-
override val frame = SpecifiedWindowFrame(RowFrame, Unbounded, CurrentRow)
467+
override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)
433468
override def dataType: DataType = IntegerType
434469
override def nullable: Boolean = true
435470
override lazy val mergeExpressions =
@@ -493,7 +528,7 @@ case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction {
493528
override def dataType: DataType = DoubleType
494529
// The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must
495530
// return the same value for equal values in the partition.
496-
override val frame = SpecifiedWindowFrame(RangeFrame, Unbounded, CurrentRow)
531+
override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow)
497532
override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType))
498533
override def prettyName: String = "cume_dist"
499534
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,13 +1190,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
11901190

11911191
ctx.boundType.getType match {
11921192
case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null =>
1193-
Unbounded
1193+
UnboundedPreceding
11941194
case SqlBaseParser.PRECEDING =>
11951195
UnaryMinus(value)
11961196
case SqlBaseParser.CURRENT =>
11971197
CurrentRow
11981198
case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null =>
1199-
Unbounded
1199+
UnboundedFollowing
12001200
case SqlBaseParser.FOLLOWING =>
12011201
value
12021202
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,11 +1138,11 @@ class TypeCoercionSuite extends AnalysisTest {
11381138
windowSpec(
11391139
Seq(UnresolvedAttribute("a")),
11401140
Seq(SortOrder(Literal(1L), Ascending)),
1141-
SpecifiedWindowFrame(RangeFrame, CurrentRow, Unbounded)),
1141+
SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)),
11421142
windowSpec(
11431143
Seq(UnresolvedAttribute("a")),
11441144
Seq(SortOrder(Literal(1L), Ascending)),
1145-
SpecifiedWindowFrame(RangeFrame, CurrentRow, Unbounded))
1145+
SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing))
11461146
)
11471147
}
11481148
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,11 @@ class ExpressionParserSuite extends PlanTest {
270270
("10 preceding", -Literal(10), CurrentRow),
271271
("2147483648 preceding", -Literal(2147483648L), CurrentRow),
272272
("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow), // Will fail during analysis
273-
("unbounded preceding", Unbounded, CurrentRow),
274-
("unbounded following", Unbounded, CurrentRow), // Will fail during analysis
275-
("between unbounded preceding and current row", Unbounded, CurrentRow),
273+
("unbounded preceding", UnboundedPreceding, CurrentRow),
274+
("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis
275+
("between unbounded preceding and current row", UnboundedPreceding, CurrentRow),
276276
("between unbounded preceding and unbounded following",
277-
Unbounded, Unbounded),
277+
UnboundedPreceding, UnboundedFollowing),
278278
("between 10 preceding and current row", -Literal(10), CurrentRow),
279279
("between current row and 5 following", CurrentRow, Literal(5)),
280280
("between 10 preceding and 5 following", -Literal(10), Literal(5))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ class TreeNodeSuite extends SparkFunSuite {
438438

439439
// Converts WindowFrame to JSON
440440
assertJSON(
441-
SpecifiedWindowFrame(RowFrame, Unbounded, CurrentRow),
441+
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow),
442442
List(
443443
JObject(
444444
"class" -> classOf[SpecifiedWindowFrame].getName,
@@ -447,7 +447,7 @@ class TreeNodeSuite extends SparkFunSuite {
447447
"lower" -> 0,
448448
"upper" -> 1),
449449
JObject(
450-
"class" -> Unbounded.getClass.getName,
450+
"class" -> UnboundedPreceding.getClass.getName,
451451
"num-children" -> 0),
452452
JObject(
453453
"class" -> CurrentRow.getClass.getName,

sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,13 @@ case class WindowExec(
219219
offset)
220220

221221
// Entire Partition Frame.
222-
case ("AGGREGATE", _, Unbounded, Unbounded) =>
222+
case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
223223
target: InternalRow => {
224224
new UnboundedWindowFunctionFrame(target, processor)
225225
}
226226

227227
// Growing Frame.
228-
case ("AGGREGATE", frameType, Unbounded, upper) =>
228+
case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
229229
target: InternalRow => {
230230
new UnboundedPrecedingWindowFunctionFrame(
231231
target,
@@ -234,7 +234,7 @@ case class WindowExec(
234234
}
235235

236236
// Shrinking Frame.
237-
case ("AGGREGATE", frameType, lower, Unbounded) =>
237+
case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
238238
target: InternalRow => {
239239
new UnboundedFollowingWindowFunctionFrame(
240240
target,

sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,14 @@ class WindowSpec private[sql](
125125
def rowsBetween(start: Long, end: Long): WindowSpec = {
126126
val boundaryStart = start match {
127127
case 0 => CurrentRow
128-
case Long.MinValue => Unbounded
128+
case Long.MinValue => UnboundedPreceding
129129
case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt)
130130
case x => throw new AnalysisException(s"Boundary start is not a valid integer: $x")
131131
}
132132

133133
val boundaryEnd = end match {
134134
case 0 => CurrentRow
135-
case Long.MaxValue => Unbounded
135+
case Long.MaxValue => UnboundedFollowing
136136
case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt)
137137
case x => throw new AnalysisException(s"Boundary end is not a valid integer: $x")
138138
}
@@ -193,13 +193,13 @@ class WindowSpec private[sql](
193193
def rangeBetween(start: Long, end: Long): WindowSpec = {
194194
val boundaryStart = start match {
195195
case 0 => CurrentRow
196-
case Long.MinValue => Unbounded
196+
case Long.MinValue => UnboundedPreceding
197197
case x => Literal(x)
198198
}
199199

200200
val boundaryEnd = end match {
201201
case 0 => CurrentRow
202-
case Long.MaxValue => Unbounded
202+
case Long.MaxValue => UnboundedFollowing
203203
case x => Literal(x)
204204
}
205205

sql/core/src/test/resources/sql-tests/inputs/window.sql

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ RANGE BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate,
2424
SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC
2525
RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val;
2626

27+
-- Invalid window frame
28+
SELECT val, cate, count(val) OVER(PARTITION BY cate
29+
ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val;
30+
SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val DESC
31+
RANGE BETWEEN 1 FOLLOWING AND CURRENT ROW) FROM testData ORDER BY cate, val;
32+
SELECT val, cate, count(val) OVER(PARTITION BY cate
33+
RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val;
34+
2735
-- Window functions
2836
SELECT val, cate,
2937
max(val) OVER w AS max,

sql/core/src/test/resources/sql-tests/results/window.sql.out

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 13
2+
-- Number of queries: 16
33

44

55
-- !query 0
@@ -126,6 +126,36 @@ NULL a NULL
126126

127127

128128
-- !query 8
129+
SELECT val, cate, count(val) OVER(PARTITION BY cate
130+
ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val
131+
-- !query 8 schema
132+
struct<>
133+
-- !query 8 output
134+
org.apache.spark.sql.AnalysisException
135+
cannot resolve '(PARTITION BY testdata.`cate` ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING)' due to data type mismatch: The upper bound of the window frame is '1', which is smaller than the lower bound 'UNBOUNDED FOLLOWING'.; line 1 pos 33
136+
137+
138+
-- !query 9
139+
SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val DESC
140+
RANGE BETWEEN 1 FOLLOWING AND CURRENT ROW) FROM testData ORDER BY cate, val
141+
-- !query 9 schema
142+
struct<>
143+
-- !query 9 output
144+
org.apache.spark.sql.AnalysisException
145+
cannot resolve '(PARTITION BY testdata.`cate` ORDER BY testdata.`val` DESC NULLS LAST RANGE BETWEEN 1 FOLLOWING AND CURRENT ROW)' due to data type mismatch: The upper bound of the window frame is 'CURRENT ROW', which is smaller than the lower bound '1'.; line 1 pos 33
146+
147+
148+
-- !query 10
149+
SELECT val, cate, count(val) OVER(PARTITION BY cate
150+
RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val
151+
-- !query 10 schema
152+
struct<>
153+
-- !query 10 output
154+
org.apache.spark.sql.AnalysisException
155+
cannot resolve '(PARTITION BY testdata.`cate` RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: A range window frame cannot be used in an unordered window specification.; line 1 pos 33
156+
157+
158+
-- !query 11
129159
SELECT val, cate,
130160
max(val) OVER w AS max,
131161
min(val) OVER w AS min,
@@ -152,9 +182,9 @@ approx_count_distinct(val) OVER w AS approx_count_distinct
152182
FROM testData
153183
WINDOW w AS (PARTITION BY cate ORDER BY val)
154184
ORDER BY cate, val
155-
-- !query 8 schema
185+
-- !query 11 schema
156186
struct<val:int,cate:string,max:int,min:int,min:int,count:bigint,sum:bigint,avg:double,stddev:double,first_value:int,first_value_ignore_null:int,first_value_contain_null:int,last_value:int,last_value_ignore_null:int,last_value_contain_null:int,rank:int,dense_rank:int,cume_dist:double,percent_rank:double,ntile:int,row_number:int,var_pop:double,var_samp:double,approx_count_distinct:bigint>
157-
-- !query 8 output
187+
-- !query 11 output
158188
NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0
159189
3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1
160190
NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0
@@ -166,11 +196,11 @@ NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.
166196
3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3
167197

168198

169-
-- !query 9
199+
-- !query 12
170200
SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val
171-
-- !query 9 schema
201+
-- !query 12 schema
172202
struct<val:int,cate:string,avg(CAST(NULL AS DOUBLE)) OVER (PARTITION BY cate ORDER BY val ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):double>
173-
-- !query 9 output
203+
-- !query 12 output
174204
NULL NULL NULL
175205
3 NULL NULL
176206
NULL a NULL
@@ -182,20 +212,20 @@ NULL a NULL
182212
3 b NULL
183213

184214

185-
-- !query 10
215+
-- !query 13
186216
SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val
187-
-- !query 10 schema
217+
-- !query 13 schema
188218
struct<>
189-
-- !query 10 output
219+
-- !query 13 output
190220
org.apache.spark.sql.AnalysisException
191221
Window function row_number() requires window to be ordered, please add ORDER BY clause. For example SELECT row_number()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table;
192222

193223

194-
-- !query 11
224+
-- !query 14
195225
SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val
196-
-- !query 11 schema
226+
-- !query 14 schema
197227
struct<val:int,cate:string,sum(CAST(val AS BIGINT)) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):bigint,avg(CAST(val AS BIGINT)) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):double>
198-
-- !query 11 output
228+
-- !query 14 output
199229
NULL NULL 13 1.8571428571428572
200230
3 NULL 13 1.8571428571428572
201231
NULL a 13 1.8571428571428572
@@ -207,7 +237,7 @@ NULL a 13 1.8571428571428572
207237
3 b 13 1.8571428571428572
208238

209239

210-
-- !query 12
240+
-- !query 15
211241
SELECT val, cate,
212242
first_value(false) OVER w AS first_value,
213243
first_value(true, true) OVER w AS first_value_ignore_null,
@@ -218,9 +248,9 @@ last_value(false, false) OVER w AS last_value_contain_null
218248
FROM testData
219249
WINDOW w AS ()
220250
ORDER BY cate, val
221-
-- !query 12 schema
251+
-- !query 15 schema
222252
struct<val:int,cate:string,first_value:boolean,first_value_ignore_null:boolean,first_value_contain_null:boolean,last_value:boolean,last_value_ignore_null:boolean,last_value_contain_null:boolean>
223-
-- !query 12 output
253+
-- !query 15 output
224254
NULL NULL false true false false true false
225255
3 NULL false true false false true false
226256
NULL a false true false false true false

0 commit comments

Comments
 (0)