Skip to content

Commit 9f05d95

Browse files
committed
Merge pull request apache#23 from fabuzaid21/dt-features-linear-sort
PR apache#6 Dt features linear sort. Dependent on PR apache#5
2 parents fa949c5 + 402b80b commit 9f05d95

2 files changed

Lines changed: 140 additions & 132 deletions

File tree

mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala

Lines changed: 103 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.model.ImpurityStats
2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.storage.StorageLevel
3131
import org.apache.spark.util.collection.BitSet
32+
import org.roaringbitmap.RoaringBitmap
3233

3334

3435
/**
@@ -135,7 +136,7 @@ private[ml] object AltDT extends Logging {
135136
}
136137
val labelsBc = input.sparkContext.broadcast(labels)
137138
// NOTE: Labels are not sorted with features since that would require 1 copy per feature,
138-
// rather than 1 copy per worker. This means a lot of random accesses.
139+
// rather than 1 copy per worker. This means a lot of random accesses.
139140
// We could improve this by applying first-level sorting (by node) to labels.
140141

141142
// Sort each column by feature values.
@@ -196,23 +197,19 @@ private[ml] object AltDT extends Logging {
196197
doneLearning = currentLevel + 1 >= strategy.maxDepth || estimatedRemainingActive == 0
197198

198199
if (!doneLearning) {
199-
// Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right.
200-
val aggBitVectors: Array[BitSubvector] =
201-
collectBitVectors(partitionInfos, bestSplitsAndGains.map(_._1))
200+
val splits: Array[Option[Split]] = bestSplitsAndGains.map(_._1)
202201

203-
// Broadcast aggregated bit vectors. On each partition, update instance--node map.
204-
val aggBitVectorsBc = input.sparkContext.broadcast(aggBitVectors)
202+
// Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right
203+
val aggBitVector: RoaringBitmap = aggregateBitVector(partitionInfos, splits, numRows)
205204
val newPartitionInfos = partitionInfos.map { partitionInfo =>
206-
partitionInfo.update(aggBitVectorsBc.value, numNodeOffsets)
205+
partitionInfo.update(aggBitVector, numNodeOffsets)
207206
}
208207
// TODO: remove. For some reason, this is needed to make things work.
209208
// Probably messing up somewhere above...
210209
newPartitionInfos.cache().count()
211210
partitionInfos = newPartitionInfos
212211

213-
// TODO: unpersist aggBitVectorsBc after action.
214212
}
215-
216213
currentLevel += 1
217214
}
218215

@@ -333,42 +330,52 @@ private[ml] object AltDT extends Logging {
333330
* @param bestSplits Split for each active node, or None if that node will not be split
334331
* @return Array of bit vectors, ordered by offset ranges
335332
*/
336-
private[impl] def collectBitVectors(
333+
private[impl] def aggregateBitVector(
337334
partitionInfos: RDD[PartitionInfo],
338-
bestSplits: Array[Option[Split]]): Array[BitSubvector] = {
335+
bestSplits: Array[Option[Split]],
336+
numRows: Int): RoaringBitmap = {
339337
val bestSplitsBc: Broadcast[Array[Option[Split]]] =
340338
partitionInfos.sparkContext.broadcast(bestSplits)
341-
val workerBitSubvectors: RDD[Array[BitSubvector]] = partitionInfos.map {
339+
val workerBitSubvectors: RDD[RoaringBitmap] = partitionInfos.map {
342340
case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int],
343341
activeNodes: BitSet) =>
344342
val localBestSplits: Array[Option[Split]] = bestSplitsBc.value
345343
// localFeatureIndex[feature index] = index into PartitionInfo.columns
346344
val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap
347-
activeNodes.iterator.zip(localBestSplits.iterator).flatMap {
345+
val bitSetForNodes: Iterator[RoaringBitmap] = activeNodes.iterator
346+
.zip(localBestSplits.iterator).flatMap {
348347
case (nodeIndexInLevel: Int, Some(split: Split)) =>
349348
if (localFeatureIndex.contains(split.featureIndex)) {
350349
// This partition has the column (feature) used for this split.
351350
val fromOffset = nodeOffsets(nodeIndexInLevel)
352351
val toOffset = nodeOffsets(nodeIndexInLevel + 1)
353352
val colIndex: Int = localFeatureIndex(split.featureIndex)
354-
Iterator(bitSubvectorFromSplit(columns(colIndex), fromOffset, toOffset, split))
353+
Iterator(bitVectorFromSplit(columns(colIndex), fromOffset, toOffset, split, numRows))
355354
} else {
356355
Iterator()
357356
}
358357
case (nodeIndexInLevel: Int, None) =>
359-
// Do not create a BitSubvector when there is no split.
360-
// This requires PartitionInfo.update to handle missing BitSubvectors.
358+
// Do not create a bitVector when there is no split.
359+
// PartitionInfo.update will detect that there is no
360+
// split, by how many instances go left/right.
361361
Iterator()
362-
}.toArray
362+
}
363+
if (bitSetForNodes.isEmpty) {
364+
new RoaringBitmap()
365+
} else {
366+
bitSetForNodes.reduce[RoaringBitmap] { (acc, bitv) => acc.or(bitv); acc }
367+
}
368+
}
369+
val aggBitVector: RoaringBitmap = workerBitSubvectors.reduce { (acc, bitv) =>
370+
acc.or(bitv)
371+
acc
363372
}
364-
val aggBitVectors: Array[BitSubvector] = workerBitSubvectors.reduce(BitSubvector.merge)
365373
bestSplitsBc.unpersist()
366-
aggBitVectors
374+
aggBitVector
367375
}
368376

369377
/**
370378
* Choose the best split for a feature at a node.
371-
*
372379
* TODO: Return null or None when the split is invalid, such as putting all instances on one
373380
* child node.
374381
*
@@ -787,20 +794,21 @@ private[ml] object AltDT extends Logging {
787794
* second by sorted row indices within the node's rows.
788795
* bit[index in sorted array of row indices] = false for left, true for right
789796
*/
790-
private[impl] def bitSubvectorFromSplit(
797+
private[impl] def bitVectorFromSplit(
791798
col: FeatureVector,
792799
fromOffset: Int,
793800
toOffset: Int,
794-
split: Split): BitSubvector = {
795-
val nodeRowIndices = col.indices.slice(fromOffset, toOffset)
796-
val nodeRowValues = col.values.slice(fromOffset, toOffset)
797-
val nodeRowValuesSortedByIndices = nodeRowIndices.zip(nodeRowValues).sortBy(_._1).map(_._2)
798-
val bitv = new BitSubvector(fromOffset, toOffset)
801+
split: Split,
802+
numRows: Int): RoaringBitmap = {
803+
val nodeRowIndices = col.indices.view.slice(fromOffset, toOffset)
804+
val nodeRowValues = col.values.view.slice(fromOffset, toOffset)
805+
val bitv = new RoaringBitmap()
799806
var i = 0
800-
while (i < nodeRowValuesSortedByIndices.length) {
801-
val value = nodeRowValuesSortedByIndices(i)
807+
while (i < nodeRowValues.length) {
808+
val value = nodeRowValues(i)
809+
val idx = nodeRowIndices(i)
802810
if (!split.shouldGoLeft(value)) {
803-
bitv.set(fromOffset + i)
811+
bitv.add(idx)
804812
}
805813
i += 1
806814
}
@@ -833,6 +841,11 @@ private[ml] object AltDT extends Logging {
833841
activeNodes: BitSet)
834842
extends Serializable {
835843

844+
// pre-allocated temporary buffers that we use to sort
845+
// instances in left and right children during update
846+
val tempVals: Array[Double] = new Array[Double](columns(0).values.length)
847+
val tempIndices: Array[Int] = new Array[Int](columns(0).values.length)
848+
836849
/** For debugging */
837850
override def toString: String = {
838851
"PartitionInfo(" +
@@ -854,82 +867,82 @@ private[ml] object AltDT extends Logging {
854867
* Update nodeOffsets, activeNodes:
855868
* Split offsets for nodes which split (which can be identified using the bit vector).
856869
*
857-
* @param bitVectors Bit vectors encoding splits for the next level of the tree.
870+
* @param instanceBitVector Bit vector encoding splits for the next level of the tree.
858871
* These must follow a 2-level ordering, where the first level is by node
859872
* and the second level is by row index.
860873
* bitVector(i) = false iff instance i goes to the left child.
861874
* For instances at inactive (leaf) nodes, the value can be arbitrary.
862-
* When an active node is not split (e.g., because no good split was found),
863-
* then the corresponding BitSubvector can be missing.
864875
* @return Updated partition info
865876
*/
866-
def update(bitVectors: Array[BitSubvector], newNumNodeOffsets: Int): PartitionInfo = {
867-
val newColumns = columns.map { oldCol =>
868-
val col = oldCol.deepCopy()
869-
var curBitVecIdx = 0
877+
def update(instanceBitVector: RoaringBitmap, newNumNodeOffsets: Int):
878+
PartitionInfo = {
879+
// Create a 2-level representation of the new nodeOffsets (to be flattened).
880+
// These 2 levels correspond to original nodes and their children (if split).
881+
val newNodeOffsets = nodeOffsets.map(Array(_))
882+
883+
val newColumns = columns.map { col =>
870884
activeNodes.iterator.foreach { nodeIdx =>
871885
val from = nodeOffsets(nodeIdx)
872886
val to = nodeOffsets(nodeIdx + 1)
873-
if (curBitVecIdx + 1 < bitVectors.length && bitVectors(curBitVecIdx).to <= from) {
874-
// If there are no more BitVectors, curBitVecIdx stays at the last bitVector,
875-
// which is acceptable (since it will not cover further nodes which were not split).
876-
curBitVecIdx += 1
877-
}
878-
val curBitVector = bitVectors(curBitVecIdx)
879-
// If the current BitVector does not cover this node, then this node was not split,
880-
// so we do not need to update its part of the column. Otherwise, we update it.
881-
if (curBitVector.from <= from && to <= curBitVector.to) {
882-
// Sort range [from, to) based on indices. This is required to match the bit vector
883-
// across all workers. See [[bitSubvectorFromSplit]] for details.
884-
val rangeIndices = col.indices.view.slice(from, to).toArray
885-
val rangeValues = col.values.view.slice(from, to).toArray
886-
val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1)
887-
// Sort range [from, to) based on bit vector.
888-
sortedRange.zipWithIndex.map { case ((idx, value), i) =>
889-
val bit = curBitVector.get(from + i)
890-
// TODO: In-place merge, rather than general sort.
891-
// TODO: We don't actually need to sort the categorical features using our approach.
892-
(bit, value, idx)
893-
}.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) =>
894-
col.values(from + i) = value
895-
col.indices(from + i) = idx
887+
val rangeIndices = col.indices.view.slice(from, to)
888+
val rangeValues = col.values.view.slice(from, to)
889+
890+
// If this is the very first time we split,
891+
// we don't use rangeIndices to count the number of bits set;
892+
// the entire bit vector will be used, so getCardinality
893+
// will give us the same result more cheaply.
894+
val numBitsSet = if (nodeOffsets.length == 2) instanceBitVector.getCardinality
895+
else rangeIndices.count(instanceBitVector.contains)
896+
897+
val numBitsNotSet = to - from - numBitsSet // number of instances splitting left
898+
val oldOffset = newNodeOffsets(nodeIdx).head
899+
900+
// If numBitsNotSet or numBitsSet equals 0, then this node was not split,
901+
// so we do not need to update its part of the column. Otherwise, we update it.
902+
if (numBitsNotSet != 0 && numBitsSet != 0) {
903+
newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet)
904+
905+
// BEGIN SORTING
906+
// We sort the [from, to) slice of col based on instance bit, then
907+
// instance value. This is required to match the bit vector across all
908+
// workers. All instances going "left" in the split (which are false)
909+
// should be ordered before the instances going "right". The instanceBitVector
910+
// gives us the bit value for each instance based on the instance's index.
911+
// Then both [from, numBitsNotSet) and [numBitsNotSet, to) need to be sorted
912+
// by value.
913+
// Since the column is already sorted by value, we can compute
914+
// this sort in a single pass over the data. We iterate from start to finish
915+
// (which preserves the sorted order), and then copy the values
916+
// into @tempVals and @tempIndices either:
917+
// 1) in the [from, numBitsNotSet) range if the bit is false, or
918+
// 2) in the [numBitsNotSet, to) range if the bit is true.
919+
var (leftInstanceIdx, rightInstanceIdx) = (from, from + numBitsNotSet)
920+
var idx = 0
921+
while (idx < rangeValues.length) {
922+
val indexForVal = rangeIndices(idx)
923+
val bit = instanceBitVector.contains(indexForVal)
924+
if (bit) {
925+
tempVals(rightInstanceIdx) = rangeValues(idx)
926+
tempIndices(rightInstanceIdx) = indexForVal
927+
rightInstanceIdx += 1
928+
} else {
929+
tempVals(leftInstanceIdx) = rangeValues(idx)
930+
tempIndices(leftInstanceIdx) = indexForVal
931+
leftInstanceIdx += 1
932+
}
933+
idx += 1
896934
}
897-
}
898-
}
899-
col
900-
}
935+
// END SORTING
901936

902-
// Create a 2-level representation of the new nodeOffsets (to be flattened).
903-
// These 2 levels correspond to original nodes and their children (if split).
904-
val newNodeOffsets = nodeOffsets.map(Array(_))
905-
var curBitVecIdx = 0
906-
activeNodes.iterator.foreach { nodeIdx =>
907-
val from = nodeOffsets(nodeIdx)
908-
val to = nodeOffsets(nodeIdx + 1)
909-
if (curBitVecIdx + 1 < bitVectors.length && bitVectors(curBitVecIdx).to <= from) {
910-
// If there are no more BitVectors, curBitVecIdx stays at the last bitVector,
911-
// which is acceptable (since it will not cover further nodes which were not split).
912-
curBitVecIdx += 1
913-
}
914-
val curBitVector = bitVectors(curBitVecIdx)
915-
// If the current BitVector does not cover this node, then this node was not split,
916-
// so we do not need to create a new node offset. Otherwise, we create an offset.
917-
if (curBitVector.from <= from && to <= curBitVector.to) {
918-
// Count number of values splitting to left vs. right
919-
val numRight = Range(from, to).count(curBitVector.get)
920-
val numLeft = to - from - numRight
921-
if (numLeft != 0 && numRight != 0) {
922-
// node is split
923-
val oldOffset = newNodeOffsets(nodeIdx).head
924-
newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft)
937+
// update the column values and indices
938+
// with the corresponding indices
939+
Array.copy(tempVals, from, col.values, from, rangeValues.length)
940+
Array.copy(tempIndices, from, col.indices, from, rangeValues.length)
925941
}
926942
}
943+
col
927944
}
928945

929-
assert(newNodeOffsets.map(_.length).sum == newNumNodeOffsets,
930-
s"(W) newNodeOffsets total size: ${newNodeOffsets.map(_.length).sum}," +
931-
s" newNumNodeOffsets: $newNumNodeOffsets")
932-
933946
// Identify the new activeNodes based on the 2-level representation of the new nodeOffsets.
934947
val newActiveNodes = new BitSet(newNumNodeOffsets - 1)
935948
var newNodeOffsetsIdx = 0

0 commit comments

Comments
 (0)