Skip to content
This repository was archived by the owner on Mar 24, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ When reading files the API accepts several options:
* `samplingRatio`: Sampling ratio for inferring schema (0.0 ~ 1). Default is 1. Possible types are `StructType`, `ArrayType`, `StringType`, `LongType`, `DoubleType`, `BooleanType`, `TimestampType` and `NullType`, unless user provides a schema for this.
* `excludeAttribute` : Whether you want to exclude attributes in elements or not. Default is false.
* `treatEmptyValuesAsNulls` : Whether you want to treat whitespaces as a null value. Default is false.
* `failFast` : Whether you want to fail when it fails to parse malformed rows in XML files, instead of dropping the rows. Default is false.
* `mode`: The mode for dealing with corrupt records during parsing. Default is `PERMISSIVE`.
* `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the malformed string into a new field configured by `columnNameOfCorruptRecord`. When a schema is set by user, it sets `null` for extra fields.
* `DROPMALFORMED` : ignores the whole corrupted records.
* `FAILFAST` : throws an exception when it meets corrupted records.
* `columnNameOfCorruptRecord`: The name of new field where malformed strings are stored. Default is `_corrupt_record`.
* `attributePrefix`: The prefix for attributes so that we can differentiate attributes and elements. This will be the prefix for field names. Default is `_`.
* `valueTag`: The tag used for the value when there are attributes in the element having no child. Default is `_VALUE`.
* `charset`: Defaults to 'UTF-8' but can be set to other valid charset names
Expand Down
25 changes: 24 additions & 1 deletion src/main/scala/com/databricks/spark/xml/XmlOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
*/
package com.databricks.spark.xml

import org.slf4j.LoggerFactory

import com.databricks.spark.xml.util.ParseModes

/**
* Options for the XML data source.
*/
private[xml] class XmlOptions(
@transient private val parameters: Map[String, String])
extends Serializable{
private val logger = LoggerFactory.getLogger(XmlRelation.getClass)

val charset = parameters.getOrElse("charset", XmlOptions.DEFAULT_CHARSET)
val codec = parameters.get("compression").orElse(parameters.get("codec")).orNull
Expand All @@ -30,11 +35,29 @@ private[xml] class XmlOptions(
val excludeAttributeFlag = parameters.get("excludeAttribute").map(_.toBoolean).getOrElse(false)
val treatEmptyValuesAsNulls =
parameters.get("treatEmptyValuesAsNulls").map(_.toBoolean).getOrElse(false)
val failFastFlag = parameters.get("failFast").map(_.toBoolean).getOrElse(false)
val attributePrefix =
parameters.getOrElse("attributePrefix", XmlOptions.DEFAULT_ATTRIBUTE_PREFIX)
val valueTag = parameters.getOrElse("valueTag", XmlOptions.DEFAULT_VALUE_TAG)
val nullValue = parameters.getOrElse("nullValue", XmlOptions.DEFAULT_NULL_VALUE)
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", "_corrupt_record")

// Leave this option for backwards compatibility.
private val failFastFlag = parameters.get("failFast").map(_.toBoolean).getOrElse(false)
private val parseMode = if (failFastFlag) {
parameters.getOrElse("mode", ParseModes.FAIL_FAST_MODE)
} else {
parameters.getOrElse("mode", ParseModes.PERMISSIVE_MODE)
}

// Parse mode flags
if (!ParseModes.isValidMode(parseMode)) {
logger.warn(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")
}

val failFast = ParseModes.isFailFastMode(parseMode)
val dropMalformed = ParseModes.isDropMalformedMode(parseMode)
val permissive = ParseModes.isPermissiveMode(parseMode)

require(rowTag.nonEmpty, "'rowTag' option should not be empty string.")
require(attributePrefix.nonEmpty, "'attributePrefix' option should not be empty string.")
Expand Down
10 changes: 10 additions & 0 deletions src/main/scala/com/databricks/spark/xml/XmlReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ class XmlReader extends Serializable {
this
}

def withParseMode(valueTag: String): XmlReader = {
parameters += ("mode" -> valueTag)
this
}

def withAttributePrefix(attributePrefix: String): XmlReader = {
parameters += ("attributePrefix" -> attributePrefix)
this
Expand All @@ -72,6 +77,11 @@ class XmlReader extends Serializable {
this
}

def withColumnNameOfCorruptRecord(name: String): XmlReader = {
parameters += ("columnNameOfCorruptRecord" -> name)
this
}

def withSchema(schema: StructType): XmlReader = {
this.schema = schema
this
Expand Down
22 changes: 3 additions & 19 deletions src/main/scala/com/databricks/spark/xml/XmlRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,12 @@ case class XmlRelation protected[spark] (
val schemaFields = schema.fields
if (schemaFields.deep == requiredFields.deep) {
buildScan()
} else if (options.failFastFlag) {
val safeRequestedSchema = StructType(requiredFields)
StaxXmlParser.parse(
baseRDD(),
safeRequestedSchema,
options)
} else {
// If `failFast` is disabled, then it needs to parse all the values
// so that we can decide which row is malformed.
val safeRequestedSchema = StructType(
requiredFields ++ schema.fields.filterNot(requiredFields.contains(_)))
val rows = StaxXmlParser.parse(
val requestedSchema = StructType(requiredFields)
StaxXmlParser.parse(
baseRDD(),
safeRequestedSchema,
requestedSchema,
options)

val rowSize = requiredFields.length
rows.mapPartitions { iter =>
iter.flatMap { xml =>
Some(Row.fromSeq(xml.toSeq.take(rowSize)))
}
}
}
}

Expand Down
41 changes: 26 additions & 15 deletions src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,25 @@ private[xml] object StaxXmlParser {
xml: RDD[String],
schema: StructType,
options: XmlOptions): RDD[Row] = {
val failFast = options.failFastFlag
def failedRecord(record: String): Option[Row] = {
// create a row even if no corrupt record column is present
if (options.failFast) {
throw new RuntimeException(
s"Malformed line in FAILFAST mode: ${record.replaceAll("\n", "")}")
} else if (options.dropMalformed) {
logger.warn(s"Dropping malformed line: ${record.replaceAll("\n", "")}")
None
} else {
val row = new Array[Any](schema.length)
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
nameToIndex.get(options.columnNameOfCorruptRecord).foreach { corruptIndex =>
require(schema(corruptIndex).dataType == StringType)
row.update(corruptIndex, record)
}
Some(Row.fromSeq(row))
}
}

xml.mapPartitions { iter =>
val factory = XMLInputFactory.newInstance()
factory.setProperty(XMLInputFactory.IS_NAMESPACE_AWARE, false)
Expand All @@ -63,22 +81,15 @@ private[xml] object StaxXmlParser {
StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.START_ELEMENT)
val rootAttributes =
rootEvent.asStartElement.getAttributes.map(_.asInstanceOf[Attribute]).toArray

Some(convertObject(parser, schema, options, rootAttributes))
.orElse(failedRecord(xml))
} catch {
case _: java.lang.NumberFormatException if !failFast =>
logger.warn("Number format exception. " +
s"Dropping malformed line: ${xml.replaceAll("\n", "")}")
None
case _: java.text.ParseException | _: IllegalArgumentException if !failFast =>
logger.warn("Parse exception. " +
s"Dropping malformed line: ${xml.replaceAll("\n", "")}")
None
case _: XMLStreamException if failFast =>
throw new RuntimeException(s"Malformed row (failing fast): ${xml.replaceAll("\n", "")}")
case _: XMLStreamException if !failFast =>
logger.warn(s"Dropping malformed row: ${xml.replaceAll("\n", "")}")
None
case _: java.lang.NumberFormatException =>
failedRecord(xml)
case _: java.text.ParseException | _: IllegalArgumentException =>
failedRecord(xml)
case _: XMLStreamException =>
failedRecord(xml)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ private[xml] object InferSchema {
def infer(xml: RDD[String], options: XmlOptions): StructType = {
require(options.samplingRatio > 0,
s"samplingRatio ($options.samplingRatio) should be greater than 0")
val shouldHandleCorruptRecord = options.permissive
val schemaData = if (options.samplingRatio > 0.99) {
xml
} else {
xml.sample(withReplacement = false, options.samplingRatio, 1)
}
val failFast = options.failFastFlag
// perform schema inference on each row and merge afterwards
val rootType = schemaData.mapPartitions { iter =>
val factory = XMLInputFactory.newInstance()
Expand All @@ -100,11 +100,10 @@ private[xml] object InferSchema {

Some(inferObject(parser, options, rootAttributes))
} catch {
case _: XMLStreamException if !failFast =>
logger.warn(s"Dropping malformed row: ${xml.replaceAll("\n", "")}")
case _: XMLStreamException if shouldHandleCorruptRecord =>
Some(StructType(Seq(StructField(options.columnNameOfCorruptRecord, StringType))))
case _: XMLStreamException =>
None
case _: XMLStreamException if failFast =>
throw new RuntimeException(s"Malformed row (failing fast): ${xml.replaceAll("\n", "")}")
}
}
}.treeAggregate[DataType](StructType(Seq()))(
Expand Down
39 changes: 39 additions & 0 deletions src/main/scala/com/databricks/spark/xml/util/ParseModes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.xml.util

private[xml] object ParseModes {
val PERMISSIVE_MODE = "PERMISSIVE"
val DROP_MALFORMED_MODE = "DROPMALFORMED"
val FAIL_FAST_MODE = "FAILFAST"

val DEFAULT = PERMISSIVE_MODE

def isValidMode(mode: String): Boolean = {
mode.toUpperCase match {
case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true
case _ => false
}
}

def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE
def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE
def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) {
mode.toUpperCase == PERMISSIVE_MODE
} else {
true // We default to permissive is the mode string is not valid
}
}
89 changes: 43 additions & 46 deletions src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.{SparkException, SparkContext}
import org.apache.spark.sql.{SaveMode, Row, SQLContext}
import org.apache.spark.sql.types._
import com.databricks.spark.xml.XmlOptions._
import com.databricks.spark.xml.util.ParseModes

class XmlSuite extends FunSuite with BeforeAndAfterAll {
val tempEmptyDir = "target/test/empty/"
Expand Down Expand Up @@ -211,58 +212,52 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {

test("DSL test for parsing a malformed XML file") {
val results = new XmlReader()
.withFailFast(false)
.withParseMode(ParseModes.DROP_MALFORMED_MODE)
.xmlFile(sqlContext, carsMalformedFile)

assert(results.count() === 1)
}

test("DSL test for dropping malformed rows") {
val schema = new StructType(
Array(
StructField("color", IntegerType, true),
StructField("make", TimestampType, true),
StructField("model", DoubleType, true),
StructField("comment", StringType, true),
StructField("year", DoubleType, true)
)
)
val results = new XmlReader()
.withSchema(schema)
.xmlFile(sqlContext, carsUnbalancedFile)
.count()
val cars = new XmlReader()
.withParseMode(ParseModes.DROP_MALFORMED_MODE)
.xmlFile(sqlContext, carsMalformedFile)

assert(results === 0)
assert(cars.count() == 1)
assert(cars.head().toSeq === Seq("Chevy", "Volt", 2015))
}

test("DSL test for failing fast") {
// Fail fast in type inference
val exceptionInSchema = intercept[SparkException] {
new XmlReader()
.withFailFast(true)
.xmlFile(sqlContext, carsMalformedFile)
.printSchema()
}
assert(exceptionInSchema.getMessage.contains("Malformed row (failing fast)"))

// Fail fast in parsing data
val schema = new StructType(
Array(
StructField("color", StringType, true),
StructField("make", StringType, true),
StructField("model", StringType, true),
StructField("comment", StringType, true),
StructField("year", StringType, true)
)
)
val exceptionInParse = intercept[SparkException] {
new XmlReader()
.withFailFast(true)
.withSchema(schema)
.xmlFile(sqlContext, carsMalformedFile)
.collect()
}
assert(exceptionInParse.getMessage.contains("Malformed row (failing fast)"))
assert(exceptionInParse.getMessage.contains("Malformed line in FAILFAST mode"))
}

test("DSL test for permissive mode for corrupt records") {
val carsDf = new XmlReader()
.withParseMode(ParseModes.PERMISSIVE_MODE)
.withColumnNameOfCorruptRecord("_malformed_records")
.xmlFile(sqlContext, carsMalformedFile)
val cars = carsDf.collect()
assert(cars.length == 3)

val malformedRowOne = carsDf.select("_malformed_records").first().toSeq.head.toString
val malformedRowTwo = carsDf.select("_malformed_records").take(2).last.toSeq.head.toString
val expectedMalformedRowOne = "<ROW><year>2012</year><make>Tesla</make><model>>S" +
"<comment>No comment</comment></ROW>"
val expectedMalformedRowTwo = "<ROW></year><make>Ford</make><model>E350</model>model></model>" +
"<comment>Go get one now they are going fast</comment></ROW>"

assert(malformedRowOne.replaceAll("\\s", "") === expectedMalformedRowOne.replaceAll("\\s", ""))
assert(malformedRowTwo.replaceAll("\\s", "") === expectedMalformedRowTwo.replaceAll("\\s", ""))
assert(cars(2).toSeq.head === null)
assert(cars(0).toSeq.takeRight(3) === Seq(null, null, null))
assert(cars(1).toSeq.takeRight(3) === Seq(null, null, null))
assert(cars(2).toSeq.takeRight(3) === Seq("Chevy", "Volt", 2015))
}

test("DSL test with empty file and known schema") {
Expand Down Expand Up @@ -421,7 +416,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
val schemaCopy = StructType(
List(StructField("a", ArrayType(
StructType(List(StructField("item", ArrayType(StringType), nullable = true)))),
nullable = true)))
nullable = true)))
val dfCopy = sqlContext.xmlFile(copyFilePath + "/")

assert(dfCopy.count == df.count)
Expand Down Expand Up @@ -461,8 +456,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
df.saveAsXmlFile(copyFilePath)

val dfCopy = new XmlReader()
.withSchema(schema)
.xmlFile(sqlContext, copyFilePath + "/")
.withSchema(schema)
.xmlFile(sqlContext, copyFilePath + "/")

assert(dfCopy.collect() === df.collect())
assert(dfCopy.schema === df.schema)
Expand Down Expand Up @@ -551,11 +546,11 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
StructField("price", DoubleType, nullable = true),
StructField("publish_dates", StructType(
List(StructField("publish_date",
ArrayType(StructType(
List(StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}tag", StringType, nullable = true),
StructField("day", LongType, nullable = true),
StructField("month", LongType, nullable = true),
StructField("year", LongType, nullable = true))))))),
ArrayType(StructType(
List(StructField(s"${DEFAULT_ATTRIBUTE_PREFIX}tag", StringType, nullable = true),
StructField("day", LongType, nullable = true),
StructField("month", LongType, nullable = true),
StructField("year", LongType, nullable = true))))))),
nullable = true),
StructField("title", StringType, nullable = true))
))
Expand Down Expand Up @@ -656,9 +651,11 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
}

test("DSL test nullable fields") {
val schema = StructType(
StructField("name", StringType, false) ::
StructField("age", StringType, true) :: Nil)
val results = new XmlReader()
.withSchema(StructType(List(StructField("name", StringType, false),
StructField("age", StringType, true))))
.withSchema(schema)
.xmlFile(sqlContext, nullNumbersFile)
.collect()

Expand Down