Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a95bc22
timing for DecisionTree internals
jkbradley Aug 5, 2014
511ec85
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 6, 2014
bcf874a
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 7, 2014
f61e9d2
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 8, 2014
3211f02
Optimizing DecisionTree
jkbradley Aug 8, 2014
0f676e2
Optimizations + Bug fix for DecisionTree
jkbradley Aug 8, 2014
b2ed1f3
Merge remote-tracking branch 'upstream/master' into dt-opt
jkbradley Aug 8, 2014
b914f3b
DecisionTree optimization: eliminated filters + small changes
jkbradley Aug 9, 2014
c1565a5
Small DecisionTree updates:
jkbradley Aug 11, 2014
a87e08f
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 14, 2014
8464a6e
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. …
jkbradley Aug 14, 2014
e66f1b1
TreePoint
jkbradley Aug 14, 2014
d036089
Print timing info to logDebug.
jkbradley Aug 14, 2014
430d782
Added more debug info on binning error. Added some docs.
jkbradley Aug 14, 2014
356daba
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 14, 2014
26d10dd
Removed tree/model/Filter.scala since no longer used. Removed debugg…
jkbradley Aug 15, 2014
2d2aaaf
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 15, 2014
6b5651e
Updates based on code review. 1 major change: persisting to memory +…
jkbradley Aug 15, 2014
5f2dec2
Fixed scalastyle issue in TreePoint
jkbradley Aug 15, 2014
f40381c
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 15, 2014
797f68a
Fixed DecisionTreeSuite bug for training second level. Needed to upd…
jkbradley Aug 15, 2014
931a3a7
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 15, 2014
6a38f48
Added DTMetadata class for cleaner code
jkbradley Aug 16, 2014
db0d773
scala style fix
jkbradley Aug 16, 2014
ac0b9f8
Small updates based on code review.
jkbradley Aug 16, 2014
3726d20
Small code improvements based on code review.
jkbradley Aug 17, 2014
a0ed0da
Renamed DTMetadata to DecisionTreeMetadata. Small doc updates.
jkbradley Aug 17, 2014
66d076f
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 18, 2014
85bbc1f
Merge remote-tracking branch 'upstream/master' into dt-opt2
jkbradley Aug 18, 2014
b7b2922
Fixed bug in python example decision_tree_runner.py with missing argu…
jkbradley Aug 18, 2014
8e4665d
Added tree.py to python doc tests. Fixed bug from missing categorica…
jkbradley Aug 18, 2014
b5114fa
Fixed python tree.py doc test (extra newline)
jkbradley Aug 18, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/src/main/python/mllib/decision_tree_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def usage():
(reindexedData, origToNewLabels) = reindexClassLabels(points)

# Train a classifier.
model = DecisionTree.trainClassifier(reindexedData, numClasses=2)
categoricalFeaturesInfo={} # no categorical features
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

= -> = (spaces around =)

model = DecisionTree.trainClassifier(reindexedData, numClasses=2,
categoricalFeaturesInfo=categoricalFeaturesInfo)
# Print learned tree and stats.
print "Trained DecisionTree for classification:"
print " Model numNodes: %d\n" % model.numNodes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
Expand Down Expand Up @@ -826,7 +827,7 @@ object DecisionTree extends Serializable with Logging {
// Calculate bin aggregates.
timer.start("aggregation")
val binAggregates = {
input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
}
timer.stop("aggregation")
logDebug("binAggregates.length = " + binAggregates.length)
Expand Down
14 changes: 8 additions & 6 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class DecisionTree(object):
It will probably be modified for Spark v1.2.

Example usage:
>>> from numpy import array, ndarray
>>> from numpy import array
>>> import sys
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
>>> from pyspark.mllib.linalg import SparseVector
Expand All @@ -99,15 +100,15 @@ class DecisionTree(object):
... LabeledPoint(1.0, [2.0]),
... LabeledPoint(1.0, [3.0])
... ]
>>>
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2)
>>> print(model)
>>> categoricalFeaturesInfo = {} # no categorical features
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2,
... categoricalFeaturesInfo=categoricalFeaturesInfo)
>>> sys.stdout.write(model)
DecisionTreeModel classifier
If (feature 0 <= 0.5)
Predict: 0.0
Else (feature 0 > 0.5)
Predict: 1.0

>>> model.predict(array([1.0])) > 0
True
>>> model.predict(array([0.0])) == 0
Expand All @@ -119,7 +120,8 @@ class DecisionTree(object):
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>>
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data))
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data),
... categoricalFeaturesInfo=categoricalFeaturesInfo)
>>> model.predict(array([0.0, 1.0])) == 1
True
>>> model.predict(array([0.0, 0.0])) == 0
Expand Down
1 change: 1 addition & 0 deletions python/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ run_test "pyspark/mllib/random.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
run_test "pyspark/mllib/tests.py"
run_test "pyspark/mllib/tree.py"
run_test "pyspark/mllib/util.py"

if [[ $FAILED == 0 ]]; then
Expand Down