@@ -94,13 +94,13 @@ private[spark] object RandomForest extends Logging {
9494 instr : Option [Instrumentation [_]],
9595 parentUID : Option [String ] = None ): Array [DecisionTreeModel ] = {
9696
97- val multiTimer = new MultiStopwatch (input.sparkContext)
97+ val timers = new MultiStopwatch (input.sparkContext)
9898
99- multiTimer .addLocal(" total" )
100- multiTimer (" total" ).start()
99+ timers .addLocal(" total" )
100+ timers (" total" ).start()
101101
102- multiTimer .addLocal(" init" )
103- multiTimer (" init" ).start()
102+ timers .addLocal(" init" )
103+ timers (" init" ).start()
104104
105105 val retaggedInput = input.retag(classOf [LabeledPoint ])
106106 val metadata =
@@ -116,10 +116,10 @@ private[spark] object RandomForest extends Logging {
116116
117117 // Find the splits and the corresponding bins (interval between the splits) using a sample
118118 // of the input data.
119- multiTimer .addLocal(" findSplitsBins" )
120- multiTimer (" findSplitsBins" ).start()
119+ timers .addLocal(" findSplitsBins" )
120+ timers (" findSplitsBins" ).start()
121121 val splits = findSplits(retaggedInput, metadata, seed)
122- multiTimer (" findSplitsBins" ).stop()
122+ timers (" findSplitsBins" ).stop()
123123
124124 logDebug(" numBins: feature: number of bins" )
125125 logDebug(Range (0 , metadata.numFeatures).map { featureIndex =>
@@ -146,7 +146,7 @@ private[spark] object RandomForest extends Logging {
146146 val maxMemoryUsage : Long = strategy.maxMemoryInMB * 1024L * 1024L
147147 logDebug(" max memory usage for aggregates = " + maxMemoryUsage + " bytes." )
148148
149- multiTimer (" init" ).stop()
149+ timers (" init" ).stop()
150150 /*
151151 * The main idea here is to perform group-wise training of the decision tree nodes thus
152152 * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
@@ -175,9 +175,9 @@ private[spark] object RandomForest extends Logging {
175175 // Allocate and queue root nodes.
176176 val topNodes = Array .fill[LearningNode ](numTrees)(LearningNode .emptyNode(nodeIndex = 1 ))
177177 Range (0 , numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
178- multiTimer .addLocal(" findBestSplits" )
179- multiTimer .addLocal(" chooseSplits" )
180- multiTimer .addDistributed(" binsToBestSplit" )
178+ timers .addLocal(" findBestSplits" )
179+ timers .addLocal(" chooseSplits" )
180+ timers .addDistributed(" binsToBestSplit" )
181181
182182 while (nodeQueue.nonEmpty) {
183183 // Collect some nodes to split, and choose features for each node (if subsampling).
@@ -189,18 +189,18 @@ private[spark] object RandomForest extends Logging {
189189 s " RandomForest selected empty nodesForGroup. Error for unknown reason. " )
190190
191191 // Choose node splits, and enqueue new nodes as needed.
192- multiTimer (" findBestSplits" ).start()
192+ timers (" findBestSplits" ).start()
193193 RandomForest .findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
194- treeToNodeToIndexInfo, splits, nodeQueue, multiTimer , nodeIdCache)
195- multiTimer (" findBestSplits" ).stop()
194+ treeToNodeToIndexInfo, splits, nodeQueue, Option (timers) , nodeIdCache)
195+ timers (" findBestSplits" ).stop()
196196 }
197197
198198 baggedInput.unpersist()
199199
200- multiTimer (" total" ).stop()
200+ timers (" total" ).stop()
201201
202202 logInfo(" Internal timing for DecisionTree:" )
203- logInfo(s " $multiTimer " )
203+ logInfo(s " $timers " )
204204
205205 // Delete any remaining checkpoints used for node Id cache.
206206 if (nodeIdCache.nonEmpty) {
@@ -362,7 +362,7 @@ private[spark] object RandomForest extends Logging {
362362 treeToNodeToIndexInfo : Map [Int , Map [Int , NodeIndexInfo ]],
363363 splits : Array [Array [Split ]],
364364 nodeQueue : mutable.Queue [(Int , LearningNode )],
365- multiTimer : MultiStopwatch ,
365+ multiStopwatch : Option [ MultiStopwatch ] = None ,
366366 nodeIdCache : Option [NodeIdCache ] = None ): Unit = {
367367
368368 /*
@@ -485,6 +485,13 @@ private[spark] object RandomForest extends Logging {
485485 }
486486 }
487487
488+ val timers = multiStopwatch match {
489+ case Some (timers) => timers
490+ case None => new MultiStopwatch (input.sparkContext)
491+ .addLocal(" chooseSplits" )
492+ .addLocal(" binsToBestSplit" )
493+ }
494+
488495 // array of nodes to train indexed by node index in group
489496 val nodes = new Array [LearningNode ](numNodes)
490497 nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
@@ -494,7 +501,7 @@ private[spark] object RandomForest extends Logging {
494501 }
495502
496503 // Calculate best splits for all nodes in the group
497- multiTimer (" chooseSplits" ).start()
504+ timers (" chooseSplits" ).start()
498505
499506 // In each partition, iterate all instances and compute aggregate stats for each node,
500507 // yield a (nodeIndex, nodeAggregateStats) pair for each node.
@@ -550,14 +557,14 @@ private[spark] object RandomForest extends Logging {
550557 }
551558
552559 // find best split for each node
553- multiTimer (" binsToBestSplit" ).start()
560+ timers (" binsToBestSplit" ).start()
554561 val (split : Split , stats : ImpurityStats ) =
555562 binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
556- multiTimer (" binsToBestSplit" ).stop()
563+ timers (" binsToBestSplit" ).stop()
557564 (nodeIndex, (split, stats))
558565 }.collectAsMap()
559566
560- multiTimer (" chooseSplits" ).stop()
567+ timers (" chooseSplits" ).stop()
561568
562569 val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
563570 Array .fill[mutable.Map [Int , NodeIndexUpdater ]](
0 commit comments