Skip to content

Commit d7dfbc8

Browse files
Dref360lhoestqmariosasko
authored
Add ability to read-write to SQL databases. (#4928)
* Add ability to read-write to SQL databases. * Fix issue where pandas<1.4.0 doesn't return the number of rows * Fix issue where connections were not closed properly * Apply suggestions from code review Co-authored-by: Quentin Lhoest <[email protected]> * Change according to reviews * Change according to reviews * Inherit from AbstractDatasetInputStream in SqlDatasetReader * Revert typing in SQLDatasetReader as we do not support Connexion * Align API with Pandas/Daskk * Update tests * Update docs * Update some more tests * Missing comma * Small docs fix * Style * Update src/datasets/arrow_dataset.py Co-authored-by: Quentin Lhoest <[email protected]> * Update src/datasets/packaged_modules/sql/sql.py Co-authored-by: Quentin Lhoest <[email protected]> * Address some comments * Address the rest * Improve tests * sqlalchemy required tip Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: mariosasko <[email protected]>
1 parent 3029926 commit d7dfbc8

File tree

16 files changed

+600
-9
lines changed

16 files changed

+600
-9
lines changed

docs/source/loading.mdx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,24 @@ To load remote Parquet files via HTTP, pass the URLs instead:
196196
>>> wiki = load_dataset("parquet", data_files=data_files, split="train")
197197
```
198198

199+
### SQL
200+
201+
Read database contents with with [`Dataset.from_sql`]. Both table names and queries are supported.
202+
203+
For example, a table from a SQLite file can be loaded with:
204+
205+
```py
206+
>>> from datasets import Dataset
207+
>>> dataset = Dataset.from_sql("data_table", "sqlite:///sqlite_file.db")
208+
```
209+
210+
Use a query for a more precise read:
211+
212+
```py
213+
>>> from datasets import Dataset
214+
>>> dataset = Dataset.from_sql("SELECT text FROM data_table WHERE length(text) > 100 LIMIT 10", "sqlite:///sqlite_file.db")
215+
```
216+
199217
## In-memory data
200218

201219
🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames.

docs/source/package_reference/loading_methods.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t")
6565

6666
[[autodoc]] datasets.packaged_modules.parquet.ParquetConfig
6767

68+
### SQL
69+
70+
[[autodoc]] datasets.packaged_modules.sql.SqlConfig
71+
6872
### Images
6973

7074
[[autodoc]] datasets.packaged_modules.imagefolder.ImageFolderConfig

docs/source/package_reference/main_classes.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table.
5858
- to_dict
5959
- to_json
6060
- to_parquet
61+
- to_sql
6162
- add_faiss_index
6263
- add_faiss_index_from_external_arrays
6364
- save_faiss_index
@@ -90,6 +91,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table.
9091
- from_json
9192
- from_parquet
9293
- from_text
94+
- from_sql
9395
- prepare_for_task
9496
- align_labels_with_mapping
9597

docs/source/process.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ Want to save your dataset to a cloud storage provider? Read our [Cloud Storage](
609609
| CSV | [`Dataset.to_csv`] |
610610
| JSON | [`Dataset.to_json`] |
611611
| Parquet | [`Dataset.to_parquet`] |
612+
| SQL | [`Dataset.to_sql`] |
612613
| In-memory Python object | [`Dataset.to_pandas`] or [`Dataset.to_dict`] |
613614

614615
For example, export your dataset to a CSV file like this:

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
"scipy",
163163
"sentencepiece", # for bleurt
164164
"seqeval",
165+
"sqlalchemy",
165166
"tldextract",
166167
# to speed up pip backtracking
167168
"toml>=0.10.1",

src/datasets/arrow_dataset.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@
113113

114114

115115
if TYPE_CHECKING:
116+
import sqlite3
117+
118+
import sqlalchemy
119+
116120
from .dataset_dict import DatasetDict
117121

118122
logger = logging.get_logger(__name__)
@@ -1097,6 +1101,56 @@ def from_text(
10971101
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
10981102
).read()
10991103

1104+
@staticmethod
1105+
def from_sql(
1106+
sql: Union[str, "sqlalchemy.sql.Selectable"],
1107+
con: str,
1108+
features: Optional[Features] = None,
1109+
cache_dir: str = None,
1110+
keep_in_memory: bool = False,
1111+
**kwargs,
1112+
):
1113+
"""Create Dataset from SQL query or database table.
1114+
1115+
Args:
1116+
sql (`str` or :obj:`sqlalchemy.sql.Selectable`): SQL query to be executed or a table name.
1117+
con (`str`): A connection URI string used to instantiate a database connection.
1118+
features (:class:`Features`, optional): Dataset features.
1119+
cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data.
1120+
keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.
1121+
**kwargs (additional keyword arguments): Keyword arguments to be passed to :class:`SqlConfig`.
1122+
1123+
Returns:
1124+
:class:`Dataset`
1125+
1126+
Example:
1127+
1128+
```py
1129+
>>> # Fetch a database table
1130+
>>> ds = Dataset.from_sql("test_data", "postgres:///db_name")
1131+
>>> # Execute a SQL query on the table
1132+
>>> ds = Dataset.from_sql("SELECT sentence FROM test_data", "postgres:///db_name")
1133+
>>> # Use a Selectable object to specify the query
1134+
>>> from sqlalchemy import select, text
1135+
>>> stmt = select([text("sentence")]).select_from(text("test_data"))
1136+
>>> ds = Dataset.from_sql(stmt, "postgres:///db_name")
1137+
```
1138+
1139+
<Tip {warning=true}>
1140+
`sqlalchemy` needs to be installed to use this function.
1141+
</Tip>
1142+
"""
1143+
from .io.sql import SqlDatasetReader
1144+
1145+
return SqlDatasetReader(
1146+
sql,
1147+
con,
1148+
features=features,
1149+
cache_dir=cache_dir,
1150+
keep_in_memory=keep_in_memory,
1151+
**kwargs,
1152+
).read()
1153+
11001154
def __del__(self):
11011155
if hasattr(self, "_data"):
11021156
del self._data
@@ -4153,6 +4207,43 @@ def to_parquet(
41534207

41544208
return ParquetDatasetWriter(self, path_or_buf, batch_size=batch_size, **parquet_writer_kwargs).write()
41554209

4210+
def to_sql(
4211+
self,
4212+
name: str,
4213+
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
4214+
batch_size: Optional[int] = None,
4215+
**sql_writer_kwargs,
4216+
) -> int:
4217+
"""Exports the dataset to a SQL database.
4218+
4219+
Args:
4220+
name (`str`): Name of SQL table.
4221+
con (`str` or `sqlite3.Connection` or `sqlalchemy.engine.Connection` or `sqlalchemy.engine.Connection`):
4222+
A database connection URI string or an existing SQLite3/SQLAlchemy connection used to write to a database.
4223+
batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once.
4224+
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
4225+
**sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql`
4226+
4227+
Returns:
4228+
int: The number of records written.
4229+
4230+
Example:
4231+
4232+
```py
4233+
>>> # con provided as a connection URI string
4234+
>>> ds.to_sql("data", "sqlite:///my_own_db.sql")
4235+
>>> # con provided as a sqlite3 connection object
4236+
>>> import sqlite3
4237+
>>> con = sqlite3.connect("my_own_db.sql")
4238+
>>> with con:
4239+
... ds.to_sql("data", con)
4240+
```
4241+
"""
4242+
# Dynamic import to avoid circular dependency
4243+
from .io.sql import SqlDatasetWriter
4244+
4245+
return SqlDatasetWriter(self, name, con, batch_size=batch_size, **sql_writer_kwargs).write()
4246+
41564247
def _push_parquet_shards_to_hub(
41574248
self,
41584249
repo_id: str,

src/datasets/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@
125125
logger.info("Disabling Apache Beam because USE_BEAM is set to False")
126126

127127

128+
# Optional tools for data loading
129+
SQLALCHEMY_AVAILABLE = importlib.util.find_spec("sqlalchemy") is not None
130+
128131
# Optional tools for feature decoding
129132
PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None
130133

src/datasets/io/sql.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import multiprocessing
2+
from typing import TYPE_CHECKING, Optional, Union
3+
4+
from .. import Dataset, Features, config
5+
from ..formatting import query_table
6+
from ..packaged_modules.sql.sql import Sql
7+
from ..utils import logging
8+
from .abc import AbstractDatasetInputStream
9+
10+
11+
if TYPE_CHECKING:
12+
import sqlite3
13+
14+
import sqlalchemy
15+
16+
17+
class SqlDatasetReader(AbstractDatasetInputStream):
18+
def __init__(
19+
self,
20+
sql: Union[str, "sqlalchemy.sql.Selectable"],
21+
con: str,
22+
features: Optional[Features] = None,
23+
cache_dir: str = None,
24+
keep_in_memory: bool = False,
25+
**kwargs,
26+
):
27+
super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs)
28+
self.builder = Sql(
29+
cache_dir=cache_dir,
30+
features=features,
31+
sql=sql,
32+
con=con,
33+
**kwargs,
34+
)
35+
36+
def read(self):
37+
download_config = None
38+
download_mode = None
39+
ignore_verifications = False
40+
use_auth_token = None
41+
base_path = None
42+
43+
self.builder.download_and_prepare(
44+
download_config=download_config,
45+
download_mode=download_mode,
46+
ignore_verifications=ignore_verifications,
47+
# try_from_hf_gcs=try_from_hf_gcs,
48+
base_path=base_path,
49+
use_auth_token=use_auth_token,
50+
)
51+
52+
# Build dataset for splits
53+
dataset = self.builder.as_dataset(
54+
split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
55+
)
56+
return dataset
57+
58+
59+
class SqlDatasetWriter:
60+
def __init__(
61+
self,
62+
dataset: Dataset,
63+
name: str,
64+
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
65+
batch_size: Optional[int] = None,
66+
num_proc: Optional[int] = None,
67+
**to_sql_kwargs,
68+
):
69+
70+
if num_proc is not None and num_proc <= 0:
71+
raise ValueError(f"num_proc {num_proc} must be an integer > 0.")
72+
73+
self.dataset = dataset
74+
self.name = name
75+
self.con = con
76+
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
77+
self.num_proc = num_proc
78+
self.to_sql_kwargs = to_sql_kwargs
79+
80+
def write(self) -> int:
81+
_ = self.to_sql_kwargs.pop("sql", None)
82+
_ = self.to_sql_kwargs.pop("con", None)
83+
84+
written = self._write(**self.to_sql_kwargs)
85+
return written
86+
87+
def _batch_sql(self, args):
88+
offset, to_sql_kwargs = args
89+
to_sql_kwargs = {**to_sql_kwargs, "if_exists": "append"} if offset > 0 else to_sql_kwargs
90+
batch = query_table(
91+
table=self.dataset.data,
92+
key=slice(offset, offset + self.batch_size),
93+
indices=self.dataset._indices,
94+
)
95+
df = batch.to_pandas()
96+
num_rows = df.to_sql(self.name, self.con, **to_sql_kwargs)
97+
return num_rows or len(df)
98+
99+
def _write(self, **to_sql_kwargs) -> int:
100+
"""Writes the pyarrow table as SQL to a database.
101+
102+
Caller is responsible for opening and closing the SQL connection.
103+
"""
104+
written = 0
105+
106+
if self.num_proc is None or self.num_proc == 1:
107+
for offset in logging.tqdm(
108+
range(0, len(self.dataset), self.batch_size),
109+
unit="ba",
110+
disable=not logging.is_progress_bar_enabled(),
111+
desc="Creating SQL from Arrow format",
112+
):
113+
written += self._batch_sql((offset, to_sql_kwargs))
114+
else:
115+
num_rows, batch_size = len(self.dataset), self.batch_size
116+
with multiprocessing.Pool(self.num_proc) as pool:
117+
for num_rows in logging.tqdm(
118+
pool.imap(
119+
self._batch_sql,
120+
[(offset, to_sql_kwargs) for offset in range(0, num_rows, batch_size)],
121+
),
122+
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
123+
unit="ba",
124+
disable=not logging.is_progress_bar_enabled(),
125+
desc="Creating SQL from Arrow format",
126+
):
127+
written += num_rows
128+
129+
return written

src/datasets/packaged_modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .json import json
1010
from .pandas import pandas
1111
from .parquet import parquet
12+
from .sql import sql # noqa F401
1213
from .text import text
1314

1415

src/datasets/packaged_modules/csv/csv.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def __post_init__(self):
7070
self.names = self.column_names
7171

7272
@property
73-
def read_csv_kwargs(self):
74-
read_csv_kwargs = dict(
73+
def pd_read_csv_kwargs(self):
74+
pd_read_csv_kwargs = dict(
7575
sep=self.sep,
7676
header=self.header,
7777
names=self.names,
@@ -112,16 +112,16 @@ def read_csv_kwargs(self):
112112

113113
# some kwargs must not be passed if they don't have a default value
114114
# some others are deprecated and we can also not pass them if they are the default value
115-
for read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS:
116-
if read_csv_kwargs[read_csv_parameter] == getattr(CsvConfig(), read_csv_parameter):
117-
del read_csv_kwargs[read_csv_parameter]
115+
for pd_read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS:
116+
if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter):
117+
del pd_read_csv_kwargs[pd_read_csv_parameter]
118118

119119
# Remove 1.3 new arguments
120120
if not (datasets.config.PANDAS_VERSION.major >= 1 and datasets.config.PANDAS_VERSION.minor >= 3):
121-
for read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS:
122-
del read_csv_kwargs[read_csv_parameter]
121+
for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS:
122+
del pd_read_csv_kwargs[pd_read_csv_parameter]
123123

124-
return read_csv_kwargs
124+
return pd_read_csv_kwargs
125125

126126

127127
class Csv(datasets.ArrowBasedBuilder):
@@ -172,7 +172,7 @@ def _generate_tables(self, files):
172172
else None
173173
)
174174
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
175-
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)
175+
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs)
176176
try:
177177
for batch_idx, df in enumerate(csv_file_reader):
178178
pa_table = pa.Table.from_pandas(df)

0 commit comments

Comments
 (0)