diff --git a/pygmt/tests/test_xarray_accessor.py b/pygmt/tests/test_xarray_accessor.py index 41a8458fac9..99098836be3 100644 --- a/pygmt/tests/test_xarray_accessor.py +++ b/pygmt/tests/test_xarray_accessor.py @@ -14,10 +14,19 @@ from pygmt.datasets import load_earth_relief from pygmt.enums import GridRegistration, GridType from pygmt.exceptions import GMTValueError +from pygmt.helpers.testing import load_static_earth_relief _HAS_NETCDF4 = bool(importlib.util.find_spec("netCDF4")) +@pytest.fixture(scope="module", name="grid") +def fixture_grid(): + """ + Load the grid data from the sample earth_relief file. + """ + return load_static_earth_relief() + + def test_xarray_accessor_gridline_cartesian(): """ Check that the accessor returns the correct registration and gtype values for a @@ -169,3 +178,44 @@ def test_xarray_accessor_tiled_grid_slice_and_add(): added_grid.gmt.gtype = GridType.GEOGRAPHIC assert added_grid.gmt.registration is GridRegistration.PIXEL assert added_grid.gmt.gtype is GridType.GEOGRAPHIC + + +def test_xarray_accessor_clip(grid): + """ + Check that the accessor has the clip method and that it works correctly. + + This test is adapted from the `test_grdclip_no_outgrid` test. + """ + clipped_grid = grid.gmt.clip( + below=[550, -1000], above=[700, 1000], region=[-53, -49, -19, -16] + ) + + expected_clipped_grid = xr.DataArray( + data=[ + [1000.0, 570.5, -1000.0, -1000.0], + [1000.0, 1000.0, 571.5, 638.5], + [555.5, 556.0, 580.0, 1000.0], + ], + coords={"lon": [-52.5, -51.5, -50.5, -49.5], "lat": [-18.5, -17.5, -16.5]}, + dims=["lat", "lon"], + ) + xr.testing.assert_allclose(a=clipped_grid, b=expected_clipped_grid) + + +def test_xarray_accessor_histeq(grid): + """ + Check that the accessor has the histeq method and that it works correctly. + + This test is adapted from the `test_equalize_grid_no_outgrid` test. + """ + equalized_grid = grid.gmt.histeq(divisions=2, region=[-52, -48, -22, -18]) + + expected_equalized_grid = xr.DataArray( + data=[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 1, 1], [1, 1, 1, 1]], + coords={ + "lon": [-51.5, -50.5, -49.5, -48.5], + "lat": [-21.5, -20.5, -19.5, -18.5], + }, + dims=["lat", "lon"], + ) + xr.testing.assert_allclose(a=equalized_grid, b=expected_equalized_grid) diff --git a/pygmt/xarray/accessor.py b/pygmt/xarray/accessor.py index 2899dfe712f..05e162ed78f 100644 --- a/pygmt/xarray/accessor.py +++ b/pygmt/xarray/accessor.py @@ -3,12 +3,25 @@ """ import contextlib +import functools from pathlib import Path import xarray as xr from pygmt.enums import GridRegistration, GridType from pygmt.exceptions import GMTValueError -from pygmt.src.grdinfo import grdinfo +from pygmt.src import ( + dimfilter, + grdclip, + grdcut, + grdfill, + grdfilter, + grdgradient, + grdhisteq, + grdinfo, + grdproject, + grdsample, + grdtrack, +) @xr.register_dataarray_accessor("gmt") @@ -23,6 +36,11 @@ class GMTDataArrayAccessor: - ``registration``: Grid registration type :class:`pygmt.enums.GridRegistration`. - ``gtype``: Grid coordinate system type :class:`pygmt.enums.GridType`. + The *gmt* accessor also provides a set of grid-operation methods that enables + applying GMT's grid processing functionalities directly to the current + :class:`xarray.DataArray` object. See the summary table below for the list of + available methods. + Notes ----- When accessed the first time, the *gmt* accessor will first be initialized to the @@ -150,6 +168,19 @@ class GMTDataArrayAccessor: >>> zval.gmt.gtype = GridType.GEOGRAPHIC >>> zval.gmt.registration, zval.gmt.gtype (, ) + + Instead of calling a grid-processing function and passing the + :class:`xarray.DataArray` object as an input, you can call the corresponding method + directly on the object. For example, the following two are equivalent: + + >>> from pygmt.datasets import load_earth_relief + >>> grid = load_earth_relief(resolution="30m", region=[10, 30, 15, 25]) + >>> # Create a new grid from an input grid. Set all values below 1,000 to 0 and all + >>> # values above 1,500 to 10,000. + >>> # Option 1: + >>> new_grid = pygmt.grdclip(grid=grid, below=[1000, 0], above=[1500, 10000]) + >>> # Option 2: + >>> new_grid = grid.gmt.clip(below=[1000, 0], above=[1500, 10000]) """ def __init__(self, xarray_obj: xr.DataArray): @@ -200,3 +231,29 @@ def gtype(self, value: GridType | int): value, description="grid coordinate system type", choices=GridType ) self._gtype = GridType(value) + + @staticmethod + def _make_method(func): + """ + Create a wrapper method for PyGMT grid-processing methods. + + The :class:`xarray.DataArray` object is passed as the first argument. + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + return func(self._obj, *args, **kwargs) + + return wrapper + + # Accessor methods for grid operations. + clip = _make_method(grdclip) + cut = _make_method(grdcut) + dimfilter = _make_method(dimfilter) + histeq = _make_method(grdhisteq.equalize_grid) + fill = _make_method(grdfill) + filter = _make_method(grdfilter) + gradient = _make_method(grdgradient) + project = _make_method(grdproject) + sample = _make_method(grdsample) + track = _make_method(grdtrack)