Skip to content

Commit d86dae8

Browse files
zero323yanboliang
authored andcommitted
[SPARK-20631][PYTHON][ML] LogisticRegression._checkThresholdConsistency should use values not Params
## What changes were proposed in this pull request? - Replace `getParam` calls with `getOrDefault` calls. - Fix exception message to avoid unintended `TypeError`. - Add unit tests ## How was this patch tested? New unit tests. Author: zero323 <[email protected]> Closes #17891 from zero323/SPARK-20631. (cherry picked from commit 804949c) Signed-off-by: Yanbo Liang <[email protected]>
1 parent 4665997 commit d86dae8

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

python/pyspark/ml/classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,13 @@ def getThresholds(self):
200200

201201
def _checkThresholdConsistency(self):
202202
if self.isSet(self.threshold) and self.isSet(self.thresholds):
203-
ts = self.getParam(self.thresholds)
203+
ts = self.getOrDefault(self.thresholds)
204204
if len(ts) != 2:
205205
raise ValueError("Logistic Regression getThreshold only applies to" +
206206
" binary classification, but thresholds has length != 2." +
207-
" thresholds: " + ",".join(ts))
207+
" thresholds: {0}".format(str(ts)))
208208
t = 1.0/(1.0 + ts[0]/ts[1])
209-
t2 = self.getParam(self.threshold)
209+
t2 = self.getOrDefault(self.threshold)
210210
if abs(t2 - t) >= 1E-5:
211211
raise ValueError("Logistic Regression getThreshold found inconsistent values for" +
212212
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))

python/pyspark/ml/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,18 @@ def test_logistic_regression(self):
765765
except OSError:
766766
pass
767767

768+
def logistic_regression_check_thresholds(self):
769+
self.assertIsInstance(
770+
LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]),
771+
LogisticRegressionModel
772+
)
773+
774+
self.assertRaisesRegexp(
775+
ValueError,
776+
"Logistic Regression getThreshold found inconsistent.*$",
777+
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
778+
)
779+
768780
def _compare_params(self, m1, m2, param):
769781
"""
770782
Compare 2 ML Params instances for the given param, and assert both have the same param value

0 commit comments

Comments
 (0)