diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index df176c579fc8..6811fa6b3b15 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -176,7 +176,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
- multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None):
+ multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None,
+ encoding=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
@@ -237,6 +238,10 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
+ :param encoding: allows to forcibly set one of standard basic or extended encoding for
+ the JSON files. For example UTF-16BE, UTF-32LE. If None is set,
+ the encoding of input JSON will be detected automatically
+ when the multiLine option is set to ``true``.
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
:param samplingRatio: defines fraction of input JSON objects used for schema inferring.
@@ -259,7 +264,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep,
- samplingRatio=samplingRatio)
+ samplingRatio=samplingRatio, encoding=encoding)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -752,7 +757,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options)
@since(1.4)
def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None,
- lineSep=None):
+ lineSep=None, encoding=None):
"""Saves the content of the :class:`DataFrame` in JSON format
(`JSON Lines text format or newline-delimited JSON `_) at the
specified path.
@@ -776,6 +781,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
+ :param encoding: specifies encoding (charset) of saved json files. If None is set,
+ the default UTF-8 charset will be used.
:param lineSep: defines the line separator that should be used for writing. If None is
set, it uses the default value, ``\\n``.
@@ -784,7 +791,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm
self.mode(mode)
self._set_opts(
compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat,
- lineSep=lineSep)
+ lineSep=lineSep, encoding=encoding)
self._jwrite.json(path)
@since(1.4)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 6b28c557a803..e0cd2aa41a2d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -685,6 +685,13 @@ def test_multiline_json(self):
multiLine=True)
self.assertEqual(people1.collect(), people_array.collect())
+ def test_encoding_json(self):
+ people_array = self.spark.read\
+ .json("python/test_support/sql/people_array_utf16le.json",
+ multiLine=True, encoding="UTF-16LE")
+ expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')]
+ self.assertEqual(people_array.collect(), expected)
+
def test_linesep_json(self):
df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",")
expected = [Row(_corrupt_record=None, name=u'Michael'),
diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json
new file mode 100644
index 000000000000..9c657fa30ac9
Binary files /dev/null and b/python/test_support/sql/people_array_utf16le.json differ
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
index 025a388aacaa..3e8e6db1dbd2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
@@ -18,10 +18,14 @@
package org.apache.spark.sql.catalyst.json
import java.io.{ByteArrayInputStream, InputStream, InputStreamReader}
+import java.nio.channels.Channels
+import java.nio.charset.Charset
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.hadoop.io.Text
+import sun.nio.cs.StreamDecoder
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.UTF8String
private[sql] object CreateJacksonParser extends Serializable {
@@ -43,7 +47,48 @@ private[sql] object CreateJacksonParser extends Serializable {
jsonFactory.createParser(record.getBytes, 0, record.getLength)
}
- def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = {
- jsonFactory.createParser(record)
+ // Jackson parsers can be ranked according to their performance:
+ // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser
+ // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically
+ // by checking leading bytes of the array.
+ // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected
+ // automatically by analyzing first bytes of the input stream.
+ // 3. Reader based parser. This is the slowest parser used here but it allows to create
+ // a reader with specific encoding.
+ // The method creates a reader for an array with given encoding and sets size of internal
+ // decoding buffer according to size of input array.
+ private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = {
+ val bais = new ByteArrayInputStream(in, 0, length)
+ val byteChannel = Channels.newChannel(bais)
+ val decodingBufferSize = Math.min(length, 8192)
+ val decoder = Charset.forName(enc).newDecoder()
+
+ StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize)
+ }
+
+ def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = {
+ val sd = getStreamDecoder(enc, record.getBytes, record.getLength)
+ jsonFactory.createParser(sd)
+ }
+
+ def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = {
+ jsonFactory.createParser(is)
+ }
+
+ def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = {
+ jsonFactory.createParser(new InputStreamReader(is, enc))
+ }
+
+ def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
+ val ba = row.getBinary(0)
+
+ jsonFactory.createParser(ba, 0, ba.length)
+ }
+
+ def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
+ val binary = row.getBinary(0)
+ val sd = getStreamDecoder(enc, binary, binary.length)
+
+ jsonFactory.createParser(sd)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 5c9adc3332bc..5f130af606e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.json
-import java.nio.charset.StandardCharsets
+import java.nio.charset.{Charset, StandardCharsets}
import java.util.{Locale, TimeZone}
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
@@ -86,14 +86,43 @@ private[sql] class JSONOptions(
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
+ /**
+ * A string between two consecutive JSON records.
+ */
val lineSeparator: Option[String] = parameters.get("lineSep").map { sep =>
require(sep.nonEmpty, "'lineSep' cannot be an empty string.")
sep
}
- // Note that the option 'lineSep' uses a different default value in read and write.
- val lineSeparatorInRead: Option[Array[Byte]] =
- lineSeparator.map(_.getBytes(StandardCharsets.UTF_8))
- // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8.
+
+ /**
+ * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE.
+ * If the encoding is not specified (None), it will be detected automatically
+ * when the multiLine option is set to `true`.
+ */
+ val encoding: Option[String] = parameters.get("encoding")
+ .orElse(parameters.get("charset")).map { enc =>
+ // The following encodings are not supported in per-line mode (multiline is false)
+ // because they cause some problems in reading files with BOM which is supposed to
+ // present in the files with such encodings. After splitting input files by lines,
+ // only the first lines will have the BOM which leads to impossibility for reading
+ // the rest lines. Besides of that, the lineSep option must have the BOM in such
+ // encodings which can never present between lines.
+ val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32"))
+ val isBlacklisted = blacklist.contains(Charset.forName(enc))
+ require(multiLine || !isBlacklisted,
+ s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled:
+ | ${blacklist.mkString(", ")}""".stripMargin)
+
+ val isLineSepRequired = !(multiLine == false &&
+ Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty)
+ require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding")
+
+ enc
+ }
+
+ val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep =>
+ lineSep.getBytes(encoding.getOrElse("UTF-8"))
+ }
val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n")
/** Sets config options on a Jackson [[JsonFactory]]. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index 7f6956994f31..a5a4a13eb608 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.json
-import java.io.ByteArrayOutputStream
+import java.io.{ByteArrayOutputStream, CharConversionException}
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
@@ -361,6 +361,14 @@ class JacksonParser(
// For such records, all fields other than the field configured by
// `columnNameOfCorruptRecord` are set to `null`.
throw BadRecordException(() => recordLiteral(record), () => None, e)
+ case e: CharConversionException if options.encoding.isEmpty =>
+ val msg =
+ """JSON parser cannot handle a character in its input.
+ |Specifying encoding as an input option explicitly might help to resolve the issue.
+ |""".stripMargin + e.getMessage
+ val wrappedCharException = new CharConversionException(msg)
+ wrappedCharException.initCause(e)
+ throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index b44552f0eb17..6b2ea6c06d3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -372,6 +372,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `java.text.SimpleDateFormat`. This applies to timestamp type.
*
`multiLine` (default `false`): parse one record, which may span multiple lines,
* per file
+ * `encoding` (by default it is not set): allows to forcibly set one of standard basic
+ * or extended encoding for the JSON files. For example UTF-16BE, UTF-32LE. If the encoding
+ * is not specified and `multiLine` is set to `true`, it will be detected automatically.
* `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
* that should be used for parsing.
* `samplingRatio` (default is 1.0): defines fraction of input JSON objects used
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index bbc063148a72..e183fa6f9542 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -518,8 +518,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
- * `lineSep` (default `\n`): defines the line separator that should
- * be used for writing.
+ * `encoding` (by default it is not set): specifies encoding (charset) of saved json
+ * files. If it is not set, the UTF-8 charset will be used.
+ * `lineSep` (default `\n`): defines the line separator that should be used for writing.
*
*
* @since 1.4.0
@@ -589,8 +590,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `compression` (default `null`): compression codec to use when saving to file. This can be
* one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
* `snappy` and `deflate`).
- * `lineSep` (default `\n`): defines the line separator that should
- * be used for writing.
+ * `lineSep` (default `\n`): defines the line separator that should be used for writing.
*
*
* @since 1.6.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 5769c09c9a1d..983a5f0dcade 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -31,11 +31,11 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.TaskContext
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
-import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions}
+import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -92,26 +92,30 @@ object TextInputJsonDataSource extends JsonDataSource {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
- val json: Dataset[String] = createBaseDataset(
- sparkSession, inputPaths, parsedOptions.lineSeparator)
+ val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions)
+
inferFromDataset(json, parsedOptions)
}
def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
- val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0))
- JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String)
+ val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd
+ val rowParser = parsedOptions.encoding.map { enc =>
+ CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow)
+ }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))
+
+ JsonInferSchema.infer(rdd, parsedOptions, rowParser)
}
private def createBaseDataset(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
- lineSeparator: Option[String]): Dataset[String] = {
- val textOptions = lineSeparator.map { lineSep =>
- Map(TextOptions.LINE_SEPARATOR -> lineSep)
- }.getOrElse(Map.empty[String, String])
-
+ parsedOptions: JSONOptions): Dataset[String] = {
val paths = inputPaths.map(_.getPath.toString)
+ val textOptions = Map.empty[String, String] ++
+ parsedOptions.encoding.map("encoding" -> _) ++
+ parsedOptions.lineSeparator.map("lineSep" -> _)
+
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
@@ -129,8 +133,12 @@ object TextInputJsonDataSource extends JsonDataSource {
schema: StructType): Iterator[InternalRow] = {
val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
+ val textParser = parser.options.encoding
+ .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text))
+ .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text))
+
val safeParser = new FailureSafeParser[Text](
- input => parser.parse(input, CreateJacksonParser.text, textToUTF8String),
+ input => parser.parse(input, textParser, textToUTF8String),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
@@ -153,7 +161,11 @@ object MultiLineJsonDataSource extends JsonDataSource {
parsedOptions: JSONOptions): StructType = {
val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
- JsonInferSchema.infer(sampled, parsedOptions, createParser)
+ val parser = parsedOptions.encoding
+ .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
+ .getOrElse(createParser(_: JsonFactory, _: PortableDataStream))
+
+ JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
}
private def createBaseRdd(
@@ -175,11 +187,18 @@ object MultiLineJsonDataSource extends JsonDataSource {
.values
}
- private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
- val path = new Path(record.getPath())
- CreateJacksonParser.inputStream(
- jsonFactory,
- CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path))
+ private def dataToInputStream(dataStream: PortableDataStream): InputStream = {
+ val path = new Path(dataStream.getPath())
+ CodecStreams.createInputStreamWithCloseResource(dataStream.getConfiguration, path)
+ }
+
+ private def createParser(jsonFactory: JsonFactory, stream: PortableDataStream): JsonParser = {
+ CreateJacksonParser.inputStream(jsonFactory, dataToInputStream(stream))
+ }
+
+ private def createParser(enc: String, jsonFactory: JsonFactory,
+ stream: PortableDataStream): JsonParser = {
+ CreateJacksonParser.inputStream(enc, jsonFactory, dataToInputStream(stream))
}
override def readFile(
@@ -194,9 +213,12 @@ object MultiLineJsonDataSource extends JsonDataSource {
UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
}
}
+ val streamParser = parser.options.encoding
+ .map(enc => CreateJacksonParser.inputStream(enc, _: JsonFactory, _: InputStream))
+ .getOrElse(CreateJacksonParser.inputStream(_: JsonFactory, _: InputStream))
val safeParser = new FailureSafeParser[InputStream](
- input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString),
+ input => parser.parse[InputStream](input, streamParser, partitionedFileString),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 0862c746fffa..3b04510d2969 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.json
+import java.nio.charset.{Charset, StandardCharsets}
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -151,7 +153,13 @@ private[json] class JsonOutputWriter(
context: TaskAttemptContext)
extends OutputWriter with Logging {
- private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
+ private val encoding = options.encoding match {
+ case Some(charsetName) => Charset.forName(charsetName)
+ case None => StandardCharsets.UTF_8
+ }
+
+ private val writer = CodecStreams.createOutputStreamWriter(
+ context, new Path(path), encoding)
// create the Generator without separator inserted between 2 records
private[this] val gen = new JacksonGenerator(dataSchema, writer, options)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
index 5c1a35434f7b..e4e201995faa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.text
-import java.nio.charset.StandardCharsets
+import java.nio.charset.{Charset, StandardCharsets}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs}
@@ -41,13 +41,18 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti
*/
val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean
- private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep =>
- require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.")
- sep
+ val encoding: Option[String] = parameters.get(ENCODING)
+
+ val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep =>
+ require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.")
+
+ lineSep
}
+
// Note that the option 'lineSep' uses a different default value in read and write.
- val lineSeparatorInRead: Option[Array[Byte]] =
- lineSeparator.map(_.getBytes(StandardCharsets.UTF_8))
+ val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep =>
+ lineSep.getBytes(encoding.map(Charset.forName(_)).getOrElse(StandardCharsets.UTF_8))
+ }
val lineSeparatorInWrite: Array[Byte] =
lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8))
}
@@ -55,5 +60,6 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti
private[datasources] object TextOptions {
val COMPRESSION = "compression"
val WHOLETEXT = "wholetext"
+ val ENCODING = "encoding"
val LINE_SEPARATOR = "lineSep"
}
diff --git a/sql/core/src/test/resources/test-data/utf16LE.json b/sql/core/src/test/resources/test-data/utf16LE.json
new file mode 100644
index 000000000000..ce4117fd299d
Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf16LE.json differ
diff --git a/sql/core/src/test/resources/test-data/utf16WithBOM.json b/sql/core/src/test/resources/test-data/utf16WithBOM.json
new file mode 100644
index 000000000000..cf4d29328b86
Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf16WithBOM.json differ
diff --git a/sql/core/src/test/resources/test-data/utf32BEWithBOM.json b/sql/core/src/test/resources/test-data/utf32BEWithBOM.json
new file mode 100644
index 000000000000..6c7733c57787
Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf32BEWithBOM.json differ
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
new file mode 100644
index 000000000000..85cf054e51f6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.json
+
+import java.io.File
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.types.{LongType, StringType, StructType}
+import org.apache.spark.util.{Benchmark, Utils}
+
+/**
+ * The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't.
+ * To run this:
+ * spark-submit --class --jars
+ */
+object JSONBenchmarks {
+ val conf = new SparkConf()
+
+ val spark = SparkSession.builder
+ .master("local[1]")
+ .appName("benchmark-json-datasource")
+ .config(conf)
+ .getOrCreate()
+ import spark.implicits._
+
+ def withTempPath(f: File => Unit): Unit = {
+ val path = Utils.createTempDir()
+ path.delete()
+ try f(path) finally Utils.deleteRecursively(path)
+ }
+
+
+ def schemaInferring(rowsNum: Int): Unit = {
+ val benchmark = new Benchmark("JSON schema inferring", rowsNum)
+
+ withTempPath { path =>
+ // scalastyle:off println
+ benchmark.out.println("Preparing data for benchmarking ...")
+ // scalastyle:on println
+
+ spark.sparkContext.range(0, rowsNum, 1)
+ .map(_ => "a")
+ .toDF("fieldA")
+ .write
+ .option("encoding", "UTF-8")
+ .json(path.getAbsolutePath)
+
+ benchmark.addCase("No encoding", 3) { _ =>
+ spark.read.json(path.getAbsolutePath)
+ }
+
+ benchmark.addCase("UTF-8 is set", 3) { _ =>
+ spark.read
+ .option("encoding", "UTF-8")
+ .json(path.getAbsolutePath)
+ }
+
+ /*
+ Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz
+
+ JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ --------------------------------------------------------------------------------------------
+ No encoding 38902 / 39282 2.6 389.0 1.0X
+ UTF-8 is set 56959 / 57261 1.8 569.6 0.7X
+ */
+ benchmark.run()
+ }
+ }
+
+ def perlineParsing(rowsNum: Int): Unit = {
+ val benchmark = new Benchmark("JSON per-line parsing", rowsNum)
+
+ withTempPath { path =>
+ // scalastyle:off println
+ benchmark.out.println("Preparing data for benchmarking ...")
+ // scalastyle:on println
+
+ spark.sparkContext.range(0, rowsNum, 1)
+ .map(_ => "a")
+ .toDF("fieldA")
+ .write.json(path.getAbsolutePath)
+ val schema = new StructType().add("fieldA", StringType)
+
+ benchmark.addCase("No encoding", 3) { _ =>
+ spark.read
+ .schema(schema)
+ .json(path.getAbsolutePath)
+ .count()
+ }
+
+ benchmark.addCase("UTF-8 is set", 3) { _ =>
+ spark.read
+ .option("encoding", "UTF-8")
+ .schema(schema)
+ .json(path.getAbsolutePath)
+ .count()
+ }
+
+ /*
+ Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz
+
+ JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ --------------------------------------------------------------------------------------------
+ No encoding 25947 / 26188 3.9 259.5 1.0X
+ UTF-8 is set 46319 / 46417 2.2 463.2 0.6X
+ */
+ benchmark.run()
+ }
+ }
+
+ def perlineParsingOfWideColumn(rowsNum: Int): Unit = {
+ val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum)
+
+ withTempPath { path =>
+ // scalastyle:off println
+ benchmark.out.println("Preparing data for benchmarking ...")
+ // scalastyle:on println
+
+ spark.sparkContext.range(0, rowsNum, 1)
+ .map { i =>
+ val s = "abcdef0123456789ABCDEF" * 20
+ s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}"""
+ }
+ .toDF().write.text(path.getAbsolutePath)
+ val schema = new StructType()
+ .add("a", StringType).add("b", LongType)
+ .add("c", StringType).add("d", LongType)
+ .add("e", StringType).add("f", LongType)
+ .add("x", StringType).add("y", LongType)
+ .add("z", StringType)
+
+ benchmark.addCase("No encoding", 3) { _ =>
+ spark.read
+ .schema(schema)
+ .json(path.getAbsolutePath)
+ .count()
+ }
+
+ benchmark.addCase("UTF-8 is set", 3) { _ =>
+ spark.read
+ .option("encoding", "UTF-8")
+ .schema(schema)
+ .json(path.getAbsolutePath)
+ .count()
+ }
+
+ /*
+ Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz
+
+ JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ --------------------------------------------------------------------------------------------
+ No encoding 45543 / 45660 0.2 4554.3 1.0X
+ UTF-8 is set 65737 / 65957 0.2 6573.7 0.7X
+ */
+ benchmark.run()
+ }
+ }
+
+ def main(args: Array[String]): Unit = {
+ schemaInferring(100 * 1000 * 1000)
+ perlineParsing(100 * 1000 * 1000)
+ perlineParsingOfWideColumn(10 * 1000 * 1000)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index a58dff827b92..0db688fec9a6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -17,8 +17,8 @@
package org.apache.spark.sql.execution.datasources.json
-import java.io.{File, StringWriter}
-import java.nio.charset.StandardCharsets
+import java.io.{File, FileOutputStream, StringWriter}
+import java.nio.charset.{StandardCharsets, UnsupportedCharsetException}
import java.nio.file.{Files, Paths, StandardOpenOption}
import java.sql.{Date, Timestamp}
import java.util.Locale
@@ -48,6 +48,10 @@ class TestFileFilter extends PathFilter {
class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
import testImplicits._
+ def testFile(fileName: String): String = {
+ Thread.currentThread().getContextClassLoader.getResource(fileName).toString
+ }
+
test("Type promotion") {
def checkTypePromotion(expected: Any, actual: Any) {
assert(expected.getClass == actual.getClass,
@@ -2167,4 +2171,241 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val sampled = spark.read.option("samplingRatio", 1.0).json(ds)
assert(sampled.count() == ds.count())
}
+
+ test("SPARK-23723: json in UTF-16 with BOM") {
+ val fileName = "test-data/utf16WithBOM.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val jsonDF = spark.read.schema(schema)
+ .option("multiline", "true")
+ .option("encoding", "UTF-16")
+ .json(testFile(fileName))
+
+ checkAnswer(jsonDF, Seq(Row("Chris", "Baird"), Row("Doug", "Rood")))
+ }
+
+ test("SPARK-23723: multi-line json in UTF-32BE with BOM") {
+ val fileName = "test-data/utf32BEWithBOM.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val jsonDF = spark.read.schema(schema)
+ .option("multiline", "true")
+ .json(testFile(fileName))
+
+ checkAnswer(jsonDF, Seq(Row("Chris", "Baird")))
+ }
+
+ test("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE") {
+ val fileName = "test-data/utf16LE.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val jsonDF = spark.read.schema(schema)
+ .option("multiline", "true")
+ .options(Map("encoding" -> "UTF-16LE"))
+ .json(testFile(fileName))
+
+ checkAnswer(jsonDF, Seq(Row("Chris", "Baird")))
+ }
+
+ test("SPARK-23723: Unsupported encoding name") {
+ val invalidCharset = "UTF-128"
+ val exception = intercept[UnsupportedCharsetException] {
+ spark.read
+ .options(Map("encoding" -> invalidCharset, "lineSep" -> "\n"))
+ .json(testFile("test-data/utf16LE.json"))
+ .count()
+ }
+
+ assert(exception.getMessage.contains(invalidCharset))
+ }
+
+ test("SPARK-23723: checking that the encoding option is case agnostic") {
+ val fileName = "test-data/utf16LE.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val jsonDF = spark.read.schema(schema)
+ .option("multiline", "true")
+ .options(Map("encoding" -> "uTf-16lE"))
+ .json(testFile(fileName))
+
+ checkAnswer(jsonDF, Seq(Row("Chris", "Baird")))
+ }
+
+
+ test("SPARK-23723: specified encoding is not matched to actual encoding") {
+ val fileName = "test-data/utf16LE.json"
+ val schema = new StructType().add("firstName", StringType).add("lastName", StringType)
+ val exception = intercept[SparkException] {
+ spark.read.schema(schema)
+ .option("mode", "FAILFAST")
+ .option("multiline", "true")
+ .options(Map("encoding" -> "UTF-16BE"))
+ .json(testFile(fileName))
+ .count()
+ }
+ val errMsg = exception.getMessage
+
+ assert(errMsg.contains("Malformed records are detected in record parsing"))
+ }
+
+ def checkEncoding(expectedEncoding: String, pathToJsonFiles: String,
+ expectedContent: String): Unit = {
+ val jsonFiles = new File(pathToJsonFiles)
+ .listFiles()
+ .filter(_.isFile)
+ .filter(_.getName.endsWith("json"))
+ val actualContent = jsonFiles.map { file =>
+ new String(Files.readAllBytes(file.toPath), expectedEncoding)
+ }.mkString.trim
+
+ assert(actualContent == expectedContent)
+ }
+
+ test("SPARK-23723: save json in UTF-32BE") {
+ val encoding = "UTF-32BE"
+ withTempPath { path =>
+ val df = spark.createDataset(Seq(("Dog", 42)))
+ df.write
+ .options(Map("encoding" -> encoding, "lineSep" -> "\n"))
+ .json(path.getCanonicalPath)
+
+ checkEncoding(
+ expectedEncoding = encoding,
+ pathToJsonFiles = path.getCanonicalPath,
+ expectedContent = """{"_1":"Dog","_2":42}""")
+ }
+ }
+
+ test("SPARK-23723: save json in default encoding - UTF-8") {
+ withTempPath { path =>
+ val df = spark.createDataset(Seq(("Dog", 42)))
+ df.write.json(path.getCanonicalPath)
+
+ checkEncoding(
+ expectedEncoding = "UTF-8",
+ pathToJsonFiles = path.getCanonicalPath,
+ expectedContent = """{"_1":"Dog","_2":42}""")
+ }
+ }
+
+ test("SPARK-23723: wrong output encoding") {
+ val encoding = "UTF-128"
+ val exception = intercept[UnsupportedCharsetException] {
+ withTempPath { path =>
+ val df = spark.createDataset(Seq((0)))
+ df.write
+ .options(Map("encoding" -> encoding, "lineSep" -> "\n"))
+ .json(path.getCanonicalPath)
+ }
+ }
+
+ assert(exception.getMessage == encoding)
+ }
+
+ test("SPARK-23723: read back json in UTF-16LE") {
+ val options = Map("encoding" -> "UTF-16LE", "lineSep" -> "\n")
+ withTempPath { path =>
+ val ds = spark.createDataset(Seq(("a", 1), ("b", 2), ("c", 3))).repartition(2)
+ ds.write.options(options).json(path.getCanonicalPath)
+
+ val readBack = spark
+ .read
+ .options(options)
+ .json(path.getCanonicalPath)
+
+ checkAnswer(readBack.toDF(), ds.toDF())
+ }
+ }
+
+ def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = {
+ test(s"SPARK-23724: checks reading json in ${encoding} #${id}") {
+ val schema = new StructType().add("f1", StringType).add("f2", IntegerType)
+ withTempPath { path =>
+ val records = List(("a", 1), ("b", 2))
+ val data = records
+ .map(rec => s"""{"f1":"${rec._1}", "f2":${rec._2}}""".getBytes(encoding))
+ .reduce((a1, a2) => a1 ++ lineSep.getBytes(encoding) ++ a2)
+ val os = new FileOutputStream(path)
+ os.write(data)
+ os.close()
+ val reader = if (inferSchema) {
+ spark.read
+ } else {
+ spark.read.schema(schema)
+ }
+ val readBack = reader
+ .option("encoding", encoding)
+ .option("lineSep", lineSep)
+ .json(path.getCanonicalPath)
+ checkAnswer(readBack, records.map(rec => Row(rec._1, rec._2)))
+ }
+ }
+ }
+
+ // scalastyle:off nonascii
+ List(
+ (0, "|", "UTF-8", false),
+ (1, "^", "UTF-16BE", true),
+ (2, "::", "ISO-8859-1", true),
+ (3, "!!!@3", "UTF-32LE", false),
+ (4, 0x1E.toChar.toString, "UTF-8", true),
+ (5, "아", "UTF-32BE", false),
+ (6, "куку", "CP1251", true),
+ (7, "sep", "utf-8", false),
+ (8, "\r\n", "UTF-16LE", false),
+ (9, "\r\n", "utf-16be", true),
+ (10, "\u000d\u000a", "UTF-32BE", false),
+ (11, "\u000a\u000d", "UTF-8", true),
+ (12, "===", "US-ASCII", false),
+ (13, "$^+", "utf-32le", true)
+ ).foreach {
+ case (testNum, sep, encoding, inferSchema) => checkReadJson(sep, encoding, inferSchema, testNum)
+ }
+ // scalastyle:on nonascii
+
+ test("SPARK-23724: lineSep should be set if encoding if different from UTF-8") {
+ val encoding = "UTF-16LE"
+ val exception = intercept[IllegalArgumentException] {
+ spark.read
+ .options(Map("encoding" -> encoding))
+ .json(testFile("test-data/utf16LE.json"))
+ .count()
+ }
+
+ assert(exception.getMessage.contains(
+ s"""The lineSep option must be specified for the $encoding encoding"""))
+ }
+
+ private val badJson = "\u0000\u0000\u0000A\u0001AAA"
+
+ test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is enabled") {
+ withTempPath { tempDir =>
+ val path = tempDir.getAbsolutePath
+ Seq(badJson + """{"a":1}""").toDS().write.text(path)
+ val expected = s"""${badJson}{"a":1}\n"""
+ val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType)
+ val df = spark.read.format("json")
+ .option("mode", "PERMISSIVE")
+ .option("multiLine", true)
+ .option("encoding", "UTF-8")
+ .schema(schema).load(path)
+ checkAnswer(df, Row(null, expected))
+ }
+ }
+
+ test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is disabled") {
+ withTempPath { tempDir =>
+ val path = tempDir.getAbsolutePath
+ Seq(badJson, """{"a":1}""").toDS().write.text(path)
+ val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType)
+ val df = spark.read.format("json")
+ .option("mode", "PERMISSIVE")
+ .option("multiLine", false)
+ .option("encoding", "UTF-8")
+ .schema(schema).load(path)
+ checkAnswer(df, Seq(Row(1, null), Row(null, badJson)))
+ }
+ }
+
+ test("SPARK-23094: permissively parse a dataset contains JSON with leading nulls") {
+ checkAnswer(
+ spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()),
+ Row(badJson))
+ }
}