Skip to content

Commit f3fe554

Browse files
committed
[SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null string array
## What changes were proposed in this pull request? To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type in NGram ## How was this patch tested? Jenkins tests. Author: Sean Owen <[email protected]> Closes apache#15179 from srowen/SPARK-10835.
1 parent 7c38252 commit f3fe554

2 files changed

Lines changed: 23 additions & 1 deletion

File tree

mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params
108108
* Validate and transform the input schema.
109109
*/
110110
protected def validateAndTransformSchema(schema: StructType): StructType = {
111-
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
111+
val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
112+
SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
112113
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
113114
}
114115
}

mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
207207
val newInstance = testDefaultReadWrite(instance)
208208
assert(newInstance.getVectors.collect() === instance.getVectors.collect())
209209
}
210+
211+
test("Word2Vec works with input that is non-nullable (NGram)") {
212+
val spark = this.spark
213+
import spark.implicits._
214+
215+
val sentence = "a q s t q s t b b b s t m s t m q "
216+
val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text")
217+
218+
val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams")
219+
val ngramDF = ngram.transform(docDF)
220+
221+
val model = new Word2Vec()
222+
.setVectorSize(2)
223+
.setInputCol("ngrams")
224+
.setOutputCol("result")
225+
.fit(ngramDF)
226+
227+
// Just test that this transformation succeeds
228+
model.transform(ngramDF).collect()
229+
}
230+
210231
}
211232

0 commit comments

Comments
 (0)