Skip to content

Commit eb61ff6

Browse files
committed
🚩 Generalize region subset method to DataFrames
Make bounding box subsetting work on DataFrames too! This includes pandas, dask and cudf DataFrames. Included a parametrized test for pandas and dask, the cudf one should work too since the APIs are similar. The original xarray.DataArray subsetter code will still work.
1 parent 8c2bf32 commit eb61ff6

6 files changed

Lines changed: 45 additions & 15 deletions

File tree

atl11_play.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@
303303
"# Do the actual computation to find data points within region of interest\n",
304304
"placename: str = \"kamb\" # Select Kamb Ice Stream region\n",
305305
"region: deepicedrain.Region = regions[placename]\n",
306-
"ds_subset: xr.Dataset = region.subset(ds=ds)\n",
306+
"ds_subset: xr.Dataset = region.subset(data=ds)\n",
307307
"ds_subset = ds_subset.unify_chunks()\n",
308308
"ds_subset = ds_subset.compute()"
309309
]

atl11_play.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@
181181
# Do the actual computation to find data points within region of interest
182182
placename: str = "kamb" # Select Kamb Ice Stream region
183183
region: deepicedrain.Region = regions[placename]
184-
ds_subset: xr.Dataset = region.subset(ds=ds)
184+
ds_subset: xr.Dataset = region.subset(data=ds)
185185
ds_subset = ds_subset.unify_chunks()
186186
ds_subset = ds_subset.compute()
187187

atlxi_dhdt.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@
218218
"# Subset dataset to geographic region of interest\n",
219219
"placename: str = \"antarctica\"\n",
220220
"region: deepicedrain.Region = regions[placename]\n",
221-
"# ds = region.subset(ds=ds)"
221+
"# ds = region.subset(data=ds)"
222222
]
223223
},
224224
{
@@ -901,7 +901,7 @@
901901
"region: deepicedrain.Region = regions[placename]\n",
902902
"if not os.path.exists(f\"ATLXI/df_dhdt_{placename}.parquet\"):\n",
903903
" # Subset dataset to geographic region of interest\n",
904-
" ds_subset: xr.Dataset = region.subset(ds=ds_dhdt)\n",
904+
" ds_subset: xr.Dataset = region.subset(data=ds_dhdt)\n",
905905
" # Add a UTC_time column to the dataframe\n",
906906
" ds_subset[\"utc_time\"] = deepicedrain.deltatime_to_utctime(\n",
907907
" dataarray=ds_subset.delta_ds_subsettime\n",

atlxi_dhdt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@
140140
# Subset dataset to geographic region of interest
141141
placename: str = "antarctica"
142142
region: deepicedrain.Region = regions[placename]
143-
# ds = region.subset(ds=ds)
143+
# ds = region.subset(data=ds)
144144

145145
# %%
146146
# We need at least 2 points to draw a trend line or compute differences
@@ -398,7 +398,7 @@
398398
region: deepicedrain.Region = regions[placename]
399399
if not os.path.exists(f"ATLXI/df_dhdt_{placename}.parquet"):
400400
# Subset dataset to geographic region of interest
401-
ds_subset: xr.Dataset = region.subset(ds=ds_dhdt)
401+
ds_subset: xr.Dataset = region.subset(data=ds_dhdt)
402402
# Add a UTC_time column to the dataframe
403403
ds_subset["utc_time"] = deepicedrain.deltatime_to_utctime(
404404
dataarray=ds_subset.delta_ds_subsettime

deepicedrain/spatiotemporal.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,26 @@ def datashade(
7474
)
7575

7676
def subset(
77-
self, ds: xr.Dataset, x_dim: str = "x", y_dim: str = "y", drop: bool = True
77+
self, data: xr.Dataset, x_dim: str = "x", y_dim: str = "y", drop: bool = True
7878
) -> xr.Dataset:
7979
"""
80-
Convenience function to find datapoints in an xarray.Dataset
81-
that fit within the bounding boxes of this region
80+
Convenience function to find datapoints in an xarray.Dataset or
81+
pandas.DataFrame that fit within the bounding boxes of this region.
82+
Note that the 'drop' boolean flag is only valid for xarray.Dataset.
8283
"""
8384
cond = np.logical_and(
84-
np.logical_and(ds[x_dim] > self.xmin, ds[x_dim] < self.xmax),
85-
np.logical_and(ds[y_dim] > self.ymin, ds[y_dim] < self.ymax),
85+
np.logical_and(data[x_dim] > self.xmin, data[x_dim] < self.xmax),
86+
np.logical_and(data[y_dim] > self.ymin, data[y_dim] < self.ymax),
8687
)
8788

88-
return ds.where(cond=cond, drop=drop)
89+
try:
90+
# xarray.DataArray subset method
91+
data_subset = data.where(cond=cond, drop=drop)
92+
except TypeError:
93+
# pandas.DataFrame subset method
94+
data_subset = data.loc[cond]
95+
96+
return data_subset
8997

9098

9199
def deltatime_to_utctime(

deepicedrain/tests/test_region.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import xarray as xr
99

10+
import dask.dataframe
1011
from deepicedrain import Region, catalog, lonlat_to_xy
1112

1213

@@ -51,7 +52,7 @@ def test_region_datashade():
5152

5253
atl11_dataset: xr.Dataset = catalog.test_data.atl11_test_case.to_dask()
5354
atl11_dataset["x"], atl11_dataset["y"] = lonlat_to_xy(
54-
longitude=atl11_dataset.longitude, latitude=atl11_dataset.latitude, epsg=3995,
55+
longitude=atl11_dataset.longitude, latitude=atl11_dataset.latitude, epsg=3995
5556
)
5657
atl11_dataset = atl11_dataset.set_coords(["x", "y"])
5758
df: pd.DataFrame = atl11_dataset.h_corr.to_dataframe()
@@ -64,7 +65,7 @@ def test_region_datashade():
6465
npt.assert_allclose(agg_grid.max(), 1798.066285)
6566

6667

67-
def test_region_subset():
68+
def test_region_subset_xarray_dataset():
6869
"""
6970
Test that we can subset an xarray.Dataset based on the region's bounds
7071
"""
@@ -76,6 +77,27 @@ def test_region_subset():
7677
"y": np.linspace(start=-160, stop=160, num=50),
7778
},
7879
)
79-
ds_subset = region.subset(ds=dataset)
80+
ds_subset = region.subset(data=dataset)
8081
assert isinstance(ds_subset, xr.Dataset)
8182
assert ds_subset.h_corr.shape == (24, 30)
83+
84+
85+
@pytest.mark.parametrize("dataframe_type", [pd.DataFrame, dask.dataframe.DataFrame])
86+
def test_region_subset_dataframe(dataframe_type):
87+
"""
88+
Test that we can subset a pandas or dask DataFrame based on the region's
89+
bounds
90+
"""
91+
region = Region("South Pole", -100, 100, -100, 100)
92+
dataframe = pd.DataFrame(
93+
data={
94+
"x": np.linspace(start=-200, stop=200, num=50),
95+
"y": np.linspace(start=-160, stop=160, num=50),
96+
"dhdt": np.random.rand(50),
97+
}
98+
)
99+
if dataframe_type == dask.dataframe.core.DataFrame:
100+
dataframe = dask.dataframe.from_pandas(data=dataframe, npartitions=2)
101+
df_subset = region.subset(data=dataframe)
102+
assert isinstance(df_subset, dataframe_type)
103+
assert len(df_subset.dhdt) == 24

0 commit comments

Comments
 (0)