-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21306][ML] OneVsRest should support setWeightCol #18554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
f178595
a3e6e96
931d02d
25d681f
c380ba7
e511b90
1c215f3
2d3ce90
ea92228
57fc4b3
00a7ed8
a57f096
54e0fca
9ba0e2b
db303a0
8c0beba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1447,7 +1447,7 @@ def weights(self): | |
| return self._call_java("weights") | ||
|
|
||
|
|
||
| class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol): | ||
| class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol): | ||
| """ | ||
| Parameters for OneVsRest and OneVsRestModel. | ||
| """ | ||
|
|
@@ -1517,20 +1517,22 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): | |
|
|
||
| @keyword_only | ||
| def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", | ||
| classifier=None): | ||
| weightCol=None, classifier=None): | ||
|
||
| """ | ||
| __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ | ||
| classifier=None) | ||
| weightCol=None, classifier=None) | ||
| """ | ||
| super(OneVsRest, self).__init__() | ||
| kwargs = self._input_kwargs | ||
| self._set(**kwargs) | ||
|
|
||
| @keyword_only | ||
| @since("2.0.0") | ||
| def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): | ||
| def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, | ||
| weightCol=None, classifier=None): | ||
| """ | ||
| setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): | ||
| setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \ | ||
| weightCol=None, classifier=None): | ||
| Sets params for OneVsRest. | ||
| """ | ||
| kwargs = self._input_kwargs | ||
|
|
@@ -1546,7 +1548,18 @@ def _fit(self, dataset): | |
|
|
||
| numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 | ||
|
|
||
| multiclassLabeled = dataset.select(labelCol, featuresCol) | ||
| weightCol = None | ||
| if (self.isDefined(self.weightCol) and self.getWeightCol()): | ||
| if isinstance(classifier, HasWeightCol): | ||
| weightCol = self.getWeightCol() | ||
| else: | ||
| warnings.warn("weightCol is ignored, " | ||
| "as it is not supported by {} now.".format(classifier)) | ||
|
|
||
| if weightCol: | ||
| multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol) | ||
| else: | ||
| multiclassLabeled = dataset.select(labelCol, featuresCol) | ||
|
|
||
| # persist if underlying dataset is not persistent. | ||
| handlePersistence = \ | ||
|
|
@@ -1562,6 +1575,8 @@ def trainSingleClass(index): | |
| paramMap = dict([(classifier.labelCol, binaryLabelCol), | ||
| (classifier.featuresCol, featuresCol), | ||
| (classifier.predictionCol, predictionCol)]) | ||
| if weightCol: | ||
| paramMap[classifier.weightCol] = weightCol | ||
| return classifier.fit(trainingDataset, paramMap) | ||
|
|
||
| # TODO: Parallel training for all classes. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1255,6 +1255,17 @@ def test_output_columns(self): | |
| output = model.transform(df) | ||
| self.assertEqual(output.columns, ["label", "features", "prediction"]) | ||
|
|
||
| def test_support_for_weightCol(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to also test with a classifier that doesn't have a weight col?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. Use DecisionTreeClassifier to test. |
||
| df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), | ||
| (1.0, Vectors.sparse(2, [], []), 1.0), | ||
| (2.0, Vectors.dense(0.5, 0.5), 1.0)], | ||
| ["label", "features", "weight"]) | ||
| lr = LogisticRegression(maxIter=5, regParam=0.01) | ||
| ovr = OneVsRest(classifier=lr, weightCol="weight") | ||
| self.assertIsNotNone(ovr.fit(df)) | ||
| ovr2 = OneVsRest(classifier=lr).setWeightCol("weight") | ||
|
||
| self.assertIsNotNone(ovr2.fit(df)) | ||
|
|
||
|
|
||
| class HashingTFTest(SparkSessionTestCase): | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OneVsRestis a classification estimator, I think we should makeweightCola member param of it likefeaturesCol. For example:The features column used by
OneVsRestisa. The features column set forOneVsRestwill override corresponding set inOneVsRest.classifier. So we should follow this way forweightColas well. Thanks.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @yanboliang . As @MLnick said, not all classifiers inherits HasWeightCol, so it might cause confusion.
In my opinion,
setWeightColis an attribute owned by one specific classifier itself, likesetProbabilityColandsetRawPredictionColfor Logistic Regreesion. So I'd suggest that user should configure the classifier itself, rather than OneVsRest.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@facaiy It doesn't matter. If the classifier doesn't inherit from
HasWeightCol, we don't runsetWeightColfor that classifier but to print out warning log to sayweightColdoesn't take effect. You can refer these lines of code to learn howfeaturesColbe handled. We can do it in similar way. Thanks.