-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-5436] [MLlib] Validate GradientBoostedTrees using runWithValidation #4677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
77549a9
3e74372
55e5c3b
fad9b6e
b928a19
b48a70f
e4d799b
1bb21d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -427,6 +427,18 @@ We omit some decision tree parameters since those are covered in the [decision t | |
|
|
||
| * **`algo`**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter. | ||
|
|
||
| #### Validation while training | ||
|
|
||
| Gradient boosting can overfit when trained with more number of trees. In order to prevent overfitting, it might | ||
| be useful to validate while training. The method **`runWithValidation`** has been provided to make use of this | ||
| option. It takes a pair of RDD's as arguments, the first one being the training dataset and the second being the validation dataset. | ||
|
|
||
| The training is stopped when the improvement in the validation error is not more than a certain tolerance | ||
| (supplied by the **`validationTol`** argument in **`BoostingStrategy`**). In practice, the validation error | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We generally don't use bold for arguments in the docs, but you could link to the API docs. |
||
| decreases with the increase in number of trees and then increases as the model starts to overfit. There might | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "decreases with the increase in number of trees and then increases" --> "decreases initially and later increases" |
||
| be cases, in which the validation error does not change monotonically, and the user is advised to set a large | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "be cases, in which" --> (no comma) |
||
| enough negative tolerance and examine the validation curve to make further inference. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "to make further inference" --> "to tune the number of iterations" |
||
|
|
||
|
|
||
| ### Examples | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) | |
| def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { | ||
| val algo = boostingStrategy.treeStrategy.algo | ||
| algo match { | ||
| case Regression => GradientBoostedTrees.boost(input, boostingStrategy) | ||
| case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) | ||
| case Classification => | ||
| // Map labels to -1, +1 so binary classification can be treated as regression. | ||
| val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
| GradientBoostedTrees.boost(remappedInput, boostingStrategy) | ||
| GradientBoostedTrees.boost(remappedInput, | ||
| remappedInput, boostingStrategy, validate=false) | ||
| case _ => | ||
| throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") | ||
| } | ||
|
|
@@ -76,8 +77,44 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) | |
| def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { | ||
| run(input.rdd) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Method to validate a gradient boosting model | ||
| * @param trainInput Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
| * @param validateInput Validation dataset: | ||
| RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. | ||
| Should follow same distribution as trainInput. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think talking about target distributions may confuse some people. Maybe we can clarify as follows:
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, and we should also explicitly say that validateInput should be a different dataset than trainInput. (We don't need to check for this, though. If it is the same, then validationTol acts like convergenceTol.) |
||
| * @return a gradient boosted trees model that can be used for prediction | ||
| */ | ||
| def runWithValidation( | ||
| trainInput: RDD[LabeledPoint], | ||
| validateInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no guarantee that training and validation are following the same distribution. I know it happens in practice but I don't know the theory behind using distribution A for training but another distribution B for validation. It would be better if we say that they should follow the same distribution.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean to add a comment? |
||
| val algo = boostingStrategy.treeStrategy.algo | ||
| algo match { | ||
| case Regression => GradientBoostedTrees.boost( | ||
| trainInput, validateInput, boostingStrategy, validate=true) | ||
| case Classification => | ||
| // Map labels to -1, +1 so binary classification can be treated as regression. | ||
| val remappedTrainInput = trainInput.map( | ||
| x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
| val remappedValidateInput = trainInput.map( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "trainInput" --> "validateInput"
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops. :/ |
||
| x => new LabeledPoint((x.label * 2) - 1, x.features)) | ||
| GradientBoostedTrees.boost(remappedTrainInput, remappedValidateInput, boostingStrategy, | ||
| validate=true) | ||
| case _ => | ||
| throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. | ||
| */ | ||
| def runWithValidation( | ||
| trainInput: JavaRDD[LabeledPoint], | ||
| validateInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { | ||
| runWithValidation(trainInput.rdd, validateInput.rdd) | ||
| } | ||
| } | ||
|
|
||
| object GradientBoostedTrees extends Logging { | ||
|
|
||
|
|
@@ -108,12 +145,16 @@ object GradientBoostedTrees extends Logging { | |
| /** | ||
| * Internal method for performing regression using trees as base learners. | ||
| * @param input training dataset | ||
| * @param validateInput validation dataset, ignored if validate is set to false. | ||
| * @param boostingStrategy boosting parameters | ||
| * @param validate whether or not to use the validation dataset. | ||
| * @return a gradient boosted trees model that can be used for prediction | ||
| */ | ||
| private def boost( | ||
| input: RDD[LabeledPoint], | ||
| boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { | ||
| validateInput: RDD[LabeledPoint], | ||
| boostingStrategy: BoostingStrategy, | ||
| validate: Boolean = false): GradientBoostedTreesModel = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need for default value; better to be explicit internally |
||
|
|
||
| val timer = new TimeTracker() | ||
| timer.start("total") | ||
|
|
@@ -129,6 +170,7 @@ object GradientBoostedTrees extends Logging { | |
| val learningRate = boostingStrategy.learningRate | ||
| // Prepare strategy for individual trees, which use regression with variance impurity. | ||
| val treeStrategy = boostingStrategy.treeStrategy.copy | ||
| val validationTol = boostingStrategy.validationTol | ||
| treeStrategy.algo = Regression | ||
| treeStrategy.impurity = Variance | ||
| treeStrategy.assertValid() | ||
|
|
@@ -152,13 +194,16 @@ object GradientBoostedTrees extends Logging { | |
| baseLearnerWeights(0) = 1.0 | ||
| val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) | ||
| logDebug("error of gbt = " + loss.computeError(startingModel, input)) | ||
|
|
||
| // Note: A model of type regression is used since we require raw prediction | ||
| timer.stop("building tree 0") | ||
|
|
||
| var bestValidateError = if (validate) loss.computeError(startingModel, validateInput) else 0.0 | ||
| var bestM = 1 | ||
|
|
||
| // psuedo-residual for second iteration | ||
| data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), | ||
| point.features)) | ||
|
|
||
| var m = 1 | ||
| while (m < numIterations) { | ||
| timer.start(s"building tree $m") | ||
|
|
@@ -177,6 +222,24 @@ object GradientBoostedTrees extends Logging { | |
| val partialModel = new GradientBoostedTreesModel( | ||
| Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) | ||
| logDebug("error of gbt = " + loss.computeError(partialModel, input)) | ||
|
|
||
| if (validate) { | ||
| // Stop training early if | ||
| // 1. Reduction in error is lesser than the validationTol or | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "lesser" --> "less" |
||
| // 2. If the error increases, that is if the model is overfit. | ||
| // We want the model returned corresponding to the best validation error. | ||
| val currentValidateError = loss.computeError(partialModel, validateInput) | ||
| if (bestValidateError - currentValidateError < validationTol) { | ||
| return new GradientBoostedTreesModel( | ||
| boostingStrategy.treeStrategy.algo, | ||
| baseLearners.slice(0, bestM), | ||
| baseLearnerWeights.slice(0, bestM)) | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. formatting: |
||
| else if (currentValidateError < bestValidateError){ | ||
| bestValidateError = currentValidateError | ||
| bestM = m + 1 | ||
| } | ||
| } | ||
| // Update data with pseudo-residuals | ||
| data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), | ||
| point.features)) | ||
|
|
@@ -191,4 +254,5 @@ object GradientBoostedTrees extends Logging { | |
| new GradientBoostedTreesModel( | ||
| boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,12 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} | |
| * weak hypotheses used in the final model. | ||
| * @param learningRate Learning rate for shrinking the contribution of each estimator. The | ||
| * learning rate should be between in the interval (0, 1] | ||
| * @param validationTol Useful when runWithValidation is used. If the error rate between two | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also: "the error rate" --> "the error rate on the validationInput" |
||
| iterations is lesser than the validationTol, then stop. If run | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "lesser" --> "less" Also, use double brackets to add link to "run" to be clearer. |
||
| is used, then this parameter is ignored. | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something weird is going on with the indentation and the lack of "*" (for comments) here
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove empty line |
||
| a pair of RDD's are supplied to run. If the error rate | ||
| * between two iterations is lesser than convergenceTol, then training stops. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "lesser" --> "less" |
||
| */ | ||
| @Experimental | ||
| case class BoostingStrategy( | ||
|
|
@@ -42,7 +48,8 @@ case class BoostingStrategy( | |
| @BeanProperty var loss: Loss, | ||
| // Optional boosting parameters | ||
| @BeanProperty var numIterations: Int = 100, | ||
| @BeanProperty var learningRate: Double = 0.1) extends Serializable { | ||
| @BeanProperty var learningRate: Double = 0.1, | ||
| @BeanProperty var validationTol: Double = 1e-5) extends Serializable { | ||
|
|
||
| /** | ||
| * Check validity of parameters. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -158,6 +158,63 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { | |
| } | ||
| } | ||
| } | ||
|
|
||
| test("runWithValidation performs better on a validation dataset (Regression)") { | ||
| // Set numIterations large enough so that it early stops. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "early stops" --> "stops early" |
||
| val numIterations = 20 | ||
| val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) | ||
| val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) | ||
|
|
||
| val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, | ||
| categoricalFeaturesInfo = Map.empty) | ||
| Array(SquaredError, AbsoluteError).foreach { error => | ||
| val boostingStrategy = | ||
| new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0) | ||
|
|
||
| val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. formatting: Put .runWithValidation on the next line with its parameters |
||
| trainRdd, validateRdd) | ||
| assert(gbtValidate.numTrees != numIterations) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
|
|
||
| val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) | ||
| val errorWithoutValidation = error.computeError(gbt, validateRdd) | ||
| val errorWithValidation = error.computeError(gbtValidate, validateRdd) | ||
| assert(errorWithValidation < errorWithoutValidation) | ||
| } | ||
| } | ||
|
|
||
| test("runWithValidation performs better on a validation dataset (Classification)") { | ||
| // Set numIterations large enough so that it early stops. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: "stops early" |
||
| val numIterations = 20 | ||
| val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) | ||
| val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) | ||
|
|
||
| val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, | ||
| categoricalFeaturesInfo = Map.empty) | ||
| val boostingStrategy = | ||
| new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0) | ||
|
|
||
| // Test that it stops early. | ||
| val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move .runWithValidation to next line |
||
| trainRdd, validateRdd) | ||
| assert(gbtValidate.numTrees != numIterations) | ||
|
|
||
| // Remap labels to {-1, 1} | ||
| val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) | ||
|
|
||
| // The error checked for internally in the GradientBoostedTrees is based on Regression. | ||
| // Hence for the validation model, the Classification error need not be strictly less than | ||
| // that done with validation. | ||
| val gbtValidateRegressor = new GradientBoostedTreesModel( | ||
|
||
| Regression, gbtValidate.trees, gbtValidate.treeWeights) | ||
| val errorWithValidation = LogLoss.computeError(gbtValidateRegressor, remappedInput) | ||
|
|
||
| val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) | ||
| val gbtRegressor = new GradientBoostedTreesModel(Regression, gbt.trees, gbt.treeWeights) | ||
| val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput) | ||
|
|
||
| assert(errorWithValidation < errorWithoutValidation) | ||
| } | ||
|
||
|
|
||
| } | ||
|
|
||
| private object GradientBoostedTreesSuite { | ||
|
|
@@ -166,4 +223,6 @@ private object GradientBoostedTreesSuite { | |
| val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) | ||
|
|
||
| val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) | ||
| val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120) | ||
| val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"with more number of trees" --> "with more trees"
"it might be" --> "it is"