Skip to content

Commit 6570665

Browse files
committed
Replace the type of parameters with CaseInsensitiveMap.
1 parent 8fbd18d commit 6570665

10 files changed

Lines changed: 82 additions & 90 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,45 +31,40 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs
3131
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
3232
*/
3333
private[sql] class JSONOptions(
34-
@transient private val parameters: Map[String, String])
34+
@transient private val parameters: CaseInsensitiveMap)
3535
extends Logging with Serializable {
3636

37-
private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters)
37+
def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
3838

3939
val samplingRatio =
40-
caseInsensitiveOptions.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
40+
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
4141
val primitivesAsString =
42-
caseInsensitiveOptions.get("primitivesAsString").map(_.toBoolean).getOrElse(false)
42+
parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false)
4343
val prefersDecimal =
44-
caseInsensitiveOptions.get("prefersDecimal").map(_.toBoolean).getOrElse(false)
44+
parameters.get("prefersDecimal").map(_.toBoolean).getOrElse(false)
4545
val allowComments =
46-
caseInsensitiveOptions.get("allowComments").map(_.toBoolean).getOrElse(false)
46+
parameters.get("allowComments").map(_.toBoolean).getOrElse(false)
4747
val allowUnquotedFieldNames =
48-
caseInsensitiveOptions.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false)
48+
parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false)
4949
val allowSingleQuotes =
50-
caseInsensitiveOptions.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true)
50+
parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true)
5151
val allowNumericLeadingZeros =
52-
caseInsensitiveOptions.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false)
52+
parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false)
5353
val allowNonNumericNumbers =
54-
caseInsensitiveOptions.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true)
54+
parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true)
5555
val allowBackslashEscapingAnyCharacter =
56-
caseInsensitiveOptions.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean)
57-
.getOrElse(false)
58-
val compressionCodec =
59-
caseInsensitiveOptions.get("compression").map(CompressionCodecs.getCodecClassName)
60-
private val parseMode = caseInsensitiveOptions.getOrElse("mode", "PERMISSIVE")
61-
val columnNameOfCorruptRecord = caseInsensitiveOptions.get("columnNameOfCorruptRecord")
56+
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
57+
val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
58+
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
59+
val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord")
6260

6361
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
6462
val dateFormat: FastDateFormat =
65-
FastDateFormat.getInstance(
66-
caseInsensitiveOptions.getOrElse("dateFormat", "yyyy-MM-dd"),
67-
Locale.US)
63+
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
6864

6965
val timestampFormat: FastDateFormat =
7066
FastDateFormat.getInstance(
71-
caseInsensitiveOptions.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"),
72-
Locale.US)
67+
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US)
7368

7469
// Parse mode flags
7570
if (!ParseModes.isValidMode(parseMode)) {

sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.catalog._
3333
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3434
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryComparison}
3535
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, PredicateHelper}
36-
import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, PartitioningUtils}
36+
import org.apache.spark.sql.execution.datasources.PartitioningUtils
3737
import org.apache.spark.sql.types._
3838
import org.apache.spark.util.SerializableConfiguration
3939

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ import org.apache.commons.lang3.time.FastDateFormat
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}
2727

28-
private[csv] class CSVOptions(@transient private val parameters: Map[String, String])
28+
private[csv] class CSVOptions(@transient private val parameters: CaseInsensitiveMap)
2929
extends Logging with Serializable {
3030

31-
private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters)
31+
def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
3232

3333
private def getChar(paramName: String, default: Char): Char = {
34-
val paramValue = caseInsensitiveOptions.get(paramName)
34+
val paramValue = parameters.get(paramName)
3535
paramValue match {
3636
case None => default
3737
case Some(null) => default
@@ -42,7 +42,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str
4242
}
4343

4444
private def getInt(paramName: String, default: Int): Int = {
45-
val paramValue = caseInsensitiveOptions.get(paramName)
45+
val paramValue = parameters.get(paramName)
4646
paramValue match {
4747
case None => default
4848
case Some(null) => default
@@ -56,7 +56,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str
5656
}
5757

5858
private def getBool(paramName: String, default: Boolean = false): Boolean = {
59-
val param = caseInsensitiveOptions.getOrElse(paramName, default.toString)
59+
val param = parameters.getOrElse(paramName, default.toString)
6060
if (param == null) {
6161
default
6262
} else if (param.toLowerCase == "true") {
@@ -69,10 +69,10 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str
6969
}
7070

7171
val delimiter = CSVTypeCast.toChar(
72-
caseInsensitiveOptions.getOrElse("sep", caseInsensitiveOptions.getOrElse("delimiter", ",")))
73-
private val parseMode = caseInsensitiveOptions.getOrElse("mode", "PERMISSIVE")
74-
val charset = caseInsensitiveOptions.getOrElse("encoding",
75-
caseInsensitiveOptions.getOrElse("charset", StandardCharsets.UTF_8.name()))
72+
parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
73+
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
74+
val charset = parameters.getOrElse("encoding",
75+
parameters.getOrElse("charset", StandardCharsets.UTF_8.name()))
7676

7777
val quote = getChar("quote", '\"')
7878
val escape = getChar("escape", '\\')
@@ -92,28 +92,26 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str
9292
val dropMalformed = ParseModes.isDropMalformedMode(parseMode)
9393
val permissive = ParseModes.isPermissiveMode(parseMode)
9494

95-
val nullValue = caseInsensitiveOptions.getOrElse("nullValue", "")
95+
val nullValue = parameters.getOrElse("nullValue", "")
9696

97-
val nanValue = caseInsensitiveOptions.getOrElse("nanValue", "NaN")
97+
val nanValue = parameters.getOrElse("nanValue", "NaN")
9898

99-
val positiveInf = caseInsensitiveOptions.getOrElse("positiveInf", "Inf")
100-
val negativeInf = caseInsensitiveOptions.getOrElse("negativeInf", "-Inf")
99+
val positiveInf = parameters.getOrElse("positiveInf", "Inf")
100+
val negativeInf = parameters.getOrElse("negativeInf", "-Inf")
101101

102102

103103
val compressionCodec: Option[String] = {
104-
val name = caseInsensitiveOptions.get("compression").orElse(caseInsensitiveOptions.get("codec"))
104+
val name = parameters.get("compression").orElse(parameters.get("codec"))
105105
name.map(CompressionCodecs.getCodecClassName)
106106
}
107107

108108
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
109109
val dateFormat: FastDateFormat =
110-
FastDateFormat.getInstance(
111-
caseInsensitiveOptions.getOrElse("dateFormat", "yyyy-MM-dd"),
112-
Locale.US)
110+
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
113111

114112
val timestampFormat: FastDateFormat =
115113
FastDateFormat.getInstance(
116-
caseInsensitiveOptions.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US)
114+
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), Locale.US)
117115

118116
val maxColumns = getInt("maxColumns", 20480)
119117

@@ -132,7 +130,7 @@ private[csv] class CSVOptions(@transient private val parameters: Map[String, Str
132130

133131
object CSVOptions {
134132

135-
def apply(): CSVOptions = new CSVOptions(Map.empty)
133+
def apply(): CSVOptions = new CSVOptions(new CaseInsensitiveMap(Map.empty))
136134

137135
def apply(paramName: String, paramValue: String): CSVOptions = {
138136
new CSVOptions(Map(paramName -> paramValue))

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,43 +28,42 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
2828
* Options for the JDBC data source.
2929
*/
3030
class JDBCOptions(
31-
@transient private val parameters: Map[String, String])
31+
@transient private val parameters: CaseInsensitiveMap)
3232
extends Serializable {
3333

3434
import JDBCOptions._
3535

36-
private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters)
36+
def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
3737

3838
def this(url: String, table: String, parameters: Map[String, String]) = {
39-
this(parameters ++ Map(
39+
this(new CaseInsensitiveMap(parameters ++ Map(
4040
JDBCOptions.JDBC_URL -> url,
41-
JDBCOptions.JDBC_TABLE_NAME -> table))
41+
JDBCOptions.JDBC_TABLE_NAME -> table)))
4242
}
4343

4444
val asConnectionProperties: Properties = {
4545
val properties = new Properties()
4646
// We should avoid to pass the options into properties. See SPARK-17776.
47-
caseInsensitiveOptions.filterKeys(!jdbcOptionNames.contains(_))
47+
parameters.filterKeys(!jdbcOptionNames.contains(_))
4848
.foreach { case (k, v) => properties.setProperty(k, v) }
4949
properties
5050
}
5151

5252
// ------------------------------------------------------------
5353
// Required parameters
5454
// ------------------------------------------------------------
55-
require(caseInsensitiveOptions.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.")
56-
require(caseInsensitiveOptions.isDefinedAt(JDBC_TABLE_NAME),
57-
s"Option '$JDBC_TABLE_NAME' is required.")
55+
require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.")
56+
require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.")
5857
// a JDBC URL
59-
val url = caseInsensitiveOptions(JDBC_URL)
58+
val url = parameters(JDBC_URL)
6059
// name of table
61-
val table = caseInsensitiveOptions(JDBC_TABLE_NAME)
60+
val table = parameters(JDBC_TABLE_NAME)
6261

6362
// ------------------------------------------------------------
6463
// Optional parameters
6564
// ------------------------------------------------------------
6665
val driverClass = {
67-
val userSpecifiedDriverClass = caseInsensitiveOptions.get(JDBC_DRIVER_CLASS)
66+
val userSpecifiedDriverClass = parameters.get(JDBC_DRIVER_CLASS)
6867
userSpecifiedDriverClass.foreach(DriverRegistry.register)
6968

7069
// Performing this part of the logic on the driver guards against the corner-case where the
@@ -79,19 +78,19 @@ class JDBCOptions(
7978
// Optional parameters only for reading
8079
// ------------------------------------------------------------
8180
// the column used to partition
82-
val partitionColumn = caseInsensitiveOptions.getOrElse(JDBC_PARTITION_COLUMN, null)
81+
val partitionColumn = parameters.getOrElse(JDBC_PARTITION_COLUMN, null)
8382
// the lower bound of partition column
84-
val lowerBound = caseInsensitiveOptions.getOrElse(JDBC_LOWER_BOUND, null)
83+
val lowerBound = parameters.getOrElse(JDBC_LOWER_BOUND, null)
8584
// the upper bound of the partition column
86-
val upperBound = caseInsensitiveOptions.getOrElse(JDBC_UPPER_BOUND, null)
85+
val upperBound = parameters.getOrElse(JDBC_UPPER_BOUND, null)
8786
// the number of partitions
88-
val numPartitions = caseInsensitiveOptions.getOrElse(JDBC_NUM_PARTITIONS, null)
87+
val numPartitions = parameters.getOrElse(JDBC_NUM_PARTITIONS, null)
8988
require(partitionColumn == null ||
9089
(lowerBound != null && upperBound != null && numPartitions != null),
9190
s"If '$JDBC_PARTITION_COLUMN' is specified then '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND'," +
9291
s" and '$JDBC_NUM_PARTITIONS' are required.")
9392
val fetchSize = {
94-
val size = caseInsensitiveOptions.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt
93+
val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt
9594
require(size >= 0,
9695
s"Invalid value `${size.toString}` for parameter " +
9796
s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " +
@@ -103,20 +102,20 @@ class JDBCOptions(
103102
// Optional parameters only for writing
104103
// ------------------------------------------------------------
105104
// if to truncate the table from the JDBC database
106-
val isTruncate = caseInsensitiveOptions.getOrElse(JDBC_TRUNCATE, "false").toBoolean
105+
val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean
107106
// the create table option , which can be table_options or partition_options.
108107
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
109108
// TODO: to reuse the existing partition parameters for those partition specific options
110-
val createTableOptions = caseInsensitiveOptions.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
109+
val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
111110
val batchSize = {
112-
val size = caseInsensitiveOptions.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
111+
val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
113112
require(size >= 1,
114113
s"Invalid value `${size.toString}` for parameter " +
115114
s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.")
116115
size
117116
}
118117
val isolationLevel =
119-
caseInsensitiveOptions.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match {
118+
parameters.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match {
120119
case "NONE" => Connection.TRANSACTION_NONE
121120
case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED
122121
case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,22 @@ import org.apache.spark.sql.internal.SQLConf
2525
/**
2626
* Options for the Parquet data source.
2727
*/
28-
private[sql] class ParquetOptions(
29-
@transient private val parameters: Map[String, String],
28+
private[parquet] class ParquetOptions(
29+
@transient private val parameters: CaseInsensitiveMap,
3030
@transient private val sqlConf: SQLConf)
3131
extends Serializable {
3232

3333
import ParquetOptions._
3434

35-
private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters)
35+
def this(parameters: Map[String, String], sqlConf: SQLConf) =
36+
this(new CaseInsensitiveMap(parameters), sqlConf)
3637

3738
/**
3839
* Compression codec to use. By default use the value specified in SQLConf.
3940
* Acceptable values are defined in [[shortParquetCompressionCodecNames]].
4041
*/
4142
val compressionCodecClassName: String = {
42-
val codecName =
43-
caseInsensitiveOptions.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase
43+
val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase
4444
if (!shortParquetCompressionCodecNames.contains(codecName)) {
4545
val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase)
4646
throw new IllegalArgumentException(s"Codec [$codecName] " +
@@ -53,7 +53,7 @@ private[sql] class ParquetOptions(
5353
* Whether it merges schemas or not. When the given Parquet files have different schemas,
5454
* the schemas can be merged. By default use the value specified in SQLConf.
5555
*/
56-
val mergeSchema: Boolean = caseInsensitiveOptions
56+
val mergeSchema: Boolean = parameters
5757
.get(MERGE_SCHEMA)
5858
.map(_.toBoolean)
5959
.getOrElse(sqlConf.isParquetSchemaMergingEnabled)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,15 @@ import org.apache.spark.util.Utils
2626
/**
2727
* User specified options for file streams.
2828
*/
29-
class FileStreamOptions(parameters: Map[String, String]) extends Logging {
29+
class FileStreamOptions(parameters: CaseInsensitiveMap) extends Logging {
3030

31-
private val caseInsensitiveOptions = new CaseInsensitiveMap(parameters)
31+
def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
3232

33-
val maxFilesPerTrigger: Option[Int] = caseInsensitiveOptions.get("maxFilesPerTrigger").map {
34-
str =>
35-
Try(str.toInt).toOption.filter(_ > 0).getOrElse {
36-
throw new IllegalArgumentException(
37-
s"Invalid value '$str' for option 'maxFilesPerTrigger', must be a positive integer")
38-
}
33+
val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str =>
34+
Try(str.toInt).toOption.filter(_ > 0).getOrElse {
35+
throw new IllegalArgumentException(
36+
s"Invalid value '$str' for option 'maxFilesPerTrigger', must be a positive integer")
37+
}
3938
}
4039

4140
/**
@@ -49,9 +48,9 @@ class FileStreamOptions(parameters: Map[String, String]) extends Logging {
4948
* Default to a week.
5049
*/
5150
val maxFileAgeMs: Long =
52-
Utils.timeStringAsMs(caseInsensitiveOptions.getOrElse("maxFileAge", "7d"))
51+
Utils.timeStringAsMs(parameters.getOrElse("maxFileAge", "7d"))
5352

5453
/** Options as specified by the user, in a case-insensitive map, without "path" set. */
5554
val optionMapWithoutPath: Map[String, String] =
56-
new CaseInsensitiveMap(caseInsensitiveOptions).filterKeys(_ != "path")
55+
parameters.filterKeys(_ != "path")
5756
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,7 +1366,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
13661366

13671367
test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
13681368
// This is really a test that it doesn't throw an exception
1369-
val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map()))
1369+
val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String]))
13701370
assert(StructType(Seq()) === emptySchema)
13711371
}
13721372

@@ -1390,7 +1390,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
13901390
}
13911391

13921392
test("SPARK-8093 Erase empty structs") {
1393-
val emptySchema = InferSchema.infer(emptyRecords, "", new JSONOptions(Map()))
1393+
val emptySchema = InferSchema.infer(
1394+
emptyRecords, "", new JSONOptions(Map.empty[String, String]))
13941395
assert(StructType(Seq()) === emptySchema)
13951396
}
13961397

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
736736
}
737737
}
738738
}
739+
740+
test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
741+
withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") {
742+
val option = new ParquetOptions(Map("Compression" -> "uncompressed"), spark.sessionState.conf)
743+
assert(option.compressionCodecClassName == "UNCOMPRESSED")
744+
}
745+
}
739746
}
740747

741748
class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)

0 commit comments

Comments
 (0)