Skip to content

Commit c26c4fc

Browse files
committed
update DecisionTree to use RDD[Vector]
1 parent 11999c7 commit c26c4fc

5 files changed

Lines changed: 19 additions & 13 deletions

File tree

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ class DenseVector(val values: Array[Double]) extends Vector {
151151
override def toArray: Array[Double] = values
152152

153153
private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)
154+
155+
override def apply(i: Int) = values(i)
154156
}
155157

156158
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
3030
import org.apache.spark.mllib.tree.model._
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.util.random.XORShiftRandom
33+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3334

3435
/**
3536
* A class that implements a decision tree algorithm for classification and regression. It
@@ -295,7 +296,7 @@ object DecisionTree extends Serializable with Logging {
295296
val numNodes = scala.math.pow(2, level).toInt
296297
logDebug("numNodes = " + numNodes)
297298
// Find the number of features by looking at the first sample.
298-
val numFeatures = input.first().features.length
299+
val numFeatures = input.first().features.size
299300
logDebug("numFeatures = " + numFeatures)
300301
val numBins = bins(0).length
301302
logDebug("numBins = " + numBins)
@@ -902,7 +903,7 @@ object DecisionTree extends Serializable with Logging {
902903
val count = input.count()
903904

904905
// Find the number of features by looking at the first sample
905-
val numFeatures = input.take(1)(0).features.length
906+
val numFeatures = input.take(1)(0).features.size
906907

907908
val maxBins = strategy.maxBins
908909
val numBins = if (maxBins <= count) maxBins else count.toInt
@@ -1116,7 +1117,7 @@ object DecisionTree extends Serializable with Logging {
11161117
sc.textFile(dir).map { line =>
11171118
val parts = line.trim().split(",")
11181119
val label = parts(0).toDouble
1119-
val features = parts.slice(1,parts.length).map(_.toDouble)
1120+
val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble))
11201121
LabeledPoint(label, features)
11211122
}
11221123
}
@@ -1127,7 +1128,7 @@ object DecisionTree extends Serializable with Logging {
11271128
*/
11281129
private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint],
11291130
threshold: Double = 0.5): Double = {
1130-
def predictedValue(features: Array[Double]) = {
1131+
def predictedValue(features: Vector) = {
11311132
if (model.predict(features) < threshold) 0.0 else 1.0
11321133
}
11331134
val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.model
1919

2020
import org.apache.spark.mllib.tree.configuration.Algo._
2121
import org.apache.spark.rdd.RDD
22+
import org.apache.spark.mllib.linalg.Vector
2223

2324
/**
2425
* Model to store the decision tree parameters
@@ -33,7 +34,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
3334
* @param features array representing a single data point
3435
* @return Double prediction from the trained model
3536
*/
36-
def predict(features: Array[Double]): Double = {
37+
def predict(features: Vector): Double = {
3738
topNode.predictIfLeaf(features)
3839
}
3940

@@ -43,7 +44,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
4344
* @param features RDD representing data points to be predicted
4445
* @return RDD[Int] where each entry contains the corresponding prediction
4546
*/
46-
def predict(features: RDD[Array[Double]]): RDD[Double] = {
47+
def predict(features: RDD[Vector]): RDD[Double] = {
4748
features.map(x => predict(x))
4849
}
4950
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.model
1919

2020
import org.apache.spark.Logging
2121
import org.apache.spark.mllib.tree.configuration.FeatureType._
22+
import org.apache.spark.mllib.linalg.Vector
2223

2324
/**
2425
* Node in a decision tree
@@ -54,8 +55,8 @@ class Node (
5455
logDebug("stats = " + stats)
5556
logDebug("predict = " + predict)
5657
if (!isLeaf) {
57-
val leftNodeIndex = id*2 + 1
58-
val rightNodeIndex = id*2 + 2
58+
val leftNodeIndex = id * 2 + 1
59+
val rightNodeIndex = id * 2 + 2
5960
leftNode = Some(nodes(leftNodeIndex))
6061
rightNode = Some(nodes(rightNodeIndex))
6162
leftNode.get.build(nodes)
@@ -68,7 +69,7 @@ class Node (
6869
* @param feature feature value
6970
* @return predicted value
7071
*/
71-
def predictIfLeaf(feature: Array[Double]) : Double = {
72+
def predictIfLeaf(feature: Vector) : Double = {
7273
if (isLeaf) {
7374
predict
7475
} else{

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.mllib.tree.model.Filter
2727
import org.apache.spark.mllib.tree.configuration.Strategy
2828
import org.apache.spark.mllib.tree.configuration.Algo._
2929
import org.apache.spark.mllib.tree.configuration.FeatureType._
30+
import org.apache.spark.mllib.linalg.Vectors
3031

3132
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
3233

@@ -396,7 +397,7 @@ object DecisionTreeSuite {
396397
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
397398
val arr = new Array[LabeledPoint](1000)
398399
for (i <- 0 until 1000){
399-
val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i))
400+
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
400401
arr(i) = lp
401402
}
402403
arr
@@ -405,7 +406,7 @@ object DecisionTreeSuite {
405406
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
406407
val arr = new Array[LabeledPoint](1000)
407408
for (i <- 0 until 1000){
408-
val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i))
409+
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
409410
arr(i) = lp
410411
}
411412
arr
@@ -415,9 +416,9 @@ object DecisionTreeSuite {
415416
val arr = new Array[LabeledPoint](1000)
416417
for (i <- 0 until 1000){
417418
if (i < 600){
418-
arr(i) = new LabeledPoint(1.0,Array(0.0,1.0))
419+
arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
419420
} else {
420-
arr(i) = new LabeledPoint(0.0,Array(1.0,0.0))
421+
arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0))
421422
}
422423
}
423424
arr

0 commit comments

Comments
 (0)