@@ -3377,29 +3377,42 @@ def test_ignore_column_of_all_nulls(self):
33773377
33783378 # SPARK-24721
33793379 @unittest .skipIf (not _test_compiled , _test_not_compiled_message )
3380- def test_datasource_with_udf_filter_lit_input (self ):
3380+ def test_datasource_with_udf (self ):
33813381 from pyspark .sql .functions import udf , lit , col
33823382
33833383 path = tempfile .mkdtemp ()
33843384 shutil .rmtree (path )
3385+
33853386 try :
33863387 self .spark .range (1 ).write .mode ("overwrite" ).format ('csv' ).save (path )
3387- filesource_df = self .spark .read .csv (path )
3388+ filesource_df = self .spark .read .option ( 'inferSchema' , True ). csv (path ). toDF ( 'i' )
33883389 datasource_df = self .spark .read \
33893390 .format ("org.apache.spark.sql.sources.SimpleScanSource" ) \
3390- .option ('from' , 0 ).option ('to' , 1 ).load ()
3391+ .option ('from' , 0 ).option ('to' , 1 ).load (). toDF ( 'i' )
33913392 datasource_v2_df = self .spark .read \
3392- .format ("org.apache.spark.sql.sources.v2.SimpleDataSourceV2" ) \
3393- .load ()
3393+ .format ("org.apache.spark.sql.sources.v2.SimpleDataSourceV2" ) \
3394+ .load ().toDF ('i' , 'j' )
3395+
3396+ c1 = udf (lambda x : x + 1 , 'int' )(lit (1 ))
3397+ c2 = udf (lambda x : x + 1 , 'int' )(col ('i' ))
33943398
3395- filter1 = udf (lambda : False , 'boolean' )()
3396- filter2 = udf (lambda x : False , 'boolean' )(lit (1 ))
3399+ f1 = udf (lambda x : False , 'boolean' )(lit (1 ))
3400+ f2 = udf (lambda x : False , 'boolean' )(col ('i' ))
3401+
3402+ for df in [filesource_df , datasource_df , datasource_v2_df ]:
3403+ result = df .withColumn ('c' , c1 )
3404+ expected = df .withColumn ('c' , lit (2 ))
3405+ self .assertEquals (expected .collect (), result .collect ())
33973406
33983407 for df in [filesource_df , datasource_df , datasource_v2_df ]:
3399- for f in [filter1 , filter2 ]:
3408+ result = df .withColumn ('c' , c2 )
3409+ expected = df .withColumn ('c' , col ('i' ) + 1 )
3410+ self .assertEquals (expected .collect (), result .collect ())
3411+
3412+ for df in [filesource_df , datasource_df , datasource_v2_df ]:
3413+ for f in [f1 , f2 ]:
34003414 result = df .filter (f )
34013415 self .assertEquals (0 , result .count ())
3402-
34033416 finally :
34043417 shutil .rmtree (path )
34053418
@@ -5307,31 +5320,46 @@ def f3(x):
53075320
53085321 # SPARK-24721
53095322 @unittest .skipIf (not _test_compiled , _test_not_compiled_message )
5310- def test_datasource_with_udf_filter_lit_input (self ):
5311- # Same as SQLTests.test_datasource_with_udf_filter_lit_input , but with Pantestdas UDF
5323+ def test_datasource_with_udf (self ):
5324+ # Same as SQLTests.test_datasource_with_udf , but with Pandas UDF
53125325 # This needs to a separate test because Arrow dependency is optional
53135326 import pandas as pd
53145327 import numpy as np
53155328 from pyspark .sql .functions import pandas_udf , lit , col
53165329
53175330 path = tempfile .mkdtemp ()
53185331 shutil .rmtree (path )
5332+
53195333 try :
53205334 self .spark .range (1 ).write .mode ("overwrite" ).format ('csv' ).save (path )
5321- filesource_df = self .spark .read .csv (path )
5335+ filesource_df = self .spark .read .option ( 'inferSchema' , True ). csv (path ). toDF ( 'i' )
53225336 datasource_df = self .spark .read \
53235337 .format ("org.apache.spark.sql.sources.SimpleScanSource" ) \
5324- .option ('from' , 0 ).option ('to' , 1 ).load ()
5338+ .option ('from' , 0 ).option ('to' , 1 ).load (). toDF ( 'i' )
53255339 datasource_v2_df = self .spark .read \
5326- .format ("org.apache.spark.sql.sources.v2.SimpleDataSourceV2" ) \
5327- .load ()
5340+ .format ("org.apache.spark.sql.sources.v2.SimpleDataSourceV2" ) \
5341+ .load ().toDF ('i' , 'j' )
5342+
5343+ c1 = pandas_udf (lambda x : x + 1 , 'int' )(lit (1 ))
5344+ c2 = pandas_udf (lambda x : x + 1 , 'int' )(col ('i' ))
53285345
5329- f = pandas_udf (lambda x : pd .Series (np .repeat (False , len (x ))), 'boolean' )(lit (1 ))
5346+ f1 = pandas_udf (lambda x : pd .Series (np .repeat (False , len (x ))), 'boolean' )(lit (1 ))
5347+ f2 = pandas_udf (lambda x : pd .Series (np .repeat (False , len (x ))), 'boolean' )(col ('i' ))
5348+
5349+ for df in [filesource_df , datasource_df , datasource_v2_df ]:
5350+ result = df .withColumn ('c' , c1 )
5351+ expected = df .withColumn ('c' , lit (2 ))
5352+ self .assertEquals (expected .collect (), result .collect ())
53305353
53315354 for df in [filesource_df , datasource_df , datasource_v2_df ]:
5332- result = df .filter (f )
5333- self .assertEquals (0 , result .count ())
5355+ result = df .withColumn ('c' , c2 )
5356+ expected = df .withColumn ('c' , col ('i' ) + 1 )
5357+ self .assertEquals (expected .collect (), result .collect ())
53345358
5359+ for df in [filesource_df , datasource_df , datasource_v2_df ]:
5360+ for f in [f1 , f2 ]:
5361+ result = df .filter (f )
5362+ self .assertEquals (0 , result .count ())
53355363 finally :
53365364 shutil .rmtree (path )
53375365
0 commit comments