@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
2424
2525import org .scalatest .{BeforeAndAfter , BeforeAndAfterAll }
2626
27+ import org .apache .spark .streaming .dstream .{TrackStateDStream , TrackStateDStreamImpl }
2728import org .apache .spark .util .{ManualClock , Utils }
2829import org .apache .spark .{SparkConf , SparkContext , SparkFunSuite }
2930
@@ -166,7 +167,8 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
166167 sum
167168 }
168169
169- testOperation(inputData, StateSpec .function(trackStateFunc), outputData, stateData)
170+ testOperation[String , Int , Int ](
171+ inputData, StateSpec .function(trackStateFunc), outputData, stateData)
170172 }
171173
172174 test(" trackStateByKey - basic operations with advanced API" ) {
@@ -213,6 +215,55 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
213215 testOperation(inputData, StateSpec .function(trackStateFunc), outputData, stateData)
214216 }
215217
218+ test(" trackStateByKey - type inferencing and class tags" ) {
219+
220+ // Simple track state function with value as Int, state as Double and emitted type as Double
221+ val simpleFunc = (value : Option [Int ], state : State [Double ]) => {
222+ 0L
223+ }
224+
225+ // Advanced track state function with key as String, value as Int, state as Double and
226+ // emitted type as Double
227+ val advancedFunc = (time : Time , key : String , value : Option [Int ], state : State [Double ]) => {
228+ Some (0L )
229+ }
230+
231+ def testTypes (dstream : TrackStateDStream [_, _, _, _]): Unit = {
232+ val dstreamImpl = dstream.asInstanceOf [TrackStateDStreamImpl [_, _, _, _]]
233+ assert(dstreamImpl.keyClass === classOf [String ])
234+ assert(dstreamImpl.valueClass === classOf [Int ])
235+ assert(dstreamImpl.stateClass === classOf [Double ])
236+ assert(dstreamImpl.emittedClass === classOf [Long ])
237+ }
238+
239+ val inputStream = new TestInputStream [(String , Int )](ssc, Seq .empty, numPartitions = 2 )
240+
241+ // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
242+ val simpleFunctionStateStream1 = inputStream.trackStateByKey(
243+ StateSpec .function(simpleFunc).numPartitions(1 ))
244+ testTypes(simpleFunctionStateStream1)
245+
246+ // Separately defining StateSpec with simple function requires explicitly specifying types
247+ val simpleFuncSpec = StateSpec .function[String , Int , Double , Long ](simpleFunc)
248+ val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec)
249+ testTypes(simpleFunctionStateStream2)
250+
251+ // Separately defining StateSpec with advanced function implicitly gets the types
252+ val advFuncSpec1 = StateSpec .function(advancedFunc)
253+ val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1)
254+ testTypes(advFunctionStateStream1)
255+
256+ // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
257+ val advFunctionStateStream2 = inputStream.trackStateByKey(
258+ StateSpec .function(simpleFunc).numPartitions(1 ))
259+ testTypes(advFunctionStateStream2)
260+
261+ // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
262+ val advFuncSpec2 = StateSpec .function[String , Int , Double , Long ](advancedFunc)
263+ val advFunctionStateStream3 = inputStream.trackStateByKey[Double , Long ](advFuncSpec2)
264+ testTypes(advFunctionStateStream3)
265+ }
266+
216267 test(" trackStateByKey - states as emitted records" ) {
217268 val inputData =
218269 Seq (
0 commit comments