From bc53ce897692d2dcd3341f5e8a2032ccf2629ea7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 16 Jul 2014 17:43:27 -0700 Subject: [PATCH] fix NaiveBayes --- python/pyspark/mllib/classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 19b90dfd6e16..f6c96e3e9f53 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -84,7 +84,7 @@ class NaiveBayesModel(object): - pi: vector of logs of class priors (dimension C) - theta: matrix of logs of class conditional probabilities (CxD) - >>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3) + >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 2.0, 1.0, 1.0]).reshape(3,3) >>> model = NaiveBayes.train(sc.parallelize(data)) >>> model.predict(array([0.0, 1.0])) 0 @@ -98,7 +98,7 @@ def __init__(self, pi, theta): def predict(self, x): """Return the most likely class for a data vector x""" - return numpy.argmax(self.pi + dot(x, self.theta)) + return numpy.argmax(self.pi + dot(x, self.theta.transpose())) class NaiveBayes(object): @classmethod