diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 82692334544e..8cda5e55685d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -40,10 +40,13 @@ import org.apache.spark.sql.types._ * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) * e1 * e2 p1 + p2 + 1 s1 + s2 - * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 / e2 max(p1-s1+s2, 0) + max(6, s1+adjP2+1) max(6, s1+adjP2+1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * + * Where adjP2 is p2 - s2 if s2 < 0, p2 otherwise. This adjustment is needed because Spark does not + * forbid decimals with negative scale, while MS SQL and Hive do. + * * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale * needed are out of the range of available values, the scale is reduced up to 6, in order to * prevent the truncation of the integer part of the decimals. @@ -129,16 +132,17 @@ object DecimalPrecision extends TypeCoercionRule { resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val adjP2 = if (s2 < 0) p2 - s2 else p2 val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) // Scale: max(6, s1 + p2 + 1) - val intDig = p1 - s1 + s2 - val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val intDig = max(p1 - s1 + s2, 0) // can be negative if s2 < 0 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + adjP2 + 1) val prec = intDig + scale DecimalType.adjustPrecisionScale(prec, scale) } else { - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + var intDig = max(min(DecimalType.MAX_SCALE, p1 - s1 + s2), 0) // can be negative if s2 < 0 + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + adjP2 + 1)) val diff = (intDig + decDig) - DecimalType.MAX_SCALE if (diff > 0) { decDig -= diff / 2 + 1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 25eddaf06a78..ce70dda42b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -27,12 +27,16 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} /** * The data type representing `java.math.BigDecimal` values. - * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number - * of digits on right side of dot). + * A Decimal represents an exact numeric in which the precision and scale can be arbitrarily sized. + * The precision is the number of significant digits and it can range from 1 to 38. The scale can + * be positive or negative. If zero or positive, the scale is the number of digits to the right of + * the decimal point. If negative, the unscaled value of the number is multiplied by ten to the + * power of the negation of the scale. * - * The precision can be up to 38, scale can also be up to 38 (less or equal to precision). + * Please, notice that not all datasources support negative scales. In that case, writing decimals + * with negative scales can lead to errors and exceptions. * - * The default precision and scale is (10, 0). + * The default precision and scale is (38, 18). * * Please use `DataTypes.createDecimalType()` to create a specific instance. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index bd87ca6017e9..ec3c04b43f79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -276,9 +276,11 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { val a = AttributeReference("a", DecimalType(3, -10))() val b = AttributeReference("b", DecimalType(1, -1))() val c = AttributeReference("c", DecimalType(35, 1))() + val nonNegative = AttributeReference("nn", DecimalType(11, 0))() checkType(Multiply(a, b), DecimalType(5, -11)) checkType(Multiply(a, c), DecimalType(38, -9)) checkType(Multiply(b, c), DecimalType(37, 0)) + checkType(Divide(nonNegative, a), DecimalType(15, 14)) } /** strength reduction for integer/decimal comparisons */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 1318ab185983..31206567a924 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -21,9 +21,12 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -376,4 +379,32 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) assert(ctx2.inlinedMutableStates.size == 1) } + + test("SPARK-25454: decimal operations with negative scale") { + val a = Literal(BigDecimal(1234567891)) + val b = Literal(BigDecimal(100e6)) + val c = Literal(BigDecimal(123456.7891)) + val d = Literal(BigDecimal(678e8)) + Seq(b, d).foreach { l => + assert(l.dataType.isInstanceOf[DecimalType] && + l.dataType.asInstanceOf[DecimalType].scale < 0) + } + Seq("true", "false").foreach { allowPrecLoss => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss) { + checkEvaluationWithAnalysis(Add(a, b), Decimal(BigDecimal(1334567891))) + checkEvaluationWithAnalysis(Add(b, c), Decimal(BigDecimal(100123456.7891))) + checkEvaluationWithAnalysis(Add(b, d), Decimal(BigDecimal(67900e6))) + checkEvaluationWithAnalysis(Subtract(a, b), Decimal(BigDecimal(1134567891))) + checkEvaluationWithAnalysis(Subtract(b, c), Decimal(BigDecimal(99876543.2109))) + checkEvaluationWithAnalysis(Subtract(d, b), Decimal(BigDecimal(67700e6))) + checkEvaluationWithAnalysis(Multiply(a, b), Decimal(BigDecimal(123456789100000000L))) + checkEvaluationWithAnalysis(Multiply(b, c), Decimal(BigDecimal(12345678910000L))) + checkEvaluationWithAnalysis(Multiply(d, b), Decimal(BigDecimal(67800e14))) + checkEvaluationWithAnalysis(Divide(a, b), Decimal(BigDecimal(12.34567891))) + checkEvaluationWithAnalysis(Divide(b, c), Decimal(BigDecimal(810.000007))) + checkEvaluationWithAnalysis(Divide(c, b), Decimal(BigDecimal(0.001234567891))) + checkEvaluationWithAnalysis(Divide(d, b), Decimal(BigDecimal(678))) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 1c91adab7137..e9c0c07adc7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -28,7 +28,8 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry, ResolveTimeZone} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.PlanTestBase @@ -44,6 +45,9 @@ import org.apache.spark.util.Utils trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase { self: SparkFunSuite => + private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + private val analyzer = new Analyzer(catalog, conf) + protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -302,6 +306,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } + protected def checkEvaluationWithAnalysis( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Analyzed($expression)")() :: Nil, OneRowRelation()) + val analyzedPlan = analyzer.execute(plan) + checkEvaluationWithoutCodegen(analyzedPlan.expressions.head, expected, inputRow) + } + protected def checkDoubleEvaluation( expression: => Expression, expected: Spread[Double], diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index 28a0e20c0f49..e8b52da5cfac 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -83,4 +83,7 @@ select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.1 select 123456789123456789.1234567890 * 1.123456789123456789; select 12345678912345.123456789123 / 0.000000012345678; +-- division with negative scale operands +select 26393499451 / 1000e6; + drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index cbf44548b3cc..b2d305b9a3e1 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 40 +-- Number of queries: 41 -- !query 0 @@ -328,8 +328,16 @@ NULL -- !query 39 -drop table decimals_test +select 26393499451 / 1000e6 -- !query 39 schema -struct<> +struct<(CAST(CAST(26393499451 AS DECIMAL(11,0)) AS DECIMAL(11,0)) / CAST(1.000E+9 AS DECIMAL(11,0))):decimal(16,11)> -- !query 39 output +26.393499451 + + +-- !query 40 +drop table decimals_test +-- !query 40 schema +struct<> +-- !query 40 output