diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 2bcbb92f1a469..d95a0b0a2b2be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -59,9 +59,10 @@ object Literal { case s: String => Literal(UTF8String.fromString(s), StringType) case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) + case d: BigDecimal => + Literal(Decimal(d), DecimalType.fromJVMDecimal(d.precision, d.scale)) case d: JavaBigDecimal => - Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) + Literal(Decimal(d), DecimalType.fromJVMDecimal(d.precision, d.scale)) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 15004e4b9667d..02a6159502672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -140,9 +140,15 @@ object DecimalType extends AbstractDataType { } private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match { - case v: Short => fromBigDecimal(BigDecimal(v)) - case v: Int => fromBigDecimal(BigDecimal(v)) - case v: Long => fromBigDecimal(BigDecimal(v)) + case v: Short => + val bd = BigDecimal(v) + fromJVMDecimal(bd.precision, bd.scale) + case v: Int => + val bd = BigDecimal(v) + fromJVMDecimal(bd.precision, bd.scale) + case v: Long => + val bd = BigDecimal(v) + fromJVMDecimal(bd.precision, bd.scale) case _ => forType(literal.dataType) } @@ -150,6 +156,27 @@ object DecimalType extends AbstractDataType { DecimalType(Math.max(d.precision, d.scale), d.scale) } + /** + * The JVM decimal (BigDecimal in Java/Scala) has a different definition of precision and scale + * compared to Spark SQL. E.g. in JVM decimal the digit count starts from the leftmost nonzero + * digit of the exact result, while Spark SQL counts digits start from the dot. This means, for + * "0.001", JVM decimal thinks the precision is 1 and scale is 3, Spark SQL thinks the precision + * and scale are both 3. JVM decimal allows negative scale, while Spark SQL can't handle it well. + * This method creates a DecimalType from a JVM decimal's precision and scale, with some proper + * translations. + */ + private[sql] def fromJVMDecimal(precision: Int, scale: Int): DecimalType = { + if (scale < 0) { + // Spark SQL accepts negative scale by accident, and the behavior is confusing. We should + // avoid creating negative-scale DecimalType by ourselves. + DecimalType(precision - scale, 0) + } else { + // For JVM decimal, scale can be larger than prevision. In this case, Spark SQL should take + // the scale of JVM decimal as precision, to satisfy the prevision definition in Spark SQL. + DecimalType(Math.max(precision, scale), scale) + } + } + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index 28a0e20c0f495..4f65aea17e1a4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -83,4 +83,7 @@ select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.1 select 123456789123456789.1234567890 * 1.123456789123456789; select 12345678912345.123456789123 / 0.000000012345678; +-- division with integer in scientific notation +select 26393499451 / 1000e6; + drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index cbf44548b3cce..8427a3295f18e 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 40 +-- Number of queries: 41 -- !query 0 @@ -116,7 +116,7 @@ struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.00000000000000000 -- !query 13 select 2.35E10 * 1.0 -- !query 13 schema -struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(14,1)> -- !query 13 output 23500000000 @@ -156,7 +156,7 @@ NULL -- !query 18 select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 18 schema -struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(31,0)) * CAST(1.2345678901234567890E+25 AS DECIMAL(31,0))):decimal(38,0)> -- !query 18 output NULL @@ -258,7 +258,7 @@ NULL -- !query 30 select 2.35E10 * 1.0 -- !query 30 schema -struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(14,1)> -- !query 30 output 23500000000 @@ -298,7 +298,7 @@ NULL -- !query 35 select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 35 schema -struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(31,0)) * CAST(1.2345678901234567890E+25 AS DECIMAL(31,0))):decimal(38,0)> -- !query 35 output NULL @@ -328,8 +328,16 @@ NULL -- !query 39 -drop table decimals_test +select 26393499451 / 1000e6 -- !query 39 schema -struct<> +struct<(CAST(CAST(26393499451 AS DECIMAL(11,0)) AS DECIMAL(11,0)) / CAST(1.000E+9 AS DECIMAL(11,0))):decimal(22,11)> -- !query 39 output +26.393499451 + + +-- !query 40 +drop table decimals_test +-- !query 40 schema +struct<> +-- !query 40 output