diff --git a/asv_bench/benchmarks/pandas.py b/asv_bench/benchmarks/pandas.py index 2a296ecc4d0..9bda5970a4c 100644 --- a/asv_bench/benchmarks/pandas.py +++ b/asv_bench/benchmarks/pandas.py @@ -29,19 +29,20 @@ def time_from_series(self, dtype, subset): class ToDataFrame: def setup(self, *args, **kwargs): xp = kwargs.get("xp", np) + nvars = kwargs.get("nvars", 1) random_kws = kwargs.get("random_kws", {}) method = kwargs.get("method", "to_dataframe") dim1 = 10_000 dim2 = 10_000 + + var = xr.Variable( + dims=("dim1", "dim2"), data=xp.random.random((dim1, dim2), **random_kws) + ) + data_vars = {f"long_name_{v}": (("dim1", "dim2"), var) for v in range(nvars)} + ds = xr.Dataset( - { - "x": xr.DataArray( - data=xp.random.random((dim1, dim2), **random_kws), - dims=["dim1", "dim2"], - coords={"dim1": np.arange(0, dim1), "dim2": np.arange(0, dim2)}, - ) - } + data_vars, coords={"dim1": np.arange(0, dim1), "dim2": np.arange(0, dim2)} ) self.to_frame = getattr(ds, method) @@ -58,4 +59,6 @@ def setup(self, *args, **kwargs): import dask.array as da - super().setup(xp=da, random_kws=dict(chunks=5000), method="to_dask_dataframe") + super().setup( + xp=da, random_kws=dict(chunks=5000), method="to_dask_dataframe", nvars=500 + ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d2ecd65ba58..433c724cc21 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6465,7 +6465,10 @@ def to_dask_dataframe( columns.extend(k for k in self.coords if k not in self.dims) columns.extend(self.data_vars) + ds_chunks = self.chunks + series_list = [] + df_meta = pd.DataFrame() for name in columns: try: var = self.variables[name] @@ -6484,8 +6487,11 @@ def to_dask_dataframe( if not is_duck_dask_array(var._data): var = var.chunk() - dask_array = var.set_dims(ordered_dims).chunk(self.chunks).data - series = dd.from_array(dask_array.reshape(-1), columns=[name]) + # Broadcast then flatten the array: + var_new_dims = var.set_dims(ordered_dims).chunk(ds_chunks) + dask_array = var_new_dims._data.reshape(-1) + + series = dd.from_dask_array(dask_array, columns=name, meta=df_meta) series_list.append(series) df = dd.concat(series_list, axis=1)