-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19733][ML]Removed unnecessary castings and refactored checked casts in ALS. #17059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
98d6481
6e0ddf0
fc3acfb
eb33189
3050f6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,12 +82,20 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo | |
| * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is | ||
| * out of integer range. | ||
| */ | ||
| protected val checkedCast = udf { (n: Double) => | ||
| if (n > Int.MaxValue || n < Int.MinValue) { | ||
| throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + | ||
| s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") | ||
| } else { | ||
| n.toInt | ||
| protected val checkedCast = udf { (n: Any) => | ||
| n match { | ||
| case v: Int => v // Avoid unnecessary casting | ||
| case v: Number => | ||
| val intV = v.intValue() | ||
| // Checks if number within Int range and has no fractional part. | ||
| if (v.doubleValue == intV) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I'm not sure if this is a good idea due to floating point precision... the code above doesn't seem to do this check, it just calls toInt -- however, if this is absolutely necessary, I would hope that we could give the user some way to specify the Int range or precision. Also, if we are going to go ahead with this change, then we should add some tests to verify the case that the exception is thrown, but without some ability to specify the precision I'm not sure if this is the correct thing to do (?).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @imatiach-msft: In this snippet we deal with Uset id and Item id. Those things should no have fractional bits. What I do here is convert the number into Integer and compare its double value. If the values are identical one of the following is true:
I think this snippet is fine.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah the reality is we mostly expect Int or Long ids, but want to support any numeric ID as long as it falls within Int range - so passing in fractional float/double just doesn't make sense and we'll throw an error.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the test does work, but may be slightly more direct to check if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, if we for some reason do decide to go with the check against the fractional part, which I would not recommend, we should add a test to verify the error message is thrown, and possible add this change in accepted input format to the documentation, since it may break some users (they may have to do additional conversion to int for their values prior to calling this API if they run into such floating point precision issues).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @imatiach-msft I am aware of the implications of floating point precision and I understand your concerns. Having said that though, even allowing user and item Ids to be double/float is not a good idea. We just keep it for backwards compatibility I guess. Also note that the current implementation of Spark 2.1 will actually take that your 0.9999999999999996 value and silently cast it to Int (so it becomes 0)! For me the only permitted types should have been Integer, Long and BigIntegers. I don't have strong opinions about refactoring anything in the Number case as it simply performance-wise it does not matter. The point of this PR is to optimize the general case where the id is Int because the casting of the current approach generates twice as much data as the original dataset (of course it is GCed but at a cost).
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @datumbox I agree that calling toInt in the previous code was not the best decision either; if we really wanted to support double type correctly we would have probably done the check with some precision. Maybe if we add to the documentation that double types for user/item should be avoided I would be ok with it.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @imatiach-msft I already tested it, check this snippet posted here. I can add it in ALSSuite if necessary. The exception message is technically correct since a number with a fractional part is not in Integer range (plus you get to see the actual value of the number). Returning a different exception would require having two separate if checks instead of one. Do we really want that? Guys to be honest this PR solves a very simple thing and I did not anticipate to be so controvercial. May I suggest we agree on the final set of changes and merge it? Perhaps we can tackle any other concerns on new pull-requests?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @datumbox Sorry the link doesn't seem to work for me. If you could add it to the ALSSuite that would be great. Oh no, I don't mean to have a different check, just to have a clearer exception message - something like "throw new IllegalArgumentException(s"ALS only supports values in Integer range and without fractional part for columns As a general rule all checkins should have tests associated with them to verify functionality, and we should aim for 100% code coverage to validate all production code paths, even if it is not always realistically possible. |
||
| intV | ||
| } else { | ||
| throw new IllegalArgumentException(s"ALS only supports values in Integer range " + | ||
| s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") | ||
| } | ||
| case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " + | ||
| s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not numeric.") | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -262,9 +270,9 @@ class ALSModel private[ml] ( | |
| } | ||
| dataset | ||
| .join(userFactors, | ||
| checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") | ||
| checkedCast(dataset($(userCol))) === userFactors("id"), "left") | ||
| .join(itemFactors, | ||
| checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") | ||
| checkedCast(dataset($(itemCol))) === itemFactors("id"), "left") | ||
| .select(dataset("*"), | ||
| predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) | ||
| } | ||
|
|
@@ -451,8 +459,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] | |
|
|
||
| val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) | ||
| val ratings = dataset | ||
| .select(checkedCast(col($(userCol)).cast(DoubleType)), | ||
| checkedCast(col($(itemCol)).cast(DoubleType)), r) | ||
| .select(checkedCast(col($(userCol))), | ||
| checkedCast(col($(itemCol))), r) | ||
|
||
| .rdd | ||
| .map { row => | ||
| Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably add to comment "... out of integer range or contains a fractional part"