Skip to content

Commit f0fd13c

Browse files
committed
Quick addition of window size test
1 parent de7209c commit f0fd13c

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ private[feature] trait Word2VecBase extends Params
4848
/** @group getParam */
4949
def getVectorSize: Int = $(vectorSize)
5050

51+
/**
52+
* The window size (context words from [-window, window])
53+
* @group param
54+
*/
55+
final val windowSize = new IntParam(
56+
this, "windowSize", "the window size (context words from [-window, window])")
57+
setDefault(windowSize -> 5)
58+
59+
/** @group getParam */
60+
def getWindowSize: Int = $(windowSize)
61+
5162
/**
5263
* Number of partitions for sentences of words.
5364
* @group param
@@ -102,6 +113,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
102113
/** @group setParam */
103114
def setVectorSize(value: Int): this.type = set(vectorSize, value)
104115

116+
/** @group setParam */
117+
def setWindowSize(value: Int): this.type = set(windowSize, value)
118+
105119
/** @group setParam */
106120
def setStepSize(value: Double): this.type = set(stepSize, value)
107121

@@ -127,6 +141,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
127141
.setNumPartitions($(numPartitions))
128142
.setSeed($(seed))
129143
.setVectorSize($(vectorSize))
144+
.setWindowSize($(windowSize))
130145
.fit(input)
131146
copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
132147
}

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ class Word2Vec extends Serializable with Logging {
128128
this
129129
}
130130

131+
/**
132+
* Sets the window of words (default: 5)
133+
*/
134+
@Since("1.6.0")
135+
def setWindowSize(window: Int): this.type = {
136+
this.window = window
137+
this
138+
}
139+
131140
/**
132141
* Sets minCount, the minimum number of times a token must appear to be included in the word2vec
133142
* model's vocabulary (default: 5).
@@ -144,7 +153,7 @@ class Word2Vec extends Serializable with Logging {
144153
private val MAX_SENTENCE_LENGTH = 1000
145154

146155
/** context words from [-window, window] */
147-
private val window = 5
156+
private var window = 5
148157

149158
private var trainWordsCount = 0
150159
private var vocabSize = 0

mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,42 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
131131
expectedSimilarity.zip(similarity).map {
132132
case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
133133
}
134+
}
135+
136+
test("window size") {
137+
138+
val sqlContext = new SQLContext(sc)
139+
import sqlContext.implicits._
140+
141+
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
142+
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
143+
val docDF = doc.zip(doc).toDF("text", "alsotext")
134144

145+
val model = new Word2Vec()
146+
.setVectorSize(3)
147+
.setWindowSize(2)
148+
.setInputCol("text")
149+
.setOutputCol("result")
150+
.setSeed(42L)
151+
.fit(docDF)
152+
153+
val (synonyms, similarity) = model.findSynonyms("a", 6).map {
154+
case Row(w: String, sim: Double) => (w, sim)
155+
}.collect().unzip
156+
157+
// Increase the window size
158+
val biggerModel = new Word2Vec()
159+
.setVectorSize(3)
160+
.setInputCol("text")
161+
.setOutputCol("result")
162+
.setSeed(42L)
163+
.setWindowSize(10)
164+
.fit(docDF)
165+
166+
val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map {
167+
case Row(w: String, sim: Double) => (w, sim)
168+
}.collect().unzip
169+
// The similarity score should be very different with the larger window
170+
assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
135171
}
136172
}
137-

0 commit comments

Comments
 (0)