Skip to content

Commit 584eb9e

Browse files
committed
[SPARK-16289][SQL] Implement posexplode table generating function
1 parent f454a7f commit 584eb9e

7 files changed

Lines changed: 59 additions & 11 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ object FunctionRegistry {
175175
expression[NullIf]("nullif"),
176176
expression[Nvl]("nvl"),
177177
expression[Nvl2]("nvl2"),
178+
expression[PosExplode]("posexplode"),
178179
expression[Rand]("rand"),
179180
expression[Randn]("randn"),
180181
expression[CreateStruct]("struct"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,10 @@ case class UserDefinedGenerator(
9494
}
9595

9696
/**
97-
* Given an input array produces a sequence of rows for each value in the array.
97+
* A base class for Explode and PosExplode
9898
*/
99-
// scalastyle:off line.size.limit
100-
@ExpressionDescription(
101-
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.")
102-
// scalastyle:on line.size.limit
103-
case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
99+
abstract class ExplodeBase(child: Expression, position: Boolean)
100+
extends UnaryExpression with Generator with CodegenFallback with Serializable {
104101

105102
override def children: Seq[Expression] = child :: Nil
106103

@@ -115,9 +112,19 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
115112

116113
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
117114
override def elementSchema: StructType = child.dataType match {
118-
case ArrayType(et, containsNull) => new StructType().add("col", et, containsNull)
115+
case ArrayType(et, containsNull) =>
116+
if (position) {
117+
new StructType().add("pos", IntegerType, false).add("col", et, containsNull)
118+
} else {
119+
new StructType().add("col", et, containsNull)
120+
}
119121
case MapType(kt, vt, valueContainsNull) =>
120-
new StructType().add("key", kt, false).add("value", vt, valueContainsNull)
122+
if (position) {
123+
new StructType().add("pos", IntegerType, false).add("key", kt, false)
124+
.add("value", vt, valueContainsNull)
125+
} else {
126+
new StructType().add("key", kt, false).add("value", vt, valueContainsNull)
127+
}
121128
}
122129

123130
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -129,7 +136,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
129136
} else {
130137
val rows = new Array[InternalRow](inputArray.numElements())
131138
inputArray.foreach(et, (i, e) => {
132-
rows(i) = InternalRow(e)
139+
rows(i) = if (position) InternalRow(i, e) else InternalRow(e)
133140
})
134141
rows
135142
}
@@ -141,11 +148,33 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
141148
val rows = new Array[InternalRow](inputMap.numElements())
142149
var i = 0
143150
inputMap.foreach(kt, vt, (k, v) => {
144-
rows(i) = InternalRow(k, v)
151+
rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v)
145152
i += 1
146153
})
147154
rows
148155
}
149156
}
150157
}
151158
}
159+
160+
/**
161+
* Given an input array produces a sequence of rows for each value in the array.
162+
*/
163+
// scalastyle:off line.size.limit
164+
@ExpressionDescription(
165+
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of map a into multiple rows and columns.")
166+
// scalastyle:on line.size.limit
167+
case class Explode(child: Expression)
168+
extends ExplodeBase(child, position = false) with Serializable {
169+
}
170+
171+
/**
172+
* Given an input array produces a sequence of rows for each position and value in the array.
173+
*/
174+
// scalastyle:off line.size.limit
175+
@ExpressionDescription(
176+
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows with positions, or the elements of a map into multiple rows and columns with positions.")
177+
// scalastyle:on line.size.limit
178+
case class PosExplode(child: Expression)
179+
extends ExplodeBase(child, position = true) with Serializable {
180+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
166166
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
167167
assertError(Explode('intField),
168168
"input to function explode should be array or map type")
169+
assertError(PosExplode('intField),
170+
"input to function explode should be array or map type")
169171
}
170172

171173
test("check types for CreateNamedStruct") {

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
159159
// Leave an unaliased generator with an empty list of names since the analyzer will generate
160160
// the correct defaults after the nested expression's type has been resolved.
161161
case explode: Explode => MultiAlias(explode, Nil)
162+
case explode: PosExplode => MultiAlias(explode, Nil)
162163

163164
case jt: JsonTuple => MultiAlias(jt, Nil)
164165

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2721,6 +2721,14 @@ object functions {
27212721
*/
27222722
def explode(e: Column): Column = withExpr { Explode(e.expr) }
27232723

2724+
/**
2725+
* Creates a new row for each element with position in the given array or map column.
2726+
*
2727+
* @group collection_funcs
2728+
* @since 2.1.0
2729+
*/
2730+
def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
2731+
27242732
/**
27252733
* Extracts json object from a json string based on json path specified, and returns json string
27262734
* of the extracted json object. It will return null if the input json string is invalid.

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
129129
Row(1) :: Row(2) :: Row(3) :: Nil)
130130
}
131131

132+
test("single posexplode") {
133+
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
134+
checkAnswer(
135+
df.select(posexplode('intList)),
136+
Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
137+
}
138+
132139
test("explode and other columns") {
133140
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
134141

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,6 @@ private[sql] class HiveSessionCatalog(
231231
"xpath_number", "xpath_short", "xpath_string",
232232

233233
// table generating function
234-
"inline", "posexplode"
234+
"inline"
235235
)
236236
}

0 commit comments

Comments
 (0)