Skip to content
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,20 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo
* Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
* out of integer range.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably add to comment "... out of integer range or contains a fractional part"

*/
protected val checkedCast = udf { (n: Double) =>
if (n > Int.MaxValue || n < Int.MinValue) {
throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " +
s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
} else {
n.toInt
protected val checkedCast = udf { (n: Any) =>
n match {
case v: Int => v // Avoid unnecessary casting
case v: Number =>
val intV = v.intValue()
// Checks if number within Int range and has no fractional part.
if (v.doubleValue == intV) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm not sure if this is a good idea due to floating point precision... the code above doesn't seem to do this check, it just calls toInt -- however, if this is absolutely necessary, I would hope that we could give the user some way to specify the Int range or precision. Also, if we are going to go ahead with this change, then we should add some tests to verify the case that the exception is thrown, but without some ability to specify the precision I'm not sure if this is the correct thing to do (?).

Copy link
Contributor Author

@datumbox datumbox Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@imatiach-msft: In this snippet we deal with Uset id and Item id. Those things should no have fractional bits. What I do here is convert the number into Integer and compare its double value. If the values are identical one of the following is true:

  • The value is Byte or Short.
  • The value is Long but within the Integer range.
  • The value is Double or Float whithin the Integer range but without any fractional part.

I think this snippet is fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the reality is we mostly expect Int or Long ids, but want to support any numeric ID as long as it falls within Int range - so passing in fractional float/double just doesn't make sense and we'll throw an error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test does work, but may be slightly more direct to check if v.doubleValue % 1.0 == 0.0? this also works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way it is written it checks 2 things at the same time. 1) that it is within integer range (so if it overflows the equality will not hold) and 2) that it has no fractional part. The modulo check only covers you for point #2 but not #1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if we for some reason do decide to go with the check against the fractional part, which I would not recommend, we should add a test to verify the error message is thrown, and possible add this change in accepted input format to the documentation, since it may break some users (they may have to do additional conversion to int for their values prior to calling this API if they run into such floating point precision issues).

Copy link
Contributor Author

@datumbox datumbox Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@imatiach-msft I am aware of the implications of floating point precision and I understand your concerns.

Having said that though, even allowing user and item Ids to be double/float is not a good idea. We just keep it for backwards compatibility I guess. Also note that the current implementation of Spark 2.1 will actually take that your 0.9999999999999996 value and silently cast it to Int (so it becomes 0)! For me the only permitted types should have been Integer, Long and BigIntegers.

I don't have strong opinions about refactoring anything in the Number case as it simply performance-wise it does not matter. The point of this PR is to optimize the general case where the id is Int because the casting of the current approach generates twice as much data as the original dataset (of course it is GCed but at a cost).

@MLnick @srowen It's your call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox I agree that calling toInt in the previous code was not the best decision either; if we really wanted to support double type correctly we would have probably done the check with some precision. Maybe if we add to the documentation that double types for user/item should be avoided I would be ok with it.
Also, could you add a test case to verify the exception is thrown for doubles with a fraction, similar to the test case I sent you (with a catch on the exception type and withClue on the message)? We should clean up the error message too, since it is thrown not only when the value is not in Integer range (Int.Max/Min) but also when there is a fractional part.

Copy link
Contributor Author

@datumbox datumbox Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@imatiach-msft I already tested it, check this snippet posted here. I can add it in ALSSuite if necessary. The exception message is technically correct since a number with a fractional part is not in Integer range (plus you get to see the actual value of the number). Returning a different exception would require having two separate if checks instead of one. Do we really want that?

Guys to be honest this PR solves a very simple thing and I did not anticipate to be so controvercial. May I suggest we agree on the final set of changes and merge it? Perhaps we can tackle any other concerns on new pull-requests?

Copy link
Contributor

@imatiach-msft imatiach-msft Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox Sorry the link doesn't seem to work for me. If you could add it to the ALSSuite that would be great. Oh no, I don't mean to have a different check, just to have a clearer exception message - something like "throw new IllegalArgumentException(s"ALS only supports values in Integer range and without fractional part for columns ${$(userCol)} and ${$(itemCol)}. Value $n was either out of Integer range or contained a fractional part that could not be converted.") "

As a general rule all checkins should have tests associated with them to verify functionality, and we should aim for 100% code coverage to validate all production code paths, even if it is not always realistically possible.

intV
} else {
throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
}
case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not numeric.")
}
}
}
Expand Down Expand Up @@ -262,9 +270,9 @@ class ALSModel private[ml] (
}
dataset
.join(userFactors,
checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
checkedCast(dataset($(userCol))) === userFactors("id"), "left")
.join(itemFactors,
checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
checkedCast(dataset($(itemCol))) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}
Expand Down Expand Up @@ -451,8 +459,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]

val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
.select(checkedCast(col($(userCol)).cast(DoubleType)),
checkedCast(col($(itemCol)).cast(DoubleType)), r)
.select(checkedCast(col($(userCol))),
checkedCast(col($(itemCol))), r)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - but can this now go on one line if short enough?

.rdd
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
Expand Down