Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyspark/ml/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_raw_and_probability_prediction(self):
expected_rawPrediction = [-11.6081922998, -8.15827998691, 22.17757045]
self.assertTrue(result.prediction, expected_prediction)
self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4))
self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4))
self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 1 the minimum difference?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup ..

JDK 8:

[-11.19194106875243,-7.677866573997363,21.280214474039443]

JDK 11:

[-11.608192299802019,-8.158279986906651,22.177570449962918]

Seems multiple floats affects the results while they are roughly correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure where the difference comes from, but it could be subtle differences in randomization or something across the JDKs. If these two tests are the only ones that vary, I think we're OK. I agree with loosening the bound here as these are log-odds, and I suspect the test values were picked just because it's what some previous run spit out (that is, it's too specific)



class OneVsRestTests(SparkSessionTestCase):
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,11 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> model.predict([-0.1,-0.05])
0
>>> softPredicted = model.predictSoft([-0.1,-0.05])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, weights within Gaussian mixture model:

JDK 8

weights: WrappedArray(0.49520257460263445, 0.33813075873069875, 0.16666666666666685)

JDK 11

weights: WrappedArray(0.5000000000000001, 0.33333333333333326, 0.16666666666666666)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also probably OK for the same reason. The test was too specific.

>>> abs(softPredicted[0] - 1.0) < 0.001
>>> abs(softPredicted[0] - 1.0) < 0.03
True
>>> abs(softPredicted[1] - 0.0) < 0.001
>>> abs(softPredicted[1] - 0.0) < 0.03
True
>>> abs(softPredicted[2] - 0.0) < 0.001
>>> abs(softPredicted[2] - 0.0) < 0.03
True

>>> path = tempfile.mkdtemp()
Expand Down