Skip to content

Commit abff51b

Browse files
committed
Make MultiStopWatch optional
1 parent e5b077d commit abff51b

1 file changed

Lines changed: 29 additions & 22 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ private[spark] object RandomForest extends Logging {
9494
instr: Option[Instrumentation[_]],
9595
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
9696

97-
val multiTimer = new MultiStopwatch(input.sparkContext)
97+
val timers = new MultiStopwatch(input.sparkContext)
9898

99-
multiTimer.addLocal("total")
100-
multiTimer("total").start()
99+
timers.addLocal("total")
100+
timers("total").start()
101101

102-
multiTimer.addLocal("init")
103-
multiTimer("init").start()
102+
timers.addLocal("init")
103+
timers("init").start()
104104

105105
val retaggedInput = input.retag(classOf[LabeledPoint])
106106
val metadata =
@@ -116,10 +116,10 @@ private[spark] object RandomForest extends Logging {
116116

117117
// Find the splits and the corresponding bins (interval between the splits) using a sample
118118
// of the input data.
119-
multiTimer.addLocal("findSplitsBins")
120-
multiTimer("findSplitsBins").start()
119+
timers.addLocal("findSplitsBins")
120+
timers("findSplitsBins").start()
121121
val splits = findSplits(retaggedInput, metadata, seed)
122-
multiTimer("findSplitsBins").stop()
122+
timers("findSplitsBins").stop()
123123

124124
logDebug("numBins: feature: number of bins")
125125
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
@@ -146,7 +146,7 @@ private[spark] object RandomForest extends Logging {
146146
val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
147147
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
148148

149-
multiTimer("init").stop()
149+
timers("init").stop()
150150
/*
151151
* The main idea here is to perform group-wise training of the decision tree nodes thus
152152
* reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
@@ -175,9 +175,9 @@ private[spark] object RandomForest extends Logging {
175175
// Allocate and queue root nodes.
176176
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
177177
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
178-
multiTimer.addLocal("findBestSplits")
179-
multiTimer.addLocal("chooseSplits")
180-
multiTimer.addDistributed("binsToBestSplit")
178+
timers.addLocal("findBestSplits")
179+
timers.addLocal("chooseSplits")
180+
timers.addDistributed("binsToBestSplit")
181181

182182
while (nodeQueue.nonEmpty) {
183183
// Collect some nodes to split, and choose features for each node (if subsampling).
@@ -189,18 +189,18 @@ private[spark] object RandomForest extends Logging {
189189
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
190190

191191
// Choose node splits, and enqueue new nodes as needed.
192-
multiTimer("findBestSplits").start()
192+
timers("findBestSplits").start()
193193
RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
194-
treeToNodeToIndexInfo, splits, nodeQueue, multiTimer, nodeIdCache)
195-
multiTimer("findBestSplits").stop()
194+
treeToNodeToIndexInfo, splits, nodeQueue, Option(timers), nodeIdCache)
195+
timers("findBestSplits").stop()
196196
}
197197

198198
baggedInput.unpersist()
199199

200-
multiTimer("total").stop()
200+
timers("total").stop()
201201

202202
logInfo("Internal timing for DecisionTree:")
203-
logInfo(s"$multiTimer")
203+
logInfo(s"$timers")
204204

205205
// Delete any remaining checkpoints used for node Id cache.
206206
if (nodeIdCache.nonEmpty) {
@@ -362,7 +362,7 @@ private[spark] object RandomForest extends Logging {
362362
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
363363
splits: Array[Array[Split]],
364364
nodeQueue: mutable.Queue[(Int, LearningNode)],
365-
multiTimer: MultiStopwatch,
365+
multiStopwatch: Option[MultiStopwatch] = None,
366366
nodeIdCache: Option[NodeIdCache] = None): Unit = {
367367

368368
/*
@@ -485,6 +485,13 @@ private[spark] object RandomForest extends Logging {
485485
}
486486
}
487487

488+
val timers = multiStopwatch match {
489+
case Some(timers) => timers
490+
case None => new MultiStopwatch(input.sparkContext)
491+
.addLocal("chooseSplits")
492+
.addLocal("binsToBestSplit")
493+
}
494+
488495
// array of nodes to train indexed by node index in group
489496
val nodes = new Array[LearningNode](numNodes)
490497
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
@@ -494,7 +501,7 @@ private[spark] object RandomForest extends Logging {
494501
}
495502

496503
// Calculate best splits for all nodes in the group
497-
multiTimer("chooseSplits").start()
504+
timers("chooseSplits").start()
498505

499506
// In each partition, iterate all instances and compute aggregate stats for each node,
500507
// yield a (nodeIndex, nodeAggregateStats) pair for each node.
@@ -550,14 +557,14 @@ private[spark] object RandomForest extends Logging {
550557
}
551558

552559
// find best split for each node
553-
multiTimer("binsToBestSplit").start()
560+
timers("binsToBestSplit").start()
554561
val (split: Split, stats: ImpurityStats) =
555562
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
556-
multiTimer("binsToBestSplit").stop()
563+
timers("binsToBestSplit").stop()
557564
(nodeIndex, (split, stats))
558565
}.collectAsMap()
559566

560-
multiTimer("chooseSplits").stop()
567+
timers("chooseSplits").stop()
561568

562569
val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
563570
Array.fill[mutable.Map[Int, NodeIndexUpdater]](

0 commit comments

Comments
 (0)