Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees

/**
* Number of trees in ensemble
*/
@Since("2.0.0")
val getNumTrees: Int = trees.length

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils

Expand Down Expand Up @@ -176,8 +176,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
}
}

override def validateParams(): Unit = {
override protected def validateAndTransformSchema(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually necessary? Before, validateParams() was never used. Seems like we could just remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bug that validateParams was never used. It should validate params interaction before fitting(if necessary), this is why we deprecate validateParams and move what it does to transformSchema. We do not have corresponding test cases before, so no test was broken when we deprecated validateParams. I added test cases in this PR.

schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
checkThresholdConsistency()
super.validateAndTransformSchema(schema, fitting, featuresDataType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
Expand Down Expand Up @@ -221,15 +221,6 @@ class RandomForestClassificationModel private[ml] (
}
}

/**
* Number of trees in ensemble
*
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
*/
// TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,6 @@ final class ChiSqSelectorModel private[ml] (
@Since("1.6.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* @group setParam
*/
@Since("1.6.0")
@deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
Expand Down
15 changes: 0 additions & 15 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -546,21 +546,6 @@ trait Params extends Identifiable with Serializable {
.map(m => m.invoke(this).asInstanceOf[Param[_]])
}

/**
* Validates parameter values stored internally.
* Raise an exception if any parameter value is invalid.
*
* This only needs to check for interactions between parameters.
* Parameter value checks which do not depend on other parameters are handled by
* `Param.validate()`. This method does not handle input/output column parameters;
* those are checked during schema validation.
* @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
*/
@deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
def validateParams(): Unit = {
// Do nothing by default. Override to handle Param interactions.
}

/**
* Explains a param.
* @param param input param, must belong to this instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ class GBTRegressionModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees

/**
* Number of trees in ensemble
*/
@Since("2.0.0")
val getNumTrees: Int = trees.length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have an @Since tag?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.


@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,6 @@ class LinearRegressionSummary private[regression] (
private val privateModel: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {

@deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0")
val model: LinearRegressionModel = privateModel

@transient private val metrics = new RegressionMetrics(
predictions
.select(col(predictionCol), col(labelCol).cast(DoubleType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
Expand Down Expand Up @@ -181,14 +181,6 @@ class RandomForestRegressionModel private[ml] (
_trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}

/**
* Number of trees in ensemble
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
*/
// TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
/** Trees in this ensemble. Warning: These have null parent Estimators. */
def trees: Array[M]

/**
* Number of trees in ensemble
*/
val getNumTrees: Int = trees.length

/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]

Expand Down
75 changes: 30 additions & 45 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,32 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}

/** Used for [[RandomForestParams]] */
private[ml] trait HasFeatureSubsetStrategy extends Params {
/**
* Parameters for Random Forest algorithms.
*/
private[ml] trait RandomForestParams extends TreeEnsembleParams {

/**
* Number of trees to train (>= 1).
* If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* TODO: Change to always do bootstrapping (simpler). SPARK-7130
* (default = 20)
*
* Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
* is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms
* are a bit different.
* @group param
*/
final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) is because the param maxIter controls how many trees a GBT has. The semantics in the algos are a bit different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, added.

ParamValidators.gtEq(1))

setDefault(numTrees -> 20)

/** @group setParam */
def setNumTrees(value: Int): this.type = set(numTrees, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these setter methods in traits Java-compatible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we already have setNumTrees which calls super.setNumTrees in RandomForestClassifier and RandomForestRegressor.


/** @group getParam */
final def getNumTrees: Int = $(numTrees)

/**
* The number of features to consider for splits at each tree node.
Expand Down Expand Up @@ -364,38 +388,6 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
}

/**
* Used for [[RandomForestParams]].
* This is separated out from [[RandomForestParams]] because of an issue with the
* `numTrees` method conflicting with this Param in the Estimator.
*/
private[ml] trait HasNumTrees extends Params {

/**
* Number of trees to train (>= 1).
* If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* TODO: Change to always do bootstrapping (simpler). SPARK-7130
* (default = 20)
* @group param
*/
final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
ParamValidators.gtEq(1))

setDefault(numTrees -> 20)

/** @group setParam */
def setNumTrees(value: Int): this.type = set(numTrees, value)

/** @group getParam */
final def getNumTrees: Int = $(numTrees)
}

/**
* Parameters for Random Forest algorithms.
*/
private[ml] trait RandomForestParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with HasNumTrees

private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Expand All @@ -405,15 +397,9 @@ private[spark] object RandomForestParams {
private[ml] trait RandomForestClassifierParams
extends RandomForestParams with TreeClassifierParams

private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with TreeClassifierParams

private[ml] trait RandomForestRegressorParams
extends RandomForestParams with TreeRegressorParams

private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with TreeRegressorParams

/**
* Parameters for Gradient-Boosted Tree algorithms.
*
Expand Down Expand Up @@ -443,12 +429,11 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
* (default = 0.1)
* @group setParam
*/
def setStepSize(value: Double): this.type = set(stepSize, value)

override def validateParams(): Unit = {
def setStepSize(value: Double): this.type = {
require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is where we want to do this? The deprecation warning said to move checks from here into transformSchema - but its possible the deprecation warning was incomplete.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to fail at set time rather than when fit is called. This was also not used before, seems we missed it when we deprecated validateParams() ?

Copy link
Contributor Author

@yanboliang yanboliang Nov 24, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original validateParams() was only used to check for interactions between parameters. Parameter value checks which do not depend on other parameters are handled by ````Param.validate()``` at the definition, i.e.

final val k = new IntParam(this, "k", "The number of clusters to create. " +
    "Must be > 1.", ParamValidators.gt(1))

However, stepSize was a trait and inherited by lots of sub-classes which have different constraints for this parameter, so I add the check at setter.
Yeah, I think the original place of this check(putting in validateParams) was a bug, since it does not involves interactions between parameters. We missed it when we deprecated validateParams(), since there is no test case to validate this param before. I added corresponding test in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes more sense :) +1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation should really be in the Param itself. It's not since it's a sharedParam, but perhaps we should just copy the shared Param code here. The problem with putting validation in the setter is that Params can get set in other ways too (such as gbt.set(stepSize, value)

getStepSize), "GBT parameter stepSize should be in interval (0, 1], " +
s"but it given invalid value $getStepSize.")
value), "GBT parameter stepSize should be in interval (0, 1], " +
s"but it given invalid value $value.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"it given invalid value" -> "it was given an invalid value"

set(stepSize, value)
}

/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private[util] sealed trait BaseReadWrite {
* Sets the Spark SQLContext to use for saving/loading.
*/
@Since("1.6.0")
@deprecated("Use session instead", "2.0.0")
@deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0")
def context(sqlContext: SQLContext): this.type = {
optionSparkSession = Option(sqlContext.sparkSession)
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
ParamsSuite.checkParams(model)
}

test("GBT parameter stepSize should be in interval (0, 1]") {
withClue("GBT parameter stepSize should be in interval (0, 1]") {
intercept[IllegalArgumentException] {
new GBTClassifier().setStepSize(10)
}
}
}

test("Binary classification with continuous features: Log Loss") {
val categoricalFeatures = Map.empty[Int, Int]
testCombinations.foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ class LogisticRegressionSuite
}
}
// thresholds and threshold must be consistent: values
withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
intercept[IllegalArgumentException] {
lr2.fit(smallBinaryDataset,
lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
}
}
withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
intercept[IllegalArgumentException] {
val lr2model = lr2.fit(smallBinaryDataset,
Expand Down
26 changes: 26 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,32 @@ object MimaExcludes {
// [SPARK-12221] Add CPU time to metrics
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
) ++ Seq(
// [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"),
ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
)
}

Expand Down
Loading