Skip to content

Commit de07d06

Browse files
yu-iskwmengxr
authored andcommitted
[SPARK-10266][DOCUMENTATION, ML] Fixed @SInCE annotation for ml.tunning
cc mengxr noel-smith I worked on this issues based on #8729. ehsanmok thank you for your contricution! Author: Yu ISHIKAWA <[email protected]> Author: Ehsan M.Kermani <[email protected]> Closes #9338 from yu-iskw/JIRA-10266.
1 parent 452690b commit de07d06

3 files changed

Lines changed: 58 additions & 16 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@ package org.apache.spark.ml.tuning
1919

2020
import com.github.fommil.netlib.F2jBLAS
2121
import org.apache.hadoop.fs.Path
22-
import org.json4s.{JObject, DefaultFormats}
2322
import org.json4s.jackson.JsonMethods._
23+
import org.json4s.{DefaultFormats, JObject}
2424

25-
import org.apache.spark.ml.classification.OneVsRestParams
26-
import org.apache.spark.ml.feature.RFormulaModel
27-
import org.apache.spark.{SparkContext, Logging}
25+
import org.apache.spark.{Logging, SparkContext}
2826
import org.apache.spark.annotation.{Experimental, Since}
2927
import org.apache.spark.ml._
28+
import org.apache.spark.ml.classification.OneVsRestParams
3029
import org.apache.spark.ml.evaluation.Evaluator
30+
import org.apache.spark.ml.feature.RFormulaModel
3131
import org.apache.spark.ml.param._
32-
import org.apache.spark.ml.util._
3332
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
33+
import org.apache.spark.ml.util._
3434
import org.apache.spark.mllib.util.MLUtils
3535
import org.apache.spark.sql.DataFrame
3636
import org.apache.spark.sql.types.StructType
@@ -58,26 +58,34 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
5858
* :: Experimental ::
5959
* K-fold cross validation.
6060
*/
61+
@Since("1.2.0")
6162
@Experimental
62-
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
63+
class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
64+
extends Estimator[CrossValidatorModel]
6365
with CrossValidatorParams with MLWritable with Logging {
6466

67+
@Since("1.2.0")
6568
def this() = this(Identifiable.randomUID("cv"))
6669

6770
private val f2jBLAS = new F2jBLAS
6871

6972
/** @group setParam */
73+
@Since("1.2.0")
7074
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
7175

7276
/** @group setParam */
77+
@Since("1.2.0")
7378
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
7479

7580
/** @group setParam */
81+
@Since("1.2.0")
7682
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
7783

7884
/** @group setParam */
85+
@Since("1.2.0")
7986
def setNumFolds(value: Int): this.type = set(numFolds, value)
8087

88+
@Since("1.4.0")
8189
override def fit(dataset: DataFrame): CrossValidatorModel = {
8290
val schema = dataset.schema
8391
transformSchema(schema, logging = true)
@@ -116,10 +124,12 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
116124
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
117125
}
118126

127+
@Since("1.4.0")
119128
override def transformSchema(schema: StructType): StructType = {
120129
$(estimator).transformSchema(schema)
121130
}
122131

132+
@Since("1.4.0")
123133
override def validateParams(): Unit = {
124134
super.validateParams()
125135
val est = $(estimator)
@@ -128,6 +138,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
128138
}
129139
}
130140

141+
@Since("1.4.0")
131142
override def copy(extra: ParamMap): CrossValidator = {
132143
val copied = defaultCopy(extra).asInstanceOf[CrossValidator]
133144
if (copied.isDefined(estimator)) {
@@ -308,26 +319,31 @@ object CrossValidator extends MLReadable[CrossValidator] {
308319
* @param avgMetrics Average cross-validation metrics for each paramMap in
309320
* [[CrossValidator.estimatorParamMaps]], in the corresponding order.
310321
*/
322+
@Since("1.2.0")
311323
@Experimental
312324
class CrossValidatorModel private[ml] (
313-
override val uid: String,
314-
val bestModel: Model[_],
315-
val avgMetrics: Array[Double])
325+
@Since("1.4.0") override val uid: String,
326+
@Since("1.2.0") val bestModel: Model[_],
327+
@Since("1.5.0") val avgMetrics: Array[Double])
316328
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
317329

330+
@Since("1.4.0")
318331
override def validateParams(): Unit = {
319332
bestModel.validateParams()
320333
}
321334

335+
@Since("1.4.0")
322336
override def transform(dataset: DataFrame): DataFrame = {
323337
transformSchema(dataset.schema, logging = true)
324338
bestModel.transform(dataset)
325339
}
326340

341+
@Since("1.4.0")
327342
override def transformSchema(schema: StructType): StructType = {
328343
bestModel.transformSchema(schema)
329344
}
330345

346+
@Since("1.4.0")
331347
override def copy(extra: ParamMap): CrossValidatorModel = {
332348
val copied = new CrossValidatorModel(
333349
uid,

mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,23 @@ package org.apache.spark.ml.tuning
2020
import scala.annotation.varargs
2121
import scala.collection.mutable
2222

23-
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.annotation.{Experimental, Since}
2424
import org.apache.spark.ml.param._
2525

2626
/**
2727
* :: Experimental ::
2828
* Builder for a param grid used in grid search-based model selection.
2929
*/
30+
@Since("1.2.0")
3031
@Experimental
31-
class ParamGridBuilder {
32+
class ParamGridBuilder @Since("1.2.0") {
3233

3334
private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]]
3435

3536
/**
3637
* Sets the given parameters in this grid to fixed values.
3738
*/
39+
@Since("1.2.0")
3840
def baseOn(paramMap: ParamMap): this.type = {
3941
baseOn(paramMap.toSeq: _*)
4042
this
@@ -43,6 +45,7 @@ class ParamGridBuilder {
4345
/**
4446
* Sets the given parameters in this grid to fixed values.
4547
*/
48+
@Since("1.2.0")
4649
@varargs
4750
def baseOn(paramPairs: ParamPair[_]*): this.type = {
4851
paramPairs.foreach { p =>
@@ -54,6 +57,7 @@ class ParamGridBuilder {
5457
/**
5558
* Adds a param with multiple values (overwrites if the input param exists).
5659
*/
60+
@Since("1.2.0")
5761
def addGrid[T](param: Param[T], values: Iterable[T]): this.type = {
5862
paramGrid.put(param, values)
5963
this
@@ -64,41 +68,47 @@ class ParamGridBuilder {
6468
/**
6569
* Adds a double param with multiple values.
6670
*/
71+
@Since("1.2.0")
6772
def addGrid(param: DoubleParam, values: Array[Double]): this.type = {
6873
addGrid[Double](param, values)
6974
}
7075

7176
/**
7277
* Adds a int param with multiple values.
7378
*/
79+
@Since("1.2.0")
7480
def addGrid(param: IntParam, values: Array[Int]): this.type = {
7581
addGrid[Int](param, values)
7682
}
7783

7884
/**
7985
* Adds a float param with multiple values.
8086
*/
87+
@Since("1.2.0")
8188
def addGrid(param: FloatParam, values: Array[Float]): this.type = {
8289
addGrid[Float](param, values)
8390
}
8491

8592
/**
8693
* Adds a long param with multiple values.
8794
*/
95+
@Since("1.2.0")
8896
def addGrid(param: LongParam, values: Array[Long]): this.type = {
8997
addGrid[Long](param, values)
9098
}
9199

92100
/**
93101
* Adds a boolean param with true and false.
94102
*/
103+
@Since("1.2.0")
95104
def addGrid(param: BooleanParam): this.type = {
96105
addGrid[Boolean](param, Array(true, false))
97106
}
98107

99108
/**
100109
* Builds and returns all combinations of parameters specified by the param grid.
101110
*/
111+
@Since("1.2.0")
102112
def build(): Array[ParamMap] = {
103113
var paramMaps = Array(new ParamMap)
104114
paramGrid.foreach { case (param, values) =>

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.tuning
1919

2020
import org.apache.spark.Logging
21-
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.annotation.{Experimental, Since}
2222
import org.apache.spark.ml.evaluation.Evaluator
2323
import org.apache.spark.ml.{Estimator, Model}
2424
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
@@ -51,24 +51,32 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
5151
* and uses evaluation metric on the validation set to select the best model.
5252
* Similar to [[CrossValidator]], but only splits the set once.
5353
*/
54+
@Since("1.5.0")
5455
@Experimental
55-
class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel]
56+
class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String)
57+
extends Estimator[TrainValidationSplitModel]
5658
with TrainValidationSplitParams with Logging {
5759

60+
@Since("1.5.0")
5861
def this() = this(Identifiable.randomUID("tvs"))
5962

6063
/** @group setParam */
64+
@Since("1.5.0")
6165
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
6266

6367
/** @group setParam */
68+
@Since("1.5.0")
6469
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
6570

6671
/** @group setParam */
72+
@Since("1.5.0")
6773
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
6874

6975
/** @group setParam */
76+
@Since("1.5.0")
7077
def setTrainRatio(value: Double): this.type = set(trainRatio, value)
7178

79+
@Since("1.5.0")
7280
override def fit(dataset: DataFrame): TrainValidationSplitModel = {
7381
val schema = dataset.schema
7482
transformSchema(schema, logging = true)
@@ -108,10 +116,12 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali
108116
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
109117
}
110118

119+
@Since("1.5.0")
111120
override def transformSchema(schema: StructType): StructType = {
112121
$(estimator).transformSchema(schema)
113122
}
114123

124+
@Since("1.5.0")
115125
override def validateParams(): Unit = {
116126
super.validateParams()
117127
val est = $(estimator)
@@ -120,6 +130,7 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali
120130
}
121131
}
122132

133+
@Since("1.5.0")
123134
override def copy(extra: ParamMap): TrainValidationSplit = {
124135
val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit]
125136
if (copied.isDefined(estimator)) {
@@ -140,26 +151,31 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali
140151
* @param bestModel Estimator determined best model.
141152
* @param validationMetrics Evaluated validation metrics.
142153
*/
154+
@Since("1.5.0")
143155
@Experimental
144156
class TrainValidationSplitModel private[ml] (
145-
override val uid: String,
146-
val bestModel: Model[_],
147-
val validationMetrics: Array[Double])
157+
@Since("1.5.0") override val uid: String,
158+
@Since("1.5.0") val bestModel: Model[_],
159+
@Since("1.5.0") val validationMetrics: Array[Double])
148160
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
149161

162+
@Since("1.5.0")
150163
override def validateParams(): Unit = {
151164
bestModel.validateParams()
152165
}
153166

167+
@Since("1.5.0")
154168
override def transform(dataset: DataFrame): DataFrame = {
155169
transformSchema(dataset.schema, logging = true)
156170
bestModel.transform(dataset)
157171
}
158172

173+
@Since("1.5.0")
159174
override def transformSchema(schema: StructType): StructType = {
160175
bestModel.transformSchema(schema)
161176
}
162177

178+
@Since("1.5.0")
163179
override def copy(extra: ParamMap): TrainValidationSplitModel = {
164180
val copied = new TrainValidationSplitModel (
165181
uid,

0 commit comments

Comments
 (0)