Skip to content

Commit a61911c

Browse files
schintapHyukjinKwon
authored andcommitted
[SPARK-31788][CORE][PYTHON] Fix UnionRDD of PairRDDs
### What changes were proposed in this pull request? UnionRDD of PairRDDs causing a bug. The fix is to check for instance type before proceeding ### Why are the changes needed? Changes are needed to avoid users running into issues with union rdd operation with any other type other than JavaRDD. ### Does this PR introduce _any_ user-facing change? Yes Before: SparkSession available as 'spark'. >>> rdd1 = sc.parallelize([1,2,3,4,5]) >>> rdd2 = sc.parallelize([6,7,8,9,10]) >>> pairRDD1 = rdd1.zip(rdd2) >>> unionRDD1 = sc.union([pairRDD1, pairRDD1]) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/gs/spark/latest/python/pyspark/context.py", line 870, in union jrdds[i] = rdds[i]._jrdd File "/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 238, in setitem File "/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 221, in __set_item File "/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 332, in get_return_value py4j.protocol.Py4JError: An error occurred while calling None.None. Trace: py4j.Py4JException: Cannot convert org.apache.spark.api.java.JavaPairRDD to org.apache.spark.api.java.JavaRDD at py4j.commands.ArrayCommand.convertArgument(ArrayCommand.java:166) at py4j.commands.ArrayCommand.setArray(ArrayCommand.java:144) at py4j.commands.ArrayCommand.execute(ArrayCommand.java:97) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) After: >>> rdd2 = sc.parallelize([6,7,8,9,10]) >>> pairRDD1 = rdd1.zip(rdd2) >>> unionRDD1 = sc.union([pairRDD1, pairRDD1]) >>> unionRDD1.collect() [(1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10)] ### How was this patch tested? Tested with the reproduced piece of code above manually Closes #28603 from redsanket/SPARK-31788. Authored-by: schintap <schintap@verizonmedia.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
1 parent 753636e commit a61911c

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

python/pyspark/context.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tempfile import NamedTemporaryFile
2626

2727
from py4j.protocol import Py4JError
28+
from py4j.java_gateway import is_instance_of
2829

2930
from pyspark import accumulators
3031
from pyspark.accumulators import Accumulator
@@ -864,10 +865,17 @@ def union(self, rdds):
864865
first_jrdd_deserializer = rdds[0]._jrdd_deserializer
865866
if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
866867
rdds = [x._reserialize() for x in rdds]
868+
gw = SparkContext._gateway
867869
cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
868-
jrdds = SparkContext._gateway.new_array(cls, len(rdds))
870+
is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls)
871+
jrdds = gw.new_array(cls, len(rdds))
869872
for i in range(0, len(rdds)):
870-
jrdds[i] = rdds[i]._jrdd
873+
if is_jrdd:
874+
jrdds[i] = rdds[i]._jrdd
875+
else:
876+
# zip could return JavaPairRDD hence we ensure `_jrdd`
877+
# to be `JavaRDD` by wrapping it in a `map`
878+
jrdds[i] = rdds[i].map(lambda x: x)._jrdd
871879
return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
872880

873881
def broadcast(self, value):

python/pyspark/tests/test_rdd.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,15 @@ def test_zip_chaining(self):
168168
set([(x, (x, x)) for x in 'abc'])
169169
)
170170

171+
def test_union_pair_rdd(self):
172+
# Regression test for SPARK-31788
173+
rdd = self.sc.parallelize([1, 2])
174+
pair_rdd = rdd.zip(rdd)
175+
self.assertEqual(
176+
self.sc.union([pair_rdd, pair_rdd]).collect(),
177+
[((1, 1), (2, 2)), ((1, 1), (2, 2))]
178+
)
179+
171180
def test_deleting_input_files(self):
172181
# Regression test for SPARK-1025
173182
tempFile = tempfile.NamedTemporaryFile(delete=False)

0 commit comments

Comments
 (0)