Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 209 additions & 32 deletions movement/utils/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,27 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray:

"""
if "space" in data.dims:
validate_dims_coords(data, {"space": ["x", "y"]})
# Allow both 2D and 3D
if len(data.coords["space"]) == 2:
validate_dims_coords(data, {"space": ["x", "y"]})
elif len(data.coords["space"]) == 3:
validate_dims_coords(data, {"space": ["x", "y", "z"]})
else:
_raise_error_for_invalid_spatial_dim_length("space", 2, 3)
return xr.apply_ufunc(
np.linalg.norm,
data,
input_core_dims=[["space"]],
kwargs={"axis": -1},
)
elif "space_pol" in data.dims:
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
# Allow both 2D polar and 3D cylindrical
if len(data.coords["space_pol"]) == 2:
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
elif len(data.coords["space_pol"]) == 3:
validate_dims_coords(data, {"space_pol": ["rho", "phi", "z"]})
else:
_raise_error_for_invalid_spatial_dim_length("space_pol", 2, 3)
return data.sel(space_pol="rho", drop=True)
else:
_raise_error_for_missing_spatial_dim()
Expand Down Expand Up @@ -81,10 +93,22 @@ def convert_to_unit(data: xr.DataArray) -> xr.DataArray:

"""
if "space" in data.dims:
validate_dims_coords(data, {"space": ["x", "y"]})
# Allow both 2D and 3D
if len(data.coords["space"]) == 2:
validate_dims_coords(data, {"space": ["x", "y"]})
elif len(data.coords["space"]) == 3:
validate_dims_coords(data, {"space": ["x", "y", "z"]})
else:
_raise_error_for_invalid_spatial_dim_length("space", 2, 3)
return data / compute_norm(data)
elif "space_pol" in data.dims:
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
# Allow both 2D polar and 3D cylindrical
if len(data.coords["space_pol"]) == 2:
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
elif len(data.coords["space_pol"]) == 3:
validate_dims_coords(data, {"space_pol": ["rho", "phi", "z"]})
else:
_raise_error_for_invalid_spatial_dim_length("space_pol", 2, 3)
# Set both rho and phi values to NaN at null vectors (where rho = 0)
new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data)
# Set the rho values to 1 for non-null vectors (phi is preserved)
Expand All @@ -97,21 +121,26 @@ def convert_to_unit(data: xr.DataArray) -> xr.DataArray:


def cart2pol(data: xr.DataArray) -> xr.DataArray:
"""Transform Cartesian coordinates to polar.
"""Transform Cartesian coordinates to polar (2D) or cylindrical (3D).

Parameters
----------
data
The input data containing ``space`` as a dimension,
with ``x`` and ``y`` in the dimension coordinate.
with ``x`` and ``y`` (2D) or ``x``, ``y``, and ``z`` (3D)
in the dimension coordinate.

Returns
-------
xarray.DataArray
An xarray DataArray containing the polar coordinates
stored in the ``space_pol`` dimension, with ``rho``
and ``phi`` in the dimension coordinate. The angles
``phi`` returned are in radians, in the range ``[-pi, pi]``.
An xarray DataArray containing the polar/cylindrical coordinates
stored in the ``space_pol`` dimension:

- 2D: ``rho`` and ``phi``
- 3D: ``rho``, ``phi``, and ``z`` (cylindrical coordinates)

The angle ``phi`` is in radians, in the range ``[-pi, pi]``.
For 3D input, ``z`` is passed through unchanged.

Notes
-----
Expand All @@ -124,7 +153,7 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray:

References
----------
.. [1] ISO/IEC standard 9899:1999, Programming language C.
.. [1] ISO/IEC standard 9899:1999, "Programming language C."
.. [2] https://en.wikipedia.org/wiki/Atan2
.. [3] https://en.wikipedia.org/wiki/Signed_zero

Expand All @@ -133,63 +162,200 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray:
:obj:`numpy.arctan2`

"""
validate_dims_coords(data, {"space": ["x", "y"]})
rho = compute_norm(data)
phi = xr.apply_ufunc(
np.arctan2,
data.sel(space="y"),
data.sel(space="x"),
)
# Validate space dimension exists
if "space" not in data.dims:
raise logger.error(
ValueError("Input data must contain 'space' as a dimension.")
)

# Validate 2D or 3D input
is_3d = len(data.coords["space"]) == 3
if is_3d:
validate_dims_coords(data, {"space": ["x", "y", "z"]})
else:
validate_dims_coords(data, {"space": ["x", "y"]})

x = data.sel(space="x", drop=True)
y = data.sel(space="y", drop=True)
rho = (x**2 + y**2) ** 0.5
phi = xr.apply_ufunc(np.arctan2, y, x)

# Make all zeros in phi positive zeros
# - where rho == 0, set phi to 0
# - where rho != 0, keep the phi value from atan2
phi = xr.where(np.isclose(rho.values, 0.0, atol=1e-9), 0.0, phi)

# Build output components
components = [
rho.assign_coords({"space_pol": "rho"}),
phi.assign_coords({"space_pol": "phi"}),
]

# For 3D, pass z through unchanged
if is_3d:
z = data.sel(space="z", drop=True)
components.append(z.assign_coords({"space_pol": "z"}))

# Replace space dim with space_pol
dims = list(data.dims)
dims[dims.index("space")] = "space_pol"
return xr.concat(
[
rho.assign_coords({"space_pol": "rho"}),
phi.assign_coords({"space_pol": "phi"}),
],
dim="space_pol",
).transpose(*dims)
return xr.concat(components, dim="space_pol").transpose(*dims)


def pol2cart(data: xr.DataArray) -> xr.DataArray:
"""Transform polar coordinates to Cartesian.
"""Transform polar (2D) or cylindrical (3D) coordinates to Cartesian.

Parameters
----------
data
The input data containing ``space_pol`` as a dimension,
with ``rho`` and ``phi`` in the dimension coordinate.
with ``rho`` and ``phi`` (2D) or ``rho``, ``phi``, and ``z`` (3D)
in the dimension coordinate.

Returns
-------
xarray.DataArray
An xarray DataArray containing the Cartesian coordinates
stored in the ``space`` dimension, with ``x`` and ``y``
in the dimension coordinate.
stored in the ``space`` dimension:

- 2D: ``x`` and ``y``
- 3D: ``x``, ``y``, and ``z``

"""
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
rho = data.sel(space_pol="rho")
phi = data.sel(space_pol="phi")
# Validate space_pol dimension exists
if "space_pol" not in data.dims:
raise logger.error(
ValueError("Input data must contain 'space_pol' as a dimension.")
)

# Validate 2D or 3D input
is_3d = len(data.coords["space_pol"]) == 3
if is_3d:
validate_dims_coords(data, {"space_pol": ["rho", "phi", "z"]})
else:
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})

rho = data.sel(space_pol="rho", drop=True)
phi = data.sel(space_pol="phi", drop=True)
x = rho * np.cos(phi)
y = rho * np.sin(phi)

# Build output components
components = [
x.assign_coords({"space": "x"}),
y.assign_coords({"space": "y"}),
]

# For 3D, pass z through unchanged
if is_3d:
z = data.sel(space_pol="z", drop=True)
components.append(z.assign_coords({"space": "z"}))

# Replace space_pol dim with space
dims = list(data.dims)
dims[dims.index("space_pol")] = "space"
return xr.concat(components, dim="space").transpose(*dims)


def cart2sph(data: xr.DataArray) -> xr.DataArray:
"""Transform 3D Cartesian coordinates to spherical.

Parameters
----------
data
The input data containing ``space`` as a dimension,
with ``x``, ``y``, and ``z`` in the dimension coordinate.

Returns
-------
xarray.DataArray
An xarray DataArray containing the spherical coordinates
stored in the ``space_sph`` dimension, with ``rho``,
``azimuth``, and ``elevation`` in the dimension coordinate:

- ``rho``: radial distance (magnitude of the vector)
- ``azimuth``: angle in the x-y plane from the positive x-axis,
in radians, in the range ``[-pi, pi]``
- ``elevation``: angle from the x-y plane, in radians,
in the range ``[-pi/2, pi/2]``

See Also
--------
sph2cart : Inverse transformation from spherical to Cartesian.

"""
validate_dims_coords(data, {"space": ["x", "y", "z"]})

x = data.sel(space="x", drop=True)
y = data.sel(space="y", drop=True)
z = data.sel(space="z", drop=True)

rho = (x**2 + y**2 + z**2) ** 0.5
azimuth = xr.apply_ufunc(np.arctan2, y, x)
# Compute elevation, handling zero-magnitude vectors
elevation = xr.where(
rho > 0,
np.arcsin((z / rho).clip(-1, 1)),
0.0,
)

# Replace space dim with space_sph
dims = list(data.dims)
dims[dims.index("space")] = "space_sph"
return xr.concat(
[
rho.assign_coords({"space_sph": "rho"}),
azimuth.assign_coords({"space_sph": "azimuth"}),
elevation.assign_coords({"space_sph": "elevation"}),
],
dim="space_sph",
coords="minimal",
).transpose(*dims)


def sph2cart(data: xr.DataArray) -> xr.DataArray:
"""Transform spherical coordinates to 3D Cartesian.

Parameters
----------
data
The input data containing ``space_sph`` as a dimension,
with ``rho``, ``azimuth``, and ``elevation`` in the
dimension coordinate.

Returns
-------
xarray.DataArray
An xarray DataArray containing the Cartesian coordinates
stored in the ``space`` dimension, with ``x``, ``y``, and ``z``
in the dimension coordinate.

See Also
--------
cart2sph : Inverse transformation from Cartesian to spherical.

"""
validate_dims_coords(data, {"space_sph": ["rho", "azimuth", "elevation"]})

rho = data.sel(space_sph="rho", drop=True)
azimuth = data.sel(space_sph="azimuth", drop=True)
elevation = data.sel(space_sph="elevation", drop=True)

x = rho * np.cos(elevation) * np.cos(azimuth)
y = rho * np.cos(elevation) * np.sin(azimuth)
z = rho * np.sin(elevation)

# Replace space_sph dim with space
dims = list(data.dims)
dims[dims.index("space_sph")] = "space"
return xr.concat(
[
x.assign_coords({"space": "x"}),
y.assign_coords({"space": "y"}),
z.assign_coords({"space": "z"}),
],
dim="space",
coords="minimal",
).transpose(*dims)


Expand Down Expand Up @@ -314,3 +480,14 @@ def _raise_error_for_missing_spatial_dim() -> NoReturn:
"as dimensions."
)
)


def _raise_error_for_invalid_spatial_dim_length(
dim_name: str, *valid_lengths: int
) -> NoReturn:
lengths_str = " or ".join(str(n) for n in valid_lengths)
raise logger.error(
ValueError(
f"Dimension '{dim_name}' must have {lengths_str} coordinates."
)
)
Loading