Skip to content

Commit 2325a4f

Browse files
committed
Add more tests; Fix lint; Remove hardcoded scala-2.11 from require_test_compiled
1 parent d440cbf commit 2325a4f

2 files changed

Lines changed: 51 additions & 20 deletions

File tree

python/pyspark/sql/tests.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3377,29 +3377,42 @@ def test_ignore_column_of_all_nulls(self):
33773377

33783378
# SPARK-24721
33793379
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
3380-
def test_datasource_with_udf_filter_lit_input(self):
3380+
def test_datasource_with_udf(self):
33813381
from pyspark.sql.functions import udf, lit, col
33823382

33833383
path = tempfile.mkdtemp()
33843384
shutil.rmtree(path)
3385+
33853386
try:
33863387
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
3387-
filesource_df = self.spark.read.csv(path)
3388+
filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
33883389
datasource_df = self.spark.read \
33893390
.format("org.apache.spark.sql.sources.SimpleScanSource") \
3390-
.option('from', 0).option('to', 1).load()
3391+
.option('from', 0).option('to', 1).load().toDF('i')
33913392
datasource_v2_df = self.spark.read \
3392-
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
3393-
.load()
3393+
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
3394+
.load().toDF('i', 'j')
3395+
3396+
c1 = udf(lambda x: x + 1, 'int')(lit(1))
3397+
c2 = udf(lambda x: x + 1, 'int')(col('i'))
33943398

3395-
filter1 = udf(lambda: False, 'boolean')()
3396-
filter2 = udf(lambda x: False, 'boolean')(lit(1))
3399+
f1 = udf(lambda x: False, 'boolean')(lit(1))
3400+
f2 = udf(lambda x: False, 'boolean')(col('i'))
3401+
3402+
for df in [filesource_df, datasource_df, datasource_v2_df]:
3403+
result = df.withColumn('c', c1)
3404+
expected = df.withColumn('c', lit(2))
3405+
self.assertEquals(expected.collect(), result.collect())
33973406

33983407
for df in [filesource_df, datasource_df, datasource_v2_df]:
3399-
for f in [filter1, filter2]:
3408+
result = df.withColumn('c', c2)
3409+
expected = df.withColumn('c', col('i') + 1)
3410+
self.assertEquals(expected.collect(), result.collect())
3411+
3412+
for df in [filesource_df, datasource_df, datasource_v2_df]:
3413+
for f in [f1, f2]:
34003414
result = df.filter(f)
34013415
self.assertEquals(0, result.count())
3402-
34033416
finally:
34043417
shutil.rmtree(path)
34053418

@@ -5307,31 +5320,46 @@ def f3(x):
53075320

53085321
# SPARK-24721
53095322
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
5310-
def test_datasource_with_udf_filter_lit_input(self):
5311-
# Same as SQLTests.test_datasource_with_udf_filter_lit_input, but with Pantestdas UDF
5323+
def test_datasource_with_udf(self):
5324+
# Same as SQLTests.test_datasource_with_udf, but with Pandas UDF
53125325
# This needs to a separate test because Arrow dependency is optional
53135326
import pandas as pd
53145327
import numpy as np
53155328
from pyspark.sql.functions import pandas_udf, lit, col
53165329

53175330
path = tempfile.mkdtemp()
53185331
shutil.rmtree(path)
5332+
53195333
try:
53205334
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
5321-
filesource_df = self.spark.read.csv(path)
5335+
filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
53225336
datasource_df = self.spark.read \
53235337
.format("org.apache.spark.sql.sources.SimpleScanSource") \
5324-
.option('from', 0).option('to', 1).load()
5338+
.option('from', 0).option('to', 1).load().toDF('i')
53255339
datasource_v2_df = self.spark.read \
5326-
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
5327-
.load()
5340+
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
5341+
.load().toDF('i', 'j')
5342+
5343+
c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1))
5344+
c2 = pandas_udf(lambda x: x + 1, 'int')(col('i'))
53285345

5329-
f = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))
5346+
f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))
5347+
f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i'))
5348+
5349+
for df in [filesource_df, datasource_df, datasource_v2_df]:
5350+
result = df.withColumn('c', c1)
5351+
expected = df.withColumn('c', lit(2))
5352+
self.assertEquals(expected.collect(), result.collect())
53305353

53315354
for df in [filesource_df, datasource_df, datasource_v2_df]:
5332-
result = df.filter(f)
5333-
self.assertEquals(0, result.count())
5355+
result = df.withColumn('c', c2)
5356+
expected = df.withColumn('c', col('i') + 1)
5357+
self.assertEquals(expected.collect(), result.collect())
53345358

5359+
for df in [filesource_df, datasource_df, datasource_v2_df]:
5360+
for f in [f1, f2]:
5361+
result = df.filter(f)
5362+
self.assertEquals(0, result.count())
53355363
finally:
53365364
shutil.rmtree(path)
53375365

python/pyspark/sql/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,17 @@ def require_test_compiled():
156156
""" Raise Exception if test classes are not compiled
157157
"""
158158
import os
159+
import glob
159160
try:
160161
spark_home = os.environ['SPARK_HOME']
161162
except KeyError:
162163
raise RuntimeError('SPARK_HOME is not defined in environment')
163164

164165
test_class_path = os.path.join(
165-
spark_home, 'sql', 'core', 'target', 'scala-2.11', 'test-classes')
166-
if not os.path.isdir(test_class_path):
166+
spark_home, 'sql', 'core', 'target', '*', 'test-classes')
167+
paths = glob.glob(test_class_path)
168+
169+
if len(paths) == 0:
167170
raise RuntimeError(
168171
"%s doesn't exist. Spark sql test classes are not compiled." % test_class_path)
169172

0 commit comments

Comments
 (0)