Skip to content
45 changes: 45 additions & 0 deletions docs/ml-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,51 @@ for more details on the API.
</div>
</div>


## RobustScaler

`RobustScaler` transforms a dataset of `Vector` rows, removing the median and scaling the data according to a specific quantile range (by default the IQR: Interquartile Range, quantile range between the 1st quartile and the 3rd quartile). Its behavior is quite similar to `StandardScaler`, however the median and the quantile range are used instead of mean and standard deviation, which make it robust to outliers. It takes parameters:

* `lower`: 0.25 by default. Lower quantile to calculate quantile range, shared by all features.
* `upper`: 0.75 by default. Upper quantile to calculate quantile range, shared by all features.
* `withScaling`: True by default. Scales the data to quantile range.
* `withCentering`: False by default. Centers the data with median before scaling. It will build a dense output, so take care when applying to sparse input.

`RobustScaler` is an `Estimator` which can be `fit` on a dataset to produce a `RobustScalerModel`; this amounts to computing quantile statistics. The model can then transform a `Vector` column in a dataset to have unit quantile range and/or zero median features.

Note that if the quantile range of a feature is zero, it will return default `0.0` value in the `Vector` for that feature.

**Examples**

The following example demonstrates how to load a dataset in libsvm format and then normalize each feature to have unit quantile range.

<div class="codetabs">
<div data-lang="scala" markdown="1">

Refer to the [RobustScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RobustScaler)
for more details on the API.

{% include_example scala/org/apache/spark/examples/ml/RobustScalerExample.scala %}
</div>

<div data-lang="java" markdown="1">

Refer to the [RobustScaler Java docs](api/java/org/apache/spark/ml/feature/RobustScaler.html)
for more details on the API.

{% include_example java/org/apache/spark/examples/ml/JavaRobustScalerExample.java %}
</div>

<div data-lang="python" markdown="1">

Refer to the [RobustScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RobustScaler)
for more details on the API.

{% include_example python/ml/robust_scaler_example.py %}
</div>
</div>


## MinMaxScaler

`MinMaxScaler` transforms a dataset of `Vector` rows, rescaling each feature to a specific range (often [0, 1]). It takes parameters:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml;

import org.apache.spark.sql.SparkSession;

// $example on$
import org.apache.spark.ml.feature.RobustScaler;
import org.apache.spark.ml.feature.RobustScalerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// $example off$

public class JavaRobustScalerExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaRobustScalerExample")
.getOrCreate();

// $example on$
Dataset<Row> dataFrame =
spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");

RobustScaler scaler = new RobustScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures")
.setWithScaling(true)
.setWithCentering(false)
.setLower(0.25)
.setUpper(0.75);

// Compute summary statistics by fitting the RobustScaler
RobustScalerModel scalerModel = scaler.fit(dataFrame);

// Transform each feature to have unit quantile range.
Dataset<Row> scaledData = scalerModel.transform(dataFrame);
scaledData.show();
// $example off$
spark.stop();
}
}
45 changes: 45 additions & 0 deletions examples/src/main/python/ml/robust_scaler_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import print_function

# $example on$
from pyspark.ml.feature import RobustScaler
# $example off$
from pyspark.sql import SparkSession

if __name__ == "__main__":
spark = SparkSession\
.builder\
.appName("RobustScalerExample")\
.getOrCreate()

# $example on$
dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
scaler = RobustScaler(inputCol="features", outputCol="scaledFeatures",
withScaling=True, withCentering=False,
lower=0.25, upper=0.75)

# Compute summary statistics by fitting the RobustScaler
scalerModel = scaler.fit(dataFrame)

# Transform each feature to have unit quantile range.
scaledData = scalerModel.transform(dataFrame)
scaledData.show()
# $example off$

spark.stop()
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.feature.RobustScaler
// $example off$
import org.apache.spark.sql.SparkSession

object RobustScalerExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName("RobustScalerExample")
.getOrCreate()

// $example on$
val dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

val scaler = new RobustScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures")
.setWithScaling(true)
.setWithCentering(false)
.setLower(0.25)
.setUpper(0.75)

// Compute summary statistics by fitting the RobustScaler.
val scalerModel = scaler.fit(dataFrame)

// Transform each feature to have unit quantile range.
val scaledData = scalerModel.transform(dataFrame)
scaledData.show()
// $example off$

spark.stop()
}
}
// scalastyle:on println
Loading