Skip to content
4 changes: 4 additions & 0 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,10 @@ def _infer_type(obj):

dataType = _type_mappings.get(type(obj))
if dataType is not None:
# Conform to Java int/long sizes SPARK-5722
if dataType == IntegerType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here.

if obj > 2**31 - 1 or obj < -2**31:
dataType = LongType
return dataType()

if isinstance(obj, dict):
Expand Down
16 changes: 15 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
CloudPickleSerializer, CompressedSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
UserDefinedType, DoubleType, LongType, _infer_type
from pyspark import shuffle

_have_scipy = False
Expand Down Expand Up @@ -923,6 +923,20 @@ def test_infer_schema(self):
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.first()[0])

def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
lrdd = self.sc.parallelize(longrow)
slrdd = self.sqlCtx.inferSchema(lrdd)
self.assertEqual(slrdd.schema().fields[1].dataType, LongType())

self.assertEqual(_infer_type(1), IntegerType())
self.assertEqual(_infer_type(2**10), IntegerType())
self.assertEqual(_infer_type(2**20), IntegerType())
self.assertEqual(_infer_type(2**31 - 1), IntegerType())
self.assertEqual(_infer_type(2**31), LongType())
self.assertEqual(_infer_type(2**61), LongType())
self.assertEqual(_infer_type(2**71), LongType())

def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
Expand Down