@@ -25,8 +25,8 @@ import org.scalatest.exceptions.TestFailedException
2525
2626import org .apache .spark .SparkException
2727import org .apache .spark .api .java .Optional
28- import org .apache .spark .api .java .function .{ FlatMapGroupsWithStateFunction }
29- import org .apache .spark .sql .{AnalysisException , DataFrame , Dataset , Encoder }
28+ import org .apache .spark .api .java .function .FlatMapGroupsWithStateFunction
29+ import org .apache .spark .sql .{AnalysisException , DataFrame , Dataset , Encoder , KeyValueGroupedDataset }
3030import org .apache .spark .sql .catalyst .InternalRow
3131import org .apache .spark .sql .catalyst .expressions .{GenericInternalRow , UnsafeProjection , UnsafeRow }
3232import org .apache .spark .sql .catalyst .plans .logical .FlatMapGroupsWithState
@@ -1413,7 +1413,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
14131413 )
14141414 }
14151415
1416- test (" flatMapGroupsWithState - initial state - streaming initial state" ) {
1416+ testQuietly (" flatMapGroupsWithState - initial state - streaming initial state" ) {
14171417 val initialStateData = MemoryStream [(String , RunningCount )]
14181418 initialStateData.addData((" a" , new RunningCount (1 )))
14191419
@@ -1438,6 +1438,35 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
14381438 assert(e.message.contains(expectedError))
14391439 }
14401440
1441+ test(" flatMapGroupsWithState - initial state - initial state has flatMapGroupsWithState" ) {
1442+ val initialStateDS = Seq ((" keyInStateAndData" , new RunningCount (1 ))).toDS()
1443+ val initialState : KeyValueGroupedDataset [String , RunningCount ] =
1444+ initialStateDS.groupByKey(_._1).mapValues(_._2)
1445+ .mapGroupsWithState(
1446+ GroupStateTimeout .NoTimeout ())(
1447+ (key : String , values : Iterator [RunningCount ], state : GroupState [Boolean ]) => {
1448+ (key, values.next())
1449+ }
1450+ ).groupByKey(_._1).mapValues(_._2)
1451+
1452+ val inputData = MemoryStream [String ]
1453+
1454+ val result =
1455+ inputData.toDS()
1456+ .groupByKey(x => x)
1457+ .flatMapGroupsWithState(
1458+ Update , NoTimeout (), initialState
1459+ )(flatMapGroupsWithStateFunc)
1460+
1461+ testStream(result, Update )(
1462+ AddData (inputData, " keyInStateAndData" ),
1463+ CheckNewAnswer (
1464+ (" keyInStateAndData" , Seq [String ](" keyInStateAndData" ), " 2" )
1465+ ),
1466+ StopStream
1467+ )
1468+ }
1469+
14411470 testWithAllStateVersions(" mapGroupsWithState - initial state - null key" ) {
14421471 val mapGroupsWithStateFunc =
14431472 (key : String , values : Iterator [String ], state : GroupState [RunningCount ]) => {
0 commit comments