Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,17 @@ class BisectingKMeansModel private[ml] (
throw new SparkException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}

/**
* Evaluates the model on a test dataset.
*
* @param dataset Test dataset to evaluate model on.
*/
@Since("2.2.0")
def evaluate(dataset: Dataset[_]): BisectingKMeansSummary = {
val wssse = computeCost(dataset)
new BisectingKMeansSummary(transform(dataset), $(predictionCol), $(featuresCol), $(k), wssse)
}
}

object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
Expand Down Expand Up @@ -265,8 +276,9 @@ class BisectingKMeans @Since("2.0.0") (
.setSeed($(seed))
val parentModel = bkm.run(rdd)
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
val wssse = model.computeCost(dataset)
val summary = new BisectingKMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.transform(dataset), $(predictionCol), $(featuresCol), $(k), wssse)
model.setSummary(Some(summary))
instr.logSuccess(model)
model
Expand Down Expand Up @@ -295,11 +307,14 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] {
* @param predictionCol Name for column of predicted clusters in `predictions`.
* @param featuresCol Name for column of features in `predictions`.
* @param k Number of clusters.
* @param wssse Within Set Sum of Squared Error.
*/
@Since("2.1.0")
@Experimental
class BisectingKMeansSummary private[clustering] (
predictions: DataFrame,
predictionCol: String,
featuresCol: String,
k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k)
k: Int,
@Since("2.2.0") val wssse: Double)
extends ClusteringSummary(predictions, predictionCol, featuresCol, k)
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,53 @@ class GaussianMixtureModel private[ml] (
throw new RuntimeException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}

/**
* Return the total log-likelihood for this model on the given data.
*/
private[clustering] def computeLogLikelihood(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
val spark = dataset.sparkSession
import spark.implicits._

val bcWeightAndDists = spark.sparkContext.broadcast(weights.zip(gaussians))
dataset.select(col($(featuresCol))).map {
case Row(feature: Vector) =>
val likelihood = bcWeightAndDists.value.map {
case (weight, dist) => EPSILON + weight * dist.pdf(feature)
}.sum
math.log(likelihood)
}.reduce(_ + _)
}

/**
* If the probability column is set returns the current model and probability column,
* otherwise generates a new column and sets it as the probability column on a new copy
* of the current model.
*/
private[clustering] def findSummaryModelAndProbabilityCol():
(GaussianMixtureModel, String) = {
$(probabilityCol) match {
case "" =>
val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
case p => (this, p)
}
}

/**
* Evaluates the model on a test dataset.
*
* @param dataset Test dataset to evaluate model on.
*/
@Since("2.2.0")
def evaluate(dataset: Dataset[_]): GaussianMixtureSummary = {
// Handle possible missing or invalid prediction columns
val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
val loglikelihood = computeLogLikelihood(dataset)
new GaussianMixtureSummary(summaryModel.transform(dataset), $(predictionCol),
probabilityColName, $(featuresCol), $(k), loglikelihood)
}
}

@Since("2.0.0")
Expand Down
19 changes: 17 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ class KMeansModel private[ml] (
throw new SparkException(
s"No training summary available for the ${this.getClass.getSimpleName}")
}

/**
* Evaluates the model on a test dataset.
*
* @param dataset Test dataset to evaluate model on.
*/
@Since("2.2.0")
def evaluate(dataset: Dataset[_]): KMeansSummary = {
val wssse = computeCost(dataset)
new KMeansSummary(transform(dataset), $(predictionCol), $(featuresCol), $(k), wssse)
}
}

@Since("1.6.0")
Expand Down Expand Up @@ -324,8 +335,9 @@ class KMeans @Since("1.5.0") (
.setEpsilon($(tol))
val parentModel = algo.run(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val wssse = model.computeCost(dataset)
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.transform(dataset), $(predictionCol), $(featuresCol), $(k), wssse)

model.setSummary(Some(summary))
instr.logSuccess(model)
Expand Down Expand Up @@ -356,11 +368,14 @@ object KMeans extends DefaultParamsReadable[KMeans] {
* @param predictionCol Name for column of predicted clusters in `predictions`.
* @param featuresCol Name for column of features in `predictions`.
* @param k Number of clusters.
* @param wssse Within Set Sum of Squared Error.
*/
@Since("2.0.0")
@Experimental
class KMeansSummary private[clustering] (
predictions: DataFrame,
predictionCol: String,
featuresCol: String,
k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k)
k: Int,
@Since("2.2.0") val wssse: Double)
extends ClusteringSummary(predictions, predictionCol, featuresCol, k)
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ class BisectingKMeansSuite
testEstimatorAndModelReadWrite(
bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
}

test("evaluate on test set") {
val bkm = new BisectingKMeans().setK(k).setSeed(1)
val model = bkm.fit(dataset)
val summary = model.summary
val sameSummary = model.evaluate(dataset)
assert(summary.wssse === sameSummary.wssse)
}
}

object BisectingKMeansSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues)
assert(symmetricMatrix === expectedMatrix)
}

test("evaluate on test set") {
val gm = new GaussianMixture().setK(k).setMaxIter(2).setSeed(1)
val model = gm.fit(dataset)
val summary = model.summary
val sameSummary = model.evaluate(dataset)
assert(summary.logLikelihood ~== sameSummary.logLikelihood absTol 2)
}
}

object GaussianMixtureSuite extends SparkFunSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
val kmeans = new KMeans()
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
}

test("evaluate on test set") {
val kmeans = new KMeans().setK(k).setSeed(1)
val model = kmeans.fit(dataset)
val summary = model.summary
val sameSummary = model.evaluate(dataset)
assert(summary.wssse === sameSummary.wssse)
}
}

object KMeansSuite {
Expand Down
6 changes: 5 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ object MimaExcludes {
// [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11")
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11"),

// [SPARK-19303][ML] Add evaluate method in clustering models
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this")
)

// Exclude rules for 2.1.x
Expand Down
50 changes: 48 additions & 2 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def summary(self):
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)

@since("2.2.0")
def evaluate(self, dataset):
"""
Evaluates the model on a test dataset.
"""
return GaussianMixtureSummary(self._call_java("evaluate", dataset))


@inherit_doc
class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
Expand Down Expand Up @@ -177,6 +184,9 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
[2, 2, 2]
>>> summary.logLikelihood
8.14636...
>>> same_summary = model.evaluate(df)
>>> abs(summary.logLikelihood - same_summary.logLikelihood) < 1e-3
True
>>> weights = model.weights
>>> len(weights)
3
Expand Down Expand Up @@ -300,7 +310,13 @@ class KMeansSummary(ClusteringSummary):

.. versionadded:: 2.1.0
"""
pass
@property
@since("2.2.0")
def wssse(self):
"""
Within Set Sum of Squared Error.
"""
return self._call_java("wssse")


class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
Expand Down Expand Up @@ -344,6 +360,13 @@ def summary(self):
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)

@since("2.2.0")
def evaluate(self, dataset):
"""
Evaluates the model on a test dataset.
"""
return KMeansSummary(self._call_java("evaluate", dataset))


@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
Expand Down Expand Up @@ -376,6 +399,11 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
2
>>> summary.clusterSizes
[2, 2]
>>> summary.wssse
2.000...
>>> same_summary = model.evaluate(df)
>>> abs(summary.wssse - same_summary.wssse) < 1e-3
True
>>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
Expand Down Expand Up @@ -517,6 +545,13 @@ def summary(self):
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)

@since("2.2.0")
def evaluate(self, dataset):
"""
Evaluates the model on a test dataset.
"""
return BisectingKMeansSummary(self._call_java("evaluate", dataset))


@inherit_doc
class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed,
Expand Down Expand Up @@ -549,6 +584,11 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
2
>>> summary.clusterSizes
[2, 2]
>>> summary.wssse
2.000...
>>> same_summary = model.evaluate(df)
>>> abs(summary.wssse - same_summary.wssse) < 1e-3
True
>>> transformed = model.transform(df).select("features", "prediction")
>>> rows = transformed.collect()
>>> rows[0].prediction == rows[1].prediction
Expand Down Expand Up @@ -646,7 +686,13 @@ class BisectingKMeansSummary(ClusteringSummary):

.. versionadded:: 2.1.0
"""
pass
@property
@since("2.2.0")
def wssse(self):
"""
Within Set Sum of Squared Error.
"""
return self._call_java("wssse")


@inherit_doc
Expand Down