diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 91b0707dec3f..87685e4c01eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml._ -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -160,15 +160,88 @@ class StandardScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) - - // TODO: Make the transformer natively in ml framework to avoid extra conversion. - val transformer: Vector => Vector = v => scaler.transform(OldVectors.fromML(v)).asML + val transformer: Vector => Vector = v => transform(v) val scale = udf(transformer) dataset.withColumn($(outputCol), scale(col($(inputCol)))) } + /** + * Since `shift` will be only used in `withMean` branch, we have it as + * `lazy val` so it will be evaluated in that branch. Note that we don't + * want to create this array multiple times in `transform` function. + */ + private lazy val shift: Array[Double] = mean.toArray + + /** + * Applies standardization transformation on a vector. + * + * @param vector Vector to be standardized. + * @return Standardized vector. If the std of a column is zero, it will return default `0.0` + * for the column with zero std. + */ + private[spark] def transform(vector: Vector): Vector = { + require(mean.size == vector.size) + if ($(withMean)) { + /** + * By default, Scala generates Java methods for member variables. So every time + * member variables are accessed, `invokespecial` is called. This is an expensive + * operation, and can be avoided by having a local reference of `shift`. + */ + val localShift = shift + /** Must have a copy of the values since they will be modified in place. */ + val values = vector match { + /** Handle DenseVector specially because its `toArray` method does not clone values. */ + case d: DenseVector => d.values.clone() + case v: Vector => v.toArray + } + val size = values.length + if ($(withStd)) { + var i = 0 + while (i < size) { + values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 + i += 1 + } + } else { + var i = 0 + while (i < size) { + values(i) -= localShift(i) + i += 1 + } + } + Vectors.dense(values) + } else if ($(withStd)) { + vector match { + case DenseVector(vs) => + val values = vs.clone() + val size = values.length + var i = 0 + while(i < size) { + values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0) + i += 1 + } + Vectors.dense(values) + case SparseVector(size, indices, vs) => + /** + * For sparse vector, the `index` array inside sparse vector object will not be changed, + * so we can re-use it to save memory. + */ + val values = vs.clone() + val nnz = values.length + var i = 0 + while (i < nnz) { + values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0) + i += 1 + } + Vectors.sparse(size, indices, values) + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + } else { + /** Note that it's safe since we always assume that the data in RDD should be immutable. */ + vector + } + } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema)