Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,68 @@ public MapIterator destructiveIterator() {
return new MapIterator(numValues, new Location(), true);
}

/**
* Iterator for the entries of this map. This is to first iterate over key indices in
* `longArray` then accessing values in `dataPages`. NOTE: this is different from `MapIterator`
* in the sense that key index is preserved here
* (See `UnsafeHashedRelation` for example of usage).
*/
public final class MapIteratorWithKeyIndex implements Iterator<Location> {

/**
* The index in `longArray` where the key is stored.
*/
private int keyIndex = 0;

private int numRecords;
private final Location loc;

private MapIteratorWithKeyIndex() {
this.numRecords = numValues;
this.loc = new Location();
}

@Override
public boolean hasNext() {
return numRecords > 0;
}

@Override
public Location next() {
if (!loc.isDefined() || !loc.nextValue()) {
while (longArray.get(keyIndex * 2) == 0) {
keyIndex++;
}
loc.with(keyIndex, 0, true);
keyIndex++;
}
numRecords--;
return loc;
}
}

/**
* Returns an iterator for iterating over the entries of this map,
* by first iterating over the key index inside hash map's `longArray`.
*
* For efficiency, all calls to `next()` will return the same {@link Location} object.
*
* The returned iterator is NOT thread-safe. If the map is modified while iterating over it,
* the behavior of the returned iterator is undefined.
*/
public MapIteratorWithKeyIndex iteratorWithKeyIndex() {
return new MapIteratorWithKeyIndex();
}

/**
* The maximum number of allowed keys index.
*
* The value of allowed keys index is in the range of [0, maxNumKeysIndex - 1].
*/
public int maxNumKeysIndex() {
return (int) (longArray.size() / 2);
}

/**
* Looks up a key, and return a {@link Location} handle that can be used to test existence
* and read/write values.
Expand Down Expand Up @@ -601,6 +663,14 @@ public boolean isDefined() {
return isDefined;
}

/**
* Returns index for key.
*/
public int getKeyIndex() {
assert (isDefined);
return pos;
}

/**
* Returns the base object for key.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ public void emptyMap() {
final byte[] key = getRandomByteArray(keyLengthInWords);
Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined());
Assert.assertFalse(map.iterator().hasNext());
Assert.assertFalse(map.iteratorWithKeyIndex().hasNext());
} finally {
map.free();
}
Expand Down Expand Up @@ -233,9 +234,10 @@ public void setAndRetrieveAKey() {
}
}

private void iteratorTestBase(boolean destructive) throws Exception {
private void iteratorTestBase(boolean destructive, boolean isWithKeyIndex) throws Exception {
final int size = 4096;
BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES);
Assert.assertEquals(size / 2, map.maxNumKeysIndex());
try {
for (long i = 0; i < size; i++) {
final long[] value = new long[] { i };
Expand Down Expand Up @@ -267,6 +269,8 @@ private void iteratorTestBase(boolean destructive) throws Exception {
final Iterator<BytesToBytesMap.Location> iter;
if (destructive) {
iter = map.destructiveIterator();
} else if (isWithKeyIndex) {
iter = map.iteratorWithKeyIndex();
} else {
iter = map.iterator();
}
Expand All @@ -291,6 +295,12 @@ private void iteratorTestBase(boolean destructive) throws Exception {
countFreedPages++;
}
}
if (keyLength != 0 && isWithKeyIndex) {
final BytesToBytesMap.Location expectedLoc = map.lookup(
loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength());
Assert.assertTrue(expectedLoc.isDefined() &&
expectedLoc.getKeyIndex() == loc.getKeyIndex());
}
}
if (destructive) {
// Latest page is not freed by iterator but by map itself
Expand All @@ -304,12 +314,17 @@ private void iteratorTestBase(boolean destructive) throws Exception {

@Test
public void iteratorTest() throws Exception {
iteratorTestBase(false);
iteratorTestBase(false, false);
}

@Test
public void destructiveIteratorTest() throws Exception {
iteratorTestBase(true);
iteratorTestBase(true, false);
}

@Test
public void iteratorWithKeyIndexTest() throws Exception {
iteratorTestBase(false, true);
}

@Test
Expand Down Expand Up @@ -603,6 +618,12 @@ public void multipleValuesForSameKey() {
final BytesToBytesMap.Location loc = iter.next();
assert loc.isDefined();
}
BytesToBytesMap.MapIteratorWithKeyIndex iterWithKeyIndex = map.iteratorWithKeyIndex();
for (i = 0; i < 2048; i++) {
assert iterWithKeyIndex.hasNext();
final BytesToBytesMap.Location loc = iterWithKeyIndex.next();
assert loc.isDefined() && loc.getKeyIndex() >= 0;
}
} finally {
map.free();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ trait JoinSelectionHelper {
canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint)
}
getBuildSide(
canBuildLeft(joinType) && buildLeft,
canBuildRight(joinType) && buildRight,
canBuildBroadcastLeft(joinType) && buildLeft,
canBuildBroadcastRight(joinType) && buildRight,
left,
right
)
Expand All @@ -260,8 +260,8 @@ trait JoinSelectionHelper {
canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left)
}
getBuildSide(
canBuildLeft(joinType) && buildLeft,
canBuildRight(joinType) && buildRight,
canBuildShuffledHashJoinLeft(joinType) && buildLeft,
canBuildShuffledHashJoinRight(joinType) && buildRight,
left,
right
)
Expand All @@ -278,20 +278,35 @@ trait JoinSelectionHelper {
plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold
}

def canBuildLeft(joinType: JoinType): Boolean = {
def canBuildBroadcastLeft(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | RightOuter => true
case _ => false
}
}

def canBuildRight(joinType: JoinType): Boolean = {
def canBuildBroadcastRight(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true
case _ => false
}
}

def canBuildShuffledHashJoinLeft(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | RightOuter | FullOuter => true
case _ => false
}
}

def canBuildShuffledHashJoinRight(joinType: JoinType): Boolean = {
joinType match {
case _: InnerLike | LeftOuter | FullOuter |
LeftSemi | LeftAnti | _: ExistenceJoin => true
case _ => false
}
}

def hintToBroadcastLeft(hint: JoinHint): Boolean = {
hint.leftHint.exists(_.strategy.contains(BROADCAST))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ object SQLConf {

val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin")
.internal()
.doc("When true, prefer sort merge join over shuffle hash join.")
.doc("When true, prefer sort merge join over shuffled hash join. " +
"Note that shuffled hash join supports all join types (e.g. full outer) " +
"that sort merge join supports.")
.version("2.0.0")
.booleanConf
.createWithDefault(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*
* - Shuffle hash join:
* Only supported for equi-joins, while the join keys do not need to be sortable.
* Supported for all join types except full outer joins.
* Supported for all join types.
* Building hash map from table is a memory-intensive operation and it could cause OOM
* when the build side is big.
*
* - Shuffle sort merge join (SMJ):
* Only supported for equi-joins and the join keys have to be sortable.
Expand Down Expand Up @@ -260,7 +262,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// it's a right join, and broadcast right side if it's a left join.
// TODO: revisit it. If left side is much smaller than the right side, it may be better
// to broadcast the left side even if it's a left join.
if (canBuildLeft(joinType)) BuildLeft else BuildRight
if (canBuildBroadcastLeft(joinType)) BuildLeft else BuildRight
}

def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
}
}

@transient private lazy val (buildOutput, streamedOutput) = {
@transient protected lazy val (buildOutput, streamedOutput) = {
buildSide match {
case BuildLeft => (left.output, right.output)
case BuildRight => (right.output, left.output)
Expand All @@ -133,7 +133,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
protected def streamSideKeyGenerator(): UnsafeProjection =
UnsafeProjection.create(streamedBoundKeys)

@transient private[this] lazy val boundCondition = if (condition.isDefined) {
@transient protected[this] lazy val boundCondition = if (condition.isDefined) {
Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _
} else {
(r: InternalRow) => true
Expand Down
Loading