Skip to content

Commit f0f80ed

Browse files
committed
improved tests
1 parent d739eea commit f0f80ed

2 files changed

Lines changed: 14 additions & 20 deletions

File tree

python/pyspark/sql/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,7 @@ def test_stopiteration_in_udf(self):
908908
def foo(x):
909909
raise StopIteration()
910910

911-
with self.assertRaises(Py4JJavaError) as cm:
911+
with self.assertRaises(Py4JJavaError):
912912
self.spark.range(0, 1000).withColumn('v', udf(foo)).show()
913913

914914
def test_validate_column_types(self):

python/pyspark/tests.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,28 +1279,22 @@ def test_pipe_unicode(self):
12791279

12801280
def test_stopiteration_in_client_code(self):
12811281

1282-
def a_rdd(keyed=False):
1283-
return self.sc.parallelize(
1284-
((x % 2, x) if keyed else x)
1285-
for x in range(10)
1286-
)
1287-
12881282
def stopit(*x):
12891283
raise StopIteration()
12901284

1291-
def do_test(action, *args, **kwargs):
1292-
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
1293-
action(*args, **kwargs)
1294-
1295-
do_test(a_rdd().map(stopit).collect)
1296-
do_test(a_rdd().filter(stopit).collect)
1297-
do_test(a_rdd().cartesian(a_rdd()).flatMap(stopit).collect)
1298-
do_test(a_rdd().foreach, stopit)
1299-
do_test(a_rdd(keyed=True).reduceByKeyLocally, stopit)
1300-
do_test(a_rdd().reduce, stopit)
1301-
do_test(a_rdd().fold, 0, stopit)
1302-
do_test(a_rdd().aggregate, 0, stopit, lambda *x: 1)
1303-
do_test(a_rdd().aggregate, 0, lambda *x: 1, stopit)
1285+
seq_rdd = self.sc.parallelize(range(10))
1286+
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
1287+
exc = Py4JJavaError, RuntimeError
1288+
1289+
self.assertRaises(exc, seq_rdd.map(stopit).collect)
1290+
self.assertRaises(exc, seq_rdd.filter(stopit).collect)
1291+
self.assertRaises(exc, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
1292+
self.assertRaises(exc, seq_rdd.foreach, stopit)
1293+
self.assertRaises(exc, keyed_rdd.reduceByKeyLocally, stopit)
1294+
self.assertRaises(exc, seq_rdd.reduce, stopit)
1295+
self.assertRaises(exc, seq_rdd.fold, 0, stopit)
1296+
self.assertRaises(exc, seq_rdd.aggregate, 0, stopit, lambda *x: 1)
1297+
self.assertRaises(exc, seq_rdd.aggregate, 0, lambda *x: 1, stopit)
13041298

13051299

13061300
class ProfilerTests(PySparkTestCase):

0 commit comments

Comments
 (0)