diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 61091bb803e4..6a1e46049e19 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -94,6 +94,64 @@ private[spark] class DTStatsAggregator( impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } + /** + * Calculate gain for a given (featureOffset, leftBin, parentBin). + * + * @param featureOffset This is a pre-computed (node, feature) offset + * from [[getFeatureOffset]]. + * @param leftBinIndex Index of the leftChild in allStats + * Given by featureOffset + leftBinIndex * statsSize + * @param parentBinIndex Index of the parent in allStats + * Given by featureOffset + parentBinIndex * statsSize + */ + def calculateGain( + featureOffset: Int, + leftBinIndex: Int, + parentBinIndex: Int): Double = { + val leftChildOffset = featureOffset + leftBinIndex * statsSize + val parentOffset = featureOffset + parentBinIndex * statsSize + val gain = metadata.impurity match { + case Gini => Gini.calculateGain( + allStats, leftChildOffset, allStats, parentOffset, statsSize, + metadata.minInstancesPerNode, metadata.minInfoGain) + case Entropy => Entropy.calculateGain( + allStats, leftChildOffset, allStats, parentOffset, statsSize, + metadata.minInstancesPerNode, metadata.minInfoGain) + case Variance => Variance.calculateGain( + allStats, leftChildOffset, allStats, parentOffset, statsSize, + metadata.minInstancesPerNode, metadata.minInfoGain) + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + gain + } + + /** + * Calculate gain for a given (featureOffset, leftBin). + * The stats of the parent are inferred from parentStats. + * @param featureOffset This is a pre-computed (node, feature) offset + * from [[getFeatureOffset]]. + * @param leftBinIndex Index of the leftChild in allStats + * Given by featureOffset + leftBinIndex * statsSize + */ + def calculateGain( + featureOffset: Int, + leftBinIndex: Int): Double = { + val leftChildOffset = featureOffset + leftBinIndex * statsSize + val gain = metadata.impurity match { + case Gini => Gini.calculateGain( + allStats, leftChildOffset, parentStats, 0, statsSize, metadata.minInstancesPerNode, + metadata.minInfoGain) + case Entropy => Entropy.calculateGain( + allStats, leftChildOffset, parentStats, 0, statsSize, metadata.minInstancesPerNode, + metadata.minInfoGain) + case Variance => Variance.calculateGain( + allStats, leftChildOffset, parentStats, 0, statsSize, metadata.minInstancesPerNode, + metadata.minInfoGain) + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + gain + } + /** * Get an [[ImpurityCalculator]] for the parent node. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eb..d3297ac1693d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -613,65 +613,6 @@ private[spark] object RandomForest extends Logging { } } - /** - * Calculate the impurity statistics for a given (feature, split) based upon left/right - * aggregates. - * - * @param stats the recycle impurity statistics for this feature's all splits, - * only 'impurity' and 'impurityCalculator' are valid between each iteration - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param metadata learning and dataset metadata for DecisionTree - * @return Impurity statistics for this (feature, split) - */ - private def calculateImpurityStats( - stats: ImpurityStats, - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): ImpurityStats = { - - val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { - leftImpurityCalculator.copy.add(rightImpurityCalculator) - } else { - stats.impurityCalculator - } - - val impurity: Double = if (stats == null) { - parentImpurityCalculator.calculate() - } else { - stats.impurity - } - - val leftCount = leftImpurityCalculator.count - val rightCount = rightImpurityCalculator.count - - val totalCount = leftCount + rightCount - - // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats. - if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) - } - - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 - val rightImpurity = rightImpurityCalculator.calculate() - - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - // if information gain doesn't satisfy minimum information gain, - // then this split is invalid, return invalid information gain stats. - if (gain < metadata.minInfoGain) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) - } - - new ImpurityStats(gain, impurity, parentImpurityCalculator, - leftImpurityCalculator, rightImpurityCalculator) - } - /** * Find the best split for a node. * @@ -684,16 +625,10 @@ private[spark] object RandomForest extends Logging { featuresForNode: Option[Array[Int]], node: LearningNode): (Split, ImpurityStats) = { - // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var gainAndImpurityStats: ImpurityStats = if (level == 0) { - null - } else { - node.stats - } // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = + val (bestSplit, bestGain, bestFeatureOffset, bestSplitIndex) = Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => val featureIndex = if (featuresForNode.nonEmpty) { featuresForNode.get.apply(featureIndexIdx) @@ -712,30 +647,23 @@ private[spark] object RandomForest extends Logging { splitIndex += 1 } // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + val (bestFeatureSplitIndex, maxGain) = + Range(0, numSplits).map { splitIdx => + val gain = binAggregates.calculateGain(nodeFeatureOffset, splitIdx, numSplits) + (splitIdx, gain) + }.maxBy(_._2) + val bestFeatureSplit = splits(featureIndex)(bestFeatureSplitIndex) + (bestFeatureSplit, maxGain, nodeFeatureOffset, bestFeatureSplitIndex) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + val (bestFeatureSplitIndex, maxGain) = + Range(0, numSplits).map { splitIdx => + val gain = binAggregates.calculateGain(leftChildOffset, splitIdx) + (splitIdx, gain) + }.maxBy(_._2) + val bestFeatureSplit = splits(featureIndex)(bestFeatureSplitIndex) + (bestFeatureSplit, maxGain, leftChildOffset, bestFeatureSplitIndex) } else { // Ordered categorical feature val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) @@ -794,27 +722,37 @@ private[spark] object RandomForest extends Logging { // lastCategory = index of bin with total aggregates for this (node, feature) val lastCategory = categoriesSortedByCentroid.last._1 // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) + val (bestFeatureSplitIndex, maxGain) = + Range(0, numSplits).map { splitIdx => + val featureValue = categoriesSortedByCentroid(splitIdx)._1 + val gain = binAggregates.calculateGain(nodeFeatureOffset, featureValue, lastCategory) + (splitIdx, gain) + }.maxBy(_._2) + val bestFeatureValue = categoriesSortedByCentroid(bestFeatureSplitIndex)._1 val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) val bestFeatureSplit = new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) - (bestFeatureSplit, bestFeatureGainStats) + (bestFeatureSplit, maxGain, nodeFeatureOffset, bestFeatureValue) } - }.maxBy(_._2.gain) - - (bestSplit, bestSplitStats) + }.maxBy(_._2) + + val leftImpurityCalculator = binAggregates.getImpurityCalculator( + bestFeatureOffset, bestSplitIndex) + val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() + val rightImpurityCalculator = parentImpurityCalculator.copy.subtract( + leftImpurityCalculator) + val bestFeatureGainStats = { + if (bestGain == Double.MinValue) { + ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + else { + new ImpurityStats(bestGain, parentImpurityCalculator.calculate(), + parentImpurityCalculator, leftImpurityCalculator, + rightImpurityCalculator) + } + } + (bestSplit, bestFeatureGainStats) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 3a731f45d6a0..a2e37db73515 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -27,7 +27,13 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} @Experimental object Entropy extends Impurity { - private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + private[tree] def log2(x: Double): Double = { + if (x == 0) { + return 0.0 + } else { + return scala.math.log(x) / scala.math.log(2) + } + } /** * :: DeveloperApi :: @@ -47,10 +53,8 @@ object Entropy extends Impurity { var classIndex = 0 while (classIndex < numClasses) { val classCount = counts(classIndex) - if (classCount != 0) { - val freq = classCount / totalCount - impurity -= freq * log2(freq) - } + val freq = classCount / totalCount + impurity -= freq * log2(freq) classIndex += 1 } impurity @@ -76,6 +80,72 @@ object Entropy extends Impurity { @Since("1.1.0") def instance: this.type = this + /** + * Information gain calculation. + * allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity + * information of the leftChild. + * parentsStats(parentOffset: parentOffset + statsSize) contains the impurity + * information of the parent. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param leftChildOffset Start index of stats for the left child. + * @param parentStats Flat stats array for impurity calculation of the parent. + * @param parentOffset Start index of stats for the parent. + * @param statsSize Size of the stats for the left child and the parent. + * @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain. + * @param minInfoGain return zero if gain < minInfoGain. + * @return information gain. + */ + override def calculateGain( + allStats: Array[Double], + leftChildOffset: Int, + parentStats: Array[Double], + parentOffset: Int, + statsSize: Int, + minInstancesPerNode: Int, + minInfoGain: Double): Double = { + var leftCount = 0.0 + var totalCount = 0.0 + var i = 0 + while (i < statsSize) { + leftCount += allStats(leftChildOffset + i) + totalCount += parentStats(parentOffset + i) + i += 1 + } + val rightCount = totalCount - leftCount + + if ((leftCount < minInstancesPerNode) || + (rightCount < minInstancesPerNode)) { + return Double.MinValue + } + + var leftImpurity = 0.0 + var rightImpurity = 0.0 + var parentImpurity = 0.0 + + i = 0 + while (i < statsSize) { + val leftStats = allStats(leftChildOffset + i) + val totalStats = parentStats(parentOffset + i) + + val leftFreq = leftStats / leftCount + val rightFreq = (totalStats - leftStats) / rightCount + val parentFreq = totalStats / totalCount + + leftImpurity -= leftFreq * log2(leftFreq) + rightImpurity -= rightFreq * log2(rightFreq) + parentImpurity -= parentFreq * log2(parentFreq) + + i += 1 + } + val leftWeighted = leftCount / totalCount * leftImpurity + val rightWeighted = rightCount / totalCount * rightImpurity + val gain = parentImpurity - leftWeighted - rightWeighted + + if (gain < minInfoGain) { + return Double.MinValue + } + gain + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 7730c0a8c111..1e90ec78af7a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -73,6 +73,75 @@ object Gini extends Impurity { @Since("1.1.0") def instance: this.type = this + /** + * Information gain calculation. + * allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity + * information of the leftChild. + * parentsStats(parentOffset: parentOffset + statsSize) contains the impurity + * information of the parent. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param leftChildOffset Start index of stats for the left child. + * @param parentStats Flat stats array for impurity calculation of the parent. + * @param parentOffset Start index of stats for the parent. + * @param statsSize Size of the stats for the left child and the parent. + * @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain. + * @param minInfoGain return zero if gain < minInfoGain. + * @return information gain. + */ + override def calculateGain( + allStats: Array[Double], + leftChildOffset: Int, + parentStats: Array[Double], + parentOffset: Int, + statsSize: Int, + minInstancesPerNode: Int, + minInfoGain: Double): Double = { + + var leftCount = 0.0 + var totalCount = 0.0 + var i = 0 + while (i < statsSize) { + leftCount += allStats(leftChildOffset + i) + totalCount += parentStats(parentOffset + i) + i += 1 + } + val rightCount = totalCount - leftCount + + if ((leftCount < minInstancesPerNode) || + (rightCount < minInstancesPerNode)) { + return Double.MinValue + } + + var leftImpurity = 1.0 + var rightImpurity = 1.0 + var parentImpurity = 1.0 + + i = 0 + while (i < statsSize) { + val leftStats = allStats(leftChildOffset + i) + val totalStats = parentStats(parentOffset + i) + + val leftFreq = leftStats / leftCount + val rightFreq = (totalStats - leftStats) / rightCount + val parentFreq = totalStats / totalCount + + leftImpurity -= leftFreq * leftFreq + rightImpurity -= rightFreq * rightFreq + parentImpurity -= parentFreq * parentFreq + + i += 1 + } + + val leftWeighted = leftCount / totalCount * leftImpurity + val rightWeighted = rightCount / totalCount * rightImpurity + val gain = parentImpurity - leftWeighted - rightWeighted + + if (gain < minInfoGain) { + return Double.MinValue + } + gain + } + } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 65f0163ec605..fe1c43dc944f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -52,6 +52,30 @@ trait Impurity extends Serializable { @Since("1.0.0") @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double + + /** + * Information gain calculation. + * allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity + * information of the leftChild. + * parentsStats(parentOffset: parentOffset + statsSize) contains the impurity + * information of the parent. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param leftChildOffset Start index of stats for the left child. + * @param parentStats Flat stats array for impurity calculation of the parent. + * @param parentOffset Start index of stats for the parent. + * @param statsSize Size of the stats for the left child and the parent. + * @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain. + * @param minInfoGain return zero if gain < minInfoGain. + * @return information gain. + */ + protected def calculateGain( + allStats: Array[Double], + leftChildOffset: Int, + parentStats: Array[Double], + parentOffset: Int, + statsSize: Int, + minInstancesPerNode: Int, + minInfoGain: Double): Double } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 2423516123b8..84eadcda3628 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -64,6 +64,59 @@ object Variance extends Impurity { @Since("1.0.0") def instance: this.type = this + /** + * Information gain calculation. + * allStats(leftChildOffset: leftChildOffset + statsSize) contains the impurity + * information of the leftChild. + * parentsStats(parentOffset: parentOffset + statsSize) contains the impurity + * information of the parent. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param leftChildOffset Start index of stats for the left child. + * @param parentStats Flat stats array for impurity calculation of the parent. + * @param parentOffset Start index of stats for the parent. + * @param statsSize Size of the stats for the left child and the parent. + * @param minInstancePerNode minimum no. of instances in the child nodes for non-zero gain. + * @param minInfoGain return zero if gain < minInfoGain. + * @return information gain. + */ + override def calculateGain( + allStats: Array[Double], + leftChildOffset: Int, + parentStats: Array[Double], + parentOffset: Int, + statsSize: Int, + minInstancesPerNode: Int, + minInfoGain: Double): Double = { + val leftCount = allStats(leftChildOffset) + val totalCount = allStats(parentOffset) + val rightCount = totalCount - leftCount + + if ((leftCount < minInstancesPerNode) || + (rightCount < minInstancesPerNode)) { + return Double.MinValue + } + + val leftSum = allStats(leftChildOffset + 1) + val leftSumSquares = allStats(leftChildOffset + 2) + + val parentSum = parentStats(parentOffset + 1) + val parentSumSquares = parentStats(parentOffset + 2) + + val rightSum = parentSum - leftSum + val rightSumSquares = parentSumSquares - leftSumSquares + + val parentImpurity = (parentSumSquares - (parentSum * parentSum) / totalCount) / totalCount + val leftWeighted = (leftSumSquares - (leftSum * leftSum) / leftCount) / totalCount + val rightWeighted = (rightSumSquares - (rightSum * rightSum) / rightCount) / totalCount + val gain = parentImpurity - leftWeighted - rightWeighted + + if (gain < minInfoGain) { + return Double.MinValue + } + gain + + } + } /**