Skip to content

Commit ae64786

Browse files
committed
Addressed type issue in StateSpec.function
1 parent 77c9a66 commit ae64786

File tree

4 files changed

+67
-7
lines changed

4 files changed

+67
-7
lines changed

streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,15 @@ object StateSpec {
150150
* @tparam StateType Class of the states data
151151
* @tparam EmittedType Class of the emitted data
152152
*/
153-
def function[ValueType, StateType, EmittedType](
153+
def function[KeyType, ValueType, StateType, EmittedType](
154154
trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
155-
): StateSpec[Any, ValueType, StateType, EmittedType] = {
155+
): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
156156
ClosureCleaner.clean(trackingFunction, checkSerializable = true)
157157
val wrappedFunction =
158158
(time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
159159
Some(trackingFunction(value, state))
160160
}
161-
new StateSpecImpl[Any, ValueType, StateType, EmittedType](wrappedFunction)
161+
new StateSpecImpl(wrappedFunction)
162162
}
163163
}
164164

streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
380380
*/
381381
@Experimental
382382
def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag](
383-
spec: StateSpec[K, V, StateType, EmittedType]): TrackStateDStream[K, StateType, EmittedType] = {
383+
spec: StateSpec[K, V, StateType, EmittedType]
384+
): TrackStateDStream[K, V, StateType, EmittedType] = {
384385
new TrackStateDStreamImpl[K, V, StateType, EmittedType](
385386
self,
386387
spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]]

streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
3939
* @tparam EmittedType Class of the emitted records
4040
*/
4141
@Experimental
42-
sealed abstract class TrackStateDStream[KeyType, StateType, EmittedType: ClassTag](
42+
sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag](
4343
ssc: StreamingContext) extends DStream[EmittedType](ssc) {
4444

4545
/** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
@@ -51,7 +51,7 @@ private[streaming] class TrackStateDStreamImpl[
5151
KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag](
5252
dataStream: DStream[(KeyType, ValueType)],
5353
spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType])
54-
extends TrackStateDStream[KeyType, StateType, EmittedType](dataStream.context) {
54+
extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) {
5555

5656
private val internalStream =
5757
new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec)
@@ -78,6 +78,14 @@ private[streaming] class TrackStateDStreamImpl[
7878
internalStream.flatMap {
7979
_.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable }
8080
}
81+
82+
def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass
83+
84+
def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass
85+
86+
def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
87+
88+
def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass
8189
}
8290

8391
/**

streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
2424

2525
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
2626

27+
import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl}
2728
import org.apache.spark.util.{ManualClock, Utils}
2829
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
2930

@@ -166,7 +167,8 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
166167
sum
167168
}
168169

169-
testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
170+
testOperation[String, Int, Int](
171+
inputData, StateSpec.function(trackStateFunc), outputData, stateData)
170172
}
171173

172174
test("trackStateByKey - basic operations with advanced API") {
@@ -213,6 +215,55 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
213215
testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
214216
}
215217

218+
test("trackStateByKey - type inferencing and class tags") {
219+
220+
// Simple track state function with value as Int, state as Double and emitted type as Double
221+
val simpleFunc = (value: Option[Int], state: State[Double]) => {
222+
0L
223+
}
224+
225+
// Advanced track state function with key as String, value as Int, state as Double and
226+
// emitted type as Double
227+
val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => {
228+
Some(0L)
229+
}
230+
231+
def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = {
232+
val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]]
233+
assert(dstreamImpl.keyClass === classOf[String])
234+
assert(dstreamImpl.valueClass === classOf[Int])
235+
assert(dstreamImpl.stateClass === classOf[Double])
236+
assert(dstreamImpl.emittedClass === classOf[Long])
237+
}
238+
239+
val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
240+
241+
// Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
242+
val simpleFunctionStateStream1 = inputStream.trackStateByKey(
243+
StateSpec.function(simpleFunc).numPartitions(1))
244+
testTypes(simpleFunctionStateStream1)
245+
246+
// Separately defining StateSpec with simple function requires explicitly specifying types
247+
val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc)
248+
val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec)
249+
testTypes(simpleFunctionStateStream2)
250+
251+
// Separately defining StateSpec with advanced function implicitly gets the types
252+
val advFuncSpec1 = StateSpec.function(advancedFunc)
253+
val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1)
254+
testTypes(advFunctionStateStream1)
255+
256+
// Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
257+
val advFunctionStateStream2 = inputStream.trackStateByKey(
258+
StateSpec.function(simpleFunc).numPartitions(1))
259+
testTypes(advFunctionStateStream2)
260+
261+
// Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
262+
val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc)
263+
val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2)
264+
testTypes(advFunctionStateStream3)
265+
}
266+
216267
test("trackStateByKey - states as emitted records") {
217268
val inputData =
218269
Seq(

0 commit comments

Comments
 (0)