Skip to content

Commit 5663386

Browse files
gaborgsomogyicloud-fan
authored andcommitted
[SPARK-28163][SS] Use CaseInsensitiveMap for KafkaOffsetReader
## What changes were proposed in this pull request? There are "unsafe" conversions in the Kafka connector. `CaseInsensitiveStringMap` comes in which is then converted the following way: ``` ... options.asScala.toMap ... ``` The main problem with this is that such case it looses its case insensitive nature (case insensitive map is converting the key to lower case when get/contains called). In this PR I'm using `CaseInsensitiveMap` to solve this problem. ## How was this patch tested? Existing + additional unit tests. Closes #24967 from gaborgsomogyi/SPARK-28163. Authored-by: Gabor Somogyi <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 5368eaa commit 5663386

File tree

7 files changed

+112
-89
lines changed

7 files changed

+112
-89
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatch.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ import org.apache.kafka.common.TopicPartition
2222
import org.apache.spark.SparkEnv
2323
import org.apache.spark.internal.Logging
2424
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
25+
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
2526
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory}
2627

2728

2829
private[kafka010] class KafkaBatch(
2930
strategy: ConsumerStrategy,
30-
sourceOptions: Map[String, String],
31+
sourceOptions: CaseInsensitiveMap[String],
3132
specifiedKafkaParams: Map[String, String],
3233
failOnDataLoss: Boolean,
3334
startingOffsets: KafkaOffsetRangeLimit,
@@ -38,7 +39,7 @@ private[kafka010] class KafkaBatch(
3839
assert(endingOffsets != EarliestOffsetRangeLimit,
3940
"Ending offset not allowed to be set to earliest offsets.")
4041

41-
private val pollTimeoutMs = sourceOptions.getOrElse(
42+
private[kafka010] val pollTimeoutMs = sourceOptions.getOrElse(
4243
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
4344
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
4445
).toLong

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
4646
* properly read.
4747
*/
4848
class KafkaContinuousStream(
49-
offsetReader: KafkaOffsetReader,
49+
private[kafka010] val offsetReader: KafkaOffsetReader,
5050
kafkaParams: ju.Map[String, Object],
5151
options: CaseInsensitiveStringMap,
5252
metadataPath: String,
5353
initialOffsets: KafkaOffsetRangeLimit,
5454
failOnDataLoss: Boolean)
5555
extends ContinuousStream with Logging {
5656

57-
private val pollTimeoutMs =
57+
private[kafka010] val pollTimeoutMs =
5858
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
5959

6060
// Initialized when creating reader factories. If this diverges from the partitions at the latest

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,19 @@ import org.apache.spark.util.UninterruptibleThread
5656
* and not use wrong broker addresses.
5757
*/
5858
private[kafka010] class KafkaMicroBatchStream(
59-
kafkaOffsetReader: KafkaOffsetReader,
59+
private[kafka010] val kafkaOffsetReader: KafkaOffsetReader,
6060
executorKafkaParams: ju.Map[String, Object],
6161
options: CaseInsensitiveStringMap,
6262
metadataPath: String,
6363
startingOffsets: KafkaOffsetRangeLimit,
6464
failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging {
6565

66-
private val pollTimeoutMs = options.getLong(
66+
private[kafka010] val pollTimeoutMs = options.getLong(
6767
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
6868
SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L)
6969

70-
private val maxOffsetsPerTrigger = Option(options.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER))
71-
.map(_.toLong)
70+
private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
71+
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)
7272

7373
private val rangeCalculator = KafkaOffsetRangeCalculator(options)
7474

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsume
3030
import org.apache.kafka.common.TopicPartition
3131

3232
import org.apache.spark.internal.Logging
33+
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
3334
import org.apache.spark.sql.types._
3435
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
3536

@@ -47,7 +48,7 @@ import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
4748
private[kafka010] class KafkaOffsetReader(
4849
consumerStrategy: ConsumerStrategy,
4950
val driverKafkaParams: ju.Map[String, Object],
50-
readerOptions: Map[String, String],
51+
readerOptions: CaseInsensitiveMap[String],
5152
driverGroupIdPrefix: String) extends Logging {
5253
/**
5354
* Used to ensure execute fetch operations execute in an UninterruptibleThread
@@ -88,10 +89,10 @@ private[kafka010] class KafkaOffsetReader(
8889
_consumer
8990
}
9091

91-
private val maxOffsetFetchAttempts =
92+
private[kafka010] val maxOffsetFetchAttempts =
9293
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, "3").toInt
9394

94-
private val offsetFetchAttemptIntervalMs =
95+
private[kafka010] val offsetFetchAttemptIntervalMs =
9596
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong
9697

9798
private def nextGroupId(): String = {

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.{Row, SQLContext}
2626
import org.apache.spark.sql.catalyst.InternalRow
27-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
27+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
2828
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
2929
import org.apache.spark.sql.types.StructType
3030
import org.apache.spark.unsafe.types.UTF8String
@@ -33,7 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String
3333
private[kafka010] class KafkaRelation(
3434
override val sqlContext: SQLContext,
3535
strategy: ConsumerStrategy,
36-
sourceOptions: Map[String, String],
36+
sourceOptions: CaseInsensitiveMap[String],
3737
specifiedKafkaParams: Map[String, String],
3838
failOnDataLoss: Boolean,
3939
startingOffsets: KafkaOffsetRangeLimit,

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -78,32 +78,32 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
7878
schema: Option[StructType],
7979
providerName: String,
8080
parameters: Map[String, String]): Source = {
81-
validateStreamOptions(parameters)
81+
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
82+
validateStreamOptions(caseInsensitiveParameters)
8283
// Each running query should use its own group id. Otherwise, the query may be only assigned
8384
// partial data since Kafka will assign partitions to multiple consumers having the same group
8485
// id. Hence, we should generate a unique id for each query.
85-
val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath)
86+
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath)
8687

87-
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
8888
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
8989

90-
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
91-
STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
90+
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
91+
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
9292

9393
val kafkaOffsetReader = new KafkaOffsetReader(
94-
strategy(caseInsensitiveParams),
94+
strategy(caseInsensitiveParameters),
9595
kafkaParamsForDriver(specifiedKafkaParams),
96-
parameters,
96+
caseInsensitiveParameters,
9797
driverGroupIdPrefix = s"$uniqueGroupId-driver")
9898

9999
new KafkaSource(
100100
sqlContext,
101101
kafkaOffsetReader,
102102
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
103-
parameters,
103+
caseInsensitiveParameters,
104104
metadataPath,
105105
startingStreamOffsets,
106-
failOnDataLoss(caseInsensitiveParams))
106+
failOnDataLoss(caseInsensitiveParameters))
107107
}
108108

109109
override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
@@ -119,24 +119,24 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
119119
override def createRelation(
120120
sqlContext: SQLContext,
121121
parameters: Map[String, String]): BaseRelation = {
122-
validateBatchOptions(parameters)
123-
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
122+
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
123+
validateBatchOptions(caseInsensitiveParameters)
124124
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
125125

126126
val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
127-
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
127+
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
128128
assert(startingRelationOffsets != LatestOffsetRangeLimit)
129129

130-
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
131-
ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
130+
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
131+
caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
132132
assert(endingRelationOffsets != EarliestOffsetRangeLimit)
133133

134134
new KafkaRelation(
135135
sqlContext,
136-
strategy(caseInsensitiveParams),
137-
sourceOptions = parameters,
136+
strategy(caseInsensitiveParameters),
137+
sourceOptions = caseInsensitiveParameters,
138138
specifiedKafkaParams = specifiedKafkaParams,
139-
failOnDataLoss = failOnDataLoss(caseInsensitiveParams),
139+
failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
140140
startingOffsets = startingRelationOffsets,
141141
endingOffsets = endingRelationOffsets)
142142
}
@@ -420,23 +420,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
420420
}
421421

422422
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
423-
val parameters = options.asScala.toMap
424-
validateStreamOptions(parameters)
423+
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
424+
validateStreamOptions(caseInsensitiveOptions)
425425
// Each running query should use its own group id. Otherwise, the query may be only assigned
426426
// partial data since Kafka will assign partitions to multiple consumers having the same group
427427
// id. Hence, we should generate a unique id for each query.
428-
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
428+
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
429429

430-
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
431-
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
430+
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
432431

433432
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
434-
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
433+
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
435434

436435
val kafkaOffsetReader = new KafkaOffsetReader(
437-
strategy(parameters),
436+
strategy(caseInsensitiveOptions),
438437
kafkaParamsForDriver(specifiedKafkaParams),
439-
parameters,
438+
caseInsensitiveOptions,
440439
driverGroupIdPrefix = s"$uniqueGroupId-driver")
441440

442441
new KafkaMicroBatchStream(
@@ -445,32 +444,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
445444
options,
446445
checkpointLocation,
447446
startingStreamOffsets,
448-
failOnDataLoss(caseInsensitiveParams))
447+
failOnDataLoss(caseInsensitiveOptions))
449448
}
450449

451450
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
452-
val parameters = options.asScala.toMap
453-
validateStreamOptions(parameters)
451+
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
452+
validateStreamOptions(caseInsensitiveOptions)
454453
// Each running query should use its own group id. Otherwise, the query may be only assigned
455454
// partial data since Kafka will assign partitions to multiple consumers having the same group
456455
// id. Hence, we should generate a unique id for each query.
457-
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
456+
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)
458457

459-
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
460-
val specifiedKafkaParams =
461-
parameters
462-
.keySet
463-
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
464-
.map { k => k.drop(6).toString -> parameters(k) }
465-
.toMap
458+
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
466459

467460
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
468-
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
461+
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
469462

470463
val kafkaOffsetReader = new KafkaOffsetReader(
471-
strategy(caseInsensitiveParams),
464+
strategy(caseInsensitiveOptions),
472465
kafkaParamsForDriver(specifiedKafkaParams),
473-
parameters,
466+
caseInsensitiveOptions,
474467
driverGroupIdPrefix = s"$uniqueGroupId-driver")
475468

476469
new KafkaContinuousStream(
@@ -479,7 +472,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
479472
options,
480473
checkpointLocation,
481474
startingStreamOffsets,
482-
failOnDataLoss(caseInsensitiveParams))
475+
failOnDataLoss(caseInsensitiveOptions))
483476
}
484477
}
485478
}

0 commit comments

Comments
 (0)