Skip to content

Commit ab7340b

Browse files
authored
Add configuration variable for CPU/GPU decimal support (#1131)
* Add configuration variable for decimal support * Make decimal import lazy
1 parent 72c93d5 commit ab7340b

6 files changed

Lines changed: 70 additions & 40 deletions

File tree

conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def pytest_runtest_setup(item):
1515
pytest.skip("need --rungpu option to run")
1616
# FIXME: P2P shuffle isn't fully supported on GPU, so we must explicitly disable it
1717
dask.config.set({"dataframe.shuffle.algorithm": "tasks"})
18+
# manually enable cudf decimal support
19+
dask.config.set({"sql.mappings.decimal_support": "cudf"})
1820
else:
1921
dask.config.set({"dataframe.shuffle.algorithm": None})
2022
if "queries" in item.keywords and not item.config.getoption("--runqueries"):

dask_sql/mappings.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
import logging
2-
from decimal import Decimal
32
from typing import Any
43

54
import dask.array as da
5+
import dask.config as dask_config
66
import dask.dataframe as dd
77
import numpy as np
88
import pandas as pd
99

1010
from dask_planner.rust import DaskTypeMap, SqlTypeName
1111

12-
try:
13-
import cudf
14-
except ImportError:
15-
cudf = None
16-
1712
logger = logging.getLogger(__name__)
1813

1914

@@ -54,7 +49,7 @@
5449
_SQL_TO_PYTHON_SCALARS = {
5550
"SqlTypeName.DOUBLE": np.float64,
5651
"SqlTypeName.FLOAT": np.float32,
57-
"SqlTypeName.DECIMAL": Decimal,
52+
"SqlTypeName.DECIMAL": np.float32,
5853
"SqlTypeName.BIGINT": np.int64,
5954
"SqlTypeName.INTEGER": np.int32,
6055
"SqlTypeName.SMALLINT": np.int16,
@@ -71,8 +66,7 @@
7166
_SQL_TO_PYTHON_FRAMES = {
7267
"SqlTypeName.DOUBLE": np.float64,
7368
"SqlTypeName.FLOAT": np.float32,
74-
# a column of Decimals in pandas is `object`, but cuDF has a dedicated dtype
75-
"SqlTypeName.DECIMAL": object if not cudf else cudf.Decimal128Dtype(38, 10),
69+
"SqlTypeName.DECIMAL": np.float64, # We use np.float64 always, even though we might be able to use a smaller type
7670
"SqlTypeName.BIGINT": pd.Int64Dtype(),
7771
"SqlTypeName.INTEGER": pd.Int32Dtype(),
7872
"SqlTypeName.SMALLINT": pd.Int16Dtype(),
@@ -151,6 +145,14 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any:
151145

152146
return literal_value
153147

148+
elif (
149+
sql_type == SqlTypeName.DECIMAL
150+
and dask_config.get("sql.mappings.decimal_support") == "cudf"
151+
):
152+
from decimal import Decimal
153+
154+
python_type = Decimal
155+
154156
elif sql_type == SqlTypeName.INTERVAL_DAY:
155157
return np.timedelta64(literal_value[0], "D") + np.timedelta64(
156158
literal_value[1], "ms"
@@ -219,7 +221,16 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any:
219221
def sql_to_python_type(sql_type: "SqlTypeName", *args) -> type:
220222
"""Turn an SQL type into a dataframe dtype"""
221223
try:
222-
if str(sql_type) == "SqlTypeName.DECIMAL":
224+
if (
225+
sql_type == SqlTypeName.DECIMAL
226+
and dask_config.get("sql.mappings.decimal_support") == "cudf"
227+
):
228+
try:
229+
import cudf
230+
except ImportError:
231+
raise ModuleNotFoundError(
232+
"Setting `sql.mappings.decimal_support=cudf` requires cudf"
233+
)
223234
return cudf.Decimal128Dtype(*args)
224235
return _SQL_TO_PYTHON_FRAMES[str(sql_type)]
225236
except KeyError: # pragma: no cover

dask_sql/sql-schema.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,12 @@ properties:
7575
optimization (when possible). ``nelem`` is defined as the limit or ``k`` value times the
7676
number of columns. Default is 1000000, corresponding to a LIMIT clause of 1 million in a
7777
1 column table.
78+
79+
mappings:
80+
type: object
81+
properties:
82+
83+
decimal_support:
84+
type: string
85+
description:
86+
Decides how to handle decimal scalars/columns. ``"pandas"`` handling will treat decimals scalars and columns as floats and float64 columns, respectively, while ``"cudf"`` handling treats decimal scalars as ``decimal.Decimal`` objects and decimal columns as ``cudf.Decimal128Dtype`` columns, handling precision/scale accordingly. Default is ``"pandas"``, but ``"cudf"`` should be used if attempting to work with decimal columns on GPU.

dask_sql/sql.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ sql:
1818

1919
sort:
2020
topk-nelem-limit: 1000000
21+
22+
mappings:
23+
decimal_support: "pandas"

tests/integration/test_filter.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,20 +251,16 @@ def test_filtered_csv(tmpdir, c):
251251
assert_eq(return_df, expected_df)
252252

253253

254-
@pytest.mark.gpu
255-
def test_filter_decimal(c):
256-
import cudf
257-
258-
df = cudf.DataFrame(
254+
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
255+
def test_filter_decimal(c, gpu):
256+
df = pd.DataFrame(
259257
{
260258
"a": [304.5, 35.305, 9.043, 102.424, 53.34],
261259
"b": [2.2, 82.4, 42, 76.9, 54.4],
262260
"c": [1, 2, 2, 5, 9],
263261
}
264262
)
265-
df["a"] = df["a"].astype(cudf.Decimal64Dtype(12, 3))
266-
df["b"] = df["b"].astype(cudf.Decimal64Dtype(7, 1))
267-
c.create_table("df", df)
263+
c.create_table("df", df, gpu=gpu)
268264

269265
result_df = c.sql(
270266
"""
@@ -273,7 +269,7 @@ def test_filter_decimal(c):
273269
FROM
274270
df
275271
WHERE
276-
a < b
272+
CAST(a AS DECIMAL) < CAST(b AS DECIMAL)
277273
"""
278274
)
279275

@@ -284,16 +280,19 @@ def test_filter_decimal(c):
284280
result_df = c.sql(
285281
"""
286282
SELECT
287-
b
283+
CAST(b AS DECIMAL) as b
288284
FROM
289285
df
290286
WHERE
291-
a < decimal '100.2'
287+
CAST(a AS DECIMAL) < DECIMAL '100.2'
292288
"""
293289
)
294290

295-
expected_df = cudf.DataFrame({"b": [82.4, 42, 54.4]})
296-
expected_df["b"] = expected_df["b"].astype(cudf.Decimal64Dtype(7, 1))
291+
# decimal precision doesn't match up with pandas floats
292+
if gpu:
293+
result_df["b"] = result_df["b"].astype("float64")
294+
295+
expected_df = df.loc[df.a < 100.2][["b"]]
297296

298-
assert_eq(result_df.reset_index(drop=True), expected_df)
297+
assert_eq(result_df, expected_df, check_index=False)
299298
c.drop_table("df")

tests/integration/test_groupby.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -598,54 +598,60 @@ def test_groupby_split_every(c, gpu):
598598
c.drop_table("split_every_input")
599599

600600

601-
@pytest.mark.gpu
602-
def test_agg_decimal(c):
603-
import cudf
604-
605-
df = cudf.DataFrame(
601+
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
602+
def test_agg_decimal(c, gpu):
603+
df = pd.DataFrame(
606604
{
607605
"a": [1.23, 12.65, 134.64, -34.3, 945.19],
608606
"b": [1, 1, 2, 2, 3],
609607
}
610608
)
611-
df["a"] = df["a"].astype(cudf.Decimal64Dtype(10, 2))
612609

613-
c.create_table("df", df, gpu=True)
610+
c.create_table("df", df, gpu=gpu)
614611

615612
result_df = c.sql(
616613
"""
617614
SELECT
618-
SUM(a) as s,
619-
COUNT(a) as c,
620-
SUM(a+a) as s2
615+
SUM(CAST(a AS DECIMAL)) as s,
616+
COUNT(CAST(a AS DECIMAL)) as c,
617+
SUM(CAST(a+a AS DECIMAL)) as s2
621618
FROM
622619
df
623620
GROUP BY
624621
b
625622
"""
626623
)
624+
# decimal precision doesn't match up with pandas floats
625+
if gpu:
626+
result_df["s"] = result_df["s"].astype("float64")
627+
result_df["s2"] = result_df["s2"].astype("float64")
627628

628-
expected_df = cudf.DataFrame(
629+
expected_df = pd.DataFrame(
629630
{
630631
"s": df.groupby("b").sum()["a"],
631-
"c": df.groupby("b").count()["a"].astype("int64"),
632+
"c": df.groupby("b").count()["a"],
632633
"s2": df.groupby("b").sum()["a"] + df.groupby("b").sum()["a"],
633634
}
634635
)
635636

636-
assert_eq(result_df, expected_df.reset_index(drop=True))
637+
# dtype of count aggregation is float on gpu
638+
assert_eq(result_df, expected_df, check_index=False, check_dtype=(not gpu))
637639

638640
result_df = c.sql(
639641
"""
640642
SELECT
641-
MIN(a) as min,
642-
MAX(a) as max
643+
MIN(CAST(a AS DECIMAL)) as min,
644+
MAX(CAST(a AS DECIMAL)) as max
643645
FROM
644646
df
645647
"""
646648
)
649+
# decimal precision doesn't match up with pandas floats
650+
if gpu:
651+
result_df["min"] = result_df["min"].astype("float64")
652+
result_df["max"] = result_df["max"].astype("float64")
647653

648-
expected_df = cudf.DataFrame(
654+
expected_df = pd.DataFrame(
649655
{
650656
"min": [df.a.min()],
651657
"max": [df.a.max()],

0 commit comments

Comments
 (0)