|
18 | 18 | // scalastyle:off println |
19 | 19 | package org.apache.spark.examples.mllib |
20 | 20 |
|
21 | | -import java.text.BreakIterator |
22 | | - |
23 | | -import scala.collection.mutable |
| 21 | +import org.apache.spark.ml.Pipeline |
| 22 | +import org.apache.spark.ml.feature.{CountVectorizerModel, CountVectorizer, StopWordsRemover, RegexTokenizer} |
| 23 | +import org.apache.spark.sql.{Row, SQLContext} |
24 | 24 |
|
25 | 25 | import scopt.OptionParser |
26 | 26 |
|
@@ -118,7 +118,7 @@ object LDAExample { |
118 | 118 | // Load documents, and prepare them for LDA. |
119 | 119 | val preprocessStart = System.nanoTime() |
120 | 120 | val (corpus, vocabArray, actualNumTokens) = |
121 | | - preprocess(sc, params.input, params.vocabSize, params.stopwordFile) |
| 121 | + preProcess(sc, params.input, params.vocabSize, params.stopwordFile) |
122 | 122 | corpus.cache() |
123 | 123 | val actualCorpusSize = corpus.count() |
124 | 124 | val actualVocabSize = vocabArray.size |
@@ -186,121 +186,52 @@ object LDAExample { |
186 | 186 | * Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors. |
187 | 187 | * @return (corpus, vocabulary as array, total token count in corpus) |
188 | 188 | */ |
189 | | - private def preprocess( |
| 189 | + private def preProcess( |
190 | 190 | sc: SparkContext, |
191 | 191 | paths: Seq[String], |
192 | 192 | vocabSize: Int, |
193 | | - stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { |
| 193 | + stopWordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { |
194 | 194 |
|
195 | 195 | // Get dataset of document texts |
196 | 196 | // One document per line in each text file. If the input consists of many small files, |
197 | 197 | // this can result in a large number of small partitions, which can degrade performance. |
198 | 198 | // In this case, consider using coalesce() to create fewer, larger partitions. |
199 | 199 | val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) |
200 | | - |
201 | | - // Split text into words |
202 | | - val tokenizer = new SimpleTokenizer(sc, stopwordFile) |
203 | | - val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => |
204 | | - id -> tokenizer.getWords(text) |
205 | | - } |
206 | | - tokenized.cache() |
207 | | - |
208 | | - // Counts words: RDD[(word, wordCount)] |
209 | | - val wordCounts: RDD[(String, Long)] = tokenized |
210 | | - .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } |
211 | | - .reduceByKey(_ + _) |
212 | | - wordCounts.cache() |
213 | | - val fullVocabSize = wordCounts.count() |
214 | | - // Select vocab |
215 | | - // (vocab: Map[word -> id], total tokens after selecting vocab) |
216 | | - val (vocab: Map[String, Int], selectedTokenCount: Long) = { |
217 | | - val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { |
218 | | - // Use all terms |
219 | | - wordCounts.collect().sortBy(-_._2) |
220 | | - } else { |
221 | | - // Sort terms to select vocab |
222 | | - wordCounts.sortBy(_._2, ascending = false).take(vocabSize) |
223 | | - } |
224 | | - (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) |
| 200 | + val sqlContext = new SQLContext(sc) |
| 201 | + import sqlContext.implicits._ |
| 202 | + |
| 203 | + val df = textRDD.toDF("texts") |
| 204 | + val customizedStopWords: Array[String] = if (stopWordFile.isEmpty) { |
| 205 | + Array.empty[String] |
| 206 | + } else { |
| 207 | + val stopWordText = sc.textFile(stopWordFile).collect() |
| 208 | + stopWordText.flatMap(_.stripMargin.split("\\s+")) |
225 | 209 | } |
226 | | - |
227 | | - val documents = tokenized.map { case (id, tokens) => |
228 | | - // Filter tokens by vocabulary, and create word count vector representation of document. |
229 | | - val wc = new mutable.HashMap[Int, Int]() |
230 | | - tokens.foreach { term => |
231 | | - if (vocab.contains(term)) { |
232 | | - val termIndex = vocab(term) |
233 | | - wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 |
234 | | - } |
235 | | - } |
236 | | - val indices = wc.keys.toArray.sorted |
237 | | - val values = indices.map(i => wc(i).toDouble) |
238 | | - |
239 | | - val sb = Vectors.sparse(vocab.size, indices, values) |
240 | | - (id, sb) |
241 | | - } |
242 | | - |
243 | | - val vocabArray = new Array[String](vocab.size) |
244 | | - vocab.foreach { case (term, i) => vocabArray(i) = term } |
245 | | - |
246 | | - (documents, vocabArray, selectedTokenCount) |
| 210 | + val tokenizer = new RegexTokenizer() |
| 211 | + .setInputCol("texts") |
| 212 | + .setOutputCol("rawTokens") |
| 213 | + val stopWordsRemover = new StopWordsRemover() |
| 214 | + .setInputCol("rawTokens") |
| 215 | + .setOutputCol("tokens") |
| 216 | + stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords) |
| 217 | + val countVectorizer = new CountVectorizer() |
| 218 | + .setVocabSize(vocabSize) |
| 219 | + .setInputCol("tokens") |
| 220 | + .setOutputCol("vectors") |
| 221 | + |
| 222 | + val pipeline = new Pipeline() |
| 223 | + .setStages(Array(tokenizer, stopWordsRemover, countVectorizer)) |
| 224 | + |
| 225 | + val model = pipeline.fit(df) |
| 226 | + val documents = model.transform(df) |
| 227 | + .select("vectors") |
| 228 | + .map { case Row(features: Vector) => features } |
| 229 | + .zipWithIndex() |
| 230 | + .map(_.swap) |
| 231 | + |
| 232 | + (documents, |
| 233 | + model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary |
| 234 | + documents.map(_._2.numActives).sum().toLong) // total token count |
247 | 235 | } |
248 | 236 | } |
249 | 237 |
|
250 | | -/** |
251 | | - * Simple Tokenizer. |
252 | | - * |
253 | | - * TODO: Formalize the interface, and make this a public class in mllib.feature |
254 | | - */ |
255 | | -private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { |
256 | | - |
257 | | - private val stopwords: Set[String] = if (stopwordFile.isEmpty) { |
258 | | - Set.empty[String] |
259 | | - } else { |
260 | | - val stopwordText = sc.textFile(stopwordFile).collect() |
261 | | - stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet |
262 | | - } |
263 | | - |
264 | | - // Matches sequences of Unicode letters |
265 | | - private val allWordRegex = "^(\\p{L}*)$".r |
266 | | - |
267 | | - // Ignore words shorter than this length. |
268 | | - private val minWordLength = 3 |
269 | | - |
270 | | - def getWords(text: String): IndexedSeq[String] = { |
271 | | - |
272 | | - val words = new mutable.ArrayBuffer[String]() |
273 | | - |
274 | | - // Use Java BreakIterator to tokenize text into words. |
275 | | - val wb = BreakIterator.getWordInstance |
276 | | - wb.setText(text) |
277 | | - |
278 | | - // current,end index start,end of each word |
279 | | - var current = wb.first() |
280 | | - var end = wb.next() |
281 | | - while (end != BreakIterator.DONE) { |
282 | | - // Convert to lowercase |
283 | | - val word: String = text.substring(current, end).toLowerCase |
284 | | - // Remove short words and strings that aren't only letters |
285 | | - word match { |
286 | | - case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => |
287 | | - words += w |
288 | | - case _ => |
289 | | - } |
290 | | - |
291 | | - current = end |
292 | | - try { |
293 | | - end = wb.next() |
294 | | - } catch { |
295 | | - case e: Exception => |
296 | | - // Ignore remaining text in line. |
297 | | - // This is a known bug in BreakIterator (for some Java versions), |
298 | | - // which fails when it sees certain characters. |
299 | | - end = BreakIterator.DONE |
300 | | - } |
301 | | - } |
302 | | - words |
303 | | - } |
304 | | - |
305 | | -} |
306 | | -// scalastyle:on println |
0 commit comments