Skip to content

Commit 7160786

Browse files
committed
Add unit tests
1 parent de4ef2b commit 7160786

3 files changed

Lines changed: 226 additions & 36 deletions

File tree

extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@
3131
import org.apache.spark.HashPartitioner;
3232
import org.apache.spark.api.java.JavaPairRDD;
3333
import org.apache.spark.api.java.JavaRDD;
34+
import org.apache.spark.api.java.function.Function2;
35+
import org.apache.spark.api.java.function.Function4;
3436
import org.apache.spark.api.java.function.PairFunction;
3537
import org.apache.spark.streaming.api.java.JavaDStream;
3638
import org.apache.spark.streaming.api.java.JavaPairDStream;
39+
import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
3740

3841
/**
3942
* Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8
@@ -617,7 +620,7 @@ public void testCombineByKey() {
617620
JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
618621

619622
JavaPairDStream<String, Integer> combined = pairStream.<Integer>combineByKey(i -> i,
620-
(x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2));
623+
(x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2));
621624

622625
JavaTestUtils.attachTestOutputStream(combined);
623626
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
@@ -700,7 +703,7 @@ public void testReduceByKeyAndWindowWithInverse() {
700703

701704
JavaPairDStream<String, Integer> reduceWindowed =
702705
pairStream.reduceByKeyAndWindow((x, y) -> x + y, (x, y) -> x - y, new Duration(2000),
703-
new Duration(1000));
706+
new Duration(1000));
704707
JavaTestUtils.attachTestOutputStream(reduceWindowed);
705708
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
706709

@@ -831,4 +834,44 @@ public void testFlatMapValues() {
831834
Assert.assertEquals(expected, result);
832835
}
833836

837+
/**
838+
* This test is only for testing the APIs. It's not necessary to run it.
839+
*/
840+
public void testTrackStateByAPI() {
841+
JavaPairRDD<String, Boolean> initialRDD = null;
842+
JavaPairDStream<String, Integer> wordsDstream = null;
843+
844+
JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
845+
wordsDstream.trackStateByKey(
846+
StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> {
847+
// Use all State's methods here
848+
state.exists();
849+
state.get();
850+
state.isTimingOut();
851+
state.remove();
852+
state.update(true);
853+
return Optional.of(2.0);
854+
}).initialState(initialRDD)
855+
.numPartitions(10)
856+
.partitioner(new HashPartitioner(10))
857+
.timeout(Durations.seconds(10)));
858+
859+
JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
860+
861+
JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
862+
wordsDstream.trackStateByKey(
863+
StateSpec.<String, Integer, Boolean, Double>function((value, state) -> {
864+
state.exists();
865+
state.get();
866+
state.isTimingOut();
867+
state.remove();
868+
state.update(true);
869+
return 2.0;
870+
}).initialState(initialRDD)
871+
.numPartitions(10)
872+
.partitioner(new HashPartitioner(10))
873+
.timeout(Durations.seconds(10)));
874+
875+
JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
876+
}
834877
}

streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,11 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
267267

268268
// Read the data of the delta
269269
val deltaMapSize = inputStream.readInt()
270-
deltaMap = new OpenHashMap[K, StateInfo[S]](deltaMapSize)
270+
deltaMap = if (deltaMapSize != 0) {
271+
new OpenHashMap[K, StateInfo[S]](deltaMapSize)
272+
} else {
273+
new OpenHashMap[K, StateInfo[S]](initialCapacity)
274+
}
271275
var deltaMapCount = 0
272276
while (deltaMapCount < deltaMapSize) {
273277
val key = inputStream.readObject().asInstanceOf[K]

streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java

Lines changed: 176 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,51 +17,194 @@
1717

1818
package org.apache.spark.streaming;
1919

20+
import java.io.Serializable;
21+
import java.util.Arrays;
22+
import java.util.Collections;
23+
import java.util.List;
24+
import java.util.Set;
25+
26+
import scala.Tuple2;
27+
2028
import com.google.common.base.Optional;
29+
import com.google.common.collect.Lists;
30+
import com.google.common.collect.Sets;
31+
import org.apache.spark.api.java.JavaRDD;
32+
import org.apache.spark.api.java.function.Function;
33+
import org.apache.spark.streaming.api.java.JavaDStream;
34+
import org.apache.spark.util.ManualClock;
35+
import org.junit.Assert;
36+
import org.junit.Test;
37+
2138
import org.apache.spark.HashPartitioner;
2239
import org.apache.spark.api.java.JavaPairRDD;
40+
import org.apache.spark.api.java.function.Function2;
2341
import org.apache.spark.api.java.function.Function4;
24-
import org.apache.spark.api.java.function.PairFunction;
25-
import org.apache.spark.streaming.Durations;
2642
import org.apache.spark.streaming.api.java.JavaPairDStream;
2743
import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
28-
import org.junit.Test;
29-
import scala.Tuple2;
30-
31-
import java.io.Serializable;
3244

3345
public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable {
3446

3547
/**
3648
* This test is only for testing the APIs. It's not necessary to run it.
3749
*/
3850
public void testAPI() {
39-
// TODO
40-
// JavaPairRDD<String, Integer> initialRDD = null;
41-
// JavaPairDStream<String, Integer> wordsDstream = null;
42-
// final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<String>>
43-
// trackStateFunc =
44-
// new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<String>>() {
45-
//
46-
// @Override
47-
// public Optional<String> call(Time time, String word, Optional<Integer> one,
48-
// State<Integer> state) {
49-
// // Use all State's methods here
50-
// state.exists();
51-
// state.get();
52-
// state.isTimingOut();
53-
// state.remove();
54-
// state.update(10);
55-
// return "test";
56-
// }
57-
// };
58-
//
59-
// JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
60-
// wordsDstream.trackStateByKey(
61-
// StateSpec.function(trackStateFunc)
62-
// .initialState(initialRDD)
63-
// .numPartitions(10)
64-
// .partitioner(new HashPartitioner(10))
65-
// .timeout(Durations.seconds(10)));
51+
JavaPairRDD<String, Boolean> initialRDD = null;
52+
JavaPairDStream<String, Integer> wordsDstream = null;
53+
54+
final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>
55+
trackStateFunc =
56+
new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() {
57+
58+
@Override
59+
public Optional<Double> call(
60+
Time time, String word, Optional<Integer> one, State<Boolean> state) {
61+
// Use all State's methods here
62+
state.exists();
63+
state.get();
64+
state.isTimingOut();
65+
state.remove();
66+
state.update(true);
67+
return Optional.of(2.0);
68+
}
69+
};
70+
71+
JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
72+
wordsDstream.trackStateByKey(
73+
StateSpec.function(trackStateFunc)
74+
.initialState(initialRDD)
75+
.numPartitions(10)
76+
.partitioner(new HashPartitioner(10))
77+
.timeout(Durations.seconds(10)));
78+
79+
JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
80+
81+
final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 =
82+
new Function2<Optional<Integer>, State<Boolean>, Double>() {
83+
84+
@Override
85+
public Double call(Optional<Integer> one, State<Boolean> state) {
86+
// Use all State's methods here
87+
state.exists();
88+
state.get();
89+
state.isTimingOut();
90+
state.remove();
91+
state.update(true);
92+
return 2.0;
93+
}
94+
};
95+
96+
JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
97+
wordsDstream.trackStateByKey(
98+
StateSpec.<String, Integer, Boolean, Double> function(trackStateFunc2)
99+
.initialState(initialRDD)
100+
.numPartitions(10)
101+
.partitioner(new HashPartitioner(10))
102+
.timeout(Durations.seconds(10)));
103+
104+
JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
105+
}
106+
107+
@Test
108+
public void testBasicFunction() {
109+
List<List<String>> inputData = Arrays.asList(
110+
Collections.<String>emptyList(),
111+
Arrays.asList("a"),
112+
Arrays.asList("a", "b"),
113+
Arrays.asList("a", "b", "c"),
114+
Arrays.asList("a", "b"),
115+
Arrays.asList("a"),
116+
Collections.<String>emptyList()
117+
);
118+
119+
List<Set<Integer>> outputData = Arrays.asList(
120+
Collections.<Integer>emptySet(),
121+
Sets.newHashSet(1),
122+
Sets.newHashSet(2, 1),
123+
Sets.newHashSet(3, 2, 1),
124+
Sets.newHashSet(4, 3),
125+
Sets.newHashSet(5),
126+
Collections.<Integer>emptySet()
127+
);
128+
129+
List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList(
130+
Collections.<Tuple2<String, Integer>>emptySet(),
131+
Sets.newHashSet(new Tuple2<String, Integer>("a", 1)),
132+
Sets.newHashSet(new Tuple2<String, Integer>("a", 2), new Tuple2<String, Integer>("b", 1)),
133+
Sets.newHashSet(
134+
new Tuple2<String, Integer>("a", 3),
135+
new Tuple2<String, Integer>("b", 2),
136+
new Tuple2<String, Integer>("c", 1)),
137+
Sets.newHashSet(
138+
new Tuple2<String, Integer>("a", 4),
139+
new Tuple2<String, Integer>("b", 3),
140+
new Tuple2<String, Integer>("c", 1)),
141+
Sets.newHashSet(
142+
new Tuple2<String, Integer>("a", 5),
143+
new Tuple2<String, Integer>("b", 3),
144+
new Tuple2<String, Integer>("c", 1)),
145+
Sets.newHashSet(
146+
new Tuple2<String, Integer>("a", 5),
147+
new Tuple2<String, Integer>("b", 3),
148+
new Tuple2<String, Integer>("c", 1))
149+
);
150+
151+
Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc =
152+
new Function2<Optional<Integer>, State<Integer>, Integer>() {
153+
154+
@Override
155+
public Integer call(Optional<Integer> value, State<Integer> state) throws Exception {
156+
int sum = value.or(0) + (state.exists() ? state.get() : 0);
157+
state.update(sum);
158+
return sum;
159+
}
160+
};
161+
testOperation(
162+
inputData,
163+
StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc),
164+
outputData,
165+
stateData);
166+
}
167+
168+
private <K, S, T> void testOperation(
169+
List<List<K>> input,
170+
StateSpec<K, Integer, S, T> trackStateSpec,
171+
List<Set<T>> expectedOutputs,
172+
List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
173+
int numBatches = expectedOutputs.size();
174+
JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2);
175+
JavaTrackStateDStream<K, Integer, S, T> trackeStateStream =
176+
JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() {
177+
@Override
178+
public Tuple2<K, Integer> call(K x) throws Exception {
179+
return new Tuple2<K, Integer>(x, 1);
180+
}
181+
})).trackStateByKey(trackStateSpec);
182+
183+
final List<Set<T>> collectedOutputs =
184+
Collections.synchronizedList(Lists.<Set<T>>newArrayList());
185+
trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
186+
@Override
187+
public Void call(JavaRDD<T> rdd) throws Exception {
188+
collectedOutputs.add(Sets.newHashSet(rdd.collect()));
189+
return null;
190+
}
191+
});
192+
final List<Set<Tuple2<K, S>>> collectedStateSnapshots =
193+
Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList());
194+
trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() {
195+
@Override
196+
public Void call(JavaPairRDD<K, S> rdd) throws Exception {
197+
collectedStateSnapshots.add(Sets.newHashSet(rdd.collect()));
198+
return null;
199+
}
200+
});
201+
BatchCounter batchCounter = new BatchCounter(ssc.ssc());
202+
ssc.start();
203+
((ManualClock) ssc.ssc().scheduler().clock())
204+
.advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1);
205+
batchCounter.waitUntilBatchesCompleted(numBatches, 10000);
206+
207+
Assert.assertEquals(expectedOutputs, collectedOutputs);
208+
Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots);
66209
}
67210
}

0 commit comments

Comments
 (0)