diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 80916ee9c5379..1f1fb51addfd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -126,7 +126,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(decimal: BigDecimal): Decimal = { this.decimalVal = decimal this.longVal = 0L - this._precision = decimal.precision + if (decimal.precision <= decimal.scale) { + // For Decimal, we expect the precision is equal to or large than the scale, however, + // in BigDecimal, the digit count starts from the leftmost nonzero digit of the exact + // result. For example, the precision of 0.01 equals to 1 based on the definition, but + // the scale is 2. The expected precision should be 3. + this._precision = decimal.scale + 1 + } else { + this._precision = decimal.precision + } this._scale = decimal.scale this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 93c231e30b49b..144f3d688d402 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -32,6 +32,16 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("0.09")), "0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("0.9")), "0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("0.90")), "0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("0.0")), "0.0", 2, 1) + checkDecimal(Decimal(BigDecimal("0")), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("1.0")), "1.0", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.09")), "-0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("-0.9")), "-0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.90")), "-0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("-1.0")), "-1.0", 2, 1) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql index f62b10ca0037b..492a405d7ebbd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql @@ -32,3 +32,27 @@ select 1 - 2; select 2 * 5; select 5 % 3; select pmod(-7, 3); + +-- math functions +select cot(1); +select cot(null); +select cot(0); +select cot(-1); + +-- ceil and ceiling +select ceiling(0); +select ceiling(1); +select ceil(1234567890123456); +select ceiling(1234567890123456); +select ceil(0.01); +select ceiling(-0.10); + +-- floor +select floor(0); +select floor(1); +select floor(1234567890123456); +select floor(0.01); +select floor(-0.10); + +-- comparison operator +select 1 > 0.00001 \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out index ce42c016a7100..3811cd2c30986 100644 --- a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 28 +-- Number of queries: 44 -- !query 0 @@ -224,3 +224,135 @@ select pmod(-7, 3) struct -- !query 27 output 2 + + +-- !query 28 +select cot(1) +-- !query 28 schema +struct<> +-- !query 28 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 29 +select cot(null) +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 30 +select cot(0) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 31 +select cot(-1) +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 32 +select ceiling(0) +-- !query 32 schema +struct +-- !query 32 output +0 + + +-- !query 33 +select ceiling(1) +-- !query 33 schema +struct +-- !query 33 output +1 + + +-- !query 34 +select ceil(1234567890123456) +-- !query 34 schema +struct +-- !query 34 output +1234567890123456 + + +-- !query 35 +select ceiling(1234567890123456) +-- !query 35 schema +struct +-- !query 35 output +1234567890123456 + + +-- !query 36 +select ceil(0.01) +-- !query 36 schema +struct +-- !query 36 output +1 + + +-- !query 37 +select ceiling(-0.10) +-- !query 37 schema +struct +-- !query 37 output +0 + + +-- !query 38 +select floor(0) +-- !query 38 schema +struct +-- !query 38 output +0 + + +-- !query 39 +select floor(1) +-- !query 39 schema +struct +-- !query 39 output +1 + + +-- !query 40 +select floor(1234567890123456) +-- !query 40 schema +struct +-- !query 40 output +1234567890123456 + + +-- !query 41 +select floor(0.01) +-- !query 41 schema +struct +-- !query 41 output +0 + + +-- !query 42 +select floor(-0.10) +-- !query 42 schema +struct +-- !query 42 output +-1 + + +-- !query 43 +select 1 > 0.00001 +-- !query 43 schema +struct<(CAST(1 AS BIGINT) > 0):boolean> +-- !query 43 output +true