@@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.model.ImpurityStats
2929import org .apache .spark .rdd .RDD
3030import org .apache .spark .storage .StorageLevel
3131import org .apache .spark .util .collection .BitSet
32+ import org .roaringbitmap .RoaringBitmap
3233
3334
3435/**
@@ -135,7 +136,7 @@ private[ml] object AltDT extends Logging {
135136 }
136137 val labelsBc = input.sparkContext.broadcast(labels)
137138 // NOTE: Labels are not sorted with features since that would require 1 copy per feature,
138- // rather than 1 copy per worker. This means a lot of random accesses.
139+ // rather than 1 copy per worker. This means a lot of random accesses.
139140 // We could improve this by applying first-level sorting (by node) to labels.
140141
141142 // Sort each column by feature values.
@@ -196,23 +197,19 @@ private[ml] object AltDT extends Logging {
196197 doneLearning = currentLevel + 1 >= strategy.maxDepth || estimatedRemainingActive == 0
197198
198199 if (! doneLearning) {
199- // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right.
200- val aggBitVectors : Array [BitSubvector ] =
201- collectBitVectors(partitionInfos, bestSplitsAndGains.map(_._1))
200+ val splits : Array [Option [Split ]] = bestSplitsAndGains.map(_._1)
202201
203- // Broadcast aggregated bit vectors. On each partition, update instance--node map.
204- val aggBitVectorsBc = input.sparkContext.broadcast(aggBitVectors )
202+ // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right
203+ val aggBitVector : RoaringBitmap = aggregateBitVector(partitionInfos, splits, numRows )
205204 val newPartitionInfos = partitionInfos.map { partitionInfo =>
206- partitionInfo.update(aggBitVectorsBc.value , numNodeOffsets)
205+ partitionInfo.update(aggBitVector , numNodeOffsets)
207206 }
208207 // TODO: remove. For some reason, this is needed to make things work.
209208 // Probably messing up somewhere above...
210209 newPartitionInfos.cache().count()
211210 partitionInfos = newPartitionInfos
212211
213- // TODO: unpersist aggBitVectorsBc after action.
214212 }
215-
216213 currentLevel += 1
217214 }
218215
@@ -333,42 +330,52 @@ private[ml] object AltDT extends Logging {
333330 * @param bestSplits Split for each active node, or None if that node will not be split
334331 * @return Array of bit vectors, ordered by offset ranges
335332 */
336- private [impl] def collectBitVectors (
333+ private [impl] def aggregateBitVector (
337334 partitionInfos : RDD [PartitionInfo ],
338- bestSplits : Array [Option [Split ]]): Array [BitSubvector ] = {
335+ bestSplits : Array [Option [Split ]],
336+ numRows : Int ): RoaringBitmap = {
339337 val bestSplitsBc : Broadcast [Array [Option [Split ]]] =
340338 partitionInfos.sparkContext.broadcast(bestSplits)
341- val workerBitSubvectors : RDD [Array [ BitSubvector ] ] = partitionInfos.map {
339+ val workerBitSubvectors : RDD [RoaringBitmap ] = partitionInfos.map {
342340 case PartitionInfo (columns : Array [FeatureVector ], nodeOffsets : Array [Int ],
343341 activeNodes : BitSet ) =>
344342 val localBestSplits : Array [Option [Split ]] = bestSplitsBc.value
345343 // localFeatureIndex[feature index] = index into PartitionInfo.columns
346344 val localFeatureIndex : Map [Int , Int ] = columns.map(_.featureIndex).zipWithIndex.toMap
347- activeNodes.iterator.zip(localBestSplits.iterator).flatMap {
345+ val bitSetForNodes : Iterator [RoaringBitmap ] = activeNodes.iterator
346+ .zip(localBestSplits.iterator).flatMap {
348347 case (nodeIndexInLevel : Int , Some (split : Split )) =>
349348 if (localFeatureIndex.contains(split.featureIndex)) {
350349 // This partition has the column (feature) used for this split.
351350 val fromOffset = nodeOffsets(nodeIndexInLevel)
352351 val toOffset = nodeOffsets(nodeIndexInLevel + 1 )
353352 val colIndex : Int = localFeatureIndex(split.featureIndex)
354- Iterator (bitSubvectorFromSplit (columns(colIndex), fromOffset, toOffset, split))
353+ Iterator (bitVectorFromSplit (columns(colIndex), fromOffset, toOffset, split, numRows ))
355354 } else {
356355 Iterator ()
357356 }
358357 case (nodeIndexInLevel : Int , None ) =>
359- // Do not create a BitSubvector when there is no split.
360- // This requires PartitionInfo.update to handle missing BitSubvectors.
358+ // Do not create a bitVector when there is no split.
359+ // PartitionInfo.update will detect that there is no
360+ // split, by how many instances go left/right.
361361 Iterator ()
362- }.toArray
362+ }
363+ if (bitSetForNodes.isEmpty) {
364+ new RoaringBitmap ()
365+ } else {
366+ bitSetForNodes.reduce[RoaringBitmap ] { (acc, bitv) => acc.or(bitv); acc }
367+ }
368+ }
369+ val aggBitVector : RoaringBitmap = workerBitSubvectors.reduce { (acc, bitv) =>
370+ acc.or(bitv)
371+ acc
363372 }
364- val aggBitVectors : Array [BitSubvector ] = workerBitSubvectors.reduce(BitSubvector .merge)
365373 bestSplitsBc.unpersist()
366- aggBitVectors
374+ aggBitVector
367375 }
368376
369377 /**
370378 * Choose the best split for a feature at a node.
371- *
372379 * TODO: Return null or None when the split is invalid, such as putting all instances on one
373380 * child node.
374381 *
@@ -787,20 +794,21 @@ private[ml] object AltDT extends Logging {
787794 * second by sorted row indices within the node's rows.
788795 * bit[index in sorted array of row indices] = false for left, true for right
789796 */
790- private [impl] def bitSubvectorFromSplit (
797+ private [impl] def bitVectorFromSplit (
791798 col : FeatureVector ,
792799 fromOffset : Int ,
793800 toOffset : Int ,
794- split : Split ) : BitSubvector = {
795- val nodeRowIndices = col.indices.slice(fromOffset, toOffset)
796- val nodeRowValues = col.values .slice(fromOffset, toOffset)
797- val nodeRowValuesSortedByIndices = nodeRowIndices.zip(nodeRowValues).sortBy(_._1).map(_._2 )
798- val bitv = new BitSubvector (fromOffset, toOffset )
801+ split : Split ,
802+ numRows : Int ) : RoaringBitmap = {
803+ val nodeRowIndices = col.indices.view .slice(fromOffset, toOffset)
804+ val nodeRowValues = col.values.view.slice(fromOffset, toOffset )
805+ val bitv = new RoaringBitmap ( )
799806 var i = 0
800- while (i < nodeRowValuesSortedByIndices.length) {
801- val value = nodeRowValuesSortedByIndices(i)
807+ while (i < nodeRowValues.length) {
808+ val value = nodeRowValues(i)
809+ val idx = nodeRowIndices(i)
802810 if (! split.shouldGoLeft(value)) {
803- bitv.set(fromOffset + i )
811+ bitv.add(idx )
804812 }
805813 i += 1
806814 }
@@ -833,6 +841,11 @@ private[ml] object AltDT extends Logging {
833841 activeNodes : BitSet )
834842 extends Serializable {
835843
844+ // pre-allocated temporary buffers that we use to sort
845+ // instances in left and right children during update
846+ val tempVals : Array [Double ] = new Array [Double ](columns(0 ).values.length)
847+ val tempIndices : Array [Int ] = new Array [Int ](columns(0 ).values.length)
848+
836849 /** For debugging */
837850 override def toString : String = {
838851 " PartitionInfo(" +
@@ -854,82 +867,82 @@ private[ml] object AltDT extends Logging {
854867 * Update nodeOffsets, activeNodes:
855868 * Split offsets for nodes which split (which can be identified using the bit vector).
856869 *
857- * @param bitVectors Bit vectors encoding splits for the next level of the tree.
870+ * @param instanceBitVector Bit vector encoding splits for the next level of the tree.
858871 * These must follow a 2-level ordering, where the first level is by node
859872 * and the second level is by row index.
860873 * bitVector(i) = false iff instance i goes to the left child.
861874 * For instances at inactive (leaf) nodes, the value can be arbitrary.
862- * When an active node is not split (e.g., because no good split was found),
863- * then the corresponding BitSubvector can be missing.
864875 * @return Updated partition info
865876 */
866- def update (bitVectors : Array [BitSubvector ], newNumNodeOffsets : Int ): PartitionInfo = {
867- val newColumns = columns.map { oldCol =>
868- val col = oldCol.deepCopy()
869- var curBitVecIdx = 0
877+ def update (instanceBitVector : RoaringBitmap , newNumNodeOffsets : Int ):
878+ PartitionInfo = {
879+ // Create a 2-level representation of the new nodeOffsets (to be flattened).
880+ // These 2 levels correspond to original nodes and their children (if split).
881+ val newNodeOffsets = nodeOffsets.map(Array (_))
882+
883+ val newColumns = columns.map { col =>
870884 activeNodes.iterator.foreach { nodeIdx =>
871885 val from = nodeOffsets(nodeIdx)
872886 val to = nodeOffsets(nodeIdx + 1 )
873- if (curBitVecIdx + 1 < bitVectors.length && bitVectors(curBitVecIdx).to <= from) {
874- // If there are no more BitVectors, curBitVecIdx stays at the last bitVector,
875- // which is acceptable (since it will not cover further nodes which were not split).
876- curBitVecIdx += 1
877- }
878- val curBitVector = bitVectors(curBitVecIdx)
879- // If the current BitVector does not cover this node, then this node was not split,
880- // so we do not need to update its part of the column. Otherwise, we update it.
881- if (curBitVector.from <= from && to <= curBitVector.to) {
882- // Sort range [from, to) based on indices. This is required to match the bit vector
883- // across all workers. See [[bitSubvectorFromSplit]] for details.
884- val rangeIndices = col.indices.view.slice(from, to).toArray
885- val rangeValues = col.values.view.slice(from, to).toArray
886- val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1)
887- // Sort range [from, to) based on bit vector.
888- sortedRange.zipWithIndex.map { case ((idx, value), i) =>
889- val bit = curBitVector.get(from + i)
890- // TODO: In-place merge, rather than general sort.
891- // TODO: We don't actually need to sort the categorical features using our approach.
892- (bit, value, idx)
893- }.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) =>
894- col.values(from + i) = value
895- col.indices(from + i) = idx
887+ val rangeIndices = col.indices.view.slice(from, to)
888+ val rangeValues = col.values.view.slice(from, to)
889+
890+ // If this is the very first time we split,
891+ // we don't use rangeIndices to count the number of bits set;
892+ // the entire bit vector will be used, so getCardinality
893+ // will give us the same result more cheaply.
894+ val numBitsSet = if (nodeOffsets.length == 2 ) instanceBitVector.getCardinality
895+ else rangeIndices.count(instanceBitVector.contains)
896+
897+ val numBitsNotSet = to - from - numBitsSet // number of instances splitting left
898+ val oldOffset = newNodeOffsets(nodeIdx).head
899+
900+ // If numBitsNotSet or numBitsSet equals 0, then this node was not split,
901+ // so we do not need to update its part of the column. Otherwise, we update it.
902+ if (numBitsNotSet != 0 && numBitsSet != 0 ) {
903+ newNodeOffsets(nodeIdx) = Array (oldOffset, oldOffset + numBitsNotSet)
904+
905+ // BEGIN SORTING
906+ // We sort the [from, to) slice of col based on instance bit, then
907+ // instance value. This is required to match the bit vector across all
908+ // workers. All instances going "left" in the split (which are false)
909+ // should be ordered before the instances going "right". The instanceBitVector
910+ // gives us the bit value for each instance based on the instance's index.
911+ // Then both [from, numBitsNotSet) and [numBitsNotSet, to) need to be sorted
912+ // by value.
913+ // Since the column is already sorted by value, we can compute
914+ // this sort in a single pass over the data. We iterate from start to finish
915+ // (which preserves the sorted order), and then copy the values
916+ // into @tempVals and @tempIndices either:
917+ // 1) in the [from, numBitsNotSet) range if the bit is false, or
918+ // 2) in the [numBitsNotSet, to) range if the bit is true.
919+ var (leftInstanceIdx, rightInstanceIdx) = (from, from + numBitsNotSet)
920+ var idx = 0
921+ while (idx < rangeValues.length) {
922+ val indexForVal = rangeIndices(idx)
923+ val bit = instanceBitVector.contains(indexForVal)
924+ if (bit) {
925+ tempVals(rightInstanceIdx) = rangeValues(idx)
926+ tempIndices(rightInstanceIdx) = indexForVal
927+ rightInstanceIdx += 1
928+ } else {
929+ tempVals(leftInstanceIdx) = rangeValues(idx)
930+ tempIndices(leftInstanceIdx) = indexForVal
931+ leftInstanceIdx += 1
932+ }
933+ idx += 1
896934 }
897- }
898- }
899- col
900- }
935+ // END SORTING
901936
902- // Create a 2-level representation of the new nodeOffsets (to be flattened).
903- // These 2 levels correspond to original nodes and their children (if split).
904- val newNodeOffsets = nodeOffsets.map(Array (_))
905- var curBitVecIdx = 0
906- activeNodes.iterator.foreach { nodeIdx =>
907- val from = nodeOffsets(nodeIdx)
908- val to = nodeOffsets(nodeIdx + 1 )
909- if (curBitVecIdx + 1 < bitVectors.length && bitVectors(curBitVecIdx).to <= from) {
910- // If there are no more BitVectors, curBitVecIdx stays at the last bitVector,
911- // which is acceptable (since it will not cover further nodes which were not split).
912- curBitVecIdx += 1
913- }
914- val curBitVector = bitVectors(curBitVecIdx)
915- // If the current BitVector does not cover this node, then this node was not split,
916- // so we do not need to create a new node offset. Otherwise, we create an offset.
917- if (curBitVector.from <= from && to <= curBitVector.to) {
918- // Count number of values splitting to left vs. right
919- val numRight = Range (from, to).count(curBitVector.get)
920- val numLeft = to - from - numRight
921- if (numLeft != 0 && numRight != 0 ) {
922- // node is split
923- val oldOffset = newNodeOffsets(nodeIdx).head
924- newNodeOffsets(nodeIdx) = Array (oldOffset, oldOffset + numLeft)
937+ // update the column values and indices
938+ // with the corresponding indices
939+ Array .copy(tempVals, from, col.values, from, rangeValues.length)
940+ Array .copy(tempIndices, from, col.indices, from, rangeValues.length)
925941 }
926942 }
943+ col
927944 }
928945
929- assert(newNodeOffsets.map(_.length).sum == newNumNodeOffsets,
930- s " (W) newNodeOffsets total size: ${newNodeOffsets.map(_.length).sum}, " +
931- s " newNumNodeOffsets: $newNumNodeOffsets" )
932-
933946 // Identify the new activeNodes based on the 2-level representation of the new nodeOffsets.
934947 val newActiveNodes = new BitSet (newNumNodeOffsets - 1 )
935948 var newNodeOffsetsIdx = 0
0 commit comments