Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -317,7 +318,12 @@ final class OneVsRest @Since("1.4.0") (
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
instr.logNumClasses(numClasses)

val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
val multiclassLabeled = getClassifier match {
// SPARK-21306: cache weightCol if necessary
case c: HasWeightCol if c.isDefined(c.weightCol) && c.getWeightCol.nonEmpty =>
dataset.select($(labelCol), $(featuresCol), c.getWeightCol)
case _ => dataset.select($(labelCol), $(featuresCol))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OneVsRest is a classification estimator, I think we should make weightCol a member param of it like featuresCol. For example:

val dataset = dataset       // This dataset has column: a, b, c.
val ova = new OneVsRest().setFeaturesCol("a").setClassifier(new LogisticRegression().setFeaturesCol("b"))

The features column used by OneVsRest is a. The features column set for OneVsRest will override corresponding set in OneVsRest.classifier. So we should follow this way for weightCol as well. Thanks.

Copy link
Contributor Author

@facaiy facaiy Jul 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @yanboliang . As @MLnick said, not all classifiers inherits HasWeightCol, so it might cause confusion.

In my opinion, setWeightCol is an attribute owned by one specific classifier itself, like setProbabilityCol and setRawPredictionCol for Logistic Regreesion. So I'd suggest that user should configure the classifier itself, rather than OneVsRest.

Copy link
Contributor

@yanboliang yanboliang Jul 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@facaiy It doesn't matter. If the classifier doesn't inherit from HasWeightCol, we don't run setWeightCol for that classifier but to print out warning log to say weightCol doesn't take effect. You can refer these lines of code to learn how featuresCol be handled. We can do it in similar way. Thanks.


// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}

test("SPARK-21306: OneVsRest should cache weightCol if necessary") {
val dataset2 = dataset.withColumn("weight", lit(1))
val ova = new OneVsRest().setClassifier(new LogisticRegression().setWeightCol("weight"))
// failed if weightCol is not cached.
val ovaModel = ova.fit(dataset2)
assert(ovaModel !== null)
}

test("OneVsRest.copy and OneVsRestModel.copy") {
val lr = new LogisticRegression()
.setMaxIter(1)
Expand Down
7 changes: 6 additions & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,12 @@ def _fit(self, dataset):

numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1

multiclassLabeled = dataset.select(labelCol, featuresCol)
if (isinstance(classifier, HasWeightCol)
and classifier.isDefined(classifier.weightCol)
and classifier.getWeightCol()):
multiclassLabeled = dataset.select(labelCol, featuresCol, classifier.getWeightCol())
else:
multiclassLabeled = dataset.select(labelCol, featuresCol)

# persist if underlying dataset is not persistent.
handlePersistence = \
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,16 @@ def test_output_columns(self):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])

def test_cache_weightCol_if_necessary(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
(1.0, Vectors.sparse(2, [], []), 1.0),
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
["label", "features", "weight"])
lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
ovr = OneVsRest(classifier=lr)
model = ovr.fit(df)
self.assertIsNotNone(model)


class HashingTFTest(SparkSessionTestCase):

Expand Down