Skip to content

Commit 609a901

Browse files
authored
Improve to_dask_dataframe performance (#7844)
* Improve to_dask_dataframe performance * Add ASV test * Update pandas.py * Update dataset.py
1 parent 95bb813 commit 609a901

2 files changed

Lines changed: 19 additions & 10 deletions

File tree

asv_bench/benchmarks/pandas.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,20 @@ def time_from_series(self, dtype, subset):
2929
class ToDataFrame:
3030
def setup(self, *args, **kwargs):
3131
xp = kwargs.get("xp", np)
32+
nvars = kwargs.get("nvars", 1)
3233
random_kws = kwargs.get("random_kws", {})
3334
method = kwargs.get("method", "to_dataframe")
3435

3536
dim1 = 10_000
3637
dim2 = 10_000
38+
39+
var = xr.Variable(
40+
dims=("dim1", "dim2"), data=xp.random.random((dim1, dim2), **random_kws)
41+
)
42+
data_vars = {f"long_name_{v}": (("dim1", "dim2"), var) for v in range(nvars)}
43+
3744
ds = xr.Dataset(
38-
{
39-
"x": xr.DataArray(
40-
data=xp.random.random((dim1, dim2), **random_kws),
41-
dims=["dim1", "dim2"],
42-
coords={"dim1": np.arange(0, dim1), "dim2": np.arange(0, dim2)},
43-
)
44-
}
45+
data_vars, coords={"dim1": np.arange(0, dim1), "dim2": np.arange(0, dim2)}
4546
)
4647
self.to_frame = getattr(ds, method)
4748

@@ -58,4 +59,6 @@ def setup(self, *args, **kwargs):
5859

5960
import dask.array as da
6061

61-
super().setup(xp=da, random_kws=dict(chunks=5000), method="to_dask_dataframe")
62+
super().setup(
63+
xp=da, random_kws=dict(chunks=5000), method="to_dask_dataframe", nvars=500
64+
)

xarray/core/dataset.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6465,7 +6465,10 @@ def to_dask_dataframe(
64656465
columns.extend(k for k in self.coords if k not in self.dims)
64666466
columns.extend(self.data_vars)
64676467

6468+
ds_chunks = self.chunks
6469+
64686470
series_list = []
6471+
df_meta = pd.DataFrame()
64696472
for name in columns:
64706473
try:
64716474
var = self.variables[name]
@@ -6484,8 +6487,11 @@ def to_dask_dataframe(
64846487
if not is_duck_dask_array(var._data):
64856488
var = var.chunk()
64866489

6487-
dask_array = var.set_dims(ordered_dims).chunk(self.chunks).data
6488-
series = dd.from_array(dask_array.reshape(-1), columns=[name])
6490+
# Broadcast then flatten the array:
6491+
var_new_dims = var.set_dims(ordered_dims).chunk(ds_chunks)
6492+
dask_array = var_new_dims._data.reshape(-1)
6493+
6494+
series = dd.from_dask_array(dask_array, columns=name, meta=df_meta)
64896495
series_list.append(series)
64906496

64916497
df = dd.concat(series_list, axis=1)

0 commit comments

Comments
 (0)