|
| 1 | +import time |
| 2 | +import numpy as np |
| 3 | +import xarray as xr |
| 4 | +from mllam_data_prep.ops.cropping import ( |
| 5 | + create_convex_hull_mask, |
| 6 | + distance_to_convex_hull_boundary, |
| 7 | +) |
| 8 | + |
| 9 | +def benchmark_cropping(grid_size=1000): |
| 10 | + print(f"--- Benchmarking grid size: {grid_size}x{grid_size} ({grid_size**2} points) ---") |
| 11 | + |
| 12 | + lon_vals = np.linspace(-180, 180, grid_size) |
| 13 | + lat_vals = np.linspace(-90, 90, grid_size) |
| 14 | + xs, ys = np.meshgrid(lon_vals, lat_vals) |
| 15 | + lon_da = xr.DataArray(xs.flatten(), dims="grid_index") |
| 16 | + lat_da = xr.DataArray(ys.flatten(), dims="grid_index") |
| 17 | + ds = xr.Dataset(coords={"longitude": lon_da, "latitude": lat_da}) |
| 18 | + |
| 19 | + ref_lon_vals = np.linspace(0, 20, 20) |
| 20 | + ref_lat_vals = np.linspace(40, 60, 20) |
| 21 | + r_xs, r_ys = np.meshgrid(ref_lon_vals, ref_lat_vals) |
| 22 | + ref_lon_da = xr.DataArray(r_xs.flatten(), dims="grid_index") |
| 23 | + ref_lat_da = xr.DataArray(r_ys.flatten(), dims="grid_index") |
| 24 | + ds_ref = xr.Dataset(coords={"longitude": ref_lon_da, "latitude": ref_lat_da}) |
| 25 | + |
| 26 | + print("Benchmarking create_convex_hull_mask...") |
| 27 | + t0 = time.time() |
| 28 | + da_mask, chull_lat_lons = create_convex_hull_mask(ds, ds_ref) |
| 29 | + mask_time = time.time() - t0 |
| 30 | + num_inside = int(da_mask.sum()) |
| 31 | + print(f"-> Hull Mask Time: {mask_time:.3f} seconds (Found {num_inside} inside)") |
| 32 | + |
| 33 | + print("Benchmarking distance_to_convex_hull_boundary...") |
| 34 | + t1 = time.time() |
| 35 | + da_mindist = distance_to_convex_hull_boundary(ds, ds_ref) |
| 36 | + dist_time = time.time() - t1 |
| 37 | + print(f"-> Arc Distance Time: {dist_time:.3f} seconds") |
| 38 | + |
| 39 | +if __name__ == "__main__": |
| 40 | + benchmark_cropping(grid_size=100) # 10k points baseline |
| 41 | + benchmark_cropping(grid_size=500) # 250k points scale test |
0 commit comments