Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ private[spark] object PythonEvalType {
val SQL_GROUPED_AGG_PANDAS_UDF = 202
val SQL_WINDOW_AGG_PANDAS_UDF = 203
val SQL_SCALAR_PANDAS_ITER_UDF = 204
val SQL_COGROUPED_MAP_PANDAS_UDF = 205

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
Expand All @@ -56,6 +57,7 @@ private[spark] object PythonEvalType {
case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
}
}

Expand Down
1 change: 1 addition & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class PythonEvalType(object):
SQL_GROUPED_AGG_PANDAS_UDF = 202
SQL_WINDOW_AGG_PANDAS_UDF = 203
SQL_SCALAR_PANDAS_ITER_UDF = 204
SQL_COGROUPED_MAP_PANDAS_UDF = 205


def portable_hash(x):
Expand Down
34 changes: 34 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,24 @@ def __repr__(self):
return "ArrowStreamPandasSerializer"


class InterleavedArrowReader(object):

def __init__(self, stream):
import pyarrow as pa
self._schema1 = pa.read_schema(stream)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to read these also using the message reader but for some reason pa.read_schema(self_reader.read_next_message()) didn't work.

self._schema2 = pa.read_schema(stream)
self._reader = pa.MessageReader.open_stream(stream)

def __iter__(self):
return self

def __next__(self):
import pyarrow as pa
batch1 = pa.read_record_batch(self._reader.read_next_message(), self._schema1)
batch2 = pa.read_record_batch(self._reader.read_next_message(), self._schema2)
return batch1, batch2


class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
"""
Serializer used by Python worker to evaluate Pandas UDFs
Expand Down Expand Up @@ -401,6 +419,22 @@ def __repr__(self):
return "ArrowStreamPandasUDFSerializer"


class InterleavedArrowStreamPandasSerializer(ArrowStreamPandasUDFSerializer):

def __init__(self, timezone, safecheck, assign_cols_by_name):
super(InterleavedArrowStreamPandasSerializer, self).__init__(timezone, safecheck, assign_cols_by_name)

def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
import pyarrow as pa
reader = InterleavedArrowReader(pa.input_stream(stream))
for batch1, batch2 in reader:
yield ( [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch1]).itercolumns()],
[self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch2]).itercolumns()])


class BatchedSerializer(Serializer):

"""
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/cogroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pyspark.sql.dataframe import DataFrame


class CoGroupedData(object):

def __init__(self, gd1, gd2):
self._gd1 = gd1
self._gd2 = gd2
self.sql_ctx = gd1.sql_ctx

def apply(self, udf):
all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)

@staticmethod
def _extract_cols(gd):
df = gd._df
return [df[col] for col in df.columns]

3 changes: 3 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2800,6 +2800,8 @@ class PandasUDFType(object):

GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF

COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF

GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF


Expand Down Expand Up @@ -3178,6 +3180,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]:
raise ValueError("Invalid functionType: "
"functionType must be one the values from PandasUDFType")
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import *
from pyspark.sql.cogroup import CoGroupedData

__all__ = ["GroupedData"]

Expand Down Expand Up @@ -220,6 +221,9 @@ def pivot(self, pivot_col, values=None):
jgd = self._jgd.pivot(pivot_col, values)
return GroupedData(jgd, self._df)

def cogroup(self, other):
return CoGroupedData(self, other)

@since(2.3)
def apply(self, udf):
"""
Expand Down
93 changes: 93 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import datetime
import unittest
import sys

from collections import OrderedDict
from decimal import Decimal

from pyspark.sql import Row
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest

if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal

if have_pyarrow:
import pyarrow as pa


"""
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
"""
if sys.version < '3':
_check_column_type = False
else:
_check_column_type = True


@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message)
class CoGroupedMapPandasUDFTests(ReusedSQLTestCase):

@property
def data1(self):
return self.spark.range(10).toDF('id') \
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
.withColumn("k", explode(col('ks')))\
.withColumn("v", col('k') * 10)\
.drop('ks')

@property
def data2(self):
return self.spark.range(10).toDF('id') \
.withColumn("ks", array([lit(i) for i in range(20, 30)])) \
.withColumn("k", explode(col('ks'))) \
.withColumn("v2", col('k') * 100) \
.drop('ks')

def test_simple(self):
import pandas as pd

l = self.data1
r = self.data2

@pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP)
def merge_pandas(left, right):
return pd.merge(left, right, how='outer', on=['k', 'id'])

# TODO: Grouping by a string fails to resolve here as analyzer cannot determine side
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as the comment says- l.groupBy('id').cogroup(r.groupBy('id')) will fail as the resolver can't work out whether each 'id' col should be resolved from l or r. I'm a bit unsure as to the best way of solving this.

result = l\
.groupby(l.id)\
.cogroup(r.groupby(r.id))\
.apply(merge_pandas)\
.sort(['id', 'k'])\
.toPandas()

expected = pd\
.merge(l.toPandas(), r.toPandas(), how='outer', on=['k', 'id'])

assert_frame_equal(expected, result, check_column_type=_check_column_type)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @d80tb7, I work with Li and am also interested in cogroup.

Can I ask how you were able to get your test to run? I wasn't able to run it without the following snippet:

if __name__ == "__main__":
    from pyspark.sql.tests.test_pandas_udf_grouped_map import *

    try:
        import xmlrunner
        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
    except ImportError:
        testRunner = None
    unittest.main(testRunner=testRunner, verbosity=2)

taken from the other similar tests like test_pandas_udf_grouped_map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @hjoo

So far I've just been running it via PyCharm's unit test runner under python 3. I suspect the problem you had was that the iterator I added wasn't compatible with python 2 (sorry!). I've fixed the iterator and added a similar snippet to the one you provided above. Now I can run using python/run-tests --testnames pyspark.sql.tests.test_pandas_udf_cogrouped_map

If you still have problems let me know the error you get and I'll take a look.

41 changes: 36 additions & 5 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pyspark.rdd import PythonEvalType
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasUDFSerializer
BatchedSerializer, ArrowStreamPandasUDFSerializer, InterleavedArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type, StructType
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle
Expand Down Expand Up @@ -111,8 +111,25 @@ def verify_result_length(result, length):
map(verify_result_type, f(*iterator)))


def wrap_grouped_map_pandas_udf(f, return_type, argspec):
def wrap_cogrouped_map_pandas_udf(f, return_type):

def wrapped(left, right):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is left a list of pd.Series here? Probably name them left_series and value_series to be more readable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes they are- they are value series for left and right sides of the cogroup respectively. Agreed that the names aren't the best. I'll improve them when I do a tidy up.

import pandas as pd
result = f(pd.concat(left, axis=1), pd.concat(right, axis=1))
if not isinstance(result, pd.DataFrame):
raise TypeError("Return type of the user-defined function should be "
"pandas.DataFrame, but is {}".format(type(result)))
if not len(result.columns) == len(return_type):
raise RuntimeError(
"Number of columns of the returned pandas.DataFrame "
"doesn't match specified schema. "
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
return result

return lambda v: [(wrapped(v[0], v[1]), to_arrow_type(return_type))]


def wrap_grouped_map_pandas_udf(f, return_type, argspec):
def wrapped(key_series, value_series):
import pandas as pd

Expand Down Expand Up @@ -232,6 +249,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = _get_argspec(chained_func) # signature was lost when wrapping it
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
Expand All @@ -246,6 +265,7 @@ def read_udfs(pickleSer, infile, eval_type):
runner_conf = {}

if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
Expand All @@ -269,10 +289,13 @@ def read_udfs(pickleSer, infile, eval_type):

# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
# pandas Series. See SPARK-27240.
df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
ser = InterleavedArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
else:
df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name,
df_for_struct)
ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name,
df_for_struct)
else:
ser = BatchedSerializer(PickleSerializer(), 100)

Expand Down Expand Up @@ -343,6 +366,14 @@ def map_batch(batch):
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
# We assume there is only one UDF here because cogrouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1
arg_offsets, udf = read_single_udf(
pickleSer, infile, eval_type, runner_conf, udf_index=0)
udfs['f'] = udf
mapper_str = "lambda a: f(a)"
else:
# Create function like this:
# lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ case class FlatMapGroupsInPandas(
override val producedAttributes = AttributeSet(output)
}


case class FlatMapCoGroupsInPandas(
leftAttributes: Seq[Attribute],
rightAttributes: Seq[Attribute],
functionExpr: Expression,
output: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override val producedAttributes = AttributeSet(output)
}


trait BaseEvalPython extends UnaryNode {

def udfs: Seq[PythonUDF]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType}
*/
@Stable
class RelationalGroupedDataset protected[sql](
df: DataFrame,
groupingExprs: Seq[Expression],
private val df: DataFrame,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these changes needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are- I need the accessors in flatMapCoGroupsInPandas.

private val groupingExprs: Seq[Expression],
groupType: RelationalGroupedDataset.GroupType) {

private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
Expand Down Expand Up @@ -523,6 +523,33 @@ class RelationalGroupedDataset protected[sql](
Dataset.ofRows(df.sparkSession, plan)
}

private[sql] def flatMapCoGroupsInPandas
(r: RelationalGroupedDataset, expr: PythonUDF): DataFrame = {
require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
"Must pass a cogrouped map udf")
require(expr.dataType.isInstanceOf[StructType],
s"The returnType of the udf must be a ${StructType.simpleString}")

val leftGroupingNamedExpressions = groupingExprs.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}

val rightGroupingNamedExpressions = r.groupingExprs.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}

val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute)
val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute)
val left = df.logicalPlan
val right = r.df.logicalPlan
val output = expr.dataType.asInstanceOf[StructType].toAttributes
val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right)

Dataset.ofRows(df.sparkSession, plan)
}

override def toString: String = {
val builder = new StringBuilder
builder.append("RelationalGroupedDataset: [grouping expressions: [")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
f, p, b, is, ot, planLater(child)) :: Nil
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, output, left, right) =>
execution.python.FlatMapCoGroupsInPandasExec(
leftGroup, rightGroup, func, output, planLater(left), planLater(right)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, _, _, in, out, child) =>
Expand Down
Loading