Skip to content

Commit ae2b607

Browse files
committed
Implement custom pool for spark-sql-kafka producer, to support 'expire' properly
1 parent d773873 commit ae2b607

4 files changed

Lines changed: 260 additions & 114 deletions

File tree

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala

Lines changed: 96 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -18,111 +18,146 @@
1818
package org.apache.spark.sql.kafka010
1919

2020
import java.{util => ju}
21-
import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit}
21+
import java.util.concurrent.TimeUnit
2222

23-
import com.google.common.cache._
24-
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}
2523
import org.apache.kafka.clients.producer.KafkaProducer
2624
import scala.collection.JavaConverters._
25+
import scala.collection.mutable
2726
import scala.util.control.NonFatal
2827

2928
import org.apache.spark.SparkEnv
3029
import org.apache.spark.internal.Logging
3130
import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaRedactionUtil}
3231

33-
private[kafka010] object CachedKafkaProducer extends Logging {
32+
private[kafka010] class CachedKafkaProducer(
33+
val cacheKey: Seq[(String, Object)],
34+
val producer: KafkaProducer[Array[Byte], Array[Byte]]) extends Logging {
35+
val id: String = ju.UUID.randomUUID().toString
3436

35-
private type Producer = KafkaProducer[Array[Byte], Array[Byte]]
37+
private def close(): Unit = {
38+
try {
39+
logInfo(s"Closing the KafkaProducer with id: $id.")
40+
producer.close()
41+
} catch {
42+
case NonFatal(e) => logWarning("Error while closing kafka producer.", e)
43+
}
44+
}
45+
}
3646

37-
private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10)
47+
private[kafka010] object CachedKafkaProducer extends Logging {
48+
private type CacheKey = Seq[(String, Object)]
49+
private type Producer = KafkaProducer[Array[Byte], Array[Byte]]
3850

39-
private lazy val cacheExpireTimeout: Long = Option(SparkEnv.get)
40-
.map(_.conf.get(PRODUCER_CACHE_TIMEOUT))
41-
.getOrElse(defaultCacheExpireTimeout)
51+
/**
52+
* This class is used as metadata of cache, and shouldn't be exposed to the public (it would be
53+
* fine for testing). This class assumes thread-safety is guaranteed by the caller.
54+
*/
55+
class CachedProducerEntry(val producer: CachedKafkaProducer) {
56+
private var refCount: Long = 0L
57+
private var expireAt: Long = Long.MaxValue
4258

43-
private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] {
44-
override def load(config: Seq[(String, Object)]): Producer = {
45-
createKafkaProducer(config)
59+
def handleBorrowed(): Unit = {
60+
refCount += 1
61+
expireAt = Long.MaxValue
4662
}
47-
}
4863

49-
private val removalListener = new RemovalListener[Seq[(String, Object)], Producer]() {
50-
override def onRemoval(
51-
notification: RemovalNotification[Seq[(String, Object)], Producer]): Unit = {
52-
val paramsSeq: Seq[(String, Object)] = notification.getKey
53-
val producer: Producer = notification.getValue
54-
if (log.isDebugEnabled()) {
55-
val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
56-
logDebug(s"Evicting kafka producer $producer params: $redactedParamsSeq, " +
57-
s"due to ${notification.getCause}")
64+
def handleReturned(): Unit = {
65+
refCount -= 1
66+
if (refCount <= 0) {
67+
expireAt = System.currentTimeMillis() + cacheExpireTimeout
5868
}
59-
close(paramsSeq, producer)
6069
}
61-
}
6270

63-
private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] =
64-
CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, TimeUnit.MILLISECONDS)
65-
.removalListener(removalListener)
66-
.build[Seq[(String, Object)], Producer](cacheLoader)
71+
def expired: Boolean = refCount <= 0 && expireAt < System.currentTimeMillis()
6772

68-
private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer = {
69-
val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava)
70-
if (log.isDebugEnabled()) {
71-
val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
72-
logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq.")
73+
/** expose for testing, don't call otherwise */
74+
private[kafka010] def injectDebugValues(refCnt: Long, expire: Long): Unit = {
75+
refCount = refCnt
76+
expireAt = expire
7377
}
74-
kafkaProducer
7578
}
7679

80+
private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10)
81+
82+
private lazy val cacheExpireTimeout: Long = Option(SparkEnv.get)
83+
.map(_.conf.get(PRODUCER_CACHE_TIMEOUT))
84+
.getOrElse(defaultCacheExpireTimeout)
85+
86+
private var acquireCount: Long = 0
87+
88+
private val cache = new mutable.HashMap[CacheKey, CachedProducerEntry]
89+
7790
/**
7891
* Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't
7992
* exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep
8093
* one instance per specified kafkaParams.
8194
*/
82-
private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): Producer = {
95+
private[kafka010] def acquire(kafkaParams: ju.Map[String, Object]): CachedKafkaProducer = {
96+
acquireCount += 1
97+
if (acquireCount % 100 == 0) {
98+
expire()
99+
}
100+
83101
val updatedKafkaProducerConfiguration =
84102
KafkaConfigUpdater("executor", kafkaParams.asScala.toMap)
85103
.setAuthenticationConfigIfNeeded()
86104
.build()
87105
val paramsSeq: Seq[(String, Object)] = paramsToSeq(updatedKafkaProducerConfiguration)
88-
try {
89-
guavaCache.get(paramsSeq)
90-
} catch {
91-
case e @ (_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError)
92-
if e.getCause != null =>
93-
throw e.getCause
106+
synchronized {
107+
val entry = cache.getOrElseUpdate(paramsSeq, {
108+
val producer = createKafkaProducer(paramsSeq)
109+
val cachedProducer = new CachedKafkaProducer(paramsSeq, producer)
110+
new CachedProducerEntry(cachedProducer)
111+
})
112+
entry.handleBorrowed()
113+
entry.producer
94114
}
95115
}
96116

97-
private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = {
98-
val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1)
99-
paramsSeq
117+
private[kafka010] def release(producer: CachedKafkaProducer): Unit = {
118+
def closeProducerNotInCache(producer: CachedKafkaProducer): Unit = {
119+
logWarning(s"Released producer ${producer.id} is not a member of the cache. Closing.")
120+
producer.close()
121+
}
122+
123+
synchronized {
124+
cache.get(producer.cacheKey) match {
125+
case Some(entry) if entry.producer.id == producer.id => entry.handleReturned()
126+
case _ => closeProducerNotInCache(producer)
127+
}
128+
}
100129
}
101130

102-
/** For explicitly closing kafka producer */
103-
private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = {
104-
val paramsSeq = paramsToSeq(kafkaParams)
105-
guavaCache.invalidate(paramsSeq)
131+
private def removeFromCache(key: CacheKey): Unit = {
132+
cache.remove(key).foreach { instance => instance.producer.close() }
106133
}
107134

108-
/** Auto close on cache evict */
109-
private def close(paramsSeq: Seq[(String, Object)], producer: Producer): Unit = {
110-
try {
111-
if (log.isInfoEnabled()) {
112-
val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
113-
logInfo(s"Closing the KafkaProducer with params: ${redactedParamsSeq.mkString("\n")}.")
114-
}
115-
producer.close()
116-
} catch {
117-
case NonFatal(e) => logWarning("Error while closing kafka producer.", e)
135+
/** expose for testing */
136+
private[kafka010] def expire(): Unit = synchronized {
137+
cache.filter { case (_, v) => v.expired }.keys.foreach(removeFromCache)
138+
}
139+
140+
private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer = {
141+
val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava)
142+
if (log.isDebugEnabled()) {
143+
val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq)
144+
logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq.")
118145
}
146+
kafkaProducer
147+
}
148+
149+
private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = {
150+
val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1)
151+
paramsSeq
119152
}
120153

121154
private[kafka010] def clear(): Unit = {
122-
logInfo("Cleaning up guava cache.")
123-
guavaCache.invalidateAll()
155+
logInfo("Cleaning up cache.")
156+
synchronized {
157+
cache.keys.foreach(removeFromCache)
158+
}
124159
}
125160

126161
// Intended for testing purpose only.
127-
private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = guavaCache.asMap()
162+
private def getAsMap: Map[CacheKey, CachedProducerEntry] = cache.toMap
128163
}

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,31 +44,44 @@ private[kafka010] class KafkaDataWriter(
4444
inputSchema: Seq[Attribute])
4545
extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] {
4646

47-
private lazy val producer = CachedKafkaProducer.getOrCreate(producerParams)
47+
private var producer: Option[CachedKafkaProducer] = None
4848

4949
def write(row: InternalRow): Unit = {
5050
checkForErrors()
51-
sendRow(row, producer)
51+
if (producer.isEmpty) {
52+
producer = Some(CachedKafkaProducer.acquire(producerParams))
53+
}
54+
producer.foreach { p => sendRow(row, p.producer) }
5255
}
5356

5457
def commit(): WriterCommitMessage = {
55-
// Send is asynchronous, but we can't commit until all rows are actually in Kafka.
56-
// This requires flushing and then checking that no callbacks produced errors.
57-
// We also check for errors before to fail as soon as possible - the check is cheap.
58-
checkForErrors()
59-
producer.flush()
60-
checkForErrors()
61-
KafkaDataWriterCommitMessage
58+
try {
59+
// Send is asynchronous, but we can't commit until all rows are actually in Kafka.
60+
// This requires flushing and then checking that no callbacks produced errors.
61+
// We also check for errors before to fail as soon as possible - the check is cheap.
62+
checkForErrors()
63+
producer.foreach(_.producer.flush())
64+
checkForErrors()
65+
KafkaDataWriterCommitMessage
66+
} finally {
67+
producer.foreach(CachedKafkaProducer.release)
68+
producer = None
69+
}
6270
}
6371

64-
def abort(): Unit = {}
72+
def abort(): Unit = {
73+
producer.foreach(CachedKafkaProducer.release)
74+
producer = None
75+
}
6576

6677
def close(): Unit = {
67-
checkForErrors()
68-
if (producer != null) {
69-
producer.flush()
78+
try {
79+
checkForErrors()
80+
producer.foreach(_.producer.flush())
7081
checkForErrors()
71-
CachedKafkaProducer.close(producerParams)
82+
} finally {
83+
producer.foreach(CachedKafkaProducer.release)
84+
producer = None
7285
}
7386
}
7487
}

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,30 @@ private[kafka010] class KafkaWriteTask(
3939
inputSchema: Seq[Attribute],
4040
topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) {
4141
// used to synchronize with Kafka callbacks
42-
private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
42+
private var producer: Option[CachedKafkaProducer] = None
4343

4444
/**
4545
* Writes key value data out to topics.
4646
*/
4747
def execute(iterator: Iterator[InternalRow]): Unit = {
48-
producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
48+
producer = Some(CachedKafkaProducer.acquire(producerConfiguration))
49+
val internalProducer = producer.get.producer
4950
while (iterator.hasNext && failedWrite == null) {
5051
val currentRow = iterator.next()
51-
sendRow(currentRow, producer)
52+
sendRow(currentRow, internalProducer)
5253
}
5354
}
5455

5556
def close(): Unit = {
56-
checkForErrors()
57-
if (producer != null) {
58-
producer.flush()
57+
try {
5958
checkForErrors()
60-
producer = null
59+
producer.foreach { p =>
60+
p.producer.flush()
61+
checkForErrors()
62+
}
63+
} finally {
64+
producer.foreach(CachedKafkaProducer.release)
65+
producer = None
6166
}
6267
}
6368
}

0 commit comments

Comments
 (0)