Skip to content

Commit e8c1a0c

Browse files
WeichenXu123mengxr
authored andcommitted
[SPARK-15784] Add Power Iteration Clustering to spark.ml
## What changes were proposed in this pull request? According to the discussion on JIRA. I rewrite the Power Iteration Clustering API in `spark.ml`. ## How was this patch tested? Unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu <[email protected]> Closes #21493 from WeichenXu123/pic_api.
1 parent b3417b7 commit e8c1a0c

2 files changed

Lines changed: 125 additions & 211 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala

Lines changed: 50 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,20 @@
1818
package org.apache.spark.ml.clustering
1919

2020
import org.apache.spark.annotation.{Experimental, Since}
21-
import org.apache.spark.ml.Transformer
2221
import org.apache.spark.ml.param._
2322
import org.apache.spark.ml.param.shared._
2423
import org.apache.spark.ml.util._
2524
import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering}
2625
import org.apache.spark.rdd.RDD
2726
import org.apache.spark.sql.{DataFrame, Dataset, Row}
28-
import org.apache.spark.sql.functions.col
27+
import org.apache.spark.sql.functions.{col, lit}
2928
import org.apache.spark.sql.types._
3029

3130
/**
3231
* Common params for PowerIterationClustering
3332
*/
3433
private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter
35-
with HasPredictionCol {
34+
with HasWeightCol {
3635

3736
/**
3837
* The number of clusters to create (k). Must be &gt; 1. Default: 2.
@@ -66,62 +65,33 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
6665
def getInitMode: String = $(initMode)
6766

6867
/**
69-
* Param for the name of the input column for vertex IDs.
70-
* Default: "id"
68+
* Param for the name of the input column for source vertex IDs.
69+
* Default: "src"
7170
* @group param
7271
*/
7372
@Since("2.4.0")
74-
val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.",
73+
val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.",
7574
(value: String) => value.nonEmpty)
7675

77-
setDefault(idCol, "id")
78-
79-
/** @group getParam */
80-
@Since("2.4.0")
81-
def getIdCol: String = getOrDefault(idCol)
82-
83-
/**
84-
* Param for the name of the input column for neighbors in the adjacency list representation.
85-
* Default: "neighbors"
86-
* @group param
87-
*/
88-
@Since("2.4.0")
89-
val neighborsCol = new Param[String](this, "neighborsCol",
90-
"Name of the input column for neighbors in the adjacency list representation.",
91-
(value: String) => value.nonEmpty)
92-
93-
setDefault(neighborsCol, "neighbors")
94-
9576
/** @group getParam */
9677
@Since("2.4.0")
97-
def getNeighborsCol: String = $(neighborsCol)
78+
def getSrcCol: String = getOrDefault(srcCol)
9879

9980
/**
100-
* Param for the name of the input column for neighbors in the adjacency list representation.
101-
* Default: "similarities"
81+
* Name of the input column for destination vertex IDs.
82+
* Default: "dst"
10283
* @group param
10384
*/
10485
@Since("2.4.0")
105-
val similaritiesCol = new Param[String](this, "similaritiesCol",
106-
"Name of the input column for neighbors in the adjacency list representation.",
86+
val dstCol = new Param[String](this, "dstCol",
87+
"Name of the input column for destination vertex IDs.",
10788
(value: String) => value.nonEmpty)
10889

109-
setDefault(similaritiesCol, "similarities")
110-
11190
/** @group getParam */
11291
@Since("2.4.0")
113-
def getSimilaritiesCol: String = $(similaritiesCol)
92+
def getDstCol: String = $(dstCol)
11493

115-
protected def validateAndTransformSchema(schema: StructType): StructType = {
116-
SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType))
117-
SchemaUtils.checkColumnTypes(schema, $(neighborsCol),
118-
Seq(ArrayType(IntegerType, containsNull = false),
119-
ArrayType(LongType, containsNull = false)))
120-
SchemaUtils.checkColumnTypes(schema, $(similaritiesCol),
121-
Seq(ArrayType(FloatType, containsNull = false),
122-
ArrayType(DoubleType, containsNull = false)))
123-
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
124-
}
94+
setDefault(srcCol -> "src", dstCol -> "dst")
12595
}
12696

12797
/**
@@ -131,21 +101,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
131101
* PIC finds a very low-dimensional embedding of a dataset using truncated power
132102
* iteration on a normalized pair-wise similarity matrix of the data.
133103
*
134-
* PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix
135-
* is a symmetric matrix whose entries are non-negative similarities between items.
136-
* PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes:
137-
* - `idCol`: vertex ID
138-
* - `neighborsCol`: neighbors of vertex in `idCol`
139-
* - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex
140-
* in `idCol` and each neighbor in `neighborsCol`
141-
* PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol`
142-
* containing the cluster assignment in `[0,k)` for each row (vertex).
143-
*
144-
* Notes:
145-
* - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation.
146-
* Transform runs the iterative PIC algorithm to cluster the whole input dataset.
147-
* - Input validation: This validates that similarities are non-negative but does NOT validate
148-
* that the input matrix is symmetric.
104+
* This class is not yet an Estimator/Transformer, use `assignClusters` method to run the
105+
* PowerIterationClustering algorithm.
149106
*
150107
* @see <a href=http://en.wikipedia.org/wiki/Spectral_clustering>
151108
* Spectral clustering (Wikipedia)</a>
@@ -154,7 +111,7 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
154111
@Experimental
155112
class PowerIterationClustering private[clustering] (
156113
@Since("2.4.0") override val uid: String)
157-
extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable {
114+
extends PowerIterationClusteringParams with DefaultParamsWritable {
158115

159116
setDefault(
160117
k -> 2,
@@ -164,10 +121,6 @@ class PowerIterationClustering private[clustering] (
164121
@Since("2.4.0")
165122
def this() = this(Identifiable.randomUID("PowerIterationClustering"))
166123

167-
/** @group setParam */
168-
@Since("2.4.0")
169-
def setPredictionCol(value: String): this.type = set(predictionCol, value)
170-
171124
/** @group setParam */
172125
@Since("2.4.0")
173126
def setK(value: Int): this.type = set(k, value)
@@ -182,66 +135,56 @@ class PowerIterationClustering private[clustering] (
182135

183136
/** @group setParam */
184137
@Since("2.4.0")
185-
def setIdCol(value: String): this.type = set(idCol, value)
138+
def setSrcCol(value: String): this.type = set(srcCol, value)
186139

187140
/** @group setParam */
188141
@Since("2.4.0")
189-
def setNeighborsCol(value: String): this.type = set(neighborsCol, value)
142+
def setDstCol(value: String): this.type = set(dstCol, value)
190143

191144
/** @group setParam */
192145
@Since("2.4.0")
193-
def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value)
146+
def setWeightCol(value: String): this.type = set(weightCol, value)
194147

148+
/**
149+
* Run the PIC algorithm and returns a cluster assignment for each input vertex.
150+
*
151+
* @param dataset A dataset with columns src, dst, weight representing the affinity matrix,
152+
* which is the matrix A in the PIC paper. Suppose the src column value is i,
153+
* the dst column value is j, the weight column value is similarity s,,ij,,
154+
* which must be nonnegative. This is a symmetric matrix and hence
155+
* s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be
156+
* either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are
157+
* ignored, because we assume s,,ij,, = 0.0.
158+
*
159+
* @return A dataset that contains columns of vertex id and the corresponding cluster for the id.
160+
* The schema of it will be:
161+
* - id: Long
162+
* - cluster: Int
163+
*/
195164
@Since("2.4.0")
196-
override def transform(dataset: Dataset[_]): DataFrame = {
197-
transformSchema(dataset.schema, logging = true)
165+
def assignClusters(dataset: Dataset[_]): DataFrame = {
166+
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
167+
lit(1.0)
168+
} else {
169+
col($(weightCol)).cast(DoubleType)
170+
}
198171

199-
val sparkSession = dataset.sparkSession
200-
val idColValue = $(idCol)
201-
val rdd: RDD[(Long, Long, Double)] =
202-
dataset.select(
203-
col($(idCol)).cast(LongType),
204-
col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)),
205-
col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false))
206-
).rdd.flatMap {
207-
case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) =>
208-
require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " +
209-
s"equal to the the length of the neighbor similarity list. Row for ID " +
210-
s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " +
211-
s"of length ${sims.length}.")
212-
nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map {
213-
case (nbr, similarity) => (id, nbr, similarity)
214-
}
215-
}
172+
SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol), Seq(IntegerType, LongType))
173+
SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol), Seq(IntegerType, LongType))
174+
val rdd: RDD[(Long, Long, Double)] = dataset.select(
175+
col($(srcCol)).cast(LongType),
176+
col($(dstCol)).cast(LongType),
177+
w).rdd.map {
178+
case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight)
179+
}
216180
val algorithm = new MLlibPowerIterationClustering()
217181
.setK($(k))
218182
.setInitializationMode($(initMode))
219183
.setMaxIterations($(maxIter))
220184
val model = algorithm.run(rdd)
221185

222-
val predictionsRDD: RDD[Row] = model.assignments.map { assignment =>
223-
Row(assignment.id, assignment.cluster)
224-
}
225-
226-
val predictionsSchema = StructType(Seq(
227-
StructField($(idCol), LongType, nullable = false),
228-
StructField($(predictionCol), IntegerType, nullable = false)))
229-
val predictions = {
230-
val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema)
231-
dataset.schema($(idCol)).dataType match {
232-
case _: LongType =>
233-
uncastPredictions
234-
case otherType =>
235-
uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol)))
236-
}
237-
}
238-
239-
dataset.join(predictions, $(idCol))
240-
}
241-
242-
@Since("2.4.0")
243-
override def transformSchema(schema: StructType): StructType = {
244-
validateAndTransformSchema(schema)
186+
import dataset.sparkSession.implicits._
187+
model.assignments.toDF
245188
}
246189

247190
@Since("2.4.0")

0 commit comments

Comments
 (0)