Skip to content

Commit c4da811

Browse files
committed
set parent for train validation split
1 parent 6bb2871 commit c4da811

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class TrainValidationSplitModel private[ml] (
221221
uid,
222222
bestModel.copy(extra).asInstanceOf[Model[_]],
223223
validationMetrics.clone())
224-
copyValues(copied, extra)
224+
copyValues(copied, extra).setParent(parent)
225225
}
226226

227227
@Since("2.0.0")

mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import org.apache.spark.ml.{Estimator, Model}
2222
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
2323
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
2424
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
25-
import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
25+
import org.apache.spark.ml.linalg.Vectors
2626
import org.apache.spark.ml.param.ParamMap
2727
import org.apache.spark.ml.param.shared.HasInputCol
2828
import org.apache.spark.ml.regression.LinearRegression
29-
import org.apache.spark.ml.util.DefaultReadWriteTest
29+
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
3030
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
3131
import org.apache.spark.sql.Dataset
3232
import org.apache.spark.sql.types.StructType
@@ -78,6 +78,10 @@ class TrainValidationSplitSuite
7878
.setTrainRatio(0.5)
7979
.setSeed(42L)
8080
val cvModel = cv.fit(dataset)
81+
82+
// copied model must have the same paren.
83+
MLTestingUtils.checkCopy(cvModel)
84+
8185
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
8286
assert(parent.getRegParam === 0.001)
8387
assert(parent.getMaxIter === 10)

0 commit comments

Comments
 (0)