Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
Expand Down Expand Up @@ -70,7 +69,7 @@ class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): MaxAbsScalerModel = {
transformSchema(dataset.schema, logging = true)
val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
val input = dataset.select($(inputCol)).rdd.map {
case Row(v: Vector) => OldVectors.fromML(v)
}
val summary = Statistics.colStats(input)
Expand All @@ -79,7 +78,7 @@ class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override val uid: String)
val n = minVals.length
val maxAbs = Array.tabulate(n) { i => math.max(math.abs(minVals(i)), math.abs(maxVals(i))) }

copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs)).setParent(this))
copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs).compressed).setParent(this))
}

@Since("2.0.0")
Expand Down Expand Up @@ -121,13 +120,13 @@ class MaxAbsScalerModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// TODO: this looks hack, we may have to handle sparse and dense vectors separately.
val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x))
val reScale = udf { (vector: Vector) =>
val brz = vector.asBreeze / maxAbsUnzero.asBreeze
Vectors.fromBreeze(brz)
}
dataset.withColumn($(outputCol), reScale(col($(inputCol))))

val scale = maxAbs.toArray.map { v => if (v == 0) 1.0 else 1 / v }
val func = StandardScalerModel.getTransformFunc(
Array.empty, scale, false, true)
val transformer = udf(func)

dataset.withColumn($(outputCol), transformer(col($(inputCol))))
}

@Since("2.0.0")
Expand Down