Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
5650e98
Changed CrossValidator and TrainValidationSplit fit methods to evalua…
BryanCutler Jan 31, 2017
36a1a68
made closure vars more explicit, moved param default to trait
BryanCutler Feb 14, 2017
b051afa
added paramvalidator for numParallelEval to ensure >=1
BryanCutler Feb 14, 2017
46fe252
added test cases for CrossValidation and TrainValidationSplit
BryanCutler Feb 15, 2017
1274ba4
added numParallelEval param usage to examples
BryanCutler Feb 15, 2017
80ac2fd
added documentation to ml-tuning
BryanCutler Feb 16, 2017
8126710
changed sliding window limit to use a semaphore instead to prevent wa…
BryanCutler Feb 16, 2017
6a9b735
added note about parallelism capped by Scala collection thread pool, …
BryanCutler Feb 16, 2017
1c2e391
reworked to use ExecutorService and Futures
BryanCutler Feb 28, 2017
9e055cd
fixed wildcard import
BryanCutler Feb 28, 2017
97ad7b4
made doc changes
BryanCutler Apr 11, 2017
5e8a086
changed ExecutorService factory to a trait to be compatible with Java
BryanCutler Apr 12, 2017
864c99c
Merge remote-tracking branch 'upstream/master' into parallel-model-ev…
BryanCutler Jun 13, 2017
ad8a870
Changed ExecutorService to be set explicitly instead of factory
BryanCutler Jun 14, 2017
911af1d
added HasParallelism trait
BryanCutler Aug 23, 2017
658aacb
Updated to use Trait HasParallelsim
BryanCutler Aug 23, 2017
2c73b0b
fixed up docs
BryanCutler Aug 23, 2017
7a8221b
removed blas calculation for CrossValidator metric calc, was not nece…
BryanCutler Sep 5, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/ml-tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ for multiclass problems. The default metric used to choose the best `ParamMap` c
method in each of these evaluators.

To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility.
Sets of parameters from the parameter grid can be evaluated in parallel by setting `numParallelEval` with a value of 2 or more (a value of 1 will evaluate in serial) before running model selection with `CrossValidator` or `TrainValidationSplit`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a little clearer to say something like "By default, sets of parameters ... will be evaluated in serial. Parameter evaluation can be done in parallel by setting ..."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a note that this only works in Scala/Java, not Python (or R)

The value of `numParallelEval` should be chosen carefully to maximize parallelism without exceeding cluster resources, and will be capped at the number of cores in the driver system. Generally speaking, a value up to 10 should be sufficient for most clusters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also will need to mention that custom ExecutorService can be specified, and some detail on the default thread pool it creates (and that it is a new separate pool to avoid blocking any of the default Scala pools).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since that API is marked as experimental, maybe it would be better to not document right away until we are sure this is what we need?


# Cross-Validation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ object ModelSelectionViaCrossValidationExample {
.setEvaluator(new BinaryClassificationEvaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(2) // Use 3+ in practice
.setNumParallelEval(2) // Evaluate up to 2 parameter settings in parallel

// Run cross-validation, and choose the best set of parameters.
val cvModel = cv.fit(training)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ object ModelSelectionViaTrainValidationSplitExample {
.setEstimatorParamMaps(paramGrid)
// 80% of the data will be used for training and the remaining 20% for validation.
.setTrainRatio(0.8)
// Evaluate up to 2 parameter settings in parallel
.setNumParallelEval(2)
Copy link
Member Author

@BryanCutler BryanCutler Apr 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: I should probably set this in Java too, to be consistent


// Run train validation split, and choose the best set of parameters.
val model = trainValidationSplit.fit(training)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,24 @@ package org.apache.spark.ml.tuning
import java.util.{List => JList}

import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration

import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats

import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils


/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
Expand Down Expand Up @@ -91,6 +95,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
@Since("2.2.0")
def setNumParallelEval(value: Int): this.type = set(numParallelEval, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
Expand All @@ -100,31 +108,60 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)

// Create execution context, run in serial if numParallelEval is 1
val executionContext = $(numParallelEval) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok to make this implicit val ... - it makes the Future code below clearer and it will be in immediate scope for all the calls so I don't think we need to specify.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of prefer to explicitly set the execution context. It does add some clutter, but then it is clear what calls rely on it, just in case some things get moved around in the future. Do you mind it we leave as is?

case 1 =>
ThreadUtils.sameThread
case n =>
ExecutionContext.fromExecutorService(executorServiceFactory(n))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my other comments about the API for setting executor service.


val instr = Instrumentation.create(this, dataset)
instr.logParams(numFolds, seed)
logTuningParams(instr)

// Compute metrics for each model over each split
logDebug(s"Running cross-validation with level of parallelism: $numParallelEval.")
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
trainingDataset.unpersist()
var i = 0
while (i < numModels) {
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
metrics(i) += metric
i += 1

// Fit models in a Future with thread-pool size determined by '$numParallelEval'
val models = epm.map { paramMap =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use the implicit val above we can do:

    val models = epm.map { paramMap =>
      Future { est.fit(trainingDataset, paramMap) }.mapTo[Model[_]]
    }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think use var name modelFutures instead of models will be more clear.

Future[Model[_]] {
val model = est.fit(trainingDataset, paramMap)
model.asInstanceOf[Model[_]]
} (executionContext)
}

Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise:

    Future.sequence[Model[_], Iterable](models).onComplete { _ =>
      trainingDataset.unpersist()
    }

trainingDataset.unpersist()
} (executionContext)

// Evaluate models in a Future with thread-pool size determined by '$numParallelEval'
val foldMetricFutures = models.zip(epm).map { case (modelFuture, paramMap) =>
modelFuture.flatMap { model =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note we could use for comprehension here. But I tried it and it doesn't really make it all that much simpler, and the explicit flatMap and Future here makes it a bit clearer.

Future {
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can then remove the (executionContext) here and elsewhere

} (executionContext)
}

// Wait for metrics to be calculated before upersisting validation dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: upersisting -> unpersisting

val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to also use sequence here?

    val metrics = (ThreadUtils.awaitResult(
      Future.sequence[Double, Iterable](metricFutures), Duration.Inf)).toArray

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that, but since it's a blocking call anyway, it will still be bound by the longest running thread.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, not a big deal either way

validationDataset.unpersist()
}
foldMetrics
}.transpose.map(_.sum)

// Calculate average metric over all splits
f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this line code because it use low-level api which make the code difficult to read.
And here is not bottleneck.
So I think we can simply use:

val metrics = ...
   ....
   .transpose.map(_.sum / $(numFolds))

instead.
What do you think about it ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must say I wondered why the author of this bothered with a BLAS call to do the computation... even if there were a few 1000s param combinations it would still not be significant overhead. +1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. I'll go ahead and make that change here.


logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.ml.tuning
import java.util.{List => JList}

import scala.collection.JavaConverters._
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.language.existentials

import org.apache.hadoop.fs.Path
Expand All @@ -33,6 +35,7 @@ import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils

/**
* Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
Expand Down Expand Up @@ -87,37 +90,64 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
@Since("2.2.0")
def setNumParallelEval(value: Int): this.type = set(numParallelEval, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)

// Create execution context, run in serial if numParallelEval is 1
val executionContext = $(numParallelEval) match {
case 1 =>
ThreadUtils.sameThread
case n =>
ExecutionContext.fromExecutorService(executorServiceFactory(n))
}

val instr = Instrumentation.create(this, dataset)
instr.logParams(trainRatio, seed)
logTuningParams(instr)

logDebug(s"Running validation with level of parallelism: $numParallelEval.")
val Array(trainingDataset, validationDataset) =
dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
trainingDataset.cache()
validationDataset.cache()

// multi-model training
// Fit models in a Future with thread-pool size determined by '$numParallelEval'
logDebug(s"Train split with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
trainingDataset.unpersist()
var i = 0
while (i < numModels) {
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
metrics(i) += metric
i += 1
val models = epm.map { paramMap =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

models ==> modelFutures

Future[Model[_]] {
val model = est.fit(trainingDataset, paramMap)
model.asInstanceOf[Model[_]]
} (executionContext)
}

Future.sequence[Model[_], Iterable](models)(implicitly, executionContext).onComplete { _ =>
trainingDataset.unpersist()
} (executionContext)

// Evaluate models concurrently, limited by a barrier with '$numParallelEval' permits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is stale

val metricFutures = models.zip(epm).map { case (modelFuture, paramMap) =>
modelFuture.flatMap { model =>
Future {
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
} (executionContext)
}

// Wait for all metrics to be calculated
val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))

validationDataset.unpersist()

logInfo(s"Train validation split metrics: ${metrics.toSeq}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@

package org.apache.spark.ml.tuning

import java.util.concurrent.ExecutorService

import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, _}
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamPair, Params, ParamValidators}
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils

/**
* Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
Expand Down Expand Up @@ -67,6 +71,39 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
/** @group getParam */
def getEvaluator: Evaluator = $(evaluator)

/**
* param to control the number of models evaluated in parallel
* Default: 1
*
* @group param
*/
val numParallelEval: IntParam = new IntParam(this, "numParallelEval",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps set this as group expertParam for now

"max number of models to evaluate in parallel, 1 for serial evaluation",
ParamValidators.gtEq(1))

/** @group getParam */
def getNumParallelEval: Int = $(numParallelEval)

/**
* Creates a execution service to be used for validation, defaults to a thread-pool with
* size of `numParallelEval`
*/
protected var executorServiceFactory: (Int) => ExecutorService = {
(requestedMaxThreads: Int) => ThreadUtils.newDaemonCachedThreadPool(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may be problematic (and probably unnecessary) to have daemon threads here - we don't have a shutdown hook.

We could add a shutdown hook to tie the lifecycle to SparkContext (SparkSession); perhaps in setExecutorService:

...
ShutdownHookManager.addShutdownHook(() => executorService.shutdown())
...

Copy link
Member Author

@BryanCutler BryanCutler Apr 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my thinking was that if the thread calling fit is terminated, it would have to be the JVM shutting down which would exit without waiting for these daemon threads. We don't really care at what point the daemon threads stop or if they stop abruptly since any unfinished work is useless. So I'm not sure if adding a shutdownHook would do anything different?

On the other hand, if the SparkSession wanted to cancel the running threads with the JVM still running, I think it could do that if it provided it's own ExecutorService.

s"${this.getClass.getSimpleName}-thread-pool", requestedMaxThreads)
}

/**
* Sets a function to get an execution service to be used for validation
*
* @param getExecutorService function to get an ExecutorService given a requestedMaxThread size
*/
@Experimental
@InterfaceStability.Unstable
def setExecutorService(getExecutorService: (Int) => ExecutorService): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer a simpler API:

 def setExecutorService(executorService: ExecutorService): Unit = {

Not sure the function version will work nicely with Java.

I think the idea should be that setting numParallelEval will specify the number of threads for the default thread-pool, while a custom one can be set with this method (advanced usage). If set, the custom one will override.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are probably right about the Java thing, I'll test that out.

custom one can be set with this method (advanced usage). If set, the custom one will override.

That's how this works, the executor service can be overridden and set any size thread pool, it is just passed the numParallelEval param for convenience in case it wants to use it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One reason I did not do the simpler API that you mentioned is then it requires the ExecutorService to be instantiated with a default and then exists for the life of CrossValidator or TrainValidationSplit. The way it is currently, by default the ExecutorService is created at the beginning of fit when it is needed, and then cleaned up at the end.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see what you are getting at. Even if I override with a custom ExecutorService but leave numParallelEval to 1, it will not use the custom.. I can fix that

executorServiceFactory = getExecutorService
}

protected def transformSchemaImpl(schema: StructType): StructType = {
require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
val firstEstimatorParamMap = $(estimatorParamMaps).head
Expand All @@ -85,6 +122,8 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)
}

setDefault(numParallelEval -> 1)
}

private[ml] object ValidatorParams {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,33 @@ class CrossValidatorSuite
}
}

test("cross validation with parallel evaluation") {
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1000.0))
.addGrid(lr.maxIter, Array(0, 3))
.build()
val eval = new BinaryClassificationEvaluator
val cv = new CrossValidator()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumFolds(2)
.setNumParallelEval(1)
val cvSerialModel = cv.fit(dataset)
cv.setNumParallelEval(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is not deterministic now. @MLnick , do you know if we have some utilities to retry tests a much of times?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the models could now be evaluated in a different order, but the end result of returning the best model would be deterministic and shouldn't ever change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'd say we don't care about deterministic execution order here - we care about the result being the same regardless of execution order.

val cvParallelModel = cv.fit(dataset)

val serialMetrics = cvSerialModel.avgMetrics.sorted
val parallelMetrics = cvParallelModel.avgMetrics.sorted
assert(serialMetrics === parallelMetrics)

val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression]
val parentParallel = cvParallelModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parentSerial.getRegParam === parentParallel.getRegParam)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to also check uid equality here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should probably be done in the test that already runs checkCopy on line 62 (at least until we cleanup these basic checks). I'll take a look at that.

assert(parentSerial.getMaxIter === parentParallel.getMaxIter)
}

test("read/write: CrossValidator with simple estimator") {
val lr = new LogisticRegression().setMaxIter(3)
val evaluator = new BinaryClassificationEvaluator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@ class TrainValidationSplitSuite

import testImplicits._

test("train validation with logistic regression") {
val dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
@transient var dataset: Dataset[_] = _

override def beforeAll(): Unit = {
super.beforeAll()
dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
}

test("train validation with logistic regression") {
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1000.0))
Expand Down Expand Up @@ -118,6 +123,32 @@ class TrainValidationSplitSuite
}
}

test("train validation with parallel evaluation") {
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1000.0))
.addGrid(lr.maxIter, Array(0, 3))
.build()
val eval = new BinaryClassificationEvaluator
val cv = new TrainValidationSplit()
.setEstimator(lr)
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumParallelEval(1)
val cvSerialModel = cv.fit(dataset)
cv.setNumParallelEval(2)
val cvParallelModel = cv.fit(dataset)

val serialMetrics = cvSerialModel.validationMetrics.sorted
val parallelMetrics = cvParallelModel.validationMetrics.sorted
assert(serialMetrics === parallelMetrics)

val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression]
val parentParallel = cvParallelModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parentSerial.getRegParam === parentParallel.getRegParam)
assert(parentSerial.getMaxIter === parentParallel.getMaxIter)
}

test("read/write: TrainValidationSplit") {
val lr = new LogisticRegression().setMaxIter(3)
val evaluator = new BinaryClassificationEvaluator()
Expand Down