Skip to content

Commit 8d1dec4

Browse files
jkbradleymengxr
authored andcommitted
[mllib] DecisionTree Strategy parameter checks
Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters. CC mengxr Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes apache#1821 from jkbradley/dt-robustness and squashes the following commits: 4dc449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-robustness 7a61f7b [Joseph K. Bradley] Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters
1 parent 75993a6 commit 8d1dec4

2 files changed

Lines changed: 38 additions & 3 deletions

File tree

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ import org.apache.spark.util.random.XORShiftRandom
4444
@Experimental
4545
class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
4646

47+
strategy.assertValid()
48+
4749
/**
4850
* Method to train a decision tree model over an RDD
4951
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
@@ -1465,10 +1467,14 @@ object DecisionTree extends Serializable with Logging {
14651467

14661468

14671469
/*
1468-
* Ensure #bins is always greater than the categories. For multiclass classification,
1469-
* #bins should be greater than 2^(maxCategories - 1) - 1.
1470+
* Ensure numBins is always greater than the categories. For multiclass classification,
1471+
* numBins should be greater than 2^(maxCategories - 1) - 1.
14701472
* It's a limitation of the current implementation but a reasonable trade-off since features
14711473
* with large number of categories get favored over continuous features.
1474+
*
1475+
* This needs to be checked here instead of in Strategy since numBins can be determined
1476+
* by the number of training examples.
1477+
* TODO: Allow this case, where we simply will know nothing about some categories.
14721478
*/
14731479
if (strategy.categoricalFeaturesInfo.size > 0) {
14741480
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration
2020
import scala.collection.JavaConverters._
2121

2222
import org.apache.spark.annotation.Experimental
23-
import org.apache.spark.mllib.tree.impurity.Impurity
23+
import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity}
2424
import org.apache.spark.mllib.tree.configuration.Algo._
2525
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2626

@@ -90,4 +90,33 @@ class Strategy (
9090
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
9191
}
9292

93+
private[tree] def assertValid(): Unit = {
94+
algo match {
95+
case Classification =>
96+
require(numClassesForClassification >= 2,
97+
s"DecisionTree Strategy for Classification must have numClassesForClassification >= 2," +
98+
s" but numClassesForClassification = $numClassesForClassification.")
99+
require(Set(Gini, Entropy).contains(impurity),
100+
s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
101+
s" Valid settings: Gini, Entropy")
102+
case Regression =>
103+
require(impurity == Variance,
104+
s"DecisionTree Strategy given invalid impurity for Regression: $impurity." +
105+
s" Valid settings: Variance")
106+
case _ =>
107+
throw new IllegalArgumentException(
108+
s"DecisionTree Strategy given invalid algo parameter: $algo." +
109+
s" Valid settings are: Classification, Regression.")
110+
}
111+
require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." +
112+
s" Valid values are integers >= 0.")
113+
require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." +
114+
s" Valid values are integers >= 2.")
115+
categoricalFeaturesInfo.foreach { case (feature, arity) =>
116+
require(arity >= 2,
117+
s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
118+
s" feature $feature has $arity categories. The number of categories should be >= 2.")
119+
}
120+
}
121+
93122
}

0 commit comments

Comments
 (0)