Skip to content

Commit 90efeff

Browse files
committed
[SPARK-21274] Add a new generator function replicate_rows to support EXCEPT ALL and INTERSECT ALL
1 parent 4d5de4d commit 90efeff

7 files changed

Lines changed: 241 additions & 0 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ object FunctionRegistry {
212212
expression[Rand]("rand"),
213213
expression[Randn]("randn"),
214214
expression[Stack]("stack"),
215+
expression[ReplicateRows]("replicate_rows"),
215216
expression[CaseWhen]("when"),
216217

217218
// math functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ object TypeCoercion {
6262
new ImplicitTypeCasts(conf) ::
6363
DateTimeOperations ::
6464
WindowFrameCoercion ::
65+
ReplicateRowsCoercion ::
6566
Nil
6667

6768
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
@@ -702,6 +703,21 @@ object TypeCoercion {
702703
}
703704
}
704705

706+
/**
707+
* Coerces first argument in ReplicateRows expression and introduces a cast to Long
708+
* if necessary.
709+
*/
710+
object ReplicateRowsCoercion extends TypeCoercionRule {
711+
private val acceptedTypes = Seq(LongType, IntegerType, ShortType, ByteType)
712+
override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
713+
case s @ ReplicateRows(children)
714+
if s.childrenResolved && acceptedTypes.contains(s.children.head.dataType) =>
715+
val numRowExpr = s.children.head
716+
val castedExpr = ImplicitTypeCasts.implicitCast(numRowExpr, LongType).getOrElse(numRowExpr)
717+
ReplicateRows(Seq(castedExpr) ++ s.children.tail)
718+
}
719+
}
720+
705721
/**
706722
* Coerces the types of [[Concat]] children to expected ones.
707723
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2626
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
2727
import org.apache.spark.sql.types._
2828

29+
2930
/**
3031
* An expression that produces zero or more rows given a single input row.
3132
*
@@ -222,6 +223,51 @@ case class Stack(children: Seq[Expression]) extends Generator {
222223
}
223224
}
224225

226+
/**
227+
* Replicate the row based N times. N is specified as the first argument to the function.
228+
* {{{
229+
* SELECT replicate_rows(2, "val1", "val2") ->
230+
* 2 val1 val2
231+
* 2 val1 val2
232+
* }}}
233+
*/
234+
@ExpressionDescription(
235+
usage = "_FUNC_(n, expr1, ..., exprk) - Replicates `expr1`, ..., `exprk` into `n` rows.",
236+
examples = """
237+
Examples:
238+
> SELECT _FUNC_(2, "val1", "val2");
239+
2 val1 val2
240+
2 val1 val2
241+
""")
242+
case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback {
243+
override def checkInputDataTypes(): TypeCheckResult = {
244+
if (children.length < 2) {
245+
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.")
246+
} else if (children.head.dataType != LongType) {
247+
TypeCheckResult.TypeCheckFailure("The number of rows must be a positive long value.")
248+
} else {
249+
TypeCheckResult.TypeCheckSuccess
250+
}
251+
}
252+
253+
override def elementSchema: StructType =
254+
StructType(children.zipWithIndex.map {
255+
case (e, index) => StructField(s"col$index", e.dataType)
256+
})
257+
258+
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
259+
val numRows = children.head.eval(input).asInstanceOf[Long]
260+
val values = children.map(_.eval(input)).toArray
261+
Range.Long(0, numRows, 1).map { i =>
262+
val fields = new Array[Any](children.length)
263+
for (col <- 0 until children.length) {
264+
fields.update(col, values(col))
265+
}
266+
InternalRow(fields: _*)
267+
}
268+
}
269+
}
270+
225271
/**
226272
* Wrapper around another generator to specify outer behavior. This is used to implement functions
227273
* such as explode_outer. This expression gets replaced during analysis.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,31 @@ class TypeCoercionSuite extends AnalysisTest {
13531353
SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing))
13541354
)
13551355
}
1356+
1357+
test("type coercion for ReplicateRows") {
1358+
val rule = TypeCoercion.ReplicateRowsCoercion
1359+
// Cast is setup to promote the first expression to Long
1360+
// for numeric types.
1361+
ruleTest(rule,
1362+
ReplicateRows(Seq(1.toShort, Literal("rowdata"))),
1363+
ReplicateRows(Seq(Cast(1.toShort, LongType), Literal("rowdata"))))
1364+
ruleTest(rule,
1365+
ReplicateRows(Seq(1, Literal("rowdata"))),
1366+
ReplicateRows(Seq(Cast(1, LongType), Literal("rowdata"))))
1367+
ruleTest(rule,
1368+
ReplicateRows(Seq(1.toByte, Literal("rowdata"))),
1369+
ReplicateRows(Seq(Cast(1.toByte, LongType), Literal("rowdata"))))
1370+
1371+
// No cast here since the expected type is Long.
1372+
ruleTest(rule,
1373+
ReplicateRows(Seq(1L, Literal("rowdata"))),
1374+
ReplicateRows(Seq(1L, Literal("rowdata"))))
1375+
1376+
// No type coercion when first expression is a non numeric type.
1377+
ruleTest(rule,
1378+
ReplicateRows(Seq(Literal("invalid"), Literal("rowdata"))),
1379+
ReplicateRows(Seq(Literal("invalid"), Literal("rowdata"))))
1380+
}
13561381
}
13571382

13581383

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
2+
(1, 'row1', 1.1),
3+
(2, 'row2', 2.2),
4+
(0, 'row3', 3.3),
5+
(-1,'row4', 4.4),
6+
(null,'row5', 5.5),
7+
(3, 'row6', null)
8+
AS tab1(c1, c2, c3);
9+
10+
-- c1, c2 replicated c1 times
11+
SELECT replicate_rows(c1, c2) FROM tab1;
12+
13+
-- c1, c2, c2 repeated replicated c1 times
14+
SELECT replicate_rows(c1, c2, c2) FROM tab1;
15+
16+
-- c1, c2, c2, c3 replicated c1 times
17+
SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1;
18+
19+
-- Used as a derived table in FROM clause.
20+
SELECT c2, c1
21+
FROM (
22+
SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1
23+
);
24+
25+
-- column expression.
26+
SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1;
27+
28+
-- Clean-up
29+
DROP VIEW IF EXISTS tab1;
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
-- Automatically generated by SQLQueryTestSuite
2+
-- Number of queries: 7
3+
4+
5+
-- !query 0
6+
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
7+
(1, 'row1', 1.1),
8+
(2, 'row2', 2.2),
9+
(0, 'row3', 3.3),
10+
(-1,'row4', 4.4),
11+
(null,'row5', 5.5),
12+
(3, 'row6', null)
13+
AS tab1(c1, c2, c3)
14+
-- !query 0 schema
15+
struct<>
16+
-- !query 0 output
17+
18+
19+
20+
-- !query 1
21+
SELECT replicate_rows(c1, c2) FROM tab1
22+
-- !query 1 schema
23+
struct<col0:bigint,col1:string>
24+
-- !query 1 output
25+
1 row1
26+
2 row2
27+
2 row2
28+
3 row6
29+
3 row6
30+
3 row6
31+
32+
33+
-- !query 2
34+
SELECT replicate_rows(c1, c2, c2) FROM tab1
35+
-- !query 2 schema
36+
struct<col0:bigint,col1:string,col2:string>
37+
-- !query 2 output
38+
1 row1 row1
39+
2 row2 row2
40+
2 row2 row2
41+
3 row6 row6
42+
3 row6 row6
43+
3 row6 row6
44+
45+
46+
-- !query 3
47+
SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1
48+
-- !query 3 schema
49+
struct<col0:bigint,col1:string,col2:string,col3:string,col4:decimal(2,1)>
50+
-- !query 3 output
51+
1 row1 row1 row1 1.1
52+
2 row2 row2 row2 2.2
53+
2 row2 row2 row2 2.2
54+
3 row6 row6 row6 NULL
55+
3 row6 row6 row6 NULL
56+
3 row6 row6 row6 NULL
57+
58+
59+
-- !query 4
60+
SELECT c2, c1
61+
FROM (
62+
SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1
63+
)
64+
-- !query 4 schema
65+
struct<c2:string,c1:bigint>
66+
-- !query 4 output
67+
row1 1
68+
row2 2
69+
row2 2
70+
row6 3
71+
row6 3
72+
row6 3
73+
74+
75+
-- !query 5
76+
SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1
77+
-- !query 5 schema
78+
struct<col0:bigint,col1:string,col2:string>
79+
-- !query 5 output
80+
1 row1... row1
81+
2 row2... row2
82+
2 row2... row2
83+
3 row6... row6
84+
3 row6... row6
85+
3 row6... row6
86+
87+
88+
-- !query 6
89+
DROP VIEW IF EXISTS tab1
90+
-- !query 6 schema
91+
struct<>
92+
-- !query 6 output
93+

sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,37 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
307307
sql("select * from values 1, 2 lateral view outer empty_gen() a as b"),
308308
Row(1, null) :: Row(2, null) :: Nil)
309309
}
310+
311+
test("ReplicateRows generator") {
312+
val df = spark.range(1)
313+
314+
// Empty DataFrame suppress the result generation
315+
checkAnswer(spark.emptyDataFrame.selectExpr("replicate_rows(1, 1, 2, 3)"), Nil)
316+
317+
checkAnswer(df.selectExpr("replicate_rows(1, 2.5)"), Row(1, 2.5) :: Nil)
318+
checkAnswer(df.selectExpr("replicate_rows(1, null)"), Row(1, null) :: Nil)
319+
checkAnswer(df.selectExpr("replicate_rows(3, 'row1')"),
320+
Row(3, "row1") :: Row(3, "row1") :: Row(3, "row1") :: Nil)
321+
checkAnswer(df.selectExpr("replicate_rows(-1, 2.5)"), Nil)
322+
323+
// The data for the same column should have the same type.
324+
val msg1 = intercept[AnalysisException] {
325+
df.selectExpr("replicate_rows(1)")
326+
}.getMessage
327+
assert(msg1.contains("requires at least 2 arguments"))
328+
329+
// The data for the same column should have the same type.
330+
val msg2 = intercept[AnalysisException] {
331+
df.selectExpr("replicate_rows('a', 1)")
332+
}.getMessage
333+
assert(msg2.contains("The number of rows must be a positive long value."))
334+
335+
val msg3 = intercept[AnalysisException] {
336+
df.selectExpr("replicate_rows(null, 1)")
337+
}.getMessage
338+
assert(msg3.contains("The number of rows must be a positive long value."))
339+
340+
}
310341
}
311342

312343
case class EmptyGenerator() extends Generator {

0 commit comments

Comments
 (0)