Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ object UnsupportedOperationChecker extends Logging {
case p if p.isStreaming =>
throwError("Queries with streaming sources must be executed with writeStream.start()")(p)

case f: FlatMapGroupsWithState =>
if (f.hasInitialState) {
throwError("Batch [flatMap|map]GroupsWithState queries should not" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial state is not supported in [flatMap|map]GroupsWithState operation on a batch DataFrame/Dataset

" pass an initial state.")(f)
}

case _ =>
}
}
Expand Down Expand Up @@ -232,6 +238,10 @@ object UnsupportedOperationChecker extends Logging {
// Check compatibility with output modes and aggregations in query
val aggsInQuery = collectStreamingAggregates(plan)

if (m.initialState.isStreaming) {
// initial state has to be a batch relation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-streaming DataFrame/Dataset is not supported as the initial state in [flatMap|map]GroupsWithState operation on a streamiing DataFrame/Dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

throwError("Initial state cannot be a streaming DataFrame/Dataset.")
}
if (m.isMapGroupsWithState) { // check mapGroupsWithState
// allowed only in update query output mode and without aggregation
if (aggsInQuery.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ object FlatMapGroupsWithState {
isMapGroupsWithState: Boolean,
timeout: GroupStateTimeout,
child: LogicalPlan): LogicalPlan = {
val encoder = encoderFor[S]
val stateEncoder = encoderFor[S]

val mapped = new FlatMapGroupsWithState(
func,
Expand All @@ -449,10 +449,49 @@ object FlatMapGroupsWithState {
groupingAttributes,
dataAttributes,
CatalystSerde.generateObjAttr[U],
encoder.asInstanceOf[ExpressionEncoder[Any]],
stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
outputMode,
isMapGroupsWithState,
timeout,
hasInitialState = false,
groupingAttributes,
dataAttributes,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
LocalRelation(stateEncoder.schema.toAttributes), // empty data set
child
)
CatalystSerde.serialize[U](mapped)
}

def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputMode: OutputMode,
isMapGroupsWithState: Boolean,
timeout: GroupStateTimeout,
child: LogicalPlan,
initialStateGroupAttrs: Seq[Attribute],
initialStateDataAttrs: Seq[Attribute],
initialState: LogicalPlan): LogicalPlan = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: init and initial. be consistent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val stateEncoder = encoderFor[S]

val mapped = new FlatMapGroupsWithState(
func,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
groupingAttributes,
dataAttributes,
CatalystSerde.generateObjAttr[U],
stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
outputMode,
isMapGroupsWithState,
timeout,
hasInitialState = true,
initialStateGroupAttrs,
initialStateDataAttrs,
UnresolvedDeserializer(encoderFor[S].deserializer, initialStateDataAttrs),
initialState,
child)
CatalystSerde.serialize[U](mapped)
}
Expand All @@ -474,6 +513,12 @@ object FlatMapGroupsWithState {
* @param outputMode the output mode of `func`
* @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method
* @param timeout used to timeout groups that have not received data in a while
* @param hasInitialState Indicates whether initial state needs to be applied or not.
* @param initialStateGroupAttrs grouping attributes for the initial state
* @param initialStateDataAttrs used to read the initial state
* @param initialStateDeserializer used to extract the initial state objects.
* @param initialState user defined initial state that is applied in the first batch.
* @param child logical plan of the underlying data
*/
case class FlatMapGroupsWithState(
func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
Expand All @@ -486,14 +531,24 @@ case class FlatMapGroupsWithState(
outputMode: OutputMode,
isMapGroupsWithState: Boolean = false,
timeout: GroupStateTimeout,
child: LogicalPlan) extends UnaryNode with ObjectProducer {
hasInitialState: Boolean = false,
initialStateGroupAttrs: Seq[Attribute] = Seq.empty,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the point of making only some of these new param have default values if there is no function call that does NOT use those default params.

initialStateDataAttrs: Seq[Attribute] = Seq.empty,
initialStateDeserializer: Expression,
initialState: LogicalPlan,
child: LogicalPlan) extends BinaryNode with ObjectProducer {

if (isMapGroupsWithState) {
assert(outputMode == OutputMode.Update)
}

override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsWithState =
copy(child = newChild)
override def left: LogicalPlan = child

override def right: LogicalPlan = initialState

override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapGroupsWithState =
copy(child = newLeft, initialState = newRight)
}

/** Factory for constructing new `FlatMapGroupsInR` nodes. */
Expand Down
Loading