diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 19b90dfd6e16..f6c96e3e9f53 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -84,7 +84,7 @@ class NaiveBayesModel(object): - pi: vector of logs of class priors (dimension C) - theta: matrix of logs of class conditional probabilities (CxD) - >>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3) + >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 2.0, 1.0, 1.0]).reshape(3,3) >>> model = NaiveBayes.train(sc.parallelize(data)) >>> model.predict(array([0.0, 1.0])) 0 @@ -98,7 +98,7 @@ def __init__(self, pi, theta): def predict(self, x): """Return the most likely class for a data vector x""" - return numpy.argmax(self.pi + dot(x, self.theta)) + return numpy.argmax(self.pi + dot(x, self.theta.transpose())) class NaiveBayes(object): @classmethod