Skip to content
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
private[ml] trait OneVsRestParams extends PredictorParams
with ClassifierTypeTrait with HasWeightCol {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand Down Expand Up @@ -294,6 +296,18 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/**
* Sets the value of param [[weightCol]].
*
* This is ignored if weight is not supported by [[classifier]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("2.3.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
Expand All @@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") (
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
instr.logNumClasses(numClasses)

val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
getClassifier match {
case _: HasWeightCol => true
case c =>
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
false
}
}

val multiclassLabeled = if (weightColIsUsed) {
dataset.select($(labelCol), $(featuresCol), $(weightCol))
} else {
dataset.select($(labelCol), $(featuresCol))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OneVsRest is a classification estimator, I think we should make weightCol a member param of it like featuresCol. For example:

val dataset = dataset       // This dataset has column: a, b, c.
val ova = new OneVsRest().setFeaturesCol("a").setClassifier(new LogisticRegression().setFeaturesCol("b"))

The features column used by OneVsRest is a. The features column set for OneVsRest will override corresponding set in OneVsRest.classifier. So we should follow this way for weightCol as well. Thanks.

Copy link
Contributor Author

@facaiy facaiy Jul 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @yanboliang . As @MLnick said, not all classifiers inherits HasWeightCol, so it might cause confusion.

In my opinion, setWeightCol is an attribute owned by one specific classifier itself, like setProbabilityCol and setRawPredictionCol for Logistic Regreesion. So I'd suggest that user should configure the classifier itself, rather than OneVsRest.

Copy link
Contributor

@yanboliang yanboliang Jul 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@facaiy It doesn't matter. If the classifier doesn't inherit from HasWeightCol, we don't run setWeightCol for that classifier but to print out warning log to say weightCol doesn't take effect. You can refer these lines of code to learn how featuresCol be handled. We can do it in similar way. Thanks.


// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
Expand All @@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
if (weightColIsUsed) {
val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
paramMap.put(classifier_.weightCol -> getWeightCol)
classifier_.fit(trainingDataset, paramMap)
} else {
classifier.fit(trainingDataset, paramMap)
}
}.toArray[ClassificationModel[_, _]]
instr.logNumFeatures(models.head.numFeatures)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}

test("SPARK-21306: OneVsRest should support setWeightCol") {
val dataset2 = dataset.withColumn("weight", lit(1))
val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression())
val ovaModel = ova.fit(dataset2)
assert(ovaModel !== null)
}

test("OneVsRest.copy and OneVsRestModel.copy") {
val lr = new LogisticRegression()
.setMaxIter(1)
Expand Down
27 changes: 21 additions & 6 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ def weights(self):
return self._call_java("weights")


class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol):
"""
Parameters for OneVsRest and OneVsRestModel.
"""
Expand Down Expand Up @@ -1517,20 +1517,22 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):

@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
classifier=None):
weightCol=None, classifier=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we append new argument at the end, as users may use non-keyword argument, we should not break users' existing code. FYI: https://docs.python.org/3/tutorial/controlflow.html#keyword-arguments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved.

"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
classifier=None)
weightCol=None, classifier=None)
"""
super(OneVsRest, self).__init__()
kwargs = self._input_kwargs
self._set(**kwargs)

@keyword_only
@since("2.0.0")
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
weightCol=None, classifier=None):
"""
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
weightCol=None, classifier=None):
Sets params for OneVsRest.
"""
kwargs = self._input_kwargs
Expand All @@ -1546,7 +1548,18 @@ def _fit(self, dataset):

numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1

multiclassLabeled = dataset.select(labelCol, featuresCol)
weightCol = None
if (self.isDefined(self.weightCol) and self.getWeightCol()):
if isinstance(classifier, HasWeightCol):
weightCol = self.getWeightCol()
else:
warnings.warn("weightCol is ignored, "
"as it is not supported by {} now.".format(classifier))

if weightCol:
multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
else:
multiclassLabeled = dataset.select(labelCol, featuresCol)

# persist if underlying dataset is not persistent.
handlePersistence = \
Expand All @@ -1562,6 +1575,8 @@ def trainSingleClass(index):
paramMap = dict([(classifier.labelCol, binaryLabelCol),
(classifier.featuresCol, featuresCol),
(classifier.predictionCol, predictionCol)])
if weightCol:
paramMap[classifier.weightCol] = weightCol
return classifier.fit(trainingDataset, paramMap)

# TODO: Parallel training for all classes.
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,17 @@ def test_output_columns(self):
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "prediction"])

def test_support_for_weightCol(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to also test with a classifier that doesn't have a weight col?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Use DecisionTreeClassifier to test.

df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
(1.0, Vectors.sparse(2, [], []), 1.0),
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
["label", "features", "weight"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, weightCol="weight")
self.assertIsNotNone(ovr.fit(df))
ovr2 = OneVsRest(classifier=lr).setWeightCol("weight")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: We can remove test of ovr2 and ovr4, setting param in different way will run the same code at backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned.

self.assertIsNotNone(ovr2.fit(df))


class HashingTFTest(SparkSessionTestCase):

Expand Down