Skip to content

Commit 0eb7f07

Browse files
author
VinceShieh
committed
code refactoring
Signed-off-by: VinceShieh <[email protected]>
1 parent 5d4b07f commit 0eb7f07

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,6 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
5656
"error (throw an error), or 'keep' (map unseen labels with indices [numLabels]).",
5757
ParamValidators.inArray(supportedHandleInvalids))
5858

59-
/** @group getParam */
60-
@Since("2.1.0")
61-
def getHandleInvalid: String = $(handleInvalid)
62-
63-
/** @group setParam */
64-
@Since("2.1.0")
65-
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
66-
setDefault(handleInvalid, ERROR_UNSEEN_LABEL)
6759
/** Validates and transforms the input schema. */
6860
protected def validateAndTransformSchema(schema: StructType): StructType = {
6961
val inputColName = $(inputCol)
@@ -105,6 +97,15 @@ class StringIndexer @Since("1.4.0") (
10597
@Since("1.4.0")
10698
def setOutputCol(value: String): this.type = set(outputCol, value)
10799

100+
/** @group getParam */
101+
@Since("2.1.0")
102+
def getHandleInvalid: String = $(handleInvalid)
103+
104+
/** @group setParam */
105+
@Since("2.1.0")
106+
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
107+
setDefault(handleInvalid, ERROR_UNSEEN_LABEL)
108+
108109
@Since("2.0.0")
109110
override def fit(dataset: Dataset[_]): StringIndexerModel = {
110111
transformSchema(dataset.schema, logging = true)
@@ -171,6 +172,15 @@ class StringIndexerModel (
171172
@Since("1.4.0")
172173
def setOutputCol(value: String): this.type = set(outputCol, value)
173174

175+
/** @group getParam */
176+
@Since("2.1.0")
177+
def getHandleInvalid: String = $(handleInvalid)
178+
179+
/** @group setParam */
180+
@Since("2.1.0")
181+
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
182+
setDefault(handleInvalid, ERROR_UNSEEN_LABEL)
183+
174184
@Since("2.0.0")
175185
override def transform(dataset: Dataset[_]): DataFrame = {
176186
if (!dataset.schema.fieldNames.contains($(inputCol))) {

0 commit comments

Comments
 (0)