Skip to content

Commit b47ac23

Browse files
committed
Add test
1 parent b8c70ab commit b47ac23

1 file changed

Lines changed: 32 additions & 3 deletions

File tree

sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ import org.scalatest.exceptions.TestFailedException
2525

2626
import org.apache.spark.SparkException
2727
import org.apache.spark.api.java.Optional
28-
import org.apache.spark.api.java.function.{FlatMapGroupsWithStateFunction}
29-
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder}
28+
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
29+
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, KeyValueGroupedDataset}
3030
import org.apache.spark.sql.catalyst.InternalRow
3131
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
3232
import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
@@ -1413,7 +1413,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
14131413
)
14141414
}
14151415

1416-
test("flatMapGroupsWithState - initial state - streaming initial state") {
1416+
testQuietly("flatMapGroupsWithState - initial state - streaming initial state") {
14171417
val initialStateData = MemoryStream[(String, RunningCount)]
14181418
initialStateData.addData(("a", new RunningCount(1)))
14191419

@@ -1438,6 +1438,35 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
14381438
assert(e.message.contains(expectedError))
14391439
}
14401440

1441+
test("flatMapGroupsWithState - initial state - initial state has flatMapGroupsWithState") {
1442+
val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS()
1443+
val initialState: KeyValueGroupedDataset[String, RunningCount] =
1444+
initialStateDS.groupByKey(_._1).mapValues(_._2)
1445+
.mapGroupsWithState(
1446+
GroupStateTimeout.NoTimeout())(
1447+
(key: String, values: Iterator[RunningCount], state: GroupState[Boolean]) => {
1448+
(key, values.next())
1449+
}
1450+
).groupByKey(_._1).mapValues(_._2)
1451+
1452+
val inputData = MemoryStream[String]
1453+
1454+
val result =
1455+
inputData.toDS()
1456+
.groupByKey(x => x)
1457+
.flatMapGroupsWithState(
1458+
Update, NoTimeout(), initialState
1459+
)(flatMapGroupsWithStateFunc)
1460+
1461+
testStream(result, Update)(
1462+
AddData(inputData, "keyInStateAndData"),
1463+
CheckNewAnswer(
1464+
("keyInStateAndData", Seq[String]("keyInStateAndData"), "2")
1465+
),
1466+
StopStream
1467+
)
1468+
}
1469+
14411470
testWithAllStateVersions("mapGroupsWithState - initial state - null key") {
14421471
val mapGroupsWithStateFunc =
14431472
(key: String, values: Iterator[String], state: GroupState[RunningCount]) => {

0 commit comments

Comments
 (0)