Skip to content

Commit 35725f7

Browse files
WeichenXu123jkbradley
authored andcommitted
[SPARK-22332][ML][TEST] Fix NaiveBayes unit test occasionly fail (cause by test dataset not deterministic)
## What changes were proposed in this pull request? Fix NaiveBayes unit test occasionly fail: Set seed for `BrzMultinomial.sample`, make `generateNaiveBayesInput` output deterministic dataset. (If we do not set seed, the generated dataset will be random, and the model will be possible to exceed the tolerance in the test, which trigger this failure) ## How was this patch tested? Manually run tests multiple times and check each time output models contains the same values. Author: WeichenXu <[email protected]> Closes #19558 from WeichenXu123/fix_nb_test_seed. (cherry picked from commit 841f1d7) Signed-off-by: Joseph K. Bradley <[email protected]>
1 parent 9ed6404 commit 35725f7

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
2020
import scala.util.Random
2121

2222
import breeze.linalg.{DenseVector => BDV, Vector => BV}
23-
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
23+
import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis}
2424

2525
import org.apache.spark.{SparkException, SparkFunSuite}
2626
import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial}
@@ -329,6 +329,7 @@ object NaiveBayesSuite {
329329
val _pi = pi.map(math.exp)
330330
val _theta = theta.map(row => row.map(math.exp))
331331

332+
implicit val rngForBrzMultinomial = BrzRandBasis.withSeed(seed)
332333
for (i <- 0 until nPoints) yield {
333334
val y = calcLabel(rnd.nextDouble(), _pi)
334335
val xi = modelType match {

0 commit comments

Comments
 (0)