Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ object UnsupportedOperationChecker extends Logging {
case p if p.isStreaming =>
throwError("Queries with streaming sources must be executed with writeStream.start()")(p)

case f: FlatMapGroupsWithState =>
if (f.hasInitialState) {
throwError("Initial state is not supported in [flatMap|map]GroupsWithState" +
" operation on a batch DataFrame/Dataset")(f)
}

case _ =>
}
}
Expand Down Expand Up @@ -232,6 +238,12 @@ object UnsupportedOperationChecker extends Logging {
// Check compatibility with output modes and aggregations in query
val aggsInQuery = collectStreamingAggregates(plan)

if (m.initialState.isStreaming) {
// initial state has to be a batch relation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-streaming DataFrame/Dataset is not supported as the initial state in [flatMap|map]GroupsWithState operation on a streamiing DataFrame/Dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

throwError("Non-streaming DataFrame/Dataset is not supported as the" +
" initial state in [flatMap|map]GroupsWithState operation on a streaming" +
" DataFrame/Dataset")
}
if (m.isMapGroupsWithState) { // check mapGroupsWithState
// allowed only in update query output mode and without aggregation
if (aggsInQuery.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ object FlatMapGroupsWithState {
isMapGroupsWithState: Boolean,
timeout: GroupStateTimeout,
child: LogicalPlan): LogicalPlan = {
val encoder = encoderFor[S]
val stateEncoder = encoderFor[S]

val mapped = new FlatMapGroupsWithState(
func,
Expand All @@ -449,10 +449,49 @@ object FlatMapGroupsWithState {
groupingAttributes,
dataAttributes,
CatalystSerde.generateObjAttr[U],
encoder.asInstanceOf[ExpressionEncoder[Any]],
stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
outputMode,
isMapGroupsWithState,
timeout,
hasInitialState = false,
groupingAttributes,
dataAttributes,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
LocalRelation(stateEncoder.schema.toAttributes), // empty data set
child
)
CatalystSerde.serialize[U](mapped)
}

def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputMode: OutputMode,
isMapGroupsWithState: Boolean,
timeout: GroupStateTimeout,
child: LogicalPlan,
initialStateGroupAttrs: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
initialState: LogicalPlan): LogicalPlan = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: init and initial. be consistent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val stateEncoder = encoderFor[S]

val mapped = new FlatMapGroupsWithState(
func,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
groupingAttributes,
dataAttributes,
CatalystSerde.generateObjAttr[U],
stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
outputMode,
isMapGroupsWithState,
timeout,
hasInitialState = true,
initialStateGroupAttrs,
initialStateDataAttrs,
UnresolvedDeserializer(encoderFor[S].deserializer, initialStateDataAttrs),
initialState,
child)
CatalystSerde.serialize[U](mapped)
}
Expand All @@ -474,6 +513,12 @@ object FlatMapGroupsWithState {
* @param outputMode the output mode of `func`
* @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method
* @param timeout used to timeout groups that have not received data in a while
* @param hasInitialState Indicates whether initial state needs to be applied or not.
* @param initialStateGroupAttrs grouping attributes for the initial state
* @param initialStateDataAttrs used to read the initial state
* @param initialStateDeserializer used to extract the initial state objects.
* @param initialState user defined initial state that is applied in the first batch.
* @param child logical plan of the underlying data
*/
case class FlatMapGroupsWithState(
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
Expand All @@ -486,14 +531,24 @@ case class FlatMapGroupsWithState(
outputMode: OutputMode,
isMapGroupsWithState: Boolean = false,
timeout: GroupStateTimeout,
child: LogicalPlan) extends UnaryNode with ObjectProducer {
hasInitialState: Boolean = false,
initialStateGroupAttrs: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
initialStateDeserializer: Expression,
initialState: LogicalPlan,
child: LogicalPlan) extends BinaryNode with ObjectProducer {

if (isMapGroupsWithState) {
assert(outputMode == OutputMode.Update)
}

override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsWithState =
copy(child = newChild)
override def left: LogicalPlan = child

override def right: LogicalPlan = initialState

override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapGroupsWithState =
copy(child = newLeft, initialState = newRight)
}

/** Factory for constructing new `FlatMapGroupsInR` nodes. */
Expand Down
Loading