-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3162] [MLlib] Add local tree training for decision tree regressors #19433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
219a120
7107143
49bf0ae
bc54b16
6a68a5c
9a7174e
abc86b2
5c29d3d
cc6a30c
c9a8e01
93e17fc
fd6cdbb
0d904aa
e6ca306
4f0b973
1e5db8a
a55a237
ebade23
9cc6333
7efb1e0
3f72cc0
b7e6e40
22de575
926b5d2
dbb6a59
c0985a8
0b27c56
072e5bc
d86dd18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.tree.impl | ||
|
|
||
| import org.apache.spark.ml.tree.Split | ||
|
|
||
| /** | ||
| * Helpers for updating DTStatsAggregators during collection of sufficient stats for tree training. | ||
| */ | ||
| private[impl] object AggUpdateUtils { | ||
|
|
||
| /** | ||
| * Updates the parent node stats of the passed-in impurity aggregator with the labels | ||
| * corresponding to the feature values at indices [from, to). | ||
| */ | ||
| private[impl] def updateParentImpurity( | ||
| statsAggregator: DTStatsAggregator, | ||
| col: FeatureVector, | ||
| from: Int, | ||
| to: Int, | ||
| labels: Array[Double]): Unit = { | ||
| from.until(to).foreach { idx => | ||
| val rowIndex = col.indices(idx) | ||
| val label = labels(rowIndex) | ||
| statsAggregator.updateParent(label, instanceWeight = 1) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Update aggregator for an (unordered feature, label) pair | ||
| * @param splits Array of arrays of splits for each feature; splits(i) = splits for feature i. | ||
| */ | ||
| private[impl] def updateUnorderedFeature( | ||
| agg: DTStatsAggregator, | ||
| featureValue: Int, | ||
| label: Double, | ||
| featureIndex: Int, | ||
| featureIndexIdx: Int, | ||
| splits: Array[Array[Split]], | ||
|
||
| instanceWeight: Double = 1.0): Unit = { | ||
| val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) | ||
| // Each unordered split has a corresponding bin for impurity stats of data points that fall | ||
| // onto the left side of the split. For each unordered split, update left-side bin if applicable | ||
| // for the current data point. | ||
| val numSplits = agg.metadata.numSplits(featureIndex) | ||
| val featureSplits = splits(featureIndex) | ||
| var splitIndex = 0 | ||
| while (splitIndex < numSplits) { | ||
| if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { | ||
| agg.featureUpdate(leftNodeFeatureOffset, splitIndex, label, instanceWeight) | ||
| } | ||
| splitIndex += 1 | ||
| } | ||
| } | ||
|
|
||
| /** Update aggregator for an (ordered feature, label) pair */ | ||
| private[impl] def updateOrderedFeature( | ||
| agg: DTStatsAggregator, | ||
| featureValue: Int, | ||
| label: Double, | ||
| featureIndex: Int, | ||
|
||
| featureIndexIdx: Int, | ||
| instanceWeight: Double = 1.0): Unit = { | ||
| // The bin index of an ordered feature is just the feature value itself | ||
| val binIndex = featureValue | ||
| agg.update(featureIndexIdx, binIndex, label, instanceWeight) | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.tree.impl | ||
|
|
||
| import org.apache.spark.util.collection.BitSet | ||
|
|
||
| /** | ||
| * Stores values for a single training data column (a single continuous or categorical feature). | ||
| * | ||
| * Values are currently stored in a dense representation only. | ||
| * TODO: Support sparse storage (to optimize deeper levels of the tree), and maybe compressed | ||
| * storage (to optimize upper levels of the tree). | ||
| * | ||
| * TODO: Sort feature values to support more complicated splitting logic (e.g. considering every | ||
| * possible continuous split instead of discretizing continuous features). | ||
| * | ||
| * NOTE: We could add sorting of feature values in this PR; the only changed required would be to | ||
| * sort feature values at construction-time. Sorting might improve locality during stats | ||
| * aggregation (we'd frequently update the same O(statsSize) array for a (feature, bin), | ||
| * instead of frequently updating for the same feature). | ||
| * | ||
| * @param featureArity For categorical features, this gives the number of categories. | ||
| * For continuous features, this should be set to 0. | ||
| */ | ||
| private[impl] class FeatureVector( | ||
| val featureIndex: Int, | ||
| val featureArity: Int, | ||
| val values: Array[Int], | ||
| val rowIndices: Option[Array[Int]]) | ||
| extends Serializable { | ||
| // Associates feature values with training point rows. indices(i) = training point index | ||
| // (row index) of ith feature value | ||
| val indices = rowIndices.getOrElse(values.indices.toArray) | ||
|
|
||
| def isCategorical: Boolean = featureArity > 0 | ||
|
|
||
| /** For debugging */ | ||
| override def toString: String = { | ||
| " FeatureVector(" + | ||
| s" featureIndex: $featureIndex,\n" + | ||
| s" featureType: ${if (featureArity == 0) "Continuous" else "Categorical"},\n" + | ||
| s" featureArity: $featureArity,\n" + | ||
| s" values: ${values.mkString(", ")},\n" + | ||
| s" indices: ${indices.mkString(", ")},\n" + | ||
| " )" | ||
| } | ||
|
|
||
| def deepCopy(): FeatureVector = | ||
| new FeatureVector(featureIndex, featureArity, values.clone(), Some(indices.clone())) | ||
|
|
||
| override def equals(other: Any): Boolean = { | ||
| other match { | ||
| case o: FeatureVector => | ||
| featureIndex == o.featureIndex && featureArity == o.featureArity && | ||
| values.sameElements(o.values) && indices.sameElements(o.indices) | ||
| case _ => false | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Reorders the subset of feature values at indices [from, to) in the passed-in column | ||
| * according to the split information encoded in instanceBitVector (feature values for rows | ||
| * that split left appear before feature values for rows that split right). | ||
| * | ||
| * @param tempVals Destination buffer for reordered feature values | ||
| * @param tempIndices Destination buffer for row indices corresponding to reordered feature values | ||
| * @param numLeftRows Number of rows on the left side of the split | ||
| * @param instanceBitVector instanceBitVector(i) = true if the row for the ith feature | ||
| * value splits right, false otherwise | ||
| */ | ||
| private[ml] def updateForSplit( | ||
| from: Int, | ||
| to: Int, | ||
| numLeftRows: Int, | ||
| tempVals: Array[Int], | ||
| tempIndices: Array[Int], | ||
| instanceBitVector: BitSet): Unit = { | ||
|
|
||
| // BEGIN SORTING | ||
| // We sort the [from, to) slice of col based on instance bit. | ||
| // All instances going "left" in the split (which are false) | ||
| // should be ordered before the instances going "right". The instanceBitVector | ||
| // gives us the split bit value for each instance based on the instance's index. | ||
| // We copy our feature values into @tempVals and @tempIndices either: | ||
| // 1) in the [from, numLeftRows) range if the bit is false, or | ||
| // 2) in the [numBitsNotSet, to) range if the bit is true. | ||
|
||
| var (leftInstanceIdx, rightInstanceIdx) = (0, numLeftRows) | ||
| var idx = from | ||
| while (idx < to) { | ||
| val indexForVal = indices(idx) | ||
| val bit = instanceBitVector.get(idx - from) | ||
| if (bit) { | ||
| tempVals(rightInstanceIdx) = values(idx) | ||
| tempIndices(rightInstanceIdx) = indexForVal | ||
| rightInstanceIdx += 1 | ||
| } else { | ||
| tempVals(leftInstanceIdx) = values(idx) | ||
| tempIndices(leftInstanceIdx) = indexForVal | ||
| leftInstanceIdx += 1 | ||
| } | ||
| idx += 1 | ||
| } | ||
| // END SORTING | ||
| // update the column values and indices | ||
| // with the corresponding indices | ||
| System.arraycopy(tempVals, 0, values, from, to - from) | ||
| System.arraycopy(tempIndices, 0, indices, from, to - from) | ||
| } | ||
|
|
||
| override def hashCode: Int = { | ||
| com.google.common.base.Objects.hashCode( | ||
| featureIndex: java.lang.Integer, | ||
| featureArity: java.lang.Integer, | ||
| values, | ||
| indices) | ||
| } | ||
| } | ||
|
|
||
| private[impl] object FeatureVector { | ||
| /** | ||
| * Store column values sorted by decision tree node (i.e. all column values for a node occur | ||
| * in a contiguous subarray). */ | ||
| def apply( | ||
| featureIndex: Int, | ||
| featureArity: Int, | ||
| values: Array[Int]): FeatureVector = { | ||
| new FeatureVector(featureIndex, featureArity, values, rowIndices = None) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.tree.impl | ||
|
|
||
| import org.apache.spark.mllib.tree.impurity.ImpurityCalculator | ||
| import org.apache.spark.mllib.tree.model.ImpurityStats | ||
|
|
||
| /** Helper methods for impurity-related calculations during node split decisions. */ | ||
| private[impl] object ImpurityUtils { | ||
|
|
||
| /** | ||
| * Calculate the impurity statistics for a given (feature, split) based upon left/right | ||
| * aggregates. | ||
| * | ||
| * @param leftImpurityCalculator left node aggregates for this (feature, split) | ||
| * @param rightImpurityCalculator right node aggregate for this (feature, split) | ||
| * @param metadata learning and dataset metadata for DecisionTree | ||
| * @return Impurity statistics for this (feature, split) | ||
| */ | ||
| private[impl] def calculateImpurityStats( | ||
| leftImpurityCalculator: ImpurityCalculator, | ||
| rightImpurityCalculator: ImpurityCalculator, | ||
| metadata: DecisionTreeMetadata): ImpurityStats = { | ||
|
|
||
| val parentImpurityCalculator = leftImpurityCalculator.copy.add(rightImpurityCalculator) | ||
|
|
||
| val impurity = parentImpurityCalculator.calculate() | ||
|
|
||
| val leftCount = leftImpurityCalculator.count | ||
| val rightCount = rightImpurityCalculator.count | ||
|
|
||
| val totalCount = leftCount + rightCount | ||
|
|
||
| // If left child or right child doesn't satisfy minimum instances per node, | ||
| // then this split is invalid, return invalid information gain stats. | ||
| if ((leftCount < metadata.minInstancesPerNode) || | ||
| (rightCount < metadata.minInstancesPerNode)) { | ||
| return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) | ||
| } | ||
|
|
||
| val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 | ||
| val rightImpurity = rightImpurityCalculator.calculate() | ||
|
|
||
| val leftWeight = leftCount / totalCount.toDouble | ||
| val rightWeight = rightCount / totalCount.toDouble | ||
|
|
||
| val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity | ||
| // if information gain doesn't satisfy minimum information gain, | ||
| // then this split is invalid, return invalid information gain stats. | ||
| // NOTE: We check gain < metadata.minInfoGain and gain <= 0 separately as this is what the | ||
| // original tree training logic did. | ||
| if (gain < metadata.minInfoGain || gain <= 0) { | ||
| return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) | ||
| } | ||
|
|
||
| new ImpurityStats(gain, impurity, parentImpurityCalculator, | ||
| leftImpurityCalculator, rightImpurityCalculator) | ||
| } | ||
|
|
||
| /** | ||
| * Given an impurity aggregator containing label statistics for a given (node, feature, bin), | ||
| * returns the corresponding "centroid", used to order bins while computing best splits. | ||
| * | ||
| * @param metadata learning and dataset metadata for DecisionTree | ||
| */ | ||
| private[impl] def getCentroid( | ||
| metadata: DecisionTreeMetadata, | ||
| binStats: ImpurityCalculator): Double = { | ||
|
|
||
| if (binStats.count != 0) { | ||
| if (metadata.isMulticlass) { | ||
| // multiclass classification | ||
| // For categorical features in multiclass classification, | ||
| // the bins are ordered by the impurity of their corresponding labels. | ||
| binStats.calculate() | ||
| } else if (metadata.isClassification) { | ||
| // binary classification | ||
| // For categorical features in binary classification, | ||
| // the bins are ordered by the count of class 1. | ||
| binStats.stats(1) | ||
| } else { | ||
| // regression | ||
| // For categorical features in regression and binary classification, | ||
| // the bins are ordered by the prediction. | ||
| binStats.predict | ||
| } | ||
| } else { | ||
| Double.MaxValue | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually,
updateParentImpurityhas no relation with any feature column, but here you pass in thefeaturecolumn only want to use theindicesarray, passing anyone feature column will be OK. But, this looks weird, maybe it can be better designed.