diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index cd1fa552fe0..45db7a1508c 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -5,6 +5,7 @@ import re import tarfile import time +import xml.dom.minidom from asyncio import TimeoutError from io import BytesIO from itertools import chain @@ -694,6 +695,25 @@ def xet_parse(source, parser=None, use_auth_token: Optional[Union[str, bool]] = return ET.parse(f, parser=parser) +def xxml_dom_minidom_parse(filename_or_file, use_auth_token: Optional[Union[str, bool]] = None, **kwargs): + """Extend `xml.dom.minidom.parse` function to support remote files. + + Args: + filename_or_file (`str` or file): File path or file object. + use_auth_token (`bool` or `str`, *optional*): Whether to use token or token to authenticate on the + Hugging Face Hub for private remote files. + **kwargs (optional): Additional keyword arguments passed to `xml.dom.minidom.parse`. + + Returns: + :obj:`xml.dom.minidom.Document`: Parsed document. + """ + if hasattr(filename_or_file, "read"): + return xml.dom.minidom.parse(filename_or_file, **kwargs) + else: + with xopen(filename_or_file, "rb", use_auth_token=use_auth_token) as f: + return xml.dom.minidom.parse(f, **kwargs) + + class _IterableFromGenerator(Iterable): """Utility class to create an iterable from a generator function, in order to reset the generator when needed.""" diff --git a/src/datasets/streaming.py b/src/datasets/streaming.py index ec9fc07a2fd..6290f23048f 100644 --- a/src/datasets/streaming.py +++ b/src/datasets/streaming.py @@ -29,6 +29,7 @@ xsplit, xsplitext, xwalk, + xxml_dom_minidom_parse, ) from .utils.logging import get_logger from .utils.patching import patch_submodule @@ -98,6 +99,9 @@ def wrapper(*args, **kwargs): patch_submodule(module, "pd.read_csv", wrap_auth(xpandas_read_csv), attrs=["__version__"]).start() patch_submodule(module, "pd.read_excel", xpandas_read_excel, attrs=["__version__"]).start() patch_submodule(module, "sio.loadmat", wrap_auth(xsio_loadmat), attrs=["__version__"]).start() + # xml.dom.minidom + if hasattr(module, "parse") and module.parse.__module__ == "xml.dom.minidom": + patch_submodule(module, "parse", wrap_auth(xxml_dom_minidom_parse)).start() # xml.etree.ElementTree for submodule in ["ElementTree", "ET"]: patch_submodule(module, f"{submodule}.parse", wrap_auth(xet_parse)).start()