|
18 | 18 | package org.apache.spark.sql.kafka010 |
19 | 19 |
|
20 | 20 | import java.{util => ju} |
21 | | -import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit} |
| 21 | +import java.util.concurrent.TimeUnit |
22 | 22 |
|
23 | | -import com.google.common.cache._ |
24 | | -import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} |
25 | 23 | import org.apache.kafka.clients.producer.KafkaProducer |
26 | 24 | import scala.collection.JavaConverters._ |
| 25 | +import scala.collection.mutable |
27 | 26 | import scala.util.control.NonFatal |
28 | 27 |
|
29 | 28 | import org.apache.spark.SparkEnv |
30 | 29 | import org.apache.spark.internal.Logging |
31 | 30 | import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaRedactionUtil} |
32 | 31 |
|
33 | | -private[kafka010] object CachedKafkaProducer extends Logging { |
| 32 | +private[kafka010] class CachedKafkaProducer( |
| 33 | + val cacheKey: Seq[(String, Object)], |
| 34 | + val producer: KafkaProducer[Array[Byte], Array[Byte]]) extends Logging { |
| 35 | + val id: String = ju.UUID.randomUUID().toString |
34 | 36 |
|
35 | | - private type Producer = KafkaProducer[Array[Byte], Array[Byte]] |
| 37 | + private def close(): Unit = { |
| 38 | + try { |
| 39 | + logInfo(s"Closing the KafkaProducer with id: $id.") |
| 40 | + producer.close() |
| 41 | + } catch { |
| 42 | + case NonFatal(e) => logWarning("Error while closing kafka producer.", e) |
| 43 | + } |
| 44 | + } |
| 45 | +} |
36 | 46 |
|
37 | | - private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10) |
| 47 | +private[kafka010] object CachedKafkaProducer extends Logging { |
| 48 | + private type CacheKey = Seq[(String, Object)] |
| 49 | + private type Producer = KafkaProducer[Array[Byte], Array[Byte]] |
38 | 50 |
|
39 | | - private lazy val cacheExpireTimeout: Long = Option(SparkEnv.get) |
40 | | - .map(_.conf.get(PRODUCER_CACHE_TIMEOUT)) |
41 | | - .getOrElse(defaultCacheExpireTimeout) |
| 51 | + /** |
| 52 | + * This class is used as metadata of cache, and shouldn't be exposed to the public (it would be |
| 53 | + * fine for testing). This class assumes thread-safety is guaranteed by the caller. |
| 54 | + */ |
| 55 | + class CachedProducerEntry(val producer: CachedKafkaProducer) { |
| 56 | + private var refCount: Long = 0L |
| 57 | + private var expireAt: Long = Long.MaxValue |
42 | 58 |
|
43 | | - private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] { |
44 | | - override def load(config: Seq[(String, Object)]): Producer = { |
45 | | - createKafkaProducer(config) |
| 59 | + def handleBorrowed(): Unit = { |
| 60 | + refCount += 1 |
| 61 | + expireAt = Long.MaxValue |
46 | 62 | } |
47 | | - } |
48 | 63 |
|
49 | | - private val removalListener = new RemovalListener[Seq[(String, Object)], Producer]() { |
50 | | - override def onRemoval( |
51 | | - notification: RemovalNotification[Seq[(String, Object)], Producer]): Unit = { |
52 | | - val paramsSeq: Seq[(String, Object)] = notification.getKey |
53 | | - val producer: Producer = notification.getValue |
54 | | - if (log.isDebugEnabled()) { |
55 | | - val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) |
56 | | - logDebug(s"Evicting kafka producer $producer params: $redactedParamsSeq, " + |
57 | | - s"due to ${notification.getCause}") |
| 64 | + def handleReturned(): Unit = { |
| 65 | + refCount -= 1 |
| 66 | + if (refCount <= 0) { |
| 67 | + expireAt = System.currentTimeMillis() + cacheExpireTimeout |
58 | 68 | } |
59 | | - close(paramsSeq, producer) |
60 | 69 | } |
61 | | - } |
62 | 70 |
|
63 | | - private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] = |
64 | | - CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, TimeUnit.MILLISECONDS) |
65 | | - .removalListener(removalListener) |
66 | | - .build[Seq[(String, Object)], Producer](cacheLoader) |
| 71 | + def expired: Boolean = refCount <= 0 && expireAt < System.currentTimeMillis() |
67 | 72 |
|
68 | | - private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer = { |
69 | | - val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava) |
70 | | - if (log.isDebugEnabled()) { |
71 | | - val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) |
72 | | - logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq.") |
| 73 | + /** expose for testing, don't call otherwise */ |
| 74 | + private[kafka010] def injectDebugValues(refCnt: Long, expire: Long): Unit = { |
| 75 | + refCount = refCnt |
| 76 | + expireAt = expire |
73 | 77 | } |
74 | | - kafkaProducer |
75 | 78 | } |
76 | 79 |
|
| 80 | + private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10) |
| 81 | + |
| 82 | + private lazy val cacheExpireTimeout: Long = Option(SparkEnv.get) |
| 83 | + .map(_.conf.get(PRODUCER_CACHE_TIMEOUT)) |
| 84 | + .getOrElse(defaultCacheExpireTimeout) |
| 85 | + |
| 86 | + private var acquireCount: Long = 0 |
| 87 | + |
| 88 | + private val cache = new mutable.HashMap[CacheKey, CachedProducerEntry] |
| 89 | + |
77 | 90 | /** |
78 | 91 | * Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't |
79 | 92 | * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep |
80 | 93 | * one instance per specified kafkaParams. |
81 | 94 | */ |
82 | | - private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): Producer = { |
| 95 | + private[kafka010] def acquire(kafkaParams: ju.Map[String, Object]): CachedKafkaProducer = { |
| 96 | + acquireCount += 1 |
| 97 | + if (acquireCount % 100 == 0) { |
| 98 | + expire() |
| 99 | + } |
| 100 | + |
83 | 101 | val updatedKafkaProducerConfiguration = |
84 | 102 | KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) |
85 | 103 | .setAuthenticationConfigIfNeeded() |
86 | 104 | .build() |
87 | 105 | val paramsSeq: Seq[(String, Object)] = paramsToSeq(updatedKafkaProducerConfiguration) |
88 | | - try { |
89 | | - guavaCache.get(paramsSeq) |
90 | | - } catch { |
91 | | - case e @ (_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError) |
92 | | - if e.getCause != null => |
93 | | - throw e.getCause |
| 106 | + synchronized { |
| 107 | + val entry = cache.getOrElseUpdate(paramsSeq, { |
| 108 | + val producer = createKafkaProducer(paramsSeq) |
| 109 | + val cachedProducer = new CachedKafkaProducer(paramsSeq, producer) |
| 110 | + new CachedProducerEntry(cachedProducer) |
| 111 | + }) |
| 112 | + entry.handleBorrowed() |
| 113 | + entry.producer |
94 | 114 | } |
95 | 115 | } |
96 | 116 |
|
97 | | - private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = { |
98 | | - val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1) |
99 | | - paramsSeq |
| 117 | + private[kafka010] def release(producer: CachedKafkaProducer): Unit = { |
| 118 | + def closeProducerNotInCache(producer: CachedKafkaProducer): Unit = { |
| 119 | + logWarning(s"Released producer ${producer.id} is not a member of the cache. Closing.") |
| 120 | + producer.close() |
| 121 | + } |
| 122 | + |
| 123 | + synchronized { |
| 124 | + cache.get(producer.cacheKey) match { |
| 125 | + case Some(entry) if entry.producer.id == producer.id => entry.handleReturned() |
| 126 | + case _ => closeProducerNotInCache(producer) |
| 127 | + } |
| 128 | + } |
100 | 129 | } |
101 | 130 |
|
102 | | - /** For explicitly closing kafka producer */ |
103 | | - private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = { |
104 | | - val paramsSeq = paramsToSeq(kafkaParams) |
105 | | - guavaCache.invalidate(paramsSeq) |
| 131 | + private def removeFromCache(key: CacheKey): Unit = { |
| 132 | + cache.remove(key).foreach { instance => instance.producer.close() } |
106 | 133 | } |
107 | 134 |
|
108 | | - /** Auto close on cache evict */ |
109 | | - private def close(paramsSeq: Seq[(String, Object)], producer: Producer): Unit = { |
110 | | - try { |
111 | | - if (log.isInfoEnabled()) { |
112 | | - val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) |
113 | | - logInfo(s"Closing the KafkaProducer with params: ${redactedParamsSeq.mkString("\n")}.") |
114 | | - } |
115 | | - producer.close() |
116 | | - } catch { |
117 | | - case NonFatal(e) => logWarning("Error while closing kafka producer.", e) |
| 135 | + /** expose for testing */ |
| 136 | + private[kafka010] def expire(): Unit = synchronized { |
| 137 | + cache.filter { case (_, v) => v.expired }.keys.foreach(removeFromCache) |
| 138 | + } |
| 139 | + |
| 140 | + private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer = { |
| 141 | + val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava) |
| 142 | + if (log.isDebugEnabled()) { |
| 143 | + val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) |
| 144 | + logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq.") |
118 | 145 | } |
| 146 | + kafkaProducer |
| 147 | + } |
| 148 | + |
| 149 | + private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = { |
| 150 | + val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1) |
| 151 | + paramsSeq |
119 | 152 | } |
120 | 153 |
|
121 | 154 | private[kafka010] def clear(): Unit = { |
122 | | - logInfo("Cleaning up guava cache.") |
123 | | - guavaCache.invalidateAll() |
| 155 | + logInfo("Cleaning up cache.") |
| 156 | + synchronized { |
| 157 | + cache.keys.foreach(removeFromCache) |
| 158 | + } |
124 | 159 | } |
125 | 160 |
|
126 | 161 | // Intended for testing purpose only. |
127 | | - private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = guavaCache.asMap() |
| 162 | + private def getAsMap: Map[CacheKey, CachedProducerEntry] = cache.toMap |
128 | 163 | } |
0 commit comments