Skip to content

Commit 8cc0ac2

Browse files
vvysotskyiekrivokonmapr
authored andcommitted
MapR [SPARK-737] Calling poll(1000) from paranoidPoll causes batch scheduling delay (apache#675)
1 parent 5755374 commit 8cc0ac2

2 files changed

Lines changed: 786 additions & 0 deletions

File tree

Lines changed: 398 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,398 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.streaming.kafka010
19+
20+
import java.{util => ju}
21+
import java.util.concurrent.ConcurrentLinkedQueue
22+
import java.util.concurrent.atomic.AtomicReference
23+
24+
import scala.annotation.tailrec
25+
import scala.collection.JavaConverters._
26+
import scala.collection.mutable
27+
import scala.util.{Failure, Success, Try}
28+
29+
import org.apache.kafka.clients.consumer._
30+
import org.apache.kafka.common.TopicPartition
31+
32+
import org.apache.spark.SparkEnv
33+
import org.apache.spark.internal.Logging
34+
import org.apache.spark.storage.StorageLevel
35+
import org.apache.spark.streaming.{StreamingContext, Time}
36+
import org.apache.spark.streaming.dstream._
37+
import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo}
38+
import org.apache.spark.streaming.scheduler.rate.RateEstimator
39+
40+
/**
41+
* A DStream where
42+
* each given Kafka topic/partition corresponds to an RDD partition.
43+
* The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number
44+
* of messages
45+
* per second that each '''partition''' will accept.
46+
* @param locationStrategy In most cases, pass in [[LocationStrategies.PreferConsistent]],
47+
* see [[LocationStrategy]] for more details.
48+
* @param consumerStrategy In most cases, pass in [[ConsumerStrategies.Subscribe]],
49+
* see [[ConsumerStrategy]] for more details
50+
* @param ppc configuration of settings such as max rate on a per-partition basis.
51+
* see [[PerPartitionConfig]] for more details.
52+
* @tparam K type of Kafka message key
53+
* @tparam V type of Kafka message value
54+
*/
55+
private[spark] class DirectKafkaInputDStream[K, V](
56+
_ssc: StreamingContext,
57+
locationStrategy: LocationStrategy,
58+
consumerStrategy: ConsumerStrategy[K, V],
59+
ppc: PerPartitionConfig
60+
) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets {
61+
62+
val executorKafkaParams = {
63+
val ekp = new ju.HashMap[String, Object](consumerStrategy.executorKafkaParams)
64+
KafkaUtils.fixKafkaParams(ekp)
65+
ekp
66+
}
67+
68+
protected var currentOffsets = Map[TopicPartition, Long]()
69+
70+
@transient private var kc: Consumer[K, V] = null
71+
def consumer(): Consumer[K, V] = this.synchronized {
72+
if (null == kc) {
73+
kc = consumerStrategy.onStart(
74+
currentOffsets.mapValues(l => java.lang.Long.valueOf(l)).toMap.asJava)
75+
}
76+
kc
77+
}
78+
79+
@transient private var sc: Consumer[K, V] = null
80+
def serviceConsumer: Consumer[K, V] = this.synchronized {
81+
if (null == sc) {
82+
sc = consumerStrategy.serviceConsumer
83+
}
84+
sc
85+
}
86+
87+
override def persist(newLevel: StorageLevel): DStream[ConsumerRecord[K, V]] = {
88+
logError("Kafka ConsumerRecord is not serializable. " +
89+
"Use .map to extract fields before calling .persist or .window")
90+
super.persist(newLevel)
91+
}
92+
93+
protected def getBrokers = {
94+
val c = consumer
95+
val result = new ju.HashMap[TopicPartition, String]()
96+
val hosts = new ju.HashMap[TopicPartition, String]()
97+
val assignments = c.assignment().iterator()
98+
while (assignments.hasNext()) {
99+
val tp: TopicPartition = assignments.next()
100+
if (null == hosts.get(tp)) {
101+
val infos = c.partitionsFor(tp.topic).iterator()
102+
while (infos.hasNext()) {
103+
val i = infos.next()
104+
hosts.put(new TopicPartition(i.topic(), i.partition()), i.leader.host())
105+
}
106+
}
107+
result.put(tp, hosts.get(tp))
108+
}
109+
result
110+
}
111+
112+
protected def getPreferredHosts: ju.Map[TopicPartition, String] = {
113+
locationStrategy match {
114+
case PreferBrokers => getBrokers
115+
case PreferConsistent => ju.Collections.emptyMap[TopicPartition, String]()
116+
case PreferFixed(hostMap) => hostMap
117+
}
118+
}
119+
120+
private[streaming] override def name: String = s"Kafka 0.10 direct stream [$id]"
121+
122+
protected[streaming] override val checkpointData =
123+
new DirectKafkaInputDStreamCheckpointData
124+
125+
/**
126+
* Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
127+
*/
128+
override protected[streaming] val rateController: Option[RateController] = {
129+
if (RateController.isBackPressureEnabled(ssc.conf)) {
130+
Some(new DirectKafkaRateController(id,
131+
RateEstimator.create(ssc.conf, context.graph.batchDuration)))
132+
} else {
133+
None
134+
}
135+
}
136+
137+
protected[streaming] def maxMessagesPerPartition(
138+
offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = {
139+
val estimatedRateLimit = rateController.map(_.getLatestRate())
140+
141+
// calculate a per-partition rate limit based on current lag
142+
val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {
143+
case Some(rate) =>
144+
val lagPerPartition = offsets.map { case (tp, offset) =>
145+
tp -> Math.max(offset - currentOffsets(tp), 0)
146+
}
147+
val totalLag = lagPerPartition.values.sum
148+
149+
lagPerPartition.map { case (tp, lag) =>
150+
val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp)
151+
val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
152+
tp -> (if (maxRateLimitPerPartition > 0) {
153+
Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
154+
}
155+
case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) }
156+
}
157+
158+
if (effectiveRateLimitPerPartition.values.sum > 0) {
159+
val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
160+
Some(effectiveRateLimitPerPartition.map {
161+
case (tp, limit) => tp -> (secsPerBatch * limit).toLong
162+
})
163+
} else {
164+
None
165+
}
166+
}
167+
168+
/**
169+
* The concern here is that poll might consume messages despite being paused,
170+
* which would throw off consumer position. Fix position if this happens.
171+
*/
172+
private def paranoidPoll(c: Consumer[K, V]): Unit = {
173+
val msgs = c.poll(0)
174+
175+
val newAssignment = c.assignment()
176+
val parts = if (currentOffsets.size < newAssignment.size()) {
177+
newAssignment
178+
} else currentOffsets.keySet.asJava
179+
180+
if (serviceConsumer.assignment().size() < parts.size) {
181+
serviceConsumer.assign(parts)
182+
}
183+
if (!msgs.isEmpty) {
184+
val waitingForAssigmentTimeout = SparkEnv.get.conf.
185+
getLong("spark.mapr.WaitingForAssignmentTimeout", 600000)
186+
// position should be minimum offset per topicpartition
187+
msgs.asScala.foldLeft(Map[TopicPartition, Long]()) { (acc, m) =>
188+
val tp = new TopicPartition(m.topic, m.partition)
189+
val off = acc.get(tp).map(o => Math.min(o, m.offset)).getOrElse(m.offset)
190+
acc + (tp -> off)
191+
}.foreach { case (tp, off) =>
192+
logInfo(s"poll(0) returned messages, seeking $tp to $off to compensate")
193+
serviceConsumer.seek(tp, off)
194+
withRetries(waitingForAssigmentTimeout)(c.seek(tp, off))
195+
}
196+
}
197+
}
198+
199+
@tailrec
200+
private def withRetries[T](t: Long)(f: => T): T = {
201+
Try(f) match {
202+
case Success(v) => v
203+
case _ if t > 0 =>
204+
Try(Thread.sleep(500))
205+
withRetries(t-500)(f)
206+
case Failure(e) => throw e
207+
}
208+
}
209+
210+
/**
211+
* Returns the latest (highest) available offsets, taking new partitions into account.
212+
*/
213+
protected def latestOffsets(): Map[TopicPartition, Long] = {
214+
val c = consumer
215+
paranoidPoll(c)
216+
217+
val parts = c.assignment().asScala
218+
219+
if (parts.size < currentOffsets.keySet.size) {
220+
logWarning("Assignment() returned fewer partitions than the previous call")
221+
}
222+
223+
if (serviceConsumer.assignment().size() < parts.size) {
224+
serviceConsumer.assign(parts.asJava)
225+
}
226+
227+
// make sure new partitions are reflected in currentOffsets
228+
val newPartitions = parts.diff(currentOffsets.keySet)
229+
230+
// Check if there's any partition been revoked because of consumer rebalance.
231+
val revokedPartitions = currentOffsets.keySet.diff(parts)
232+
if (revokedPartitions.nonEmpty) {
233+
throw new IllegalStateException(s"Previously tracked partitions " +
234+
s"${revokedPartitions.mkString("[", ",", "]")} been revoked by Kafka because of consumer " +
235+
s"rebalance. This is mostly due to another stream with same group id joined, " +
236+
s"please check if there're different streaming application misconfigure to use same " +
237+
s"group id. Fundamentally different stream should use different group id")
238+
}
239+
240+
// position for new partitions determined by auto.offset.reset if no commit
241+
currentOffsets = currentOffsets ++ newPartitions
242+
.map(tp => tp -> serviceConsumer.position(tp)).toMap
243+
// don't want to consume messages, so pause
244+
c.pause(newPartitions.asJava)
245+
// find latest available offsets
246+
247+
if (!serviceConsumer.assignment().isEmpty) {
248+
serviceConsumer.seekToEnd(currentOffsets.keySet.asJava)
249+
}
250+
251+
// c.seekToEnd(currentOffsets.keySet.asJava)
252+
parts.map(tp => tp -> serviceConsumer.position(tp)).toMap
253+
}
254+
255+
// limits the maximum number of messages per partition
256+
protected def clamp(
257+
offsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
258+
259+
maxMessagesPerPartition(offsets).map { mmp =>
260+
mmp.map { case (tp, messages) =>
261+
val uo = offsets(tp)
262+
tp -> Math.min(currentOffsets(tp) + messages, uo)
263+
}
264+
}.getOrElse(offsets)
265+
}
266+
267+
override def compute(validTime: Time): Option[KafkaRDD[K, V]] = {
268+
val untilOffsets = clamp(latestOffsets())
269+
val offsetRanges = untilOffsets.map { case (tp, uo) =>
270+
val fo = currentOffsets(tp)
271+
OffsetRange(tp.topic, tp.partition, fo, uo)
272+
}
273+
val useConsumerCache = context.conf.get(CONSUMER_CACHE_ENABLED)
274+
val rdd = new KafkaRDD[K, V](context.sparkContext, executorKafkaParams, offsetRanges.toArray,
275+
getPreferredHosts, useConsumerCache)
276+
277+
// Report the record number and metadata of this batch interval to InputInfoTracker.
278+
val description = offsetRanges.filter { offsetRange =>
279+
// Don't display empty ranges.
280+
offsetRange.fromOffset != offsetRange.untilOffset
281+
}.toSeq.sortBy(-_.count()).map { offsetRange =>
282+
s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" +
283+
s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}\t" +
284+
s"count: ${offsetRange.count()}"
285+
}.mkString("\n")
286+
// Copy offsetRanges to immutable.List to prevent from being modified by the user
287+
val metadata = Map(
288+
"offsets" -> offsetRanges.toList,
289+
StreamInputInfo.METADATA_KEY_DESCRIPTION -> description)
290+
val inputInfo = StreamInputInfo(id, rdd.count, metadata)
291+
ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
292+
293+
currentOffsets = untilOffsets
294+
commitAll()
295+
Some(rdd)
296+
}
297+
298+
override def start(): Unit = {
299+
val c = consumer
300+
paranoidPoll(c)
301+
if (currentOffsets.isEmpty) {
302+
currentOffsets = c.assignment().asScala.map { tp =>
303+
tp -> c.position(tp)
304+
}.toMap
305+
}
306+
307+
// don't actually want to consume any messages, so pause all partitions
308+
c.pause(currentOffsets.keySet.asJava)
309+
}
310+
311+
override def stop(): Unit = this.synchronized {
312+
if (kc != null) {
313+
kc.close()
314+
}
315+
316+
serviceConsumer.close()
317+
}
318+
319+
protected val commitQueue = new ConcurrentLinkedQueue[OffsetRange]
320+
protected val commitCallback = new AtomicReference[OffsetCommitCallback]
321+
322+
/**
323+
* Queue up offset ranges for commit to Kafka at a future time. Threadsafe.
324+
* @param offsetRanges The maximum untilOffset for a given partition will be used at commit.
325+
*/
326+
def commitAsync(offsetRanges: Array[OffsetRange]): Unit = {
327+
commitAsync(offsetRanges, null)
328+
}
329+
330+
/**
331+
* Queue up offset ranges for commit to Kafka at a future time. Threadsafe.
332+
* @param offsetRanges The maximum untilOffset for a given partition will be used at commit.
333+
* @param callback Only the most recently provided callback will be used at commit.
334+
*/
335+
def commitAsync(offsetRanges: Array[OffsetRange], callback: OffsetCommitCallback): Unit = {
336+
commitCallback.set(callback)
337+
commitQueue.addAll(ju.Arrays.asList(offsetRanges: _*))
338+
}
339+
340+
protected def commitAll(): Unit = {
341+
val m = new ju.HashMap[TopicPartition, OffsetAndMetadata]()
342+
var osr = commitQueue.poll()
343+
while (null != osr) {
344+
val tp = osr.topicPartition
345+
val x = m.get(tp)
346+
val offset = if (null == x) { osr.untilOffset } else { Math.max(x.offset, osr.untilOffset) }
347+
m.put(tp, new OffsetAndMetadata(offset))
348+
osr = commitQueue.poll()
349+
}
350+
if (!m.isEmpty) {
351+
if (KafkaUtils.isStreams(currentOffsets)) {
352+
serviceConsumer.commitAsync(m, commitCallback.get)
353+
} else {
354+
consumer.commitAsync(m, commitCallback.get)
355+
}
356+
}
357+
}
358+
359+
private[streaming]
360+
class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) {
361+
def batchForTime: mutable.HashMap[Time, Array[(String, Int, Long, Long)]] = {
362+
data.asInstanceOf[mutable.HashMap[Time, Array[OffsetRange.OffsetRangeTuple]]]
363+
}
364+
365+
override def update(time: Time): Unit = {
366+
batchForTime.clear()
367+
generatedRDDs.foreach { kv =>
368+
val a = kv._2.asInstanceOf[KafkaRDD[K, V]].offsetRanges.map(_.toTuple)
369+
batchForTime += kv._1 -> a
370+
}
371+
}
372+
373+
override def cleanup(time: Time): Unit = { }
374+
375+
override def restore(): Unit = {
376+
batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) =>
377+
logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}")
378+
generatedRDDs += t -> new KafkaRDD[K, V](
379+
context.sparkContext,
380+
executorKafkaParams,
381+
b.map(OffsetRange(_)),
382+
getPreferredHosts,
383+
// during restore, it's possible same partition will be consumed from multiple
384+
// threads, so do not use cache.
385+
false
386+
)
387+
}
388+
}
389+
}
390+
391+
/**
392+
* A RateController to retrieve the rate from RateEstimator.
393+
*/
394+
private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator)
395+
extends RateController(id, estimator) {
396+
override def publish(rate: Long): Unit = ()
397+
}
398+
}

0 commit comments

Comments
 (0)