Skip to content

Commit 69786ea

Browse files
zero323yanboliang
authored andcommitted
[SPARK-20631][PYTHON][ML] LogisticRegression._checkThresholdConsistency should use values not Params
## What changes were proposed in this pull request? - Replace `getParam` calls with `getOrDefault` calls. - Fix exception message to avoid unintended `TypeError`. - Add unit tests ## How was this patch tested? New unit tests. Author: zero323 <[email protected]> Closes #17891 from zero323/SPARK-20631. (cherry picked from commit 804949c) Signed-off-by: Yanbo Liang <[email protected]>
1 parent 8e09789 commit 69786ea

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

python/pyspark/ml/classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,13 @@ def getThresholds(self):
238238

239239
def _checkThresholdConsistency(self):
240240
if self.isSet(self.threshold) and self.isSet(self.thresholds):
241-
ts = self.getParam(self.thresholds)
241+
ts = self.getOrDefault(self.thresholds)
242242
if len(ts) != 2:
243243
raise ValueError("Logistic Regression getThreshold only applies to" +
244244
" binary classification, but thresholds has length != 2." +
245-
" thresholds: " + ",".join(ts))
245+
" thresholds: {0}".format(str(ts)))
246246
t = 1.0/(1.0 + ts[0]/ts[1])
247-
t2 = self.getParam(self.threshold)
247+
t2 = self.getOrDefault(self.threshold)
248248
if abs(t2 - t) >= 1E-5:
249249
raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
250250
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))

python/pyspark/ml/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,18 @@ def test_logistic_regression(self):
808808
except OSError:
809809
pass
810810

811+
def logistic_regression_check_thresholds(self):
812+
self.assertIsInstance(
813+
LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]),
814+
LogisticRegressionModel
815+
)
816+
817+
self.assertRaisesRegexp(
818+
ValueError,
819+
"Logistic Regression getThreshold found inconsistent.*$",
820+
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
821+
)
822+
811823
def _compare_params(self, m1, m2, param):
812824
"""
813825
Compare 2 ML Params instances for the given param, and assert both have the same param value

0 commit comments

Comments
 (0)