Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<I
};

// This will give a Dstream made of state (which is the cumulative count of the words)
JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD));
JavaMapWithStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
wordsDstream.mapWithState(StateSpec.function(trackStateFunc).initialState(initialRDD));

stateDstream.print();
ssc.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object StatefulNetworkWordCount {
Some(output)
}

val stateDstream = wordDstream.trackStateByKey(
val stateDstream = wordDstream.mapWithState(
StateSpec.function(trackStateFunc).initialState(initialRDD))
stateDstream.print()
ssc.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.Function4;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;

/**
* Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8
Expand Down Expand Up @@ -867,8 +865,8 @@ public void testTrackStateByAPI() {
JavaPairRDD<String, Boolean> initialRDD = null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the method name testTrackStateByAPI should be renamed to testMapWithStateAPI

JavaPairDStream<String, Integer> wordsDstream = null;

JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
wordsDstream.trackStateByKey(
JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
wordsDstream.mapWithState(
StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> {
// Use all State's methods here
state.exists();
Expand All @@ -884,8 +882,8 @@ StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state)

JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();

JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
wordsDstream.trackStateByKey(
JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
wordsDstream.mapWithState(
StateSpec.<String, Integer, Boolean, Double>function((value, state) -> {
state.exists();
state.get();
Expand Down
16 changes: 9 additions & 7 deletions streaming/src/main/scala/org/apache/spark/streaming/State.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental

/**
* :: Experimental ::
* Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
* a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
* Abstract class for getting and updating the state in mapping function used in the `mapWithState`
* operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala)
* or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Scala example of using `State`:
* {{{
* // A tracking function that maintains an integer state and return a String
* def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
* // A mapping function that maintains an integer state and return a String
* def mappingFunction(data: Option[Int], state: State[Int]): Option[String] = {
* // Check if state exists
* if (state.exists) {
* val existingState = state.get // Get the existing state
Expand All @@ -52,8 +52,8 @@ import org.apache.spark.annotation.Experimental
*
* Java example of using `State`:
* {{{
* // A tracking function that maintains an integer state and return a String
* Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
* // A mapping function that maintains an integer state and return a String
* Function2<Optional<Integer>, State<Integer>, Optional<String>> mappingFunction =
* new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
*
* @Override
Expand All @@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental
* }
* };
* }}}
*
* @tparam S Class of the state
*/
@Experimental
sealed abstract class State[S] {
Expand Down
118 changes: 62 additions & 56 deletions streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Partitioner}
/**
* :: Experimental ::
* Abstract class representing all the specifications of the DStream transformation
* `trackStateByKey` operation of a
* `mapWithState` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
* Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or
Expand All @@ -37,42 +37,47 @@ import org.apache.spark.{HashPartitioner, Partitioner}
*
* Example in Scala:
* {{{
* def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
* def mappingFunction(data: Option[ValueType], wrappedState: State[StateType]): MappedType = {
* ...
* }
*
* val spec = StateSpec.function(trackingFunction).numPartitions(10)
* val spec = StateSpec.function(mappingFunction).numPartitions(10)
*
* val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
* val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* Example in Java:
* {{{
* StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
* StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
* StateSpec<KeyType, ValueType, StateType, MappedType> spec =
* StateSpec.<KeyType, ValueType, StateType, MappedType>function(mappingFunction)
* .numPartition(10);
*
* JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
* javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
* JavaMapWithStateDStream<KeyType, ValueType, StateType, MappedType> mapWithStateDStream =
* javaPairDStream.<StateType, MappedType>mapWithState(spec);
* }}}
*
* @tparam KeyType Class of the state key
* @tparam ValueType Class of the state value
* @tparam StateType Class of the state data
* @tparam MappedType Class of the mapped elements
*/
@Experimental
sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable {
sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable {

/** Set the RDD containing the initial states that will be used by `trackStateByKey` */
/** Set the RDD containing the initial states that will be used by `mapWithState` */
def initialState(rdd: RDD[(KeyType, StateType)]): this.type

/** Set the RDD containing the initial states that will be used by `trackStateByKey` */
/** Set the RDD containing the initial states that will be used by `mapWithState` */
def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type

/**
* Set the number of partitions by which the state RDDs generated by `trackStateByKey`
* Set the number of partitions by which the state RDDs generated by `mapWithState`
* will be partitioned. Hash partitioning will be used.
*/
def numPartitions(numPartitions: Int): this.type

/**
* Set the partitioner by which the state RDDs generated by `trackStateByKey` will be
* Set the partitioner by which the state RDDs generated by `mapWithState` will be
* be partitioned.
*/
def partitioner(partitioner: Partitioner): this.type
Expand All @@ -91,113 +96,114 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
/**
* :: Experimental ::
* Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
* that is used for specifying the parameters of the DStream transformation `trackStateByKey`
* that is used for specifying the parameters of the DStream transformation `mapWithState`
* that is used for specifying the parameters of the DStream transformation
* `trackStateByKey` operation of a
* `mapWithState` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Example in Scala:
* {{{
* def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
* def mappingFunction(data: Option[ValueType], wrappedState: State[StateType]): MappedType = {
* ...
* }
*
* val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](
* StateSpec.function(trackingFunction).numPartitions(10))
* val spec = StateSpec.function(mappingFunction).numPartitions(10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember this line cannot be put here because the compiler cannot infer KeyType.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am modifying the signature of the function to have the key. See jira SPARK-12245. So then this should not be a problem.

*
* val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* Example in Java:
* {{{
* StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
* StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
* StateSpec<KeyType, ValueType, StateType, MappedType> spec =
* StateSpec.<KeyType, ValueType, StateType, MappedType>function(mappingFunction)
* .numPartition(10);
*
* JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
* javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
* JavaMapWithStateDStream<KeyType, ValueType, StateType, MappedType> mapWithStateDStream =
* javaPairDStream.<StateType, MappedType>mapWithState(spec);
* }}}
*/
@Experimental
object StateSpec {
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
* of the `trackStateByKey` operation on a
* of the `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
* @param trackingFunction The function applied on every data item to manage the associated state
* and generate the emitted data
* @param mappingFunction The function applied on every data item to manage the associated state
* and generate the mapped data
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
* @tparam EmittedType Class of the emitted data
* @tparam MappedType Class of the mapped data
*/
def function[KeyType, ValueType, StateType, EmittedType](
trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
ClosureCleaner.clean(trackingFunction, checkSerializable = true)
new StateSpecImpl(trackingFunction)
def function[KeyType, ValueType, StateType, MappedType](
mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType]
): StateSpec[KeyType, ValueType, StateType, MappedType] = {
ClosureCleaner.clean(mappingFunction, checkSerializable = true)
new StateSpecImpl(mappingFunction)
}

/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
* of the `trackStateByKey` operation on a
* of the `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
* @param trackingFunction The function applied on every data item to manage the associated state
* and generate the emitted data
* @param mappingFunction The function applied on every data item to manage the associated state
* and generate the mapped data
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
* @tparam EmittedType Class of the emitted data
* @tparam MappedType Class of the mapped data
*/
def function[KeyType, ValueType, StateType, EmittedType](
trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
ClosureCleaner.clean(trackingFunction, checkSerializable = true)
def function[KeyType, ValueType, StateType, MappedType](
mappingFunction: (Option[ValueType], State[StateType]) => MappedType
): StateSpec[KeyType, ValueType, StateType, MappedType] = {
ClosureCleaner.clean(mappingFunction, checkSerializable = true)
val wrappedFunction =
(time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
Some(trackingFunction(value, state))
Some(mappingFunction(value, state))
}
new StateSpecImpl(wrappedFunction)
}

/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all
* the specifications of the `trackStateByKey` operation on a
* the specifications of the `mapWithState` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
* @param javaTrackingFunction The function applied on every data item to manage the associated
* state and generate the emitted data
* @param mappingFunction The function applied on every data item to manage the associated
* state and generate the mapped data
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
* @tparam EmittedType Class of the emitted data
* @tparam MappedType Class of the mapped data
*/
def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction:
JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]):
StateSpec[KeyType, ValueType, StateType, EmittedType] = {
def function[KeyType, ValueType, StateType, MappedType](mappingFunction:
JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]):
StateSpec[KeyType, ValueType, StateType, MappedType] = {
val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
Option(t.orNull)
}
StateSpec.function(trackingFunc)
}

/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
* of the `trackStateByKey` operation on a
* of the `mapWithState` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
* @param javaTrackingFunction The function applied on every data item to manage the associated
* state and generate the emitted data
* @param mappingFunction The function applied on every data item to manage the associated
* state and generate the mapped data
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
* @tparam EmittedType Class of the emitted data
* @tparam MappedType Class of the mapped data
*/
def function[KeyType, ValueType, StateType, EmittedType](
javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]):
StateSpec[KeyType, ValueType, StateType, EmittedType] = {
def function[KeyType, ValueType, StateType, MappedType](
mappingFunction: JFunction2[Optional[ValueType], State[StateType], MappedType]):
StateSpec[KeyType, ValueType, StateType, MappedType] = {
val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
javaTrackingFunction.call(Optional.fromNullable(v.get), s)
mappingFunction.call(Optional.fromNullable(v.get), s)
}
StateSpec.function(trackingFunc)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ package org.apache.spark.streaming.api.java

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.streaming.dstream.TrackStateDStream
import org.apache.spark.streaming.dstream.MapWithStateDStream

/**
* :: Experimental ::
* [[JavaDStream]] representing the stream of records emitted by the tracking function in the
* `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the
* DStream representing the stream of data generated by `mapWithState` operation on a
* [[JavaPairDStream]]. Additionally, it also gives access to the
* stream of state snapshots, that is, the state data of all keys after a batch has updated them.
*
* @tparam KeyType Class of the state key
* @tparam ValueType Class of the state value
* @tparam StateType Class of the state
* @tparam EmittedType Class of the emitted records
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the state data
* @tparam MappedType Class of the mapped data
*/
@Experimental
class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType](
dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType])
extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) {
class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming](
dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType])
extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) {

def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
new JavaPairDStream(dstream.stateSnapshots())(
Expand Down
Loading