Skip to content

Commit 381f17b

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-12201][SQL] add type coercion rule for greatest/least
checked with hive, greatest/least should cast their children to a tightest common type, i.e. `(int, long) => long`, `(int, string) => error`, `(decimal(10,5), decimal(5, 10)) => error` Author: Wenchen Fan <wenchen@databricks.com> Closes apache#10196 from cloud-fan/type-coercion.
1 parent 75c60bf commit 381f17b

3 files changed

Lines changed: 47 additions & 0 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,20 @@ object HiveTypeCoercion {
594594
case None => c
595595
}
596596

597+
case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 =>
598+
val types = children.map(_.dataType)
599+
findTightestCommonType(types) match {
600+
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
601+
case None => g
602+
}
603+
604+
case l @ Least(children) if children.map(_.dataType).distinct.size > 1 =>
605+
val types = children.map(_.dataType)
606+
findTightestCommonType(types) match {
607+
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
608+
case None => l
609+
}
610+
597611
case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
598612
NaNvl(l, Cast(r, DoubleType))
599613
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
3232
'intField.int,
3333
'stringField.string,
3434
'booleanField.boolean,
35+
'decimalField.decimal(8, 0),
3536
'arrayField.array(StringType),
3637
'mapField.map(StringType, LongType))
3738

@@ -189,4 +190,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
189190
assertError(Round('intField, 'mapField), "requires int type")
190191
assertError(Round('booleanField, 'intField), "requires numeric type")
191192
}
193+
194+
test("check types for Greatest/Least") {
195+
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
196+
assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
197+
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
198+
assertError(operator(Seq('intField, 'decimalField)), "should all have the same type")
199+
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
200+
}
201+
}
192202
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,29 @@ class HiveTypeCoercionSuite extends PlanTest {
251251
:: Nil))
252252
}
253253

254+
test("greatest/least cast") {
255+
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
256+
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
257+
operator(Literal(1.0)
258+
:: Literal(1)
259+
:: Literal.create(1.0, FloatType)
260+
:: Nil),
261+
operator(Cast(Literal(1.0), DoubleType)
262+
:: Cast(Literal(1), DoubleType)
263+
:: Cast(Literal.create(1.0, FloatType), DoubleType)
264+
:: Nil))
265+
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
266+
operator(Literal(1L)
267+
:: Literal(1)
268+
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
269+
:: Nil),
270+
operator(Cast(Literal(1L), DecimalType(22, 0))
271+
:: Cast(Literal(1), DecimalType(22, 0))
272+
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
273+
:: Nil))
274+
}
275+
}
276+
254277
test("nanvl casts") {
255278
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
256279
NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),

0 commit comments

Comments
 (0)