@@ -249,7 +249,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
249249 binaryPythonDataSource
250250 }
251251
252- private lazy val pandasFunc : Array [Byte ] = if (shouldTestPandasUDFs) {
252+ private lazy val pandasScalarFunc : Array [Byte ] = if (shouldTestPandasUDFs) {
253253 var binaryPandasFunc : Array [Byte ] = null
254254 withTempPath { path =>
255255 Process (
@@ -272,6 +272,29 @@ object IntegratedUDFTestUtils extends SQLHelper {
272272 throw new RuntimeException (s " Python executable [ $pythonExec] and/or pyspark are unavailable. " )
273273 }
274274
275+ private lazy val pandasScalarIterFunc : Array [Byte ] = if (shouldTestPandasUDFs) {
276+ var binaryPandasFunc : Array [Byte ] = null
277+ withTempPath { path =>
278+ Process (
279+ Seq (
280+ pythonExec,
281+ " -c" ,
282+ " from pyspark.sql.types import StringType; " +
283+ " from pyspark.serializers import CloudPickleSerializer; " +
284+ s " f = open(' $path', 'wb'); " +
285+ " f.write(CloudPickleSerializer().dumps((" +
286+ " lambda it: (x.apply(lambda v: None if v is None else str(v)) for x in it), " +
287+ " StringType())))" ),
288+ None ,
289+ " PYTHONPATH" -> s " $pysparkPythonPath: $pythonPath" ).!!
290+ binaryPandasFunc = Files .readAllBytes(path.toPath)
291+ }
292+ assert(binaryPandasFunc != null )
293+ binaryPandasFunc
294+ } else {
295+ throw new RuntimeException (s " Python executable [ $pythonExec] and/or pyspark are unavailable. " )
296+ }
297+
275298 private lazy val pandasGroupedAggFunc : Array [Byte ] = if (shouldTestPandasUDFs) {
276299 var binaryPandasFunc : Array [Byte ] = null
277300 withTempPath { path =>
@@ -1380,7 +1403,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
13801403 private [IntegratedUDFTestUtils ] lazy val udf = new UserDefinedPythonFunction (
13811404 name = name,
13821405 func = SimplePythonFunction (
1383- command = pandasFunc .toImmutableArraySeq,
1406+ command = pandasScalarFunc .toImmutableArraySeq,
13841407 envVars = workerEnv.clone().asInstanceOf [java.util.Map [String , String ]],
13851408 pythonIncludes = List .empty[String ].asJava,
13861409 pythonExec = pythonExec,
@@ -1410,6 +1433,60 @@ object IntegratedUDFTestUtils extends SQLHelper {
14101433 val prettyName : String = " Scalar Pandas UDF"
14111434 }
14121435
1436+ /**
1437+ * A Scalar Iterator Pandas UDF that takes one column, casts into string, executes the
1438+ * Python native function, and casts back to the type of input column.
1439+ *
1440+ * Virtually equivalent to:
1441+ *
1442+ * {{{
1443+ * from pyspark.sql.functions import pandas_udf, PandasUDFType
1444+ *
1445+ * df = spark.range(3).toDF("col")
1446+ * scalar_iter_udf = pandas_udf(
1447+ * lambda it: map(lambda x: x.apply(lambda v: str(v)), it),
1448+ * "string",
1449+ * PandasUDFType.SCALAR_ITER)
1450+ * casted_col = scalar_iter_udf(df.col.cast("string"))
1451+ * casted_col.cast(df.schema["col"].dataType)
1452+ * }}}
1453+ */
1454+ case class TestScalarIterPandasUDF (
1455+ name : String ,
1456+ returnType : Option [DataType ] = None ) extends TestUDF {
1457+ private [IntegratedUDFTestUtils ] lazy val udf = new UserDefinedPythonFunction (
1458+ name = name,
1459+ func = SimplePythonFunction (
1460+ command = pandasScalarIterFunc.toImmutableArraySeq,
1461+ envVars = workerEnv.clone().asInstanceOf [java.util.Map [String , String ]],
1462+ pythonIncludes = List .empty[String ].asJava,
1463+ pythonExec = pythonExec,
1464+ pythonVer = pythonVer,
1465+ broadcastVars = List .empty[Broadcast [PythonBroadcast ]].asJava,
1466+ accumulator = null ),
1467+ dataType = StringType ,
1468+ pythonEvalType = PythonEvalType .SQL_SCALAR_PANDAS_ITER_UDF ,
1469+ udfDeterministic = true ) {
1470+
1471+ override def builder (e : Seq [Expression ]): Expression = {
1472+ assert(e.length == 1 , " Defined UDF only has one column" )
1473+ val expr = e.head
1474+ val rt = returnType.getOrElse {
1475+ assert(expr.resolved, " column should be resolved to use the same type " +
1476+ " as input. Try df(name) or df.col(name)" )
1477+ expr.dataType
1478+ }
1479+ val pythonUDF = new PythonUDFWithoutId (
1480+ super .builder(Cast (expr, StringType ) :: Nil ).asInstanceOf [PythonUDF ])
1481+ Cast (pythonUDF, rt)
1482+ }
1483+ }
1484+
1485+ def apply (exprs : Column * ): Column = udf(exprs : _* )
1486+
1487+ val prettyName : String = " Scalar Pandas Iterator UDF"
1488+ }
1489+
14131490 /**
14141491 * A Grouped Aggregate Pandas UDF that takes one column, executes the
14151492 * Python native function calculating the count of the column using pandas.
@@ -1606,6 +1683,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
16061683 def registerTestUDF (testUDF : TestUDF , session : classic.SparkSession ): Unit = testUDF match {
16071684 case udf : TestPythonUDF => session.udf.registerPython(udf.name, udf.udf)
16081685 case udf : TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf)
1686+ case udf : TestScalarIterPandasUDF => session.udf.registerPython(udf.name, udf.udf)
16091687 case udf : TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf)
16101688 case udf : TestScalaUDF =>
16111689 val registry = session.sessionState.functionRegistry
0 commit comments