|
1 | 1 | import logging |
2 | | -from decimal import Decimal |
3 | 2 | from typing import Any |
4 | 3 |
|
5 | 4 | import dask.array as da |
| 5 | +import dask.config as dask_config |
6 | 6 | import dask.dataframe as dd |
7 | 7 | import numpy as np |
8 | 8 | import pandas as pd |
9 | 9 |
|
10 | 10 | from dask_planner.rust import DaskTypeMap, SqlTypeName |
11 | 11 |
|
12 | | -try: |
13 | | - import cudf |
14 | | -except ImportError: |
15 | | - cudf = None |
16 | | - |
17 | 12 | logger = logging.getLogger(__name__) |
18 | 13 |
|
19 | 14 |
|
|
54 | 49 | _SQL_TO_PYTHON_SCALARS = { |
55 | 50 | "SqlTypeName.DOUBLE": np.float64, |
56 | 51 | "SqlTypeName.FLOAT": np.float32, |
57 | | - "SqlTypeName.DECIMAL": Decimal, |
| 52 | + "SqlTypeName.DECIMAL": np.float32, |
58 | 53 | "SqlTypeName.BIGINT": np.int64, |
59 | 54 | "SqlTypeName.INTEGER": np.int32, |
60 | 55 | "SqlTypeName.SMALLINT": np.int16, |
|
71 | 66 | _SQL_TO_PYTHON_FRAMES = { |
72 | 67 | "SqlTypeName.DOUBLE": np.float64, |
73 | 68 | "SqlTypeName.FLOAT": np.float32, |
74 | | - # a column of Decimals in pandas is `object`, but cuDF has a dedicated dtype |
75 | | - "SqlTypeName.DECIMAL": object if not cudf else cudf.Decimal128Dtype(38, 10), |
| 69 | + "SqlTypeName.DECIMAL": np.float64, # We use np.float64 always, even though we might be able to use a smaller type |
76 | 70 | "SqlTypeName.BIGINT": pd.Int64Dtype(), |
77 | 71 | "SqlTypeName.INTEGER": pd.Int32Dtype(), |
78 | 72 | "SqlTypeName.SMALLINT": pd.Int16Dtype(), |
@@ -151,6 +145,14 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any: |
151 | 145 |
|
152 | 146 | return literal_value |
153 | 147 |
|
| 148 | + elif ( |
| 149 | + sql_type == SqlTypeName.DECIMAL |
| 150 | + and dask_config.get("sql.mappings.decimal_support") == "cudf" |
| 151 | + ): |
| 152 | + from decimal import Decimal |
| 153 | + |
| 154 | + python_type = Decimal |
| 155 | + |
154 | 156 | elif sql_type == SqlTypeName.INTERVAL_DAY: |
155 | 157 | return np.timedelta64(literal_value[0], "D") + np.timedelta64( |
156 | 158 | literal_value[1], "ms" |
@@ -219,7 +221,16 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any: |
219 | 221 | def sql_to_python_type(sql_type: "SqlTypeName", *args) -> type: |
220 | 222 | """Turn an SQL type into a dataframe dtype""" |
221 | 223 | try: |
222 | | - if str(sql_type) == "SqlTypeName.DECIMAL": |
| 224 | + if ( |
| 225 | + sql_type == SqlTypeName.DECIMAL |
| 226 | + and dask_config.get("sql.mappings.decimal_support") == "cudf" |
| 227 | + ): |
| 228 | + try: |
| 229 | + import cudf |
| 230 | + except ImportError: |
| 231 | + raise ModuleNotFoundError( |
| 232 | + "Setting `sql.mappings.decimal_support=cudf` requires cudf" |
| 233 | + ) |
223 | 234 | return cudf.Decimal128Dtype(*args) |
224 | 235 | return _SQL_TO_PYTHON_FRAMES[str(sql_type)] |
225 | 236 | except KeyError: # pragma: no cover |
|
0 commit comments