-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11827] [SQL] Adding java.math.BigInteger support in Java type inference for POJOs and Java collections #10125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
3b44c59
18b4a31
4f4d1c8
f5f0cbe
d8b2edb
196b6c6
f37a01e
bb5b01f
bde5820
5f7cd96
ae0be70
741daff
893a49a
bbed47a
536d20c
54cfc24
4bbe1fd
b2dd795
de7757d
db4bb48
8c3e5da
a0eaa40
b1527b7
b26412e
43faed3
3b4e360
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst | |
|
|
||
| import java.lang.{Iterable => JavaIterable} | ||
| import java.math.{BigDecimal => JavaBigDecimal} | ||
| import java.math.{BigInteger => JavaBigInteger} | ||
| import java.sql.{Date, Timestamp} | ||
| import java.util.{Map => JavaMap} | ||
| import javax.annotation.Nullable | ||
|
|
||
| import scala.language.existentials | ||
| import scala.math.{BigInt => ScalaBigInt} | ||
|
|
||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
|
|
@@ -321,11 +323,13 @@ object CatalystTypeConverters { | |
| } | ||
|
|
||
| private class DecimalConverter(dataType: DecimalType) | ||
| extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { | ||
| extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { | ||
| override def toCatalystImpl(scalaValue: Any): Decimal = { | ||
| val decimal = scalaValue match { | ||
| case d: BigDecimal => Decimal(d) | ||
| case d: JavaBigDecimal => Decimal(d) | ||
| case d: JavaBigInteger => Decimal(d) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you hold on until #13008? Then we can revert this change as |
||
| case d: ScalaBigInt => Decimal(d) | ||
| case d: Decimal => d | ||
| } | ||
| if (decimal.changePrecision(dataType.precision, dataType.scale)) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.types | ||
|
|
||
| import java.math.{MathContext, RoundingMode} | ||
| import java.math.{BigInteger, MathContext, RoundingMode} | ||
|
|
||
| import org.apache.spark.annotation.DeveloperApi | ||
|
|
||
|
|
@@ -128,6 +128,23 @@ final class Decimal extends Ordered[Decimal] with Serializable { | |
| this | ||
| } | ||
|
|
||
| /** | ||
| * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. | ||
| */ | ||
| def set(BigIntVal: BigInteger): Decimal = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lower case the variable name
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will change it. |
||
| try { | ||
| this.decimalVal = null | ||
| this.longVal = BigIntVal.longValueExact() | ||
| this._precision = DecimalType.MAX_PRECISION | ||
| this._scale = 0 | ||
| this | ||
| } | ||
| catch { | ||
| case e: ArithmeticException => | ||
| throw new IllegalArgumentException(s"BigInteger ${BigIntVal} too large for decimal") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Set this Decimal to the given Decimal value. | ||
| */ | ||
|
|
@@ -371,6 +388,10 @@ object Decimal { | |
|
|
||
| def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value) | ||
|
|
||
| def apply(value: java.math.BigInteger): Decimal = new Decimal().set(value) | ||
|
|
||
| def apply(value: scala.math.BigInt): Decimal = new Decimal().set(value.bigInteger) | ||
|
|
||
| def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = | ||
| new Decimal().set(value, precision, scale) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.types | ||
|
|
||
| import java.math.BigInteger | ||
|
|
||
| import scala.reflect.runtime.universe.typeTag | ||
|
|
||
| import org.apache.spark.annotation.DeveloperApi | ||
|
|
@@ -109,6 +111,7 @@ object DecimalType extends AbstractDataType { | |
| val MAX_SCALE = 38 | ||
| val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) | ||
| val USER_DEFAULT: DecimalType = DecimalType(10, 0) | ||
| val BIGINT_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 0) | ||
|
||
|
|
||
| // The decimal types compatible with other numeric types | ||
| private[sql] val ByteDecimal = DecimalType(3, 0) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,8 @@ | |
| import java.net.URISyntaxException; | ||
| import java.net.URL; | ||
| import java.util.*; | ||
| import java.math.BigInteger; | ||
| import java.math.BigDecimal; | ||
|
|
||
| import scala.collection.JavaConverters; | ||
| import scala.collection.Seq; | ||
|
|
@@ -130,6 +132,7 @@ public static class Bean implements Serializable { | |
| private Integer[] b = { 0, 1 }; | ||
| private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 }); | ||
| private List<String> d = Arrays.asList("floppy", "disk"); | ||
| private BigInteger e = new BigInteger("1234567"); | ||
|
|
||
| public double getA() { | ||
| return a; | ||
|
|
@@ -146,6 +149,8 @@ public Map<String, int[]> getC() { | |
| public List<String> getD() { | ||
| return d; | ||
| } | ||
|
|
||
| public BigInteger getE() { return e; } | ||
| } | ||
|
|
||
| void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { | ||
|
|
@@ -163,7 +168,9 @@ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { | |
| Assert.assertEquals( | ||
| new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), | ||
| schema.apply("d")); | ||
| Row first = df.select("a", "b", "c", "d").first(); | ||
| Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), | ||
| schema.apply("e")); | ||
| Row first = df.select("a", "b", "c", "d","e").first(); | ||
|
||
| Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); | ||
| // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, | ||
| // verify that it has the expected length, and contains expected elements. | ||
|
|
@@ -182,6 +189,8 @@ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { | |
| for (int i = 0; i < d.length(); i++) { | ||
| Assert.assertEquals(bean.getD().get(i), d.apply(i)); | ||
| } | ||
| // Java.math.BigInteger is equavient to Spark Decimal(38,0) | ||
| Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4).setScale(0)); | ||
|
||
| } | ||
|
|
||
| @Test | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,13 @@ case class ReflectData( | |
| decimalField: java.math.BigDecimal, | ||
| date: Date, | ||
| timestampField: Timestamp, | ||
| seqInt: Seq[Int]) | ||
| seqInt: Seq[Int], | ||
| javaBigInt: java.math.BigInteger, | ||
| scalaBigInt: scala.math.BigInt) | ||
|
|
||
| case class ReflectData3( | ||
| scalaBigInt: scala.math.BigInt | ||
| ) | ||
|
||
|
|
||
| case class NullReflectData( | ||
| intField: java.lang.Integer, | ||
|
|
@@ -77,13 +83,15 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { | |
|
|
||
| test("query case class RDD") { | ||
| val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, | ||
| new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) | ||
| new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3), | ||
| new java.math.BigInteger("1"), scala.math.BigInt(1)) | ||
| Seq(data).toDF().registerTempTable("reflectData") | ||
|
|
||
| assert(sql("SELECT * FROM reflectData").collect().head === | ||
| Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, | ||
| new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), | ||
| new Timestamp(12345), Seq(1, 2, 3))) | ||
| new Timestamp(12345), Seq(1, 2, 3), new java.math.BigDecimal(1), | ||
| new java.math.BigDecimal(1))) | ||
| } | ||
|
|
||
| test("query case class RDD with nulls") { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this file? I think we should use encoders most of the time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I will take this out. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello Wenchen: I have to keep case d: JavaBigInteger => Decimal(d) there, otherwise, this testcase will fail with the java.math.BigInteger.
@test
public void testCreateDataFrameFromLocalJavaBeans() {
Bean bean = new Bean();
List data = Arrays.asList(bean);
Dataset df = spark.createDataFrame(data, Bean.class);
validateDataFrameWithBeans(bean, df);
}
here is the trace
scala.MatchError: 1234567 (of class java.math.BigInteger)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$DecimalConverter.toCatalystImpl(CatalystTypeConverters.scala:326)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$DecimalConverter.toCatalystImpl(CatalystTypeConverters.scala:323)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:102)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:401)
at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1$$anonfun$apply$1.apply(SQLContext.scala:892)
at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1$$anonfun$apply$1.apply(SQLContext.scala:892)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1.apply(SQLContext.scala:892)
at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1.apply(SQLContext.scala:890)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
at scala.collection.Iterator$class.toStream(Iterator.scala:1322)
at scala.collection.AbstractIterator.toStream(Iterator.scala:1336)
at scala.collection.TraversableOnce$class.toSeq(TraversableOnce.scala:298)
at scala.collection.AbstractIterator.toSeq(Iterator.scala:1336)
at org.apache.spark.sql.SparkSession.createDataFrame(SparkSession.scala:373)
at test.org.apache.spark.sql.JavaDataFrameSuite.testCreateDataFrameFromLocalJavaBeans(JavaDataFrameSuite.java:200)