Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,32 +137,14 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}

def checkSupportedGroupingDataType(
expressionString: String,
dataType: DataType): Unit = dataType match {
case BinaryType =>
failAnalysis(s"expression $expressionString cannot be used in " +
s"grouping expression because it is in binary type or its inner field is " +
s"in binary type")
case a: ArrayType =>
failAnalysis(s"expression $expressionString cannot be used in " +
s"grouping expression because it is in array type or its inner field is " +
s"in array type")
case m: MapType =>
failAnalysis(s"expression $expressionString cannot be used in " +
s"grouping expression because it is in map type or its inner field is " +
s"in map type")
case s: StructType =>
s.fields.foreach { f =>
checkSupportedGroupingDataType(expressionString, f.dataType)
}
case udt: UserDefinedType[_] =>
checkSupportedGroupingDataType(expressionString, udt.sqlType)
case _ => // OK
}

def checkValidGroupingExprs(expr: Expression): Unit = {
checkSupportedGroupingDataType(expr.prettyString, expr.dataType)
// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
failAnalysis(
s"expression ${expr.prettyString} cannot be used as a grouping expression " +
s"because its data type ${expr.dataType.simpleString} is not a orderable " +
s"data type.")
}

if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,55 @@ class CodeGenContext {
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case NullType => "0"
case array: ArrayType =>
val elementType = array.elementType
val elementA = freshName("elementA")
val isNullA = freshName("isNullA")
val elementB = freshName("elementB")
val isNullB = freshName("isNullB")
val compareFunc = freshName("compareArray")
val i = freshName("i")
val minLength = freshName("minLength")
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
int lengthA = a.numElements();
int lengthB = b.numElements();
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
boolean $isNullA;
boolean $isNullB;
${javaType(elementType)} $elementA;
${javaType(elementType)} $elementB;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could be defined in the loop (let compiler to optimize them easily)

for (int $i = 0; $i < $minLength; $i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i should be enough here

$isNullA = a.isNullAt($i);
$isNullB = b.isNullAt($i);

if ($isNullA && $isNullB) {
// Nothing
} else if ($isNullA) {
return -1;
} else if ($isNullB) {
return 1;
} else {
$elementA = ${getValue("a", elementType, i)};
$elementB = ${getValue("b", elementType, i)};
int comp = ${genComp(elementType, elementA, elementB)};
if (comp != 0) {
return comp;
}
}
}

if (lengthA < lengthB) {
return -1;
} else if (lengthA > lengthB) {
return 1;
}
return 0;
}
"""
addNewFunction(compareFunc, funcCode)
s"this.$compareFunc($c1, $c2)"
case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._


Expand All @@ -29,35 +30,76 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))

private def compareValue(
left: Any,
right: Any,
dataType: DataType,
direction: SortDirection): Int = {
if (left == null && right == null) {
return 0
} else if (left == null) {
return if (direction == Ascending) -1 else 1
} else if (right == null) {
return if (direction == Ascending) 1 else -1
} else {
dataType match {
case dt: AtomicType if direction == Ascending =>
return dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
case dt: AtomicType if direction == Descending =>
return dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case s: StructType if direction == Ascending =>
return s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case s: StructType if direction == Descending =>
return s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case a: ArrayType =>
val leftArray = left.asInstanceOf[ArrayData]
val rightArray = right.asInstanceOf[ArrayData]
val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
var i = 0
while (i < minLength) {
val isNullLeft = leftArray.isNullAt(i)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines have been moved to compareValue.

val isNullRight = rightArray.isNullAt(i)
if (isNullLeft && isNullRight) {
// Do nothing.
} else if (isNullLeft) {
return if (direction == Ascending) -1 else 1
} else if (isNullRight) {
return if (direction == Ascending) 1 else -1
} else {
val comp =
compareValue(
leftArray.get(i, a.elementType),
rightArray.get(i, a.elementType),
a.elementType,
direction)
if (comp != 0) {
return comp
}
}
i += 1
}
if (leftArray.numElements() < rightArray.numElements()) {
return if (direction == Ascending) -1 else 1
} else if (leftArray.numElements() > rightArray.numElements()) {
return if (direction == Ascending) 1 else -1
} else {
return 0
}
case other =>
throw new IllegalArgumentException(s"Type $other does not support ordered operations")
}
}
}

def compare(a: InternalRow, b: InternalRow): Int = {
var i = 0
while (i < ordering.size) {
val order = ordering(i)
val left = order.child.eval(a)
val right = order.child.eval(b)

if (left == null && right == null) {
// Both null, continue looking.
} else if (left == null) {
return if (order.direction == Ascending) -1 else 1
} else if (right == null) {
return if (order.direction == Ascending) 1 else -1
} else {
val comparison = order.dataType match {
case dt: AtomicType if order.direction == Ascending =>
dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
case dt: AtomicType if order.direction == Descending =>
dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case s: StructType if order.direction == Ascending =>
s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case s: StructType if order.direction == Descending =>
s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case other =>
throw new IllegalArgumentException(s"Type $other does not support ordered operations")
}
if (comparison != 0) {
return comparison
}
val comparison = compareValue(left, right, order.dataType, order.direction)
if (comparison != 0) {
return comparison
}
i += 1
}
Expand Down Expand Up @@ -86,6 +128,8 @@ object RowOrdering {
case NullType => true
case dt: AtomicType => true
case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
case array: ArrayType => isOrderable(array.elementType)
case udt: UserDefinedType[_] => isOrderable(udt.sqlType)
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData}
import org.apache.spark.sql.types._

import scala.beans.{BeanProperty, BeanInfo}
Expand Down Expand Up @@ -53,21 +53,29 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
}

@BeanInfo
private[sql] case class UngroupableData(@BeanProperty data: Array[Int])
private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int])

private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {

override def sqlType: DataType = ArrayType(IntegerType)
override def sqlType: DataType = MapType(IntegerType, IntegerType)

override def serialize(obj: Any): ArrayData = {
override def serialize(obj: Any): MapData = {
obj match {
case groupableData: UngroupableData => new GenericArrayData(groupableData.data)
case groupableData: UngroupableData =>
val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
new ArrayBasedMapData(keyArray, valueArray)
}
}

override def deserialize(datum: Any): UngroupableData = {
datum match {
case data: Array[Int] => UngroupableData(data)
case data: MapData =>
val keyArray = data.keyArray().array
val valueArray = data.valueArray().array
assert(keyArray.length == valueArray.length)
val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]]
UngroupableData(mapData)
}
}

Expand Down Expand Up @@ -154,8 +162,8 @@ class AnalysisErrorSuite extends AnalysisTest {

errorTest(
"sorting by unsupported column types",
listRelation.orderBy('list.asc),
"sort" :: "type" :: "array<int>" :: Nil)
mapRelation.orderBy('map.asc),
"sort" :: "type" :: "map<int,int>" :: Nil)

errorTest(
"non-boolean filters",
Expand Down Expand Up @@ -259,32 +267,33 @@ class AnalysisErrorSuite extends AnalysisTest {
case true =>
assertAnalysisSuccess(plan, true)
case false =>
assertAnalysisError(plan, "expression a cannot be used in grouping expression" :: Nil)
assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil)
}

}

val supportedDataTypes = Seq(
StringType,
StringType, BinaryType,
NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", StringType, nullable = true),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
new GroupableUDT())
supportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = true)
}

val unsupportedDataTypes = Seq(
BinaryType,
ArrayType(IntegerType),
MapType(StringType, LongType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
.add("f2", MapType(StringType, LongType), nullable = true),
new UngroupableUDT())
unsupportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,7 @@ object TestRelations {

val listRelation = LocalRelation(
AttributeReference("list", ArrayType(IntegerType))())

val mapRelation = LocalRelation(
AttributeReference("map", MapType(IntegerType, IntegerType))())
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.expressions

import scala.math._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{Row, RandomDataGenerator}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
Expand Down Expand Up @@ -49,40 +47,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
futures.foreach(Await.result(_, 10.seconds))
}

// Test GenerateOrdering for all common types. For each type, we construct random input rows that
// contain two columns of that type, then for pairs of randomly-generated rows we check that
// GenerateOrdering agrees with RowOrdering.
(DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
test(s"GenerateOrdering with $dataType") {
val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType))
val genOrdering = GenerateOrdering.generate(
BoundReference(0, dataType, nullable = true).asc ::
BoundReference(1, dataType, nullable = true).asc :: Nil)
val rowType = StructType(
StructField("a", dataType, nullable = true) ::
StructField("b", dataType, nullable = true) :: Nil)
val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
assume(maybeDataGenerator.isDefined)
val randGenerator = maybeDataGenerator.get
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
for (_ <- 1 to 50) {
val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
withClue(s"a = $a, b = $b") {
assert(genOrdering.compare(a, a) === 0)
assert(genOrdering.compare(b, b) === 0)
assert(rowOrdering.compare(a, a) === 0)
assert(rowOrdering.compare(b, b) === 0)
assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a)))
assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a)))
assert(
signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
"Generated and non-generated orderings should agree")
}
}
}
}

test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
val length = 5000
val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
Expand Down
Loading