Skip to content

Commit 182460f

Browse files
committed
add guide for naive Bayes
1 parent 137fd1d commit 182460f

1 file changed

Lines changed: 61 additions & 0 deletions

File tree

docs/mllib-naive-bayes.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
---
2+
layout: global
3+
title: MLlib - Classification and Regression - Naive Bayes
4+
---
5+
6+
Naive Bayes is a simple multiclass classification algorithm with the assumption of independence between every pair of features. Naive Bayes can be trained very efficiently. Within a single pass to the training data, it computes the conditional probability distribution of each feature given label, and then it applies Bayes' theorem to compute the conditional probability distribution of label given an observation and use it for prediction. For more details, please visit the wikipedia page [Naive Bayes classifier](http://en.wikipedia.org/wiki/Naive_Bayes_classifier).
7+
8+
In MLlib, we implemented multinomial naive Bayes, which is typically used for document classification. Within that context, each observation is a document, each feature represents a term, whose value is the frequency of the term. For its formulation, please visit the wikipedia page [Multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) or the section [Naive Bayes text classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) from the book Introduction to Information Retrieval. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$. For document classification, the input feature vectors are usually sparse. Please supply sparse vectors as input to take advantage of sparsity. Since the training data is only used once, it is not necessary to cache it.
9+
10+
## Interfaces and examples
11+
12+
<div class="codetabs">
13+
<div data-lang="scala" markdown="1">
14+
[NaiveBayes](api/mllib/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes an RDD of [LabeledPoint](api/mllib/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional smoothing parameter `lambda` as input, and output a [NaiveBayesModel](api/mllib/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction.
15+
16+
{% highlight scala %}
17+
import org.apache.spark.mllib.classification.NaiveBayes
18+
19+
val training: RDD[LabeledPoint] = ... // training set
20+
val test: RDD[LabeledPoint] = ... // test set
21+
22+
val model = NaiveBayes.train(training, lambda = 1.0)
23+
val prediction = model.predict(test.map(_.features))
24+
25+
val predictionAndLabel = prediction.zip(test.map(_.label))
26+
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
27+
{% endhighlight %}
28+
29+
</div>
30+
31+
<div data-lang="java" markdown="1">
32+
[NaiveBayes](api/mllib/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes a Scala RDD of [LabeledPoint](api/mllib/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional smoothing parameter `lambda` as input, and output a [NaiveBayesModel](api/mllib/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction.
33+
34+
{% highlight java %}
35+
import org.apache.spark.mllib.classification.NaiveBayes;
36+
37+
JavaRDD<LabeledPoint> training = ... // training set
38+
JavaRDD<LabeledPoint> test = ... // test set
39+
40+
NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
41+
42+
JavaRDD<Double> prediction = model.predict(test.map(new Function<LabeledPoint, Vector>() {
43+
public Vector call(LabeledPoint p) {
44+
return p.features();
45+
}
46+
})
47+
JavaPairRDD<Double, Double> predictionAndLabel =
48+
prediction.zip(test.map(new Function<LabeledPoint, Double>() {
49+
public Double call(LabeledPoint p) {
50+
return p.label();
51+
}
52+
})
53+
double accuracy = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
54+
public Boolean call(Tuple2<Double, Double> pl) {
55+
return pl._1() == pl._2();
56+
}
57+
}).count() / test.count()
58+
59+
{% endhighlight %}
60+
</div>
61+
</div>

0 commit comments

Comments
 (0)