@@ -1279,28 +1279,22 @@ def test_pipe_unicode(self):
12791279
12801280 def test_stopiteration_in_client_code (self ):
12811281
1282- def a_rdd (keyed = False ):
1283- return self .sc .parallelize (
1284- ((x % 2 , x ) if keyed else x )
1285- for x in range (10 )
1286- )
1287-
12881282 def stopit (* x ):
12891283 raise StopIteration ()
12901284
1291- def do_test ( action , * args , ** kwargs ):
1292- with self .assertRaises (( Py4JJavaError , RuntimeError )) as cm :
1293- action ( * args , ** kwargs )
1294-
1295- do_test ( a_rdd () .map (stopit ).collect )
1296- do_test ( a_rdd () .filter (stopit ).collect )
1297- do_test ( a_rdd () .cartesian (a_rdd () ).flatMap (stopit ).collect )
1298- do_test ( a_rdd () .foreach , stopit )
1299- do_test ( a_rdd ( keyed = True ) .reduceByKeyLocally , stopit )
1300- do_test ( a_rdd () .reduce , stopit )
1301- do_test ( a_rdd () .fold , 0 , stopit )
1302- do_test ( a_rdd () .aggregate , 0 , stopit , lambda * x : 1 )
1303- do_test ( a_rdd () .aggregate , 0 , lambda * x : 1 , stopit )
1285+ seq_rdd = self . sc . parallelize ( range ( 10 ))
1286+ keyed_rdd = self .sc . parallelize (( x % 2 , x ) for x in range ( 10 ))
1287+ exc = Py4JJavaError , RuntimeError
1288+
1289+ self . assertRaises ( exc , seq_rdd .map (stopit ).collect )
1290+ self . assertRaises ( exc , seq_rdd .filter (stopit ).collect )
1291+ self . assertRaises ( exc , seq_rdd .cartesian (seq_rdd ).flatMap (stopit ).collect )
1292+ self . assertRaises ( exc , seq_rdd .foreach , stopit )
1293+ self . assertRaises ( exc , keyed_rdd .reduceByKeyLocally , stopit )
1294+ self . assertRaises ( exc , seq_rdd .reduce , stopit )
1295+ self . assertRaises ( exc , seq_rdd .fold , 0 , stopit )
1296+ self . assertRaises ( exc , seq_rdd .aggregate , 0 , stopit , lambda * x : 1 )
1297+ self . assertRaises ( exc , seq_rdd .aggregate , 0 , lambda * x : 1 , stopit )
13041298
13051299
13061300class ProfilerTests (PySparkTestCase ):
0 commit comments