Skip to content

Commit 6bbd769

Browse files
committed
fixes for SPARK-5722
1 parent e477e91 commit 6bbd769

2 files changed

Lines changed: 30 additions & 1 deletion

File tree

python/pyspark/sql/tests.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
from pyspark.sql import SQLContext, HiveContext, Column
3838
from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
39-
UserDefinedType, DoubleType, LongType, StringType
39+
UserDefinedType, DoubleType, LongType, StringType, _infer_type
4040
from pyspark.tests import ReusedPySparkTestCase
4141

4242

@@ -210,6 +210,28 @@ def test_struct_in_map(self):
210210
self.assertEqual(1, k.i)
211211
self.assertEqual("", v.s)
212212

213+
def test_infer_long_type(self):
214+
longrow = [Row(f1='a', f2=100000000000000)]
215+
lrdd = self.sc.parallelize(longrow)
216+
slrdd = self.sqlCtx.inferSchema(lrdd)
217+
self.assertEqual(slrdd.schema().fields[1].dataType, LongType())
218+
219+
# this saving as Parquet caused issues as well.
220+
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
221+
slrdd.saveAsParquetFile(output_dir)
222+
df1 = self.sqlCtx.parquetFile(output_dir)
223+
self.assertEquals('a', df1.first().f1)
224+
self.assertEquals(100000000000000, df1.first().f2)
225+
226+
self.assertEquals(point, ExamplePoint(1.0, 2.0))
227+
self.assertEqual(_infer_type(1), IntegerType())
228+
self.assertEqual(_infer_type(2**10), IntegerType())
229+
self.assertEqual(_infer_type(2**20), IntegerType())
230+
self.assertEqual(_infer_type(2**31 - 1), IntegerType())
231+
self.assertEqual(_infer_type(2**31), LongType())
232+
self.assertEqual(_infer_type(2**61), LongType())
233+
self.assertEqual(_infer_type(2**71), LongType())
234+
213235
def test_convert_row_to_dict(self):
214236
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
215237
self.assertEqual(1, row.asDict()['l'][0].a)

python/pyspark/sql/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,13 @@ def _infer_type(obj):
579579

580580
dataType = _type_mappings.get(type(obj))
581581
if dataType is not None:
582+
# Conform to Java int/long sizes SPARK-5722
583+
# Inference is usually done on a sample of the dataset
584+
# so, if values that should be Long do not appear in
585+
# the sample, the dataType will be chosen as IntegerType
586+
if dataType == IntegerType:
587+
if obj > 2**31 - 1 or obj < -2**31:
588+
dataType = LongType
582589
return dataType()
583590

584591
if isinstance(obj, dict):

0 commit comments

Comments
 (0)