Skip to content

Commit e750d3e

Browse files
committed
Fixed python style.
1 parent b69f201 commit e750d3e

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

python/pyspark/ml/classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
15291529

15301530
@keyword_only
15311531
@since("2.0.0")
1532-
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None, parallelism=None):
1532+
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
1533+
classifier=None, parallelism=None):
15331534
"""
15341535
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
15351536
Sets params for OneVsRest.
@@ -1565,11 +1566,9 @@ def trainSingleClass(index):
15651566
(classifier.predictionCol, predictionCol)])
15661567
return classifier.fit(trainingDataset, paramMap)
15671568

1568-
# TODO: Parallel training for all classes.
15691569
pool = ThreadPool(processes=self.getParallelism())
15701570

15711571
models = pool.map(trainSingleClass, range(numClasses))
1572-
#models = [trainSingleClass(i) for i in range(numClasses)]
15731572

15741573
if handlePersistence:
15751574
multiclassLabeled.unpersist()
@@ -1652,6 +1651,7 @@ def _to_java(self):
16521651
_java_obj.setPredictionCol(self.getPredictionCol())
16531652
return _java_obj
16541653

1654+
16551655
class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
16561656
"""
16571657
.. note:: Experimental

python/pyspark/ml/tests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,7 @@ def test_output_columns(self):
12341234
output = model.transform(df)
12351235
self.assertEqual(output.columns, ["label", "features", "prediction"])
12361236

1237+
12371238
class ParOneVsRestTests(SparkSessionTestCase):
12381239

12391240
def test_copy(self):
@@ -1261,6 +1262,7 @@ def test_output_columns(self):
12611262
output = model.transform(df)
12621263
self.assertEqual(output.columns, ["label", "features", "prediction"])
12631264

1265+
12641266
class HashingTFTest(SparkSessionTestCase):
12651267

12661268
def test_apply_binary_term_freqs(self):

0 commit comments

Comments
 (0)