Skip to content
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
282e724
[SPARK-23736][SQL] Implementation of the concat_arrays function conca…
Mar 13, 2018
aa5a089
[SPARK-23736][SQL] Code style fixes.
Mar 26, 2018
90d3ab7
[SPARK-23736][SQL] Improving the description of the ConcatArrays expr…
Mar 26, 2018
bb46c3d
[SPARK-23736][SQL] Merging concat and concat_arrays into one function.
Mar 26, 2018
11205af
[SPARK-23736][SQL] Adding new line at the end of the unresolved.scala…
Mar 26, 2018
753499d
[SPARK-23736][SQL] Fixing failing unit test from DDLSuite.
Mar 26, 2018
2efdd77
[SPARK-23736][SQL] Changing method styling according to the standards.
Mar 27, 2018
fd84bee
[SPARK-23736][SQL] Changing data type to ArrayType(StringType) for th…
Mar 27, 2018
116f91f
[SPARK-23736][SQL] Fixing a SparkR unit test by filtering out Unresol…
Mar 27, 2018
e199ac5
[SPARK-23736][SQL] Merging the current master into the feature branch.
Mar 28, 2018
067c2db
[SPARK-23736][SQL] Merging the current master to the feature branch.
Mar 29, 2018
090929f
[SPARK-23736][SQL] Merging string concat and array concat into one ex…
Apr 6, 2018
8abd1a8
[SPARK-23736][SQL] Adding more test cases
Apr 7, 2018
367ee22
[SPARK-23736][SQL] Optimizing null elements protection.
Apr 7, 2018
6bb33e6
[SPARK-23736][SQL] Protection against the length limit of Java functions
Apr 12, 2018
57b250c
Merge remote-tracking branch 'spark/master' into feature/array-api-co…
Apr 12, 2018
944e0c9
[SPARK-23736][SQL] Adding test for the limit of Java function size.
Apr 12, 2018
7f5124b
[SPARK-23736][SQL] Adding more tests
Apr 13, 2018
0201e4b
[SPARK-23736][SQL] Checks of max array size + Rewriting codegen using…
Apr 16, 2018
600ae89
[SPARK-23736][SQL] Merging current master into the feature branch.
Apr 16, 2018
f2a67e8
[SPARK-23736][SQL] Fixing exception messages
Apr 17, 2018
8a125d9
[SPARK-23736][SQL] Small refactoring
Apr 18, 2018
5a4cc8c
[SPARK-23736][SQL] Merging current master to the feature branch
Apr 18, 2018
f7bdcf7
[SPARK-23736][SQL] Merging current master to the feature branch.
Apr 19, 2018
36d5d25
[SPARK-23736][SQL] Merging current master to the feature branch.
Apr 19, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,21 +1426,6 @@ def hash(*cols):
del _name, _doc


@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.

>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))


@since(1.5)
@ignore_unicode_prefix
def concat_ws(sep, *cols):
Expand Down Expand Up @@ -1846,6 +1831,25 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
The function works with strings, binary and compatible array columns.

>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]

>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we move this down .. ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole file is divide into sections according to groups of functions. Based on @gatorsmile's suggestion, the concat function should be categorized as a collection function. So I moved the function to comply with the file structure.



@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ object FunctionRegistry {
expression[BitLength]("bit_length"),
expression[Length]("char_length"),
expression[Length]("character_length"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"),
expression[Elt]("elt"),
Expand Down Expand Up @@ -408,6 +407,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[Concat]("concat"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,14 @@ object TypeCoercion {
case None => a
}

case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
case None => c
}

case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ import java.util.Comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

/**
* Given an array or map, returns its size. Returns -1 if null.
Expand Down Expand Up @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: Expression)

override def prettyName: String = "array_contains"
}

/**
* Concatenates multiple input columns together into a single column.
* The function works with strings, binary and compatible array columns.
*/
@ExpressionDescription(
usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.",
examples = """
Examples:
> SELECT _FUNC_('Spark', 'SQL');
SparkSQL
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
| [1,2,3,4,5,6]
""")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add since too?

...
       [1,2,3,4,5,6]
  """,
  since = "2.4.0")

case class Concat(children: Seq[Expression]) extends Expression {

val allowedTypes = Seq(StringType, BinaryType, ArrayType)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
val childTypes = children.map(_.dataType)
if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
return TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should have been StringType, BinaryType or ArrayType," +
s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]"))
}
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
}
}

override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)

lazy val javaType: String = CodeGenerator.javaType(dataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this into doGenCode() method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! But I think it would be better to reuse javaType also in genCodeForPrimitiveArrays and genCodeForNonPrimitiveArrays.


override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this pattern match will probably cause significant regression in the interpreted (non-codegen) mode, due to the way scala pattern matching is implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've created #22471 to call the pattern matching only once.

WDYT about Reverse? It looks like a similar problem.

case BinaryType =>
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
case StringType =>
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs : _*)
case ArrayType(elementType, _) =>
val inputs = children.toStream.map(_.eval(input))
if (inputs.contains(null)) {
null
} else {
val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType))
Copy link
Member

@kiszk kiszk Apr 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we always allocate an concatenated array? I think that the total array element size may be overflow in some cases.

new GenericArrayData(elements)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val args = ctx.freshName("args")

val inputs = evals.zipWithIndex.map { case (eval, index) =>
s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
}

val (concatenator, initCode) = dataType match {
case BinaryType =>
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, _) =>
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrayConcat(ctx, elementType)
} else {
genCodeForComplexArrayConcat(ctx, elementType)
}
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"${javaType}[]", args) :: Nil)
ev.copy(s"""
$initCode
$codes
${javaType} ${ev.value} = $concatenator.concat($args);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: $javaType

boolean ${ev.isNull} = ${ev.value} == null;
""")
}

private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
val tempVariableName = ctx.freshName("tempNumElements")
val numElementsConstant = ctx.freshName("numElements")
val assignments = (0 until children.length)
.map(idx => s"$tempVariableName[0] += args[$idx].numElements();")

val assignmentSection = ctx.splitExpressions(
expressions = assignments,
funcName = "complexArrayConcat",
arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName)))

(s"""
|int[] $tempVariableName = new int[]{0};
|$assignmentSection
|final int $numElementsConstant = $tempVariableName[0];
""".stripMargin,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can simply use for-loop here?

int $tempVariableName = 0;
for (int $idx = 0; $idx < ${children.length}; $idx++) {
  $tempVariableName += args[$idx].numElements();
}
final int $numElementsConstant = $tempVariableName;

numElementsConstant)
}

private def nullArgumentProtection(ctx: CodegenContext) : String = {
val isNullVariable = ctx.freshName("isArrayNull")
val assignments = children
.zipWithIndex
.filter(_._1.nullable)
.map(ci => s"$isNullVariable[0] |= args[${ci._2}] == null;")

if (assignments.length > 0) {
val assignmentSection = ctx.splitExpressions(
expressions = assignments,
funcName = "isNullArrayConcat",
arguments = Seq((s"${javaType}[]", "args"), ("boolean[]", isNullVariable)))

s"""
|boolean[] $isNullVariable = new boolean[]{false};
|$assignmentSection;
|if ($isNullVariable[0]) return null;
""".stripMargin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can simply use for-loop here?

for (int $idx = 0; $idx < ${children.length}; $idx++) {
  if (args[$idx] == null) {
    return null;
  }
}

We can return as soon as we found null in this case.

} else {
""
}
}

private def genCodeForPrimitiveArrayConcat(ctx: CodegenContext, elementType: DataType): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)

val unsafeArraySizeInBytes = s"""
|int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) +
|${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord(
| ${elementType.defaultSize} * $numElemName
|);
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET

val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val assignments = (0 until children.length).map { idx =>
s"""
|for (int z = 0; z < args[$idx].numElements(); z++) {
| if (args[$idx].isNullAt(z)) {
| $arrayData.setNullAt($counter[0]);
| } else {
| $arrayData.set$primitiveValueTypeName(
| $counter[0],
| ${CodeGenerator.getValue(s"args[$idx]", elementType, "z")}
| );
| }
| $counter[0]++;
|}
""".stripMargin
}
val assignmentSection = ctx.splitExpressions(
expressions = assignments,
funcName = "primitiveArrayConcat",
arguments = Seq(
(s"${javaType}[]", "args"),
("UnsafeArrayData", arrayData),
("int[]", counter)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can simply use for-loop here?

for (int $idx = 0; $idx < ${children.length}; $idx++) {
  for (int z = 0; z < args[$idx].numElements(); z++) {
    ...
  }
}


s"""new Object() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

s"""
   |new Object() {
...

| public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) {
| ${nullArgumentProtection(ctx)}
| $numElemCode
| $unsafeArraySizeInBytes
| byte[] $arrayName = new byte[$arraySizeName];
| UnsafeArrayData $arrayData = new UnsafeArrayData();
| Platform.putLong($arrayName, $baseOffset, $numElemName);
| $arrayData.pointTo($arrayName, $baseOffset, $arraySizeName);
| int[] $counter = new int[]{0};
| $assignmentSection
| return $arrayData;
| }
|}""".stripMargin
}

private def genCodeForComplexArrayConcat(ctx: CodegenContext, elementType: DataType): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayData = ctx.freshName("arrayObjects")
val counter = ctx.freshName("counter")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)

val assignments = (0 until children.length).map { idx =>
s"""
|for (int z = 0; z < args[$idx].numElements(); z++) {
| $arrayData[$counter[0]] = ${CodeGenerator.getValue(s"args[$idx]", elementType, "z")};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to check null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we operate only with non-primitive types where null is treated as a regular value so the null check shouldn't be necessary.
The added tests should cover this scenario.

| $counter[0]++;
|}
""".stripMargin
}
val assignmentSection = ctx.splitExpressions(
expressions = assignments,
funcName = "complexArrayConcat",
arguments = Seq((s"${javaType}[]", "args"), ("Object[]", arrayData), ("int[]", counter)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can simply use for-loop here?

for (int $idx = 0; $idx < ${children.length}; $idx++) {
  for (int z = 0; z < args[$idx].numElements(); z++) {
    ...
  }
}


s"""new Object() {
| public ArrayData concat(${CodeGenerator.javaType(dataType)}[] args) {
| ${nullArgumentProtection(ctx)}
| $numElemCode
| Object[] $arrayData = new Object[$numElemName];
| int[] $counter = new int[]{0};
| $assignmentSection
| return new $genericArrayClass($arrayData);
| }
|}""".stripMargin
}

override def toString: String = s"concat(${children.mkString(", ")})"

override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}
Loading