From f1785959e04e462877fd74e6c767e47e0fb207d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 6 Jul 2017 17:38:01 +0800 Subject: [PATCH 01/16] BUG: cache weightCol if necessary --- .../org/apache/spark/ml/classification/OneVsRest.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 7cbcccf2720a3..8a521435ff68e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -317,7 +318,12 @@ final class OneVsRest @Since("1.4.0") ( val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) instr.logNumClasses(numClasses) - val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) + val multiclassLabeled = getClassifier match { + // SPARK-21306: cache weightCol if necessary + case c: HasWeightCol if c.isDefined(c.weightCol) && !c.getWeightCol.isEmpty => + dataset.select($(labelCol), $(featuresCol), c.getWeightCol) + case _ => dataset.select($(labelCol), $(featuresCol)) + } // persist if underlying dataset is not persistent. val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE From a3e6e967a23f06dbd50aee08c34b07055c5646ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 6 Jul 2017 17:55:12 +0800 Subject: [PATCH 02/16] TST: add unit test --- .../org/apache/spark/ml/classification/OneVsRest.scala | 2 +- .../apache/spark/ml/classification/OneVsRestSuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 8a521435ff68e..6e77c2a4b3cb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -320,7 +320,7 @@ final class OneVsRest @Since("1.4.0") ( val multiclassLabeled = getClassifier match { // SPARK-21306: cache weightCol if necessary - case c: HasWeightCol if c.isDefined(c.weightCol) && !c.getWeightCol.isEmpty => + case c: HasWeightCol if c.isDefined(c.weightCol) && c.getWeightCol.nonEmpty => dataset.select($(labelCol), $(featuresCol), c.getWeightCol) case _ => dataset.select($(labelCol), $(featuresCol)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index c02e38ad64e3e..a227b5e293261 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -156,6 +156,14 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) } + test("SPARK-21306: OneVsRest should cache weightCol if necessary") { + val dataset2 = dataset.withColumn("weight", lit(1)) + val ova = new OneVsRest().setClassifier(new LogisticRegression().setWeightCol("weight")) + // run without any exception. + val ovaModel = ova.fit(dataset2) + assert(ovaModel !== null) + } + test("OneVsRest.copy and OneVsRestModel.copy") { val lr = new LogisticRegression() .setMaxIter(1) From 931d02d218eddeb829d621916636ca7411f38c64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 6 Jul 2017 18:03:08 +0800 Subject: [PATCH 03/16] TST: concise comment --- .../org/apache/spark/ml/classification/OneVsRestSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index a227b5e293261..daa101e94bf36 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -159,7 +159,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("SPARK-21306: OneVsRest should cache weightCol if necessary") { val dataset2 = dataset.withColumn("weight", lit(1)) val ova = new OneVsRest().setClassifier(new LogisticRegression().setWeightCol("weight")) - // run without any exception. + // failed if weightCol is not cached. val ovaModel = ova.fit(dataset2) assert(ovaModel !== null) } From 25d681f38ff670c8209dd1ae5c874e10ac89b402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 7 Jul 2017 12:15:45 +0800 Subject: [PATCH 04/16] ENH: python, cache weightCol --- python/pyspark/ml/classification.py | 5 ++++- python/pyspark/ml/tests.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 948806a5c936c..64c59c02c89ac 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1546,7 +1546,10 @@ def _fit(self, dataset): numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 - multiclassLabeled = dataset.select(labelCol, featuresCol) + if isinstance(classifier, HasWeightCol) and classifier.getWeightCol(): + multiclassLabeled = dataset.select(labelCol, featuresCol, classifier.getWeightCol()) + else: + multiclassLabeled = dataset.select(labelCol, featuresCol) # persist if underlying dataset is not persistent. handlePersistence = \ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 7870047651601..8d2b70b7ab993 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1255,6 +1255,16 @@ def test_output_columns(self): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) + def test_cache_weightCol_if_necessary(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), + (1.0, Vectors.sparse(2, [], []), 1.0), + (2.0, Vectors.dense(0.5, 0.5), 1.0)], + ["label", "features", "weight"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") + ovr = OneVsRest(classifier=lr) + model = ovr.fit(df) + self.assertIsNone(model) + class HashingTFTest(SparkSessionTestCase): From c380ba7f4b4bd4ee36795ae84f6f670ff73faf59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 7 Jul 2017 16:38:00 +0800 Subject: [PATCH 05/16] BUG: check weightCol defined --- python/pyspark/ml/classification.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 64c59c02c89ac..a2c5054d2a3eb 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1546,7 +1546,9 @@ def _fit(self, dataset): numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 - if isinstance(classifier, HasWeightCol) and classifier.getWeightCol(): + if (isinstance(classifier, HasWeightCol) and + classifier.isDefined(classifier.weightCol) and + classifier.getWeightCol()): multiclassLabeled = dataset.select(labelCol, featuresCol, classifier.getWeightCol()) else: multiclassLabeled = dataset.select(labelCol, featuresCol) From e511b905dbf089824d7e6e0d60c10e642ffc87e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 7 Jul 2017 16:41:54 +0800 Subject: [PATCH 06/16] TST: model is not None --- python/pyspark/ml/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 8d2b70b7ab993..aedd4f89ca9a6 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1263,7 +1263,7 @@ def test_cache_weightCol_if_necessary(self): lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") ovr = OneVsRest(classifier=lr) model = ovr.fit(df) - self.assertIsNone(model) + self.assertIsNotNone(model) class HashingTFTest(SparkSessionTestCase): From 1c215f3c682a8e13db7dd71803f36985fc20b8da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 7 Jul 2017 17:02:11 +0800 Subject: [PATCH 07/16] CLN: use if-indent style --- python/pyspark/ml/classification.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index a2c5054d2a3eb..e489857566108 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1546,9 +1546,9 @@ def _fit(self, dataset): numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 - if (isinstance(classifier, HasWeightCol) and - classifier.isDefined(classifier.weightCol) and - classifier.getWeightCol()): + if (isinstance(classifier, HasWeightCol) + and classifier.isDefined(classifier.weightCol) + and classifier.getWeightCol()): multiclassLabeled = dataset.select(labelCol, featuresCol, classifier.getWeightCol()) else: multiclassLabeled = dataset.select(labelCol, featuresCol) From 2d3ce90a314d84601623160a0f8f323d0cdadac2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 13 Jul 2017 14:05:06 +0800 Subject: [PATCH 08/16] ENH: add weightCol --- .../spark/ml/classification/OneVsRest.scala | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 6e77c2a4b3cb1..d6b3d77b0cc1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -54,7 +54,8 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait { +private[ml] trait OneVsRestParams extends PredictorParams + with ClassifierTypeTrait with HasWeightCol { /** * param for the base binary classifier that we reduce multiclass classification into. @@ -295,6 +296,10 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.3.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) @@ -318,11 +323,20 @@ final class OneVsRest @Since("1.4.0") ( val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) instr.logNumClasses(numClasses) - val multiclassLabeled = getClassifier match { - // SPARK-21306: cache weightCol if necessary - case c: HasWeightCol if c.isDefined(c.weightCol) && c.getWeightCol.nonEmpty => - dataset.select($(labelCol), $(featuresCol), c.getWeightCol) - case _ => dataset.select($(labelCol), $(featuresCol)) + // SPARK-21306: cache weightCol if necessary + val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && { + getClassifier match { + case _: HasWeightCol => true + case c => + logWarning(s"weightCol is ignored, as it is not supported by $c now.") + false + } + } + + val multiclassLabeled = if (weightColIsUsed) { + dataset.select($(labelCol), $(featuresCol), $(weightCol)) + } else { + dataset.select($(labelCol), $(featuresCol)) } // persist if underlying dataset is not persistent. @@ -343,7 +357,13 @@ final class OneVsRest @Since("1.4.0") ( paramMap.put(classifier.labelCol -> labelColName) paramMap.put(classifier.featuresCol -> getFeaturesCol) paramMap.put(classifier.predictionCol -> getPredictionCol) - classifier.fit(trainingDataset, paramMap) + if (weightColIsUsed) { + val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol] + paramMap.put(classifier_.weightCol -> getWeightCol) + classifier_.fit(trainingDataset, paramMap) + } else { + classifier.fit(trainingDataset, paramMap) + } }.toArray[ClassificationModel[_, _]] instr.logNumFeatures(models.head.numFeatures) From ea92228c1601a2d0178aa28d60ba440253bc7e71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 13 Jul 2017 14:36:43 +0800 Subject: [PATCH 09/16] TST: modify test case --- .../org/apache/spark/ml/classification/OneVsRestSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index daa101e94bf36..757368760c6e5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -158,7 +158,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("SPARK-21306: OneVsRest should cache weightCol if necessary") { val dataset2 = dataset.withColumn("weight", lit(1)) - val ova = new OneVsRest().setClassifier(new LogisticRegression().setWeightCol("weight")) + val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression()) // failed if weightCol is not cached. val ovaModel = ova.fit(dataset2) assert(ovaModel !== null) From 57fc4b35916d8b6293a8543a5a265990c8742caf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 13 Jul 2017 15:29:09 +0800 Subject: [PATCH 10/16] ENH: add weightCol for python part --- python/pyspark/ml/classification.py | 29 ++++++++++++++++++++--------- python/pyspark/ml/tests.py | 4 ++-- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index e489857566108..f4cb70019e256 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1447,7 +1447,7 @@ def weights(self): return self._call_java("weights") -class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol): +class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol): """ Parameters for OneVsRest and OneVsRestModel. """ @@ -1517,10 +1517,10 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - classifier=None): + weightCol=None, classifier=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - classifier=None) + weightCol=None, classifier=None) """ super(OneVsRest, self).__init__() kwargs = self._input_kwargs @@ -1528,9 +1528,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only @since("2.0.0") - def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): + def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, + weightCol=None, classifier=None): """ - setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): + setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \ + weightCol=None, classifier=None): Sets params for OneVsRest. """ kwargs = self._input_kwargs @@ -1546,10 +1548,17 @@ def _fit(self, dataset): numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 - if (isinstance(classifier, HasWeightCol) - and classifier.isDefined(classifier.weightCol) - and classifier.getWeightCol()): - multiclassLabeled = dataset.select(labelCol, featuresCol, classifier.getWeightCol()) + # SPARK - 21306: cache weightCol if necessary + weightCol = None + if (self.isDefined(self.weightCol) and self.getWeightCol()): + if isinstance(classifier, HasWeightCol): + weightCol = self.getWeightCol() + else: + warnings.warn("weightCol is ignored, " + "as it is not supported by {} now.".format(classifier)) + + if weightCol: + multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol) else: multiclassLabeled = dataset.select(labelCol, featuresCol) @@ -1567,6 +1576,8 @@ def trainSingleClass(index): paramMap = dict([(classifier.labelCol, binaryLabelCol), (classifier.featuresCol, featuresCol), (classifier.predictionCol, predictionCol)]) + if weightCol: + paramMap[classifier.weightCol] = weightCol return classifier.fit(trainingDataset, paramMap) # TODO: Parallel training for all classes. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index aedd4f89ca9a6..acd31377f794a 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1260,8 +1260,8 @@ def test_cache_weightCol_if_necessary(self): (1.0, Vectors.sparse(2, [], []), 1.0), (2.0, Vectors.dense(0.5, 0.5), 1.0)], ["label", "features", "weight"]) - lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") - ovr = OneVsRest(classifier=lr) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, weightCol="weight") model = ovr.fit(df) self.assertIsNotNone(model) From 00a7ed82cee5a9f2eb5f819e1177f480d9cb8f58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 13 Jul 2017 15:48:05 +0800 Subject: [PATCH 11/16] CLN: revise comment --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 1 - .../org/apache/spark/ml/classification/OneVsRestSuite.scala | 3 +-- python/pyspark/ml/classification.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index d6b3d77b0cc1c..08df7a6133b42 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -323,7 +323,6 @@ final class OneVsRest @Since("1.4.0") ( val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) instr.logNumClasses(numClasses) - // SPARK-21306: cache weightCol if necessary val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && { getClassifier match { case _: HasWeightCol => true diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 757368760c6e5..fdd6db3275aeb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -156,10 +156,9 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) } - test("SPARK-21306: OneVsRest should cache weightCol if necessary") { + test("SPARK-21306: OneVsRest should support setWeightCol") { val dataset2 = dataset.withColumn("weight", lit(1)) val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression()) - // failed if weightCol is not cached. val ovaModel = ova.fit(dataset2) assert(ovaModel !== null) } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f4cb70019e256..2f9ddd1d64fcf 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1548,7 +1548,6 @@ def _fit(self, dataset): numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 - # SPARK - 21306: cache weightCol if necessary weightCol = None if (self.isDefined(self.weightCol) and self.getWeightCol()): if isinstance(classifier, HasWeightCol): From a57f096eb34e57d6a72221f29d84b5ef0c296b9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 13 Jul 2017 15:48:24 +0800 Subject: [PATCH 12/16] TST: python test for setWeightCol --- python/pyspark/ml/tests.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index acd31377f794a..451f462b2b61a 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1255,15 +1255,16 @@ def test_output_columns(self): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) - def test_cache_weightCol_if_necessary(self): + def test_support_for_weightCol(self): df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), (1.0, Vectors.sparse(2, [], []), 1.0), (2.0, Vectors.dense(0.5, 0.5), 1.0)], ["label", "features", "weight"]) lr = LogisticRegression(maxIter=5, regParam=0.01) ovr = OneVsRest(classifier=lr, weightCol="weight") - model = ovr.fit(df) - self.assertIsNotNone(model) + self.assertIsNotNone(ovr.fit(df)) + ovr2 = OneVsRest(classifier=lr).setWeightCol("weight") + self.assertIsNotNone(ovr2.fit(df)) class HashingTFTest(SparkSessionTestCase): From 54e0fcae245259bbec1294a2d2397c21bb61ef70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 13 Jul 2017 16:20:32 +0800 Subject: [PATCH 13/16] DOC: add description --- .../org/apache/spark/ml/classification/OneVsRest.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 08df7a6133b42..05b8c3ab5456e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -296,7 +296,15 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) - /** @group setParam */ + /** + * Sets the value of param [[weightCol]]. + * + * This is ignored if weight is not supported by [[classifier]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is not set, so all instances have weight one. + * + * @group setParam + */ @Since("2.3.0") def setWeightCol(value: String): this.type = set(weightCol, value) From 9ba0e2be0b6bb0b8012c6003984c307f569dda1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 19 Jul 2017 13:55:44 +0800 Subject: [PATCH 14/16] TST: classifier doesn't have weightCol --- .../apache/spark/ml/classification/OneVsRestSuite.scala | 7 +++++-- python/pyspark/ml/tests.py | 7 +++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index fdd6db3275aeb..17f82827b74e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -158,9 +158,12 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("SPARK-21306: OneVsRest should support setWeightCol") { val dataset2 = dataset.withColumn("weight", lit(1)) + // classifier inherits hasWeightCol val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression()) - val ovaModel = ova.fit(dataset2) - assert(ovaModel !== null) + assert(ova.fit(dataset2) !== null) + // classifier doesn't inherit hasWeightCol + val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier()) + assert(ova2.fit(dataset2) !== null) } test("OneVsRest.copy and OneVsRestModel.copy") { diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 451f462b2b61a..78de674b84edd 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1260,11 +1260,18 @@ def test_support_for_weightCol(self): (1.0, Vectors.sparse(2, [], []), 1.0), (2.0, Vectors.dense(0.5, 0.5), 1.0)], ["label", "features", "weight"]) + # classifier inherits hasWeightCol lr = LogisticRegression(maxIter=5, regParam=0.01) ovr = OneVsRest(classifier=lr, weightCol="weight") self.assertIsNotNone(ovr.fit(df)) ovr2 = OneVsRest(classifier=lr).setWeightCol("weight") self.assertIsNotNone(ovr2.fit(df)) + # classifier doesn't inherit hasWeightCol + dt = DecisionTreeClassifier() + ovr3 = OneVsRest(classifier=dt, weightCol="weight") + self.assertIsNotNone(ovr3.fit(df)) + ovr4 = OneVsRest(classifier=dt).setWeightCol("weight") + self.assertIsNotNone(ovr4.fit(df)) class HashingTFTest(SparkSessionTestCase): From db303a0d1a5352d0c5f945a9506d8fe9bf2db678 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 26 Jul 2017 20:49:00 +0800 Subject: [PATCH 15/16] CLN: mv weightCol to the end --- python/pyspark/ml/classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 2f9ddd1d64fcf..4aa2a0b16eb01 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1517,10 +1517,10 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - weightCol=None, classifier=None): + classifier=None, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - weightCol=None, classifier=None) + classifier=None, weightCol=None) """ super(OneVsRest, self).__init__() kwargs = self._input_kwargs @@ -1529,10 +1529,10 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only @since("2.0.0") def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, - weightCol=None, classifier=None): + classifier=None, weightCol=None): """ setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \ - weightCol=None, classifier=None): + classifier=None, weightCol=None): Sets params for OneVsRest. """ kwargs = self._input_kwargs From 8c0beba4fdc29f7a78c7933e49b2fdf58995f8e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 26 Jul 2017 20:50:55 +0800 Subject: [PATCH 16/16] TST: remove ovr2 nad ovr4 --- python/pyspark/ml/tests.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 78de674b84edd..37f1e98dfe787 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1264,14 +1264,10 @@ def test_support_for_weightCol(self): lr = LogisticRegression(maxIter=5, regParam=0.01) ovr = OneVsRest(classifier=lr, weightCol="weight") self.assertIsNotNone(ovr.fit(df)) - ovr2 = OneVsRest(classifier=lr).setWeightCol("weight") - self.assertIsNotNone(ovr2.fit(df)) # classifier doesn't inherit hasWeightCol dt = DecisionTreeClassifier() - ovr3 = OneVsRest(classifier=dt, weightCol="weight") - self.assertIsNotNone(ovr3.fit(df)) - ovr4 = OneVsRest(classifier=dt).setWeightCol("weight") - self.assertIsNotNone(ovr4.fit(df)) + ovr2 = OneVsRest(classifier=dt, weightCol="weight") + self.assertIsNotNone(ovr2.fit(df)) class HashingTFTest(SparkSessionTestCase):