Skip to content

Commit 84cc1be

Browse files
committed
update lda example
1 parent d65656c commit 84cc1be

1 file changed

Lines changed: 40 additions & 109 deletions

File tree

examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala

Lines changed: 40 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
// scalastyle:off println
1919
package org.apache.spark.examples.mllib
2020

21-
import java.text.BreakIterator
22-
23-
import scala.collection.mutable
21+
import org.apache.spark.ml.Pipeline
22+
import org.apache.spark.ml.feature.{CountVectorizerModel, CountVectorizer, StopWordsRemover, RegexTokenizer}
23+
import org.apache.spark.sql.{Row, SQLContext}
2424

2525
import scopt.OptionParser
2626

@@ -118,7 +118,7 @@ object LDAExample {
118118
// Load documents, and prepare them for LDA.
119119
val preprocessStart = System.nanoTime()
120120
val (corpus, vocabArray, actualNumTokens) =
121-
preprocess(sc, params.input, params.vocabSize, params.stopwordFile)
121+
preProcess(sc, params.input, params.vocabSize, params.stopwordFile)
122122
corpus.cache()
123123
val actualCorpusSize = corpus.count()
124124
val actualVocabSize = vocabArray.size
@@ -186,121 +186,52 @@ object LDAExample {
186186
* Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors.
187187
* @return (corpus, vocabulary as array, total token count in corpus)
188188
*/
189-
private def preprocess(
189+
private def preProcess(
190190
sc: SparkContext,
191191
paths: Seq[String],
192192
vocabSize: Int,
193-
stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {
193+
stopWordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {
194194

195195
// Get dataset of document texts
196196
// One document per line in each text file. If the input consists of many small files,
197197
// this can result in a large number of small partitions, which can degrade performance.
198198
// In this case, consider using coalesce() to create fewer, larger partitions.
199199
val textRDD: RDD[String] = sc.textFile(paths.mkString(","))
200-
201-
// Split text into words
202-
val tokenizer = new SimpleTokenizer(sc, stopwordFile)
203-
val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
204-
id -> tokenizer.getWords(text)
205-
}
206-
tokenized.cache()
207-
208-
// Counts words: RDD[(word, wordCount)]
209-
val wordCounts: RDD[(String, Long)] = tokenized
210-
.flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
211-
.reduceByKey(_ + _)
212-
wordCounts.cache()
213-
val fullVocabSize = wordCounts.count()
214-
// Select vocab
215-
// (vocab: Map[word -> id], total tokens after selecting vocab)
216-
val (vocab: Map[String, Int], selectedTokenCount: Long) = {
217-
val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
218-
// Use all terms
219-
wordCounts.collect().sortBy(-_._2)
220-
} else {
221-
// Sort terms to select vocab
222-
wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
223-
}
224-
(tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
200+
val sqlContext = new SQLContext(sc)
201+
import sqlContext.implicits._
202+
203+
val df = textRDD.toDF("texts")
204+
val customizedStopWords: Array[String] = if (stopWordFile.isEmpty) {
205+
Array.empty[String]
206+
} else {
207+
val stopWordText = sc.textFile(stopWordFile).collect()
208+
stopWordText.flatMap(_.stripMargin.split("\\s+"))
225209
}
226-
227-
val documents = tokenized.map { case (id, tokens) =>
228-
// Filter tokens by vocabulary, and create word count vector representation of document.
229-
val wc = new mutable.HashMap[Int, Int]()
230-
tokens.foreach { term =>
231-
if (vocab.contains(term)) {
232-
val termIndex = vocab(term)
233-
wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
234-
}
235-
}
236-
val indices = wc.keys.toArray.sorted
237-
val values = indices.map(i => wc(i).toDouble)
238-
239-
val sb = Vectors.sparse(vocab.size, indices, values)
240-
(id, sb)
241-
}
242-
243-
val vocabArray = new Array[String](vocab.size)
244-
vocab.foreach { case (term, i) => vocabArray(i) = term }
245-
246-
(documents, vocabArray, selectedTokenCount)
210+
val tokenizer = new RegexTokenizer()
211+
.setInputCol("texts")
212+
.setOutputCol("rawTokens")
213+
val stopWordsRemover = new StopWordsRemover()
214+
.setInputCol("rawTokens")
215+
.setOutputCol("tokens")
216+
stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
217+
val countVectorizer = new CountVectorizer()
218+
.setVocabSize(vocabSize)
219+
.setInputCol("tokens")
220+
.setOutputCol("vectors")
221+
222+
val pipeline = new Pipeline()
223+
.setStages(Array(tokenizer, stopWordsRemover, countVectorizer))
224+
225+
val model = pipeline.fit(df)
226+
val documents = model.transform(df)
227+
.select("vectors")
228+
.map { case Row(features: Vector) => features }
229+
.zipWithIndex()
230+
.map(_.swap)
231+
232+
(documents,
233+
model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary
234+
documents.map(_._2.numActives).sum().toLong) // total token count
247235
}
248236
}
249237

250-
/**
251-
* Simple Tokenizer.
252-
*
253-
* TODO: Formalize the interface, and make this a public class in mllib.feature
254-
*/
255-
private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {
256-
257-
private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
258-
Set.empty[String]
259-
} else {
260-
val stopwordText = sc.textFile(stopwordFile).collect()
261-
stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
262-
}
263-
264-
// Matches sequences of Unicode letters
265-
private val allWordRegex = "^(\\p{L}*)$".r
266-
267-
// Ignore words shorter than this length.
268-
private val minWordLength = 3
269-
270-
def getWords(text: String): IndexedSeq[String] = {
271-
272-
val words = new mutable.ArrayBuffer[String]()
273-
274-
// Use Java BreakIterator to tokenize text into words.
275-
val wb = BreakIterator.getWordInstance
276-
wb.setText(text)
277-
278-
// current,end index start,end of each word
279-
var current = wb.first()
280-
var end = wb.next()
281-
while (end != BreakIterator.DONE) {
282-
// Convert to lowercase
283-
val word: String = text.substring(current, end).toLowerCase
284-
// Remove short words and strings that aren't only letters
285-
word match {
286-
case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
287-
words += w
288-
case _ =>
289-
}
290-
291-
current = end
292-
try {
293-
end = wb.next()
294-
} catch {
295-
case e: Exception =>
296-
// Ignore remaining text in line.
297-
// This is a known bug in BreakIterator (for some Java versions),
298-
// which fails when it sees certain characters.
299-
end = BreakIterator.DONE
300-
}
301-
}
302-
words
303-
}
304-
305-
}
306-
// scalastyle:on println

0 commit comments

Comments
 (0)