-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19357][ML] Adding parallel model evaluation in ML tuning #16774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
5650e98
36a1a68
b051afa
46fe252
1274ba4
80ac2fd
8126710
6a9b735
1c2e391
9e055cd
97ad7b4
5e8a086
864c99c
ad8a870
911af1d
658aacb
2c73b0b
7a8221b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,6 +55,9 @@ for multiclass problems. The default metric used to choose the best `ParamMap` c | |
| method in each of these evaluators. | ||
|
|
||
| To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. | ||
| Sets of parameters from the parameter grid can be evaluated in parallel by setting `numParallelEval` with a value of 2 or more (a value of 1 will evaluate in serial) before running model selection with `CrossValidator` or `TrainValidationSplit`. | ||
|
||
| The value of `numParallelEval` should be chosen carefully to maximize parallelism without exceeding cluster resources, and will be capped at the number of cores in the driver system. Generally speaking, a value up to 10 should be sufficient for most clusters. | ||
|
|
||
|
||
|
|
||
| # Cross-Validation | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,8 @@ object ModelSelectionViaTrainValidationSplitExample { | |
| .setEstimatorParamMaps(paramGrid) | ||
| // 80% of the data will be used for training and the remaining 20% for validation. | ||
| .setTrainRatio(0.8) | ||
| // Evaluate up to 2 parameter settings in parallel | ||
| .setNumParallelEval(2) | ||
|
||
|
|
||
| // Run train validation split, and choose the best set of parameters. | ||
| val model = trainValidationSplit.fit(training) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,20 +20,24 @@ package org.apache.spark.ml.tuning | |
| import java.util.{List => JList} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.concurrent.{ExecutionContext, Future} | ||
| import scala.concurrent.duration.Duration | ||
|
|
||
| import com.github.fommil.netlib.F2jBLAS | ||
| import org.apache.hadoop.fs.Path | ||
| import org.json4s.DefaultFormats | ||
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.ml._ | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.evaluation.Evaluator | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.util.MLUtils | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.util.ThreadUtils | ||
|
|
||
|
|
||
| /** | ||
| * Params for [[CrossValidator]] and [[CrossValidatorModel]]. | ||
|
|
@@ -91,6 +95,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| @Since("2.0.0") | ||
| def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setNumParallelEval(value: Int): this.type = set(numParallelEval, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): CrossValidatorModel = { | ||
| val schema = dataset.schema | ||
|
|
@@ -100,31 +108,60 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| val eval = $(evaluator) | ||
| val epm = $(estimatorParamMaps) | ||
| val numModels = epm.length | ||
| val metrics = new Array[Double](epm.length) | ||
|
|
||
| // Create execution context, run in serial if numParallelEval is 1 | ||
| val executionContext = $(numParallelEval) match { | ||
|
||
| case 1 => | ||
| ThreadUtils.sameThread | ||
| case n => | ||
| ExecutionContext.fromExecutorService(executorServiceFactory(n)) | ||
| } | ||
|
||
|
|
||
| val instr = Instrumentation.create(this, dataset) | ||
| instr.logParams(numFolds, seed) | ||
| logTuningParams(instr) | ||
|
|
||
| // Compute metrics for each model over each split | ||
| logDebug(s"Running cross-validation with level of parallelism: $numParallelEval.") | ||
| val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) | ||
| splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => | ||
| val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => | ||
| val trainingDataset = sparkSession.createDataFrame(training, schema).cache() | ||
| val validationDataset = sparkSession.createDataFrame(validation, schema).cache() | ||
| // multi-model training | ||
| logDebug(s"Train split $splitIndex with multiple sets of parameters.") | ||
| val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] | ||
| trainingDataset.unpersist() | ||
| var i = 0 | ||
| while (i < numModels) { | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) | ||
| logDebug(s"Got metric $metric for model trained with ${epm(i)}.") | ||
| metrics(i) += metric | ||
| i += 1 | ||
|
|
||
| // Fit models in a Future with thread-pool size determined by '$numParallelEval' | ||
| val models = epm.map { paramMap => | ||
|
||
| Future[Model[_]] { | ||
| val model = est.fit(trainingDataset, paramMap) | ||
| model.asInstanceOf[Model[_]] | ||
| } (executionContext) | ||
| } | ||
|
|
||
| Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ => | ||
|
||
| trainingDataset.unpersist() | ||
| } (executionContext) | ||
|
|
||
| // Evaluate models in a Future with thread-pool size determined by '$numParallelEval' | ||
| val foldMetricFutures = models.zip(epm).map { case (modelFuture, paramMap) => | ||
| modelFuture.flatMap { model => | ||
|
||
| Future { | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(model.transform(validationDataset, paramMap)) | ||
| logDebug(s"Got metric $metric for model trained with $paramMap.") | ||
| metric | ||
| } (executionContext) | ||
|
||
| } (executionContext) | ||
| } | ||
|
|
||
| // Wait for metrics to be calculated before upersisting validation dataset | ||
|
||
| val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to also use val metrics = (ThreadUtils.awaitResult(
Future.sequence[Double, Iterable](metricFutures), Duration.Inf)).toArray
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about that, but since it's a blocking call anyway, it will still be bound by the longest running thread.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, not a big deal either way |
||
| validationDataset.unpersist() | ||
| } | ||
| foldMetrics | ||
| }.transpose.map(_.sum) | ||
|
|
||
| // Calculate average metric over all splits | ||
| f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) | ||
|
||
|
|
||
| logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") | ||
| val (bestMetric, bestIndex) = | ||
| if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,8 @@ package org.apache.spark.ml.tuning | |
| import java.util.{List => JList} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.concurrent.{ExecutionContext, Future} | ||
| import scala.concurrent.duration.Duration | ||
| import scala.language.existentials | ||
|
|
||
| import org.apache.hadoop.fs.Path | ||
|
|
@@ -33,6 +35,7 @@ import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} | |
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.util.ThreadUtils | ||
|
|
||
| /** | ||
| * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. | ||
|
|
@@ -87,37 +90,64 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St | |
| @Since("2.0.0") | ||
| def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setNumParallelEval(value: Int): this.type = set(numParallelEval, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { | ||
| val schema = dataset.schema | ||
| transformSchema(schema, logging = true) | ||
| val est = $(estimator) | ||
| val eval = $(evaluator) | ||
| val epm = $(estimatorParamMaps) | ||
| val numModels = epm.length | ||
| val metrics = new Array[Double](epm.length) | ||
|
|
||
| // Create execution context, run in serial if numParallelEval is 1 | ||
| val executionContext = $(numParallelEval) match { | ||
| case 1 => | ||
| ThreadUtils.sameThread | ||
| case n => | ||
| ExecutionContext.fromExecutorService(executorServiceFactory(n)) | ||
| } | ||
|
|
||
| val instr = Instrumentation.create(this, dataset) | ||
| instr.logParams(trainRatio, seed) | ||
| logTuningParams(instr) | ||
|
|
||
| logDebug(s"Running validation with level of parallelism: $numParallelEval.") | ||
| val Array(trainingDataset, validationDataset) = | ||
| dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) | ||
| trainingDataset.cache() | ||
| validationDataset.cache() | ||
|
|
||
| // multi-model training | ||
| // Fit models in a Future with thread-pool size determined by '$numParallelEval' | ||
| logDebug(s"Train split with multiple sets of parameters.") | ||
| val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] | ||
| trainingDataset.unpersist() | ||
| var i = 0 | ||
| while (i < numModels) { | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) | ||
| logDebug(s"Got metric $metric for model trained with ${epm(i)}.") | ||
| metrics(i) += metric | ||
| i += 1 | ||
| val models = epm.map { paramMap => | ||
|
||
| Future[Model[_]] { | ||
| val model = est.fit(trainingDataset, paramMap) | ||
| model.asInstanceOf[Model[_]] | ||
| } (executionContext) | ||
| } | ||
|
|
||
| Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ => | ||
| trainingDataset.unpersist() | ||
| } (executionContext) | ||
|
|
||
| // Evaluate models concurrently, limited by a barrier with '$numParallelEval' permits | ||
|
||
| val metricFutures = models.zip(epm).map { case (modelFuture, paramMap) => | ||
| modelFuture.flatMap { model => | ||
| Future { | ||
| // TODO: duplicate evaluator to take extra params from input | ||
| val metric = eval.evaluate(model.transform(validationDataset, paramMap)) | ||
| logDebug(s"Got metric $metric for model trained with $paramMap.") | ||
| metric | ||
| } (executionContext) | ||
| } (executionContext) | ||
| } | ||
|
|
||
| // Wait for all metrics to be calculated | ||
| val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) | ||
|
|
||
| validationDataset.unpersist() | ||
|
|
||
| logInfo(s"Train validation split metrics: ${metrics.toSeq}") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,18 +17,22 @@ | |
|
|
||
| package org.apache.spark.ml.tuning | ||
|
|
||
| import java.util.concurrent.ExecutorService | ||
|
|
||
| import org.apache.hadoop.fs.Path | ||
| import org.json4s.{DefaultFormats, _} | ||
| import org.json4s.jackson.JsonMethods._ | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, InterfaceStability} | ||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.evaluation.Evaluator | ||
| import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} | ||
| import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamPair, Params, ParamValidators} | ||
| import org.apache.spark.ml.param.shared.HasSeed | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.ml.util.DefaultParamsReader.Metadata | ||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.util.ThreadUtils | ||
|
|
||
| /** | ||
| * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]]. | ||
|
|
@@ -67,6 +71,39 @@ private[ml] trait ValidatorParams extends HasSeed with Params { | |
| /** @group getParam */ | ||
| def getEvaluator: Evaluator = $(evaluator) | ||
|
|
||
| /** | ||
| * param to control the number of models evaluated in parallel | ||
| * Default: 1 | ||
| * | ||
| * @group param | ||
| */ | ||
| val numParallelEval: IntParam = new IntParam(this, "numParallelEval", | ||
|
||
| "max number of models to evaluate in parallel, 1 for serial evaluation", | ||
| ParamValidators.gtEq(1)) | ||
|
|
||
| /** @group getParam */ | ||
| def getNumParallelEval: Int = $(numParallelEval) | ||
|
|
||
| /** | ||
| * Creates a execution service to be used for validation, defaults to a thread-pool with | ||
| * size of `numParallelEval` | ||
| */ | ||
| protected var executorServiceFactory: (Int) => ExecutorService = { | ||
| (requestedMaxThreads: Int) => ThreadUtils.newDaemonCachedThreadPool( | ||
|
||
| s"${this.getClass.getSimpleName}-thread-pool", requestedMaxThreads) | ||
| } | ||
|
|
||
| /** | ||
| * Sets a function to get an execution service to be used for validation | ||
| * | ||
| * @param getExecutorService function to get an ExecutorService given a requestedMaxThread size | ||
| */ | ||
| @Experimental | ||
| @InterfaceStability.Unstable | ||
| def setExecutorService(getExecutorService: (Int) => ExecutorService): Unit = { | ||
|
||
| executorServiceFactory = getExecutorService | ||
| } | ||
|
|
||
| protected def transformSchemaImpl(schema: StructType): StructType = { | ||
| require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps") | ||
| val firstEstimatorParamMap = $(estimatorParamMaps).head | ||
|
|
@@ -85,6 +122,8 @@ private[ml] trait ValidatorParams extends HasSeed with Params { | |
| instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) | ||
| instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) | ||
| } | ||
|
|
||
| setDefault(numParallelEval -> 1) | ||
| } | ||
|
|
||
| private[ml] object ValidatorParams { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,6 +121,33 @@ class CrossValidatorSuite | |
| } | ||
| } | ||
|
|
||
| test("cross validation with parallel evaluation") { | ||
| val lr = new LogisticRegression | ||
| val lrParamMaps = new ParamGridBuilder() | ||
| .addGrid(lr.regParam, Array(0.001, 1000.0)) | ||
| .addGrid(lr.maxIter, Array(0, 3)) | ||
| .build() | ||
| val eval = new BinaryClassificationEvaluator | ||
| val cv = new CrossValidator() | ||
| .setEstimator(lr) | ||
| .setEstimatorParamMaps(lrParamMaps) | ||
| .setEvaluator(eval) | ||
| .setNumFolds(2) | ||
| .setNumParallelEval(1) | ||
| val cvSerialModel = cv.fit(dataset) | ||
| cv.setNumParallelEval(2) | ||
|
||
| val cvParallelModel = cv.fit(dataset) | ||
|
|
||
| val serialMetrics = cvSerialModel.avgMetrics.sorted | ||
| val parallelMetrics = cvParallelModel.avgMetrics.sorted | ||
| assert(serialMetrics === parallelMetrics) | ||
|
|
||
| val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression] | ||
| val parentParallel = cvParallelModel.bestModel.parent.asInstanceOf[LogisticRegression] | ||
| assert(parentSerial.getRegParam === parentParallel.getRegParam) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to also check
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should probably be done in the test that already runs |
||
| assert(parentSerial.getMaxIter === parentParallel.getMaxIter) | ||
| } | ||
|
|
||
| test("read/write: CrossValidator with simple estimator") { | ||
| val lr = new LogisticRegression().setMaxIter(3) | ||
| val evaluator = new BinaryClassificationEvaluator() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a little clearer to say something like "By default, sets of parameters ... will be evaluated in serial. Parameter evaluation can be done in parallel by setting ..."