From 7f8eb364f216c0e4e776f115192acc01c5e3d0f0 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 7 Apr 2014 09:45:48 -0700 Subject: [PATCH 1/5] change labelParser from annoymous function to trait --- .../spark/mllib/util/LabelParsers.scala | 55 +++++++++++++++++++ .../org/apache/spark/mllib/util/MLUtils.scala | 26 ++------- .../spark/mllib/util/MLUtilsSuite.scala | 4 +- 3 files changed, 61 insertions(+), 24 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala new file mode 100644 index 000000000000..21d633d60504 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +/** Trait for label parsers. */ +trait LabelParser extends Serializable { + /** Parses a string label into a double label. */ + def apply(labelString: String): Double +} + +/** + * Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5, + * or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling. + */ +class BinaryLabelParser extends LabelParser { + /** + * Parses the input label into positive (1.0) if the value is greater than 0.5, + * or negative (0.0) otherwise. + */ + override def apply(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0 +} + +object BinaryLabelParser { + private lazy val instance = new BinaryLabelParser() + /** Gets the default instance of BinaryLabelParser. */ + def apply() = instance +} + +/** + * Label parser for multiclass labels, which converts the input label to double. + */ +class MulticlassLabelParser extends LabelParser { + override def apply(labelString: String): Double = labelString.toDouble +} + +object MulticlassLabelParser { + private lazy val instance = new MulticlassLabelParser() + /** Gets the default instance of MulticlassLabelParser. */ + def apply() = instance +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index cb85e433bfc7..b28c8950922b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -38,17 +38,6 @@ object MLUtils { eps } - /** - * Multiclass label parser, which parses a string into double. - */ - val multiclassLabelParser: String => Double = _.toDouble - - /** - * Binary label parser, which outputs 1.0 (positive) if the value is greater than 0.5, - * or 0.0 (negative) otherwise. - */ - val binaryLabelParser: String => Double = label => if (label.toDouble > 0.5) 1.0 else 0.0 - /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint]. * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR. @@ -69,7 +58,7 @@ object MLUtils { def loadLibSVMData( sc: SparkContext, path: String, - labelParser: String => Double, + labelParser: LabelParser, numFeatures: Int, minSplits: Int): RDD[LabeledPoint] = { val parsed = sc.textFile(path, minSplits) @@ -107,14 +96,7 @@ object MLUtils { * with number of features determined automatically and the default number of partitions. */ def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] = - loadLibSVMData(sc, path, binaryLabelParser, -1, sc.defaultMinSplits) - - /** - * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], - * with number of features specified explicitly and the default number of partitions. - */ - def loadLibSVMData(sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] = - loadLibSVMData(sc, path, binaryLabelParser, numFeatures, sc.defaultMinSplits) + loadLibSVMData(sc, path, BinaryLabelParser(), -1, sc.defaultMinSplits) /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], @@ -124,7 +106,7 @@ object MLUtils { def loadLibSVMData( sc: SparkContext, path: String, - labelParser: String => Double): RDD[LabeledPoint] = + labelParser: LabelParser): RDD[LabeledPoint] = loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits) /** @@ -135,7 +117,7 @@ object MLUtils { def loadLibSVMData( sc: SparkContext, path: String, - labelParser: String => Double, + labelParser: LabelParser, numFeatures: Int): RDD[LabeledPoint] = loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 27d41c7869aa..4b08169caf2a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -80,7 +80,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString - val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect() + val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser(), 6).collect() val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect() for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { @@ -93,7 +93,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) } - val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect() + val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser()).collect() assert(multiclassPoints.length === 3) assert(multiclassPoints(0).label === 1.0) assert(multiclassPoints(1).label === -1.0) From 11c94e0876a14b679b8f538139c38d6e2d824996 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 7 Apr 2014 11:59:21 -0700 Subject: [PATCH 2/5] add return types --- .../main/scala/org/apache/spark/mllib/util/LabelParsers.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala index 21d633d60504..6cb3d53c15f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala @@ -38,7 +38,7 @@ class BinaryLabelParser extends LabelParser { object BinaryLabelParser { private lazy val instance = new BinaryLabelParser() /** Gets the default instance of BinaryLabelParser. */ - def apply() = instance + def apply(): BinaryLabelParser = instance } /** @@ -51,5 +51,5 @@ class MulticlassLabelParser extends LabelParser { object MulticlassLabelParser { private lazy val instance = new MulticlassLabelParser() /** Gets the default instance of MulticlassLabelParser. */ - def apply() = instance + def apply(): MulticlassLabelParser = instance } From c2e571c2572dbf1d59fbaa6b6dff177afa0b0f66 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 7 Apr 2014 14:14:49 -0700 Subject: [PATCH 3/5] rename LabelParser.apply to LabelParser.parse use extends for singleton --- .../apache/spark/mllib/util/LabelParsers.scala | 16 +++++++--------- .../org/apache/spark/mllib/util/MLUtils.scala | 4 ++-- .../apache/spark/mllib/util/MLUtilsSuite.scala | 4 ++-- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala index 6cb3d53c15f6..c328718bef94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.util /** Trait for label parsers. */ trait LabelParser extends Serializable { /** Parses a string label into a double label. */ - def apply(labelString: String): Double + def parse(labelString: String): Double } /** @@ -32,24 +32,22 @@ class BinaryLabelParser extends LabelParser { * Parses the input label into positive (1.0) if the value is greater than 0.5, * or negative (0.0) otherwise. */ - override def apply(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0 + override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0 } -object BinaryLabelParser { - private lazy val instance = new BinaryLabelParser() +object BinaryLabelParser extends BinaryLabelParser { /** Gets the default instance of BinaryLabelParser. */ - def apply(): BinaryLabelParser = instance + def getInstance(): BinaryLabelParser = this } /** * Label parser for multiclass labels, which converts the input label to double. */ class MulticlassLabelParser extends LabelParser { - override def apply(labelString: String): Double = labelString.toDouble + override def parse(labelString: String): Double = labelString.toDouble } -object MulticlassLabelParser { - private lazy val instance = new MulticlassLabelParser() +object MulticlassLabelParser extends MulticlassLabelParser { /** Gets the default instance of MulticlassLabelParser. */ - def apply(): MulticlassLabelParser = instance + def getInstance(): MulticlassLabelParser = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index b28c8950922b..83d1bd3fd57f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -78,7 +78,7 @@ object MLUtils { }.reduce(math.max) } parsed.map { items => - val label = labelParser(items.head) + val label = labelParser.parse(items.head) val (indices, values) = items.tail.map { item => val indexAndValue = item.split(':') val index = indexAndValue(0).toInt - 1 @@ -96,7 +96,7 @@ object MLUtils { * with number of features determined automatically and the default number of partitions. */ def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] = - loadLibSVMData(sc, path, BinaryLabelParser(), -1, sc.defaultMinSplits) + loadLibSVMData(sc, path, BinaryLabelParser, -1, sc.defaultMinSplits) /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 4b08169caf2a..e451c350b8d8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -80,7 +80,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString - val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser(), 6).collect() + val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser, 6).collect() val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect() for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { @@ -93,7 +93,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) } - val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser()).collect() + val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser).collect() assert(multiclassPoints.length === 3) assert(multiclassPoints(0).label === 1.0) assert(multiclassPoints(1).label === -1.0) From 3b1a7c60d8c87cb0babdaace6470cdebf5983ff3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 8 Apr 2014 10:18:12 -0700 Subject: [PATCH 4/5] add tests for label parsers --- .../spark/mllib/util/LabelParsersSuite.scala | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala new file mode 100644 index 000000000000..ac85677f2f01 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.scalatest.FunSuite + +class LabelParsersSuite extends FunSuite { + test("binary label parser") { + for (parser <- Seq(BinaryLabelParser, BinaryLabelParser.getInstance())) { + assert(parser.parse("+1") === 1.0) + assert(parser.parse("1") === 1.0) + assert(parser.parse("0") === 0.0) + assert(parser.parse("-1") === 0.0) + } + } + + test("multiclass label parser") { + for (parser <- Seq(MulticlassLabelParser, MulticlassLabelParser.getInstance())) { + assert(parser.parse("0") == 0.0) + assert(parser.parse("+1") === 1.0) + assert(parser.parse("1") === 1.0) + assert(parser.parse("2") === 2.0) + assert(parser.parse("3") === 3.0) + } + } +} From ac444093df8bf395de8bfb7e8cbe9ab5bf7b2fee Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 8 Apr 2014 11:22:31 -0700 Subject: [PATCH 5/5] use singleton objects for label parsers --- .../spark/mllib/util/LabelParsers.scala | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala index c328718bef94..f7966d3ebb61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala @@ -27,7 +27,10 @@ trait LabelParser extends Serializable { * Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5, * or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling. */ -class BinaryLabelParser extends LabelParser { +object BinaryLabelParser extends LabelParser { + /** Gets the default instance of BinaryLabelParser. */ + def getInstance(): LabelParser = this + /** * Parses the input label into positive (1.0) if the value is greater than 0.5, * or negative (0.0) otherwise. @@ -35,19 +38,12 @@ class BinaryLabelParser extends LabelParser { override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0 } -object BinaryLabelParser extends BinaryLabelParser { - /** Gets the default instance of BinaryLabelParser. */ - def getInstance(): BinaryLabelParser = this -} - /** * Label parser for multiclass labels, which converts the input label to double. */ -class MulticlassLabelParser extends LabelParser { - override def parse(labelString: String): Double = labelString.toDouble -} - -object MulticlassLabelParser extends MulticlassLabelParser { +object MulticlassLabelParser extends LabelParser { /** Gets the default instance of MulticlassLabelParser. */ - def getInstance(): MulticlassLabelParser = this + def getInstance(): LabelParser = this + + override def parse(labelString: String): Double = labelString.toDouble }