Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ case class StreamingSymmetricHashJoinExec(
postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue))
}
}

// NOTE: we need to make sure `outerOutputIter` is evaluated "after" exhausting all of
// elements in `innerOutputIter`, because evaluation of `innerOutputIter` may update
// the match flag which the logic for outer join is relying on.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify: this comment is not related to the bug and just to document an existing assumption?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes right.

TBH I suspected this first and crafted a patch including the part with new iterator explicitly runs the logic after evaluating innerOutputIter, and later realized current logic already dealt with this properly, because removeOldState() doesn't materialize the candidates and evaluate lazily. This patch contains minimal change.

Worth to mention how it works for someone who may need to touch here.

val removedRowIter = leftSideJoiner.removeOldState()
val outerOutputIter = removedRowIter.filterNot { kv =>
stateFormatVersion match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ class SymmetricHashJoinStateManager(
=====================================================
*/

/** Get all the values of a key */
/**
* Get all the values of a key.
*
* NOTE: the returned row "may" be reused during execution (to avoid initialization of object),
* so the caller should ensure that the logic doesn't affect by such behavior. Call copy()
* against the row if needed.
*/
def get(key: UnsafeRow): Iterator[UnsafeRow] = {
val numValues = keyToNumValues.get(key)
keyWithIndexToValue.getAll(key, numValues).map(_.value)
Expand All @@ -99,6 +105,10 @@ class SymmetricHashJoinStateManager(
/**
* Get all the matched values for given join condition, with marking matched.
* This method is designed to mark joined rows properly without exposing internal index of row.
*
* NOTE: the "value" field in JoinedRow "may" be reused during execution
* (to avoid initialization of object), so the caller should ensure that the logic
* doesn't affect by such behavior. Call copy() against these rows if needed.
*/
def getJoinedRows(
key: UnsafeRow,
Expand Down Expand Up @@ -250,7 +260,7 @@ class SymmetricHashJoinStateManager(
}

override def getNext(): KeyToValuePair = {
val currentValue = findNextValueForIndex()
var currentValue = findNextValueForIndex()

// If there's no value, clean up and finish. There aren't any more available.
if (currentValue == null) {
Expand All @@ -259,6 +269,9 @@ class SymmetricHashJoinStateManager(
return null
}

// Make a copy on value row, as below cleanup logic may update the value row silently.
currentValue = currentValue.copy(value = currentValue.value.copy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is the only place to do copy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That wasn't necessary for format V1 as the original row was stored into state store, and state store (strictly saying, the implementation of HDFS state store provider) makes sure these rows are copied version.

For other places, it can propagate to the callers outside of state manager, and looks like these callers don't need to copy the row. (It's super tricky for me to determine whether the copy is necessary or not, if the code is not in a simple loop or stream.)

Copy link
Contributor

@cloud-fan cloud-fan Jul 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After seeing the new changes, I think the first version looks better. The caller sides is nested and we still have unnecessary copies for v1 format. What do you think? @viirya

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, also prefer the first approach personally. As the issue was in v2 format, the first version is a straightforward way.
@cloud-fan typo? ... unnecessary copies for v1 format

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I'll roll back the change. I'll also leave a commit sha so we can do back and forth depending on the decision.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reverted the latest commit to leave the history and pick the commit selectively according to the decision.


// The backing store is arraylike - we as the caller are responsible for filling back in
// any hole. So we swap the last element into the hole and decrement numValues to shorten.
// clean
Expand Down Expand Up @@ -451,10 +464,26 @@ class SymmetricHashJoinStateManager(
}

private trait KeyWithIndexToValueRowConverter {
/** Defines the schema of the value row (the value side of K-V in state store). */
def valueAttributes: Seq[Attribute]

/**
* Convert the value row to (actual value, match) pair.
*
* NOTE: depending on the implementation, the row (actual value) in the pair "may" be reused
* during execution (to avoid initialization of object), so the caller should ensure that
* the logic doesn't affect by such behavior. Call copy() against the row if needed.
*/
def convertValue(value: UnsafeRow): ValueAndMatchPair

/**
* Build the value row from (actual value, match) pair. This is expected to be called just
* before storing to the state store.
*
* NOTE: depending on the implementation, the result row "may" be reused during execution
* (to avoid initialization of object), so the caller should ensure that the logic doesn't
* affect by such behavior. Call copy() against the result row if needed.
*/
def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow
}

Expand Down Expand Up @@ -530,13 +559,21 @@ class SymmetricHashJoinStateManager(
protected val stateStore = getStateStore(keyWithIndexSchema,
valueRowConverter.valueAttributes.toStructType)

/**
* NOTE: the "value" field in return value "may" be reused during execution
* (to avoid initialization of object), so the caller should ensure that the logic
* doesn't affect by such behavior. Call copy() against the row if needed.
*/
def get(key: UnsafeRow, valueIndex: Long): ValueAndMatchPair = {
valueRowConverter.convertValue(stateStore.get(keyWithIndexRow(key, valueIndex)))
}

/**
* Get all values and indices for the provided key.
* Should not return null.
* Get all values and indices for the provided key. Should not return null.
*
* NOTE: the "key" and "value" field in return value "may" be reused during execution
* (to avoid initialization of object), so the caller should ensure that the logic
* doesn't affect by such behavior. Call copy() against these rows if needed.
*/
def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = {
val keyWithIndexAndValue = new KeyWithIndexAndValue()
Expand Down Expand Up @@ -583,6 +620,11 @@ class SymmetricHashJoinStateManager(
}
}

/**
* NOTE: the "key" and "value" field in return value "may" be reused during execution
* (to avoid initialization of object), so the caller should ensure that the logic
* doesn't affect by such behavior. Call copy() against these rows if needed.
*/
def iterator: Iterator[KeyWithIndexAndValue] = {
val keyWithIndexAndValue = new KeyWithIndexAndValue()
stateStore.getRange(None, None).map { pair =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.streaming

import java.io.File
import java.sql.Timestamp
import java.util.{Locale, UUID}

import scala.util.Random
Expand Down Expand Up @@ -996,4 +997,47 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
)
}
}

test("SPARK-32148 stream-stream join regression on Spark 3.0.0") {
val input1 = MemoryStream[(Timestamp, String, String)]
val df1 = input1.toDF
.selectExpr("_1 as eventTime", "_2 as id", "_3 as comment")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason why not use select? I don't see any expression here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's pretty much simpler and more readable than select('_1.as("eventTime"), '_2.as("id"), '_3.as("comment")) (or even with col(...) if ' notation doesn't work for _1, _2, _3).

.withWatermark(s"eventTime", "2 minutes")

val input2 = MemoryStream[(Timestamp, String, String)]
val df2 = input2.toDF
.selectExpr("_1 as eventTime", "_2 as id", "_3 as name")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here as well.

.withWatermark(s"eventTime", "4 minutes")

val joined = df1.as("left")
.join(df2.as("right"),
expr(s"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why string interpolation needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's not necessary. Will remove.

|left.id = right.id AND left.eventTime BETWEEN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: indent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation of """ looks vary on the codebase, and I can find same indentation on the codebase.

| right.eventTime - INTERVAL 30 seconds AND
| right.eventTime + INTERVAL 30 seconds
""".stripMargin),
joinType = "leftOuter")

val inputDataForInput1 = Seq(
(Timestamp.valueOf("2020-01-01 00:00:00"), "abc", "has no join partner"),
(Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "joined with A"),
(Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "joined with B"))

val inputDataForInput2 = Seq(
(Timestamp.valueOf("2020-01-02 00:00:10"), "abc", "A"),
(Timestamp.valueOf("2020-01-02 00:59:59"), "abc", "B"),
(Timestamp.valueOf("2020-01-02 02:00:00"), "abc", "C"))

val expectedOutput = Seq(
(Timestamp.valueOf("2020-01-01 00:00:00"), "abc", "has no join partner", null, null, null),
(Timestamp.valueOf("2020-01-02 00:00:00"), "abc", "joined with A",
Timestamp.valueOf("2020-01-02 00:00:10"), "abc", "A"),
(Timestamp.valueOf("2020-01-02 01:00:00"), "abc", "joined with B",
Timestamp.valueOf("2020-01-02 00:59:59"), "abc", "B"))

testStream(joined)(
MultiAddData((input1, inputDataForInput1), (input2, inputDataForInput2)),
CheckNewAnswer(expectedOutput.head, expectedOutput.tail: _*)
)
}
}