-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-32092][ML][PySpark] Fix parameters not being copied in CrossValidatorModel.copy(), read() and write() #29445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
c745198
202ffcf
fe5837d
c3a8407
850d307
b662ee0
a7a9163
b831161
e7d79be
ba994fd
fcfac36
1c98218
8ae74d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -536,7 +536,7 @@ def copy(self, extra=None): | |
| bestModel = self.bestModel.copy(extra) | ||
| avgMetrics = self.avgMetrics | ||
| subModels = self.subModels | ||
| return CrossValidatorModel(bestModel, avgMetrics, subModels) | ||
| return self._copyValues(CrossValidatorModel(bestModel, avgMetrics, subModels), extra=extra) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can. However I just found one potential issue with using cvModelCopied = cvModel.copy()
cvModel.avgMetrics[0] = 'foo'
assert cvModelCopied.avgMetrics[0] != 'foo' # This will failBased on the Scala equivalent I think
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You meant avgMetrics should be or should not be shallow copied?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm? I think the test makes sure it isn't shallow copy but deep copy, isn't?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By shallow copy I mean
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I get your point above. You meant if we just shallow copy the model itself, reassigning of element in You want to shallow copy
No matter deep copy or shallow copy, I think reassigning
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed with the above. |
||
|
|
||
| @since("2.3.0") | ||
| def write(self): | ||
|
|
@@ -560,8 +560,17 @@ def _from_java(cls, java_stage): | |
| avgMetrics = _java2py(sc, java_stage.avgMetrics()) | ||
| estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) | ||
|
|
||
| py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics)._set(estimator=estimator) | ||
| py_stage = py_stage._set(estimatorParamMaps=epms)._set(evaluator=evaluator) | ||
| py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics) | ||
| params = { | ||
| "evaluator": evaluator, | ||
| "estimator": estimator, | ||
| "estimatorParamMaps": epms, | ||
| "numFolds": java_stage.getNumFolds(), | ||
| "foldCol": java_stage.getFoldCol(), | ||
| "seed": java_stage.getSeed(), | ||
| } | ||
| for param_name, param_val in params.items(): | ||
| py_stage = py_stage._set(**{param_name: param_val}) | ||
|
|
||
| if java_stage.hasSubModels(): | ||
| py_stage.subModels = [[JavaParams._from_java(sub_model) | ||
|
|
@@ -585,9 +594,16 @@ def _to_java(self): | |
| _py2java(sc, self.avgMetrics)) | ||
| estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() | ||
|
|
||
| _java_obj.set("evaluator", evaluator) | ||
| _java_obj.set("estimator", estimator) | ||
| _java_obj.set("estimatorParamMaps", epms) | ||
| params = { | ||
| "evaluator": evaluator, | ||
| "estimator": estimator, | ||
| "estimatorParamMaps": epms, | ||
| "numFolds": self.getNumFolds(), | ||
| "foldCol": self.getFoldCol(), | ||
| "seed": self.getSeed(), | ||
| } | ||
| for param_name, param_val in params.items(): | ||
| _java_obj.set(param_name, param_val) | ||
|
|
||
| if self.subModels is not None: | ||
| java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.