Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
567 changes: 328 additions & 239 deletions examples/python/data-preprocessing/SparkNLP_Partition_Demo.ipynb

Large diffs are not rendered by default.

339 changes: 339 additions & 0 deletions examples/python/reader/SparkNLP_XML_Reader_Demo.ipynb

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions python/sparknlp/reader/sparknlp_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,49 @@ def txt(self, docPath):
if not isinstance(docPath, str):
raise TypeError("docPath must be a string")
jdf = self._java_obj.txt(docPath)
return self.getDataFrame(self.spark, jdf)

def xml(self, docPath):
"""Reads XML files and returns a Spark DataFrame.

Parameters
----------
docPath : str
Path to an XML file or a directory containing XML files.

Returns
-------
pyspark.sql.DataFrame
A DataFrame containing parsed XML content.

Examples
--------
>>> from sparknlp.reader import SparkNLPReader
>>> xml_df = SparkNLPReader(spark).xml("home/user/xml-directory")

You can use SparkNLP for one line of code

>>> import sparknlp
>>> xml_df = sparknlp.read().xml("home/user/xml-directory")
>>> xml_df.show(truncate=False)
+-----------------------------------------------------------+
|xml |
+-----------------------------------------------------------+
|[{Title, John Smith, {elementId -> ..., tag -> title}}] |
+-----------------------------------------------------------+

>>> xml_df.printSchema()
root
|-- path: string (nullable = true)
|-- xml: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- elementType: string (nullable = true)
| | |-- content: string (nullable = true)
| | |-- metadata: map (nullable = true)
| | | |-- key: string
| | | |-- value: string (valueContainsNull = true)
"""
if not isinstance(docPath, str):
raise TypeError("docPath must be a string")
jdf = self._java_obj.xml(docPath)
return self.getDataFrame(self.spark, jdf)
16 changes: 15 additions & 1 deletion python/test/sparknlp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,18 @@ def runTest(self):
txt_df = sparknlp.read().txt(self.txt_file)
txt_df.show()

self.assertTrue(txt_df.select("txt").count() > 0)
self.assertTrue(txt_df.select("txt").count() > 0)


@pytest.mark.fast
class SparkNLPTestXMLFilesSpec(unittest.TestCase):

def setUp(self):
self.data = SparkContextForTest.data
self.xml_files = f"file:///{os.getcwd()}/../src/test/resources/reader/xml"

def runTest(self):
xml_df = sparknlp.read().xml(self.xml_files)
xml_df.show()

self.assertTrue(xml_df.select("xml").count() > 0)
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2017-2025 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.partition

import com.johnsnowlabs.nlp.ParamsAndFeaturesWritable
import org.apache.spark.ml.param.Param

trait HasXmlReaderProperties extends ParamsAndFeaturesWritable {

val xmlKeepTags = new Param[Boolean](
this,
"xmlKeepTags",
"Whether to include XML tag names as metadata in the output.")

def setXmlKeepTags(value: Boolean): this.type = set(xmlKeepTags, value)

val onlyLeafNodes = new Param[Boolean](
this,
"onlyLeafNodes",
"If true, only processes XML leaf nodes (no nested children).")

def setOnlyLeafNodes(value: Boolean): this.type = set(onlyLeafNodes, value)

setDefault(xmlKeepTags -> false, onlyLeafNodes -> true)
}
3 changes: 3 additions & 0 deletions src/main/scala/com/johnsnowlabs/partition/Partition.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class Partition(params: java.util.Map[String, String] = new java.util.HashMap())
"application/vnd.openxmlformats-officedocument.presentationml.presentation" =>
sparkNLPReader.ppt
case "application/pdf" => sparkNLPReader.pdf
case "application/xml" => sparkNLPReader.xml
case _ => throw new IllegalArgumentException(s"Unsupported content type: $contentType")
}
}
Expand All @@ -199,6 +200,7 @@ class Partition(params: java.util.Map[String, String] = new java.util.HashMap())
case "text/plain" => sparkNLPReader.txtToHTMLElement
case "text/html" => sparkNLPReader.htmlToHTMLElement
case "url" => sparkNLPReader.urlToHTMLElement
case "application/xml" => sparkNLPReader.xmlToHTMLElement
case _ => throw new IllegalArgumentException(s"Unsupported content type: $contentType")
}
}
Expand Down Expand Up @@ -234,6 +236,7 @@ class Partition(params: java.util.Map[String, String] = new java.util.HashMap())
case "xls" | "xlsx" => sparkNLPReader.xls
case "ppt" | "pptx" => sparkNLPReader.ppt
case "pdf" => sparkNLPReader.pdf
case "xml" => sparkNLPReader.xml
case _ => throw new IllegalArgumentException(s"Unsupported file type: $extension")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class PartitionTransformer(override val uid: String)
with HasPowerPointProperties
with HasTextReaderProperties
with HasPdfProperties
with HasXmlReaderProperties
with HasChunkerProperties {

def this() = this(Identifiable.randomUID("PartitionTransformer"))
Expand Down Expand Up @@ -157,7 +158,9 @@ class PartitionTransformer(override val uid: String)
"newAfterNChars" -> $(newAfterNChars).toString,
"overlap" -> $(overlap).toString,
"combineTextUnderNChars" -> $(combineTextUnderNChars).toString,
"overlapAll" -> $(overlapAll).toString)
"overlapAll" -> $(overlapAll).toString,
"xmlKeepTags" -> $(xmlKeepTags).toString,
"onlyLeafNodes" -> $(onlyLeafNodes).toString)
val partitionInstance = new Partition(params.asJava)

val inputColum = if (get(inputCols).isDefined) {
Expand Down
66 changes: 65 additions & 1 deletion src/main/scala/com/johnsnowlabs/reader/SparkNLPReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ class SparkNLPReader(
* |-- width_dimension: integer (nullable = true)
* |-- content: binary (nullable = true)
* |-- exception: string (nullable = true)
* |-- pagenum: integer (nullable = true)
* }}}
*
* @param params
Expand Down Expand Up @@ -642,4 +641,69 @@ class SparkNLPReader(
default = BLOCK_SPLIT_PATTERN)
}

/** Instantiates class to read XML files.
*
* xmlPath: this is a path to a directory of XML files or a path to an XML file. E.g.,
* "path/xml/files"
*
* ==Example==
* {{{
* val xmlPath = "home/user/xml-directory"
* val sparkNLPReader = new SparkNLPReader()
* val xmlDf = sparkNLPReader.xml(xmlPath)
* }}}
*
* ==Example 2==
* You can use SparkNLP for one line of code
* {{{
* val xmlDf = SparkNLP.read.xml(xmlPath)
* }}}
*
* {{{
* xmlDf.select("xml").show(false)
* +------------------------------------------------------------------------------------------------------------------------+
* |xml |
* +------------------------------------------------------------------------------------------------------------------------+
* |[{Title, John Smith, {elementId -> ..., tag -> title}}, {UncategorizedText, Some content..., {elementId -> ...}}] |
* +------------------------------------------------------------------------------------------------------------------------+
*
* xmlDf.printSchema()
* root
* |-- path: string (nullable = true)
* |-- xml: array (nullable = true)
* | |-- element: struct (containsNull = true)
* | | |-- elementType: string (nullable = true)
* | | |-- content: string (nullable = true)
* | | |-- metadata: map (nullable = true)
* | | | |-- key: string
* | | | |-- value: string (valueContainsNull = true)
* }}}
*
* @param xmlPath
* Path to the XML file or directory
* @return
* A DataFrame with parsed XML as structured elements
*/

def xml(xmlPath: String): DataFrame = {
val xmlReader = new XMLReader(getStoreContent, getXmlKeepTags, getOnlyLeafNodes)
xmlReader.read(xmlPath)
}

def xmlToHTMLElement(xml: String): Seq[HTMLElement] = {
val xmlReader = new XMLReader(getStoreContent, getXmlKeepTags, getOnlyLeafNodes)
xmlReader.parseXml(xml)
}

private def getXmlKeepTags: Boolean = {
getDefaultBoolean(params.asScala.toMap, Seq("xmlKeepTags", "xml_keep_tags"), default = false)
}

private def getOnlyLeafNodes: Boolean = {
getDefaultBoolean(
params.asScala.toMap,
Seq("onlyLeafNodes", "only_leaf_nodes"),
default = true)
}

}
150 changes: 150 additions & 0 deletions src/main/scala/com/johnsnowlabs/reader/XMLReader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright 2017-2025 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.reader

import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.util.io.ResourceHelper.validFile
import com.johnsnowlabs.partition.util.PartitionHelper.datasetWithTextFile
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.xml.{Elem, Node, XML}

/** Class to parse and read XML files.
*
* @param storeContent
* Whether to include the raw XML content in the resulting DataFrame as a separate 'content'
* column. By default, this is false.
*
* @param xmlKeepTags
* Whether to retain original XML tag names and include them in the metadata for each extracted
* element. Useful for preserving structure. Default is false.
*
* @param onlyLeafNodes
* If true, only the deepest elements (those without child elements) are extracted. If false,
* all elements are extracted. Default is true.
*
* ==Input Format==
* Input must be a valid path to an XML file or a directory containing XML files.
*
* ==Example==
* {{{
* val xmlPath = "./data/sample.xml"
* val xmlReader = new XMLReader()
* val xmlDf = xmlReader.read(xmlPath)
* }}}
*
* {{{
* xmlDf.show(truncate = false)
* +----------------------+--------------------------------------------------+
* |path |xml |
* +----------------------+--------------------------------------------------+
* |file:/data/sample.xml |[{Title, My Book, {tag -> title}}, ...] |
* +----------------------+--------------------------------------------------+
*
* xmlDf.printSchema()
* root
* |-- path: string (nullable = true)
* |-- xml: array (nullable = true)
* | |-- element: struct (containsNull = true)
* | | |-- elementType: string (nullable = true)
* | | |-- content: string (nullable = true)
* | | |-- metadata: map (nullable = true)
* | | | |-- key: string
* | | | |-- value: string (valueContainsNull = true)
* }}}
*
* For more examples refer to:
* [[https://github.com/JohnSnowLabs/spark-nlp/examples/python/reader/SparkNLP_XML_Reader_Demo.ipynb notebook]]
*/
class XMLReader(
storeContent: Boolean = false,
xmlKeepTags: Boolean = false,
onlyLeafNodes: Boolean = true)
extends Serializable {

private lazy val spark = ResourceHelper.spark

private var outputColumn = "xml"

def setOutputColumn(value: String): this.type = {
require(value.nonEmpty, "Output column name cannot be empty.")
outputColumn = value
this
}

def read(inputSource: String): DataFrame = {
if (validFile(inputSource)) {
val xmlDf = datasetWithTextFile(spark, inputSource)
.withColumn(outputColumn, parseXmlUDF(col("content")))
if (storeContent) xmlDf.select("path", "content", outputColumn)
else xmlDf.select("path", outputColumn)
} else throw new IllegalArgumentException(s"Invalid inputSource: $inputSource")
}

private val parseXmlUDF = udf((xml: String) => {
parseXml(xml)
})

def parseXml(xmlString: String): List[HTMLElement] = {
val xml = XML.loadString(xmlString)
val elements = ListBuffer[HTMLElement]()

def traverse(node: Node, parentId: Option[String]): Unit = {
node match {
case elem: Elem =>
val tagName = elem.label.toLowerCase
val textContent = elem.text.trim
val elementId = hash(tagName + textContent)

val isLeaf = !elem.child.exists(_.isInstanceOf[Elem])

if (!onlyLeafNodes || isLeaf) {
val elementType = tagName match {
case "title" | "author" => ElementType.TITLE
case _ => ElementType.UNCATEGORIZED_TEXT
}

val metadata = mutable.Map[String, String]("elementId" -> elementId)
if (xmlKeepTags) metadata += ("tag" -> tagName)
parentId.foreach(id => metadata += ("parentId" -> id))

val content = if (isLeaf) textContent else ""
elements += HTMLElement(elementType, content, metadata)
}

// Traverse children
elem.child.foreach(traverse(_, Some(elementId)))

case _ => // Ignore other types
}
}

traverse(xml, None)
elements.toList
}

def hash(s: String): String = {
java.security.MessageDigest
.getInstance("MD5")
.digest(s.getBytes)
.map("%02x".format(_))
.mkString
}

}
Loading