1818package org .apache .spark .ml .clustering
1919
2020import org .apache .spark .annotation .{Experimental , Since }
21- import org .apache .spark .ml .Transformer
2221import org .apache .spark .ml .param ._
2322import org .apache .spark .ml .param .shared ._
2423import org .apache .spark .ml .util ._
2524import org .apache .spark .mllib .clustering .{PowerIterationClustering => MLlibPowerIterationClustering }
2625import org .apache .spark .rdd .RDD
2726import org .apache .spark .sql .{DataFrame , Dataset , Row }
28- import org .apache .spark .sql .functions .col
27+ import org .apache .spark .sql .functions .{ col , lit }
2928import org .apache .spark .sql .types ._
3029
3130/**
3231 * Common params for PowerIterationClustering
3332 */
3433private [clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter
35- with HasPredictionCol {
34+ with HasWeightCol {
3635
3736 /**
3837 * The number of clusters to create (k). Must be > 1. Default: 2.
@@ -66,62 +65,33 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
6665 def getInitMode : String = $(initMode)
6766
6867 /**
69- * Param for the name of the input column for vertex IDs.
70- * Default: "id "
68+ * Param for the name of the input column for source vertex IDs.
69+ * Default: "src "
7170 * @group param
7271 */
7372 @ Since (" 2.4.0" )
74- val idCol = new Param [String ](this , " idCol " , " Name of the input column for vertex IDs." ,
73+ val srcCol = new Param [String ](this , " srcCol " , " Name of the input column for source vertex IDs." ,
7574 (value : String ) => value.nonEmpty)
7675
77- setDefault(idCol, " id" )
78-
79- /** @group getParam */
80- @ Since (" 2.4.0" )
81- def getIdCol : String = getOrDefault(idCol)
82-
83- /**
84- * Param for the name of the input column for neighbors in the adjacency list representation.
85- * Default: "neighbors"
86- * @group param
87- */
88- @ Since (" 2.4.0" )
89- val neighborsCol = new Param [String ](this , " neighborsCol" ,
90- " Name of the input column for neighbors in the adjacency list representation." ,
91- (value : String ) => value.nonEmpty)
92-
93- setDefault(neighborsCol, " neighbors" )
94-
9576 /** @group getParam */
9677 @ Since (" 2.4.0" )
97- def getNeighborsCol : String = $(neighborsCol )
78+ def getSrcCol : String = getOrDefault(srcCol )
9879
9980 /**
100- * Param for the name of the input column for neighbors in the adjacency list representation .
101- * Default: "similarities "
81+ * Name of the input column for destination vertex IDs .
82+ * Default: "dst "
10283 * @group param
10384 */
10485 @ Since (" 2.4.0" )
105- val similaritiesCol = new Param [String ](this , " similaritiesCol " ,
106- " Name of the input column for neighbors in the adjacency list representation ." ,
86+ val dstCol = new Param [String ](this , " dstCol " ,
87+ " Name of the input column for destination vertex IDs ." ,
10788 (value : String ) => value.nonEmpty)
10889
109- setDefault(similaritiesCol, " similarities" )
110-
11190 /** @group getParam */
11291 @ Since (" 2.4.0" )
113- def getSimilaritiesCol : String = $(similaritiesCol )
92+ def getDstCol : String = $(dstCol )
11493
115- protected def validateAndTransformSchema (schema : StructType ): StructType = {
116- SchemaUtils .checkColumnTypes(schema, $(idCol), Seq (IntegerType , LongType ))
117- SchemaUtils .checkColumnTypes(schema, $(neighborsCol),
118- Seq (ArrayType (IntegerType , containsNull = false ),
119- ArrayType (LongType , containsNull = false )))
120- SchemaUtils .checkColumnTypes(schema, $(similaritiesCol),
121- Seq (ArrayType (FloatType , containsNull = false ),
122- ArrayType (DoubleType , containsNull = false )))
123- SchemaUtils .appendColumn(schema, $(predictionCol), IntegerType )
124- }
94+ setDefault(srcCol -> " src" , dstCol -> " dst" )
12595}
12696
12797/**
@@ -131,21 +101,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
131101 * PIC finds a very low-dimensional embedding of a dataset using truncated power
132102 * iteration on a normalized pair-wise similarity matrix of the data.
133103 *
134- * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix
135- * is a symmetric matrix whose entries are non-negative similarities between items.
136- * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes:
137- * - `idCol`: vertex ID
138- * - `neighborsCol`: neighbors of vertex in `idCol`
139- * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex
140- * in `idCol` and each neighbor in `neighborsCol`
141- * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol`
142- * containing the cluster assignment in `[0,k)` for each row (vertex).
143- *
144- * Notes:
145- * - [[PowerIterationClustering ]] is a transformer with an expensive [[transform ]] operation.
146- * Transform runs the iterative PIC algorithm to cluster the whole input dataset.
147- * - Input validation: This validates that similarities are non-negative but does NOT validate
148- * that the input matrix is symmetric.
104+ * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the
105+ * PowerIterationClustering algorithm.
149106 *
150107 * @see <a href=http://en.wikipedia.org/wiki/Spectral_clustering>
151108 * Spectral clustering (Wikipedia)</a>
@@ -154,7 +111,7 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
154111@ Experimental
155112class PowerIterationClustering private [clustering] (
156113 @ Since (" 2.4.0" ) override val uid : String )
157- extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable {
114+ extends PowerIterationClusteringParams with DefaultParamsWritable {
158115
159116 setDefault(
160117 k -> 2 ,
@@ -164,10 +121,6 @@ class PowerIterationClustering private[clustering] (
164121 @ Since (" 2.4.0" )
165122 def this () = this (Identifiable .randomUID(" PowerIterationClustering" ))
166123
167- /** @group setParam */
168- @ Since (" 2.4.0" )
169- def setPredictionCol (value : String ): this .type = set(predictionCol, value)
170-
171124 /** @group setParam */
172125 @ Since (" 2.4.0" )
173126 def setK (value : Int ): this .type = set(k, value)
@@ -182,66 +135,56 @@ class PowerIterationClustering private[clustering] (
182135
183136 /** @group setParam */
184137 @ Since (" 2.4.0" )
185- def setIdCol (value : String ): this .type = set(idCol , value)
138+ def setSrcCol (value : String ): this .type = set(srcCol , value)
186139
187140 /** @group setParam */
188141 @ Since (" 2.4.0" )
189- def setNeighborsCol (value : String ): this .type = set(neighborsCol , value)
142+ def setDstCol (value : String ): this .type = set(dstCol , value)
190143
191144 /** @group setParam */
192145 @ Since (" 2.4.0" )
193- def setSimilaritiesCol (value : String ): this .type = set(similaritiesCol , value)
146+ def setWeightCol (value : String ): this .type = set(weightCol , value)
194147
148+ /**
149+ * Run the PIC algorithm and returns a cluster assignment for each input vertex.
150+ *
151+ * @param dataset A dataset with columns src, dst, weight representing the affinity matrix,
152+ * which is the matrix A in the PIC paper. Suppose the src column value is i,
153+ * the dst column value is j, the weight column value is similarity s,,ij,,
154+ * which must be nonnegative. This is a symmetric matrix and hence
155+ * s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be
156+ * either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are
157+ * ignored, because we assume s,,ij,, = 0.0.
158+ *
159+ * @return A dataset that contains columns of vertex id and the corresponding cluster for the id.
160+ * The schema of it will be:
161+ * - id: Long
162+ * - cluster: Int
163+ */
195164 @ Since (" 2.4.0" )
196- override def transform (dataset : Dataset [_]): DataFrame = {
197- transformSchema(dataset.schema, logging = true )
165+ def assignClusters (dataset : Dataset [_]): DataFrame = {
166+ val w = if (! isDefined(weightCol) || $(weightCol).isEmpty) {
167+ lit(1.0 )
168+ } else {
169+ col($(weightCol)).cast(DoubleType )
170+ }
198171
199- val sparkSession = dataset.sparkSession
200- val idColValue = $(idCol)
201- val rdd : RDD [(Long , Long , Double )] =
202- dataset.select(
203- col($(idCol)).cast(LongType ),
204- col($(neighborsCol)).cast(ArrayType (LongType , containsNull = false )),
205- col($(similaritiesCol)).cast(ArrayType (DoubleType , containsNull = false ))
206- ).rdd.flatMap {
207- case Row (id : Long , nbrs : Seq [_], sims : Seq [_]) =>
208- require(nbrs.size == sims.size, s " The length of the neighbor ID list must be " +
209- s " equal to the the length of the neighbor similarity list. Row for ID " +
210- s " $idColValue= $id has neighbor ID list of length ${nbrs.length} but similarity list " +
211- s " of length ${sims.length}. " )
212- nbrs.asInstanceOf [Seq [Long ]].zip(sims.asInstanceOf [Seq [Double ]]).map {
213- case (nbr, similarity) => (id, nbr, similarity)
214- }
215- }
172+ SchemaUtils .checkColumnTypes(dataset.schema, $(srcCol), Seq (IntegerType , LongType ))
173+ SchemaUtils .checkColumnTypes(dataset.schema, $(dstCol), Seq (IntegerType , LongType ))
174+ val rdd : RDD [(Long , Long , Double )] = dataset.select(
175+ col($(srcCol)).cast(LongType ),
176+ col($(dstCol)).cast(LongType ),
177+ w).rdd.map {
178+ case Row (src : Long , dst : Long , weight : Double ) => (src, dst, weight)
179+ }
216180 val algorithm = new MLlibPowerIterationClustering ()
217181 .setK($(k))
218182 .setInitializationMode($(initMode))
219183 .setMaxIterations($(maxIter))
220184 val model = algorithm.run(rdd)
221185
222- val predictionsRDD : RDD [Row ] = model.assignments.map { assignment =>
223- Row (assignment.id, assignment.cluster)
224- }
225-
226- val predictionsSchema = StructType (Seq (
227- StructField ($(idCol), LongType , nullable = false ),
228- StructField ($(predictionCol), IntegerType , nullable = false )))
229- val predictions = {
230- val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema)
231- dataset.schema($(idCol)).dataType match {
232- case _ : LongType =>
233- uncastPredictions
234- case otherType =>
235- uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol)))
236- }
237- }
238-
239- dataset.join(predictions, $(idCol))
240- }
241-
242- @ Since (" 2.4.0" )
243- override def transformSchema (schema : StructType ): StructType = {
244- validateAndTransformSchema(schema)
186+ import dataset .sparkSession .implicits ._
187+ model.assignments.toDF
245188 }
246189
247190 @ Since (" 2.4.0" )
0 commit comments