Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,20 @@ class WindowSpec private[sql](
private def between(typ: FrameType, start: Long, end: Long): WindowSpec = {
val boundaryStart = start match {
case 0 => CurrentRow
case Long.MinValue => UnboundedPreceding
case x if x < 0 => ValuePreceding(-start.toInt)
case x if x > 0 => ValueFollowing(start.toInt)
case x if x < Int.MinValue => UnboundedPreceding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we throw an exception if x < Int.MinValue and x > Long.MinValue? @hvanhovell what do you think?

BTW I remember we have document to explain this behavior, we should update that too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, the doc is in rangeBetween

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, the type of start and end should not be Long here, but we can not change it for compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @hvanhovell any ideas?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case x if x < 0 && x >= Int.MinValue => ValuePreceding(-start.toInt)
case x if x > 0 && x <= Int.MaxValue => ValueFollowing(start.toInt)
case _ => throw new IllegalArgumentException(s"Boundary start($start) should not be " +
s"larger than Int.MaxValue(${Int.MaxValue}).")
}

val boundaryEnd = end match {
case 0 => CurrentRow
case Long.MaxValue => UnboundedFollowing
case x if x < 0 => ValuePreceding(-end.toInt)
case x if x > 0 => ValueFollowing(end.toInt)
case x if x > Int.MaxValue => UnboundedFollowing
case x if x < 0 && x >= Int.MinValue => ValuePreceding(-end.toInt)
case x if x > 0 && x <= Int.MaxValue => ValueFollowing(end.toInt)
case _ => throw new IllegalArgumentException(s"Boundary end($end) should not be " +
s"smaller than Int.MinValue(${Int.MinValue}).")
}

new WindowSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,4 +423,118 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
df.select(selectList: _*).where($"value" < 2),
Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0)))
}

test("SPARK-19451: Underlying integer overflow in Window function") {
val df = Seq((1L, "a"), (1L, "a"), (2L, "a"), (1L, "b"), (2L, "b"), (3L, "b"))
.toDF("id", "category")
df.createOrReplaceTempView("window_table")

// range frames
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rangeBetween(-2160000000L, -1))),
Seq(
Row(1, "b", null), Row(2, "b", 1), Row(3, "b", 3),
Row(1, "a", null), Row(1, "a", null), Row(2, "a", 2)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rangeBetween(-2160000000L, 0))),
Seq(
Row(1, "b", 1), Row(2, "b", 3), Row(3, "b", 6),
Row(1, "a", 2), Row(1, "a", 2), Row(2, "a", 4)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rangeBetween(-2160000000L, 2))),
Seq(
Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6),
Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rangeBetween(-2160000000L, 2160000000L))),
Seq(
Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6),
Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rangeBetween(-1, 2160000000L))),
Seq(
Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 5),
Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rangeBetween(0, 2160000000L))),
Seq(
Row(1, "b", 6), Row(2, "b", 5), Row(3, "b", 3),
Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 2)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rangeBetween(2, 2160000000L))),
Seq(
Row(1, "b", 3), Row(2, "b", null), Row(3, "b", null),
Row(1, "a", null), Row(1, "a", null), Row(2, "a", null)))

// row frames
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rowsBetween(-2160000000L, -1))),
Seq(
Row(1, "b", null), Row(2, "b", 1), Row(3, "b", 3),
Row(1, "a", null), Row(1, "a", 1), Row(2, "a", 2)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rowsBetween(-2160000000L, 0))),
Seq(
Row(1, "b", 1), Row(2, "b", 3), Row(3, "b", 6),
Row(1, "a", 1), Row(1, "a", 2), Row(2, "a", 4)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rowsBetween(-2160000000L, 2))),
Seq(
Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6),
Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rowsBetween(-2160000000L, 2160000000L))),
Seq(
Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6),
Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rowsBetween(-1, 2160000000L))),
Seq(
Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 5),
Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 3)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rowsBetween(0, 2160000000L))),
Seq(
Row(1, "b", 6), Row(2, "b", 5), Row(3, "b", 3),
Row(1, "a", 4), Row(1, "a", 3), Row(2, "a", 2)))
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rowsBetween(2, 2160000000L))),
Seq(
Row(1, "b", 3), Row(2, "b", null), Row(3, "b", null),
Row(1, "a", 2), Row(1, "a", null), Row(2, "a", null)))
try {
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rowsBetween(-3160000000L, -2160000000L))),
Seq())
assert(false, "Boundary end should not be smaller than Int.MinValue(-2147483648).")
} catch {
case e: IllegalArgumentException =>
// expected
}
try {
checkAnswer(
df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id)
.rowsBetween(2160000000L, 3160000000L))),
Seq())
assert(false, "Boundary start should not be larger than Int.MaxValue(2147483647).")
} catch {
case e: IllegalArgumentException =>
// expected
}
}
}