-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19826][ML][PYTHON]add spark.ml Python API for PIC #21119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
53d7763
2d0e394
387d6ff
6d00f34
a6b1822
c25d3dc
ae9f953
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
|
|
||
| from pyspark import since, keyword_only | ||
| from pyspark.ml.util import * | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, JavaWrapper | ||
| from pyspark.ml.param.shared import * | ||
| from pyspark.ml.common import inherit_doc | ||
|
|
||
|
|
@@ -1156,6 +1156,201 @@ def getKeepLastCheckpoint(self): | |
| return self.getOrDefault(self.keepLastCheckpoint) | ||
|
|
||
|
|
||
| class _PowerIterationClusteringParams(JavaParams, HasMaxIter, HasPredictionCol): | ||
|
||
| """ | ||
| Params for :py:attr:`PowerIterationClustering`. | ||
| .. versionadded:: 2.4.0 | ||
| """ | ||
|
|
||
| k = Param(Params._dummy(), "k", | ||
| "The number of clusters to create. Must be > 1.", | ||
| typeConverter=TypeConverters.toInt) | ||
| initMode = Param(Params._dummy(), "initMode", | ||
| "The initialization algorithm. This can be either " + | ||
| "'random' to use a random vector as vertex properties, or 'degree' to use " + | ||
| "a normalized sum of similarities with other vertices. Supported options: " + | ||
| "'random' and 'degree'.", | ||
| typeConverter=TypeConverters.toString) | ||
| idCol = Param(Params._dummy(), "idCol", | ||
| "Name of the input column for vertex IDs.", | ||
| typeConverter=TypeConverters.toString) | ||
| neighborsCol = Param(Params._dummy(), "neighborsCol", | ||
| "Name of the input column for neighbors in the adjacency list " + | ||
| "representation.", | ||
| typeConverter=TypeConverters.toString) | ||
| similaritiesCol = Param(Params._dummy(), "similaritiesCol", | ||
| "Name of the input column for non-negative weights (similarities) " + | ||
| "of edges between the vertex in `idCol` and each neighbor in " + | ||
| "`neighborsCol`", | ||
| typeConverter=TypeConverters.toString) | ||
|
|
||
| @since("2.4.0") | ||
| def getK(self): | ||
| """ | ||
| Gets the value of `k` | ||
|
||
| """ | ||
| return self.getOrDefault(self.k) | ||
|
|
||
| @since("2.4.0") | ||
| def getInitMode(self): | ||
| """ | ||
| Gets the value of `initMode` | ||
| """ | ||
| return self.getOrDefault(self.initMode) | ||
|
|
||
| @since("2.4.0") | ||
| def getIdCol(self): | ||
| """ | ||
| Gets the value of `idCol` | ||
| """ | ||
| return self.getOrDefault(self.idCol) | ||
|
|
||
| @since("2.4.0") | ||
| def getNeighborsCol(self): | ||
| """ | ||
| Gets the value of `neighborsCol` | ||
| """ | ||
| return self.getOrDefault(self.neighborsCol) | ||
|
|
||
| @since("2.4.0") | ||
| def getSimilaritiesCol(self): | ||
| """ | ||
| Gets the value of `similaritiesCol` | ||
| """ | ||
| return self.getOrDefault(self.binary) | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class PowerIterationClustering(JavaTransformer, _PowerIterationClusteringParams, JavaMLReadable, | ||
| JavaMLWritable): | ||
| """ | ||
| Model produced by [[PowerIterationClustering]]. | ||
|
||
| >>> from pyspark.sql.types import ArrayType, DoubleType, LongType, StructField, StructType | ||
| >>> import math | ||
| >>> def genCircle(r, n): | ||
| ... points = [] | ||
| ... for i in range(0, n): | ||
| ... theta = 2.0 * math.pi * i / n | ||
| ... points.append((r * math.cos(theta), r * math.sin(theta))) | ||
| ... return points | ||
| >>> def sim(x, y): | ||
| ... dist = (x[0] - y[0]) * (x[0] - y[0]) + (x[1] - y[1]) * (x[1] - y[1]) | ||
| ... return math.exp(-dist / 2.0) | ||
| >>> r1 = 1.0 | ||
| >>> n1 = 10 | ||
| >>> r2 = 4.0 | ||
| >>> n2 = 40 | ||
| >>> n = n1 + n2 | ||
| >>> points = genCircle(r1, n1) + genCircle(r2, n2) | ||
| >>> similarities = [] | ||
| >>> for i in range (1, n): | ||
| ... neighbor = [] | ||
| ... weight = [] | ||
| ... for j in range (i): | ||
| ... neighbor.append((long)(j)) | ||
| ... weight.append(sim(points[i], points[j])) | ||
| ... similarities.append([(long)(i), neighbor, weight]) | ||
|
||
| >>> rdd = sc.parallelize(similarities, 2) | ||
| >>> schema = StructType([StructField("id", LongType(), False), \ | ||
| StructField("neighbors", ArrayType(LongType(), False), True), \ | ||
| StructField("similarities", ArrayType(DoubleType(), False), True)]) | ||
| >>> df = spark.createDataFrame(rdd, schema) | ||
| >>> pic = PowerIterationClustering() | ||
| >>> result = pic.setK(2).setMaxIter(40).transform(df) | ||
| >>> predictions = sorted(set([(i[0], i[1]) for i in result.select(result.id, result.prediction) | ||
| ... .collect()]), key=lambda x: x[0]) | ||
| >>> predictions[0] | ||
| (1, 1) | ||
| >>> predictions[8] | ||
| (9, 1) | ||
| >>> predictions[9] | ||
| (10, 0) | ||
| >>> predictions[20] | ||
| (21, 0) | ||
| >>> predictions[48] | ||
| (49, 0) | ||
| >>> pic_path = temp_path + "/pic" | ||
| >>> pic.save(pic_path) | ||
| >>> pic2 = PowerIterationClustering.load(pic_path) | ||
| >>> pic2.getK() | ||
| 2 | ||
| >>> pic2.getMaxIter() | ||
| 40 | ||
| >>> pic3 = PowerIterationClustering(k=4, initMode="degree") | ||
| >>> pic3.getIdCol() | ||
| 'id' | ||
| >>> pic3.getK() | ||
| 4 | ||
| >>> pic3.getMaxIter() | ||
| 20 | ||
| >>> pic3.getInitMode() | ||
| 'degree' | ||
|
|
||
| .. versionadded:: 2.4.0 | ||
| """ | ||
| @keyword_only | ||
| def __init__(self, predictionCol="prediction", k=2, maxIter=20, initMode="random", | ||
| idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"): | ||
| """ | ||
| __init__(self, predictionCol="prediction", k=2, maxIter=20, initMode="random",\ | ||
| idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"): | ||
|
||
| """ | ||
| super(PowerIterationClustering, self).__init__() | ||
| self._java_obj = self._new_java_obj( | ||
| "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid) | ||
| self._setDefault(k=2, maxIter=20, initMode="random", idCol="id", neighborsCol="neighbors", | ||
| similaritiesCol="similarities") | ||
| kwargs = self._input_kwargs | ||
| self.setParams(**kwargs) | ||
|
|
||
| @keyword_only | ||
| @since("2.4.0") | ||
| def setParams(self, predictionCol="prediction", k=2, maxIter=20, initMode="random", | ||
| idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"): | ||
| """ | ||
| setParams(self, predictionCol="prediction", k=2, maxIter=20, initMode="random",\ | ||
| idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"): | ||
|
||
| Sets params for PowerIterationClustering. | ||
| """ | ||
| kwargs = self._input_kwargs | ||
| return self._set(**kwargs) | ||
|
|
||
| @since("2.4.0") | ||
| def setK(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`k`. | ||
| """ | ||
| return self._set(k=value) | ||
|
|
||
| @since("2.4.0") | ||
| def setInitMode(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`initMode`. | ||
| """ | ||
| return self._set(initMode=value) | ||
|
|
||
| @since("2.4.0") | ||
| def setIdCol(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`idCol`. | ||
| """ | ||
| return self._set(idCol=value) | ||
|
|
||
| @since("2.4.0") | ||
| def setNeighborsCol(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`neighborsCol. | ||
|
||
| """ | ||
| return self._set(neighborsCol=value) | ||
|
|
||
| @since("2.4.0") | ||
| def setSimilaritiesCol(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`similaritiesCol`. | ||
| """ | ||
| return self._set(similaritiesCol=value) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import doctest | ||
| import pyspark.ml.clustering | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!