Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,20 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo
* Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
* out of integer range.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably add to comment "... out of integer range or contains a fractional part"

*/
protected val checkedCast = udf { (n: Double) =>
if (n > Int.MaxValue || n < Int.MinValue) {
throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " +
s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
} else {
n.toInt
protected val checkedCast = udf { (n: Any) =>
n match {
case v: Int => v // Avoid unnecessary casting
case v: Number =>
val intV = v.intValue()
if (v == intV) { // True for Byte/Short, Long within the Int range and Double/Float with no fractional part.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One could write this differently and explicitly handle all the permitted types. Unfortunately this would lead to duplicate code. Instead what I do here is convert the number into Integer and compare it with the original Number. If the values are identical one of the following is true:

  • The value is Byte or Short.
  • The value is Long but within the Integer range.
  • The value is Double or Float but without any fractional part.

intV
}
else {
throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
}
case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n is not numeric.")
}
}
}
Expand Down Expand Up @@ -262,9 +270,9 @@ class ALSModel private[ml] (
}
dataset
.join(userFactors,
checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
checkedCast(dataset($(userCol))) === userFactors("id"), "left")
.join(itemFactors,
checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
checkedCast(dataset($(itemCol))) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}
Expand Down Expand Up @@ -451,8 +459,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]

val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
.select(checkedCast(col($(userCol)).cast(DoubleType)),
checkedCast(col($(itemCol)).cast(DoubleType)), r)
.select(checkedCast(col($(userCol))),
checkedCast(col($(itemCol))), r)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - but can this now go on one line if short enough?

.rdd
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
Expand Down