Skip to content

Commit 6cbde33

Browse files
yanboliangsrowen
authored andcommitted
[SPARK-16750][FOLLOW-UP][ML] Add transformSchema for StringIndexer/VectorAssembler and fix failed tests.
## What changes were proposed in this pull request? This is follow-up for #14378. When we add ```transformSchema``` for all estimators and transformers, I found there are tests failed for ```StringIndexer``` and ```VectorAssembler```. So I moved these parts of work separately in this PR, to make it more clear to review. The corresponding tests should throw ```IllegalArgumentException``` at schema validation period after we add ```transformSchema```. It's efficient that to throw exception at the start of ```fit``` or ```transform``` rather than during the process. ## How was this patch tested? Modified unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14455 from yanboliang/transformSchema.
1 parent 1f96c97 commit 6cbde33

4 files changed

Lines changed: 16 additions & 5 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class StringIndexer @Since("1.4.0") (
8585

8686
@Since("2.0.0")
8787
override def fit(dataset: Dataset[_]): StringIndexerModel = {
88+
transformSchema(dataset.schema, logging = true)
8889
val counts = dataset.select(col($(inputCol)).cast(StringType))
8990
.rdd
9091
.map(_.getString(0))
@@ -160,7 +161,7 @@ class StringIndexerModel (
160161
"Skip StringIndexerModel.")
161162
return dataset.toDF
162163
}
163-
validateAndTransformSchema(dataset.schema)
164+
transformSchema(dataset.schema, logging = true)
164165

165166
val indexer = udf { label: String =>
166167
if (labelToIndex.contains(label)) {
@@ -305,6 +306,7 @@ class IndexToString private[ml] (@Since("1.5.0") override val uid: String)
305306

306307
@Since("2.0.0")
307308
override def transform(dataset: Dataset[_]): DataFrame = {
309+
transformSchema(dataset.schema, logging = true)
308310
val inputColSchema = dataset.schema($(inputCol))
309311
// If the labels array is empty use column metadata
310312
val values = if (!isDefined(labels) || $(labels).isEmpty) {

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
5151

5252
@Since("2.0.0")
5353
override def transform(dataset: Dataset[_]): DataFrame = {
54+
transformSchema(dataset.schema, logging = true)
5455
// Schema transformation.
5556
val schema = dataset.schema
5657
lazy val first = dataset.toDF.first()

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,20 @@ class StringIndexerSuite
120120

121121
test("StringIndexerModel can't overwrite output column") {
122122
val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
123+
intercept[IllegalArgumentException] {
124+
new StringIndexer()
125+
.setInputCol("input")
126+
.setOutputCol("output")
127+
.fit(df)
128+
}
129+
123130
val indexer = new StringIndexer()
124131
.setInputCol("input")
125-
.setOutputCol("output")
132+
.setOutputCol("indexedInput")
126133
.fit(df)
134+
127135
intercept[IllegalArgumentException] {
128-
indexer.transform(df)
136+
indexer.setOutputCol("output").transform(df)
129137
}
130138
}
131139

mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ class VectorAssemblerSuite
7474
val assembler = new VectorAssembler()
7575
.setInputCols(Array("a", "b", "c"))
7676
.setOutputCol("features")
77-
val thrown = intercept[SparkException] {
77+
val thrown = intercept[IllegalArgumentException] {
7878
assembler.transform(df)
7979
}
80-
assert(thrown.getMessage contains "VectorAssembler does not support the StringType type")
80+
assert(thrown.getMessage contains "Data type StringType is not supported")
8181
}
8282

8383
test("ML attributes") {

0 commit comments

Comments
 (0)