diff --git a/movement/io/load.py b/movement/io/load.py index cad9752cb..19ea3f919 100644 --- a/movement/io/load.py +++ b/movement/io/load.py @@ -29,6 +29,7 @@ "Anipose", "NWB", "VIA-tracks", + "idtracker.ai", ] diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index c76715b63..a2539bba9 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -2,7 +2,7 @@ import warnings from pathlib import Path -from typing import Literal, cast +from typing import Any, Literal, cast import h5py import numpy as np @@ -20,6 +20,7 @@ ValidDeepLabCutCSV, ValidDeepLabCutH5, ValidFile, + ValidIdtrackerH5, ValidNWBFile, ValidSleapAnalysis, ValidSleapLabels, @@ -105,6 +106,7 @@ def from_file( "LightningPose", "Anipose", "NWB", + "idtracker.ai", ], fps: float | None = None, **kwargs, @@ -174,6 +176,8 @@ def from_file( return from_lp_file(file, fps) elif source_software == "Anipose": return from_anipose_file(file, fps, **kwargs) + elif source_software == "idtracker.ai": + return from_idtracker_file(file, fps) elif source_software == "NWB": if fps is not None: logger.warning( @@ -271,6 +275,89 @@ def from_dlc_style_df( ) +def from_idtracker_style_dict( + idtracker_data: dict[str, Any], + fps: float | None = None, + source_software: str = "idtracker.ai", +) -> xr.Dataset: + """Create a movement poses dataset from an idtracker.ai style dictionary. + + Parameters + ---------- + idtracker_data + Dictionary containing the pose tracks ("trajectories") and optional + confidence scores ("id_probabilities") extracted from an idtracker.ai + file. + fps + The number of frames per second in the video. If None (default), it + attempts to read from the dictionary's "frames_per_second" key. + If still None, the ``time`` coordinates will be in frame numbers. + source_software + Name of the pose estimation software from which the data originate. + Defaults to "idtracker.ai". + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. + + Raises + ------ + ValueError + If the dictionary is missing the "trajectories" key, or + if the extracted arrays have unexpected dimensions. + + See Also + -------- + movement.io.load_poses.from_idtracker_file + + """ + trajectories = np.asarray(idtracker_data["trajectories"]) + n_frames, n_individuals, _ = trajectories.shape + + # Reshape from idtracker (frames, individuals, space) + # to movement (frames, space, keypoints, individuals) + # Note: idtracker does not track multiple keypoints, so keypoints=1 + pos_reshaped = np.moveaxis(trajectories, source=-1, destination=1) + position_array = np.expand_dims(pos_reshaped, axis=2) + + probs = idtracker_data.get("id_probabilities") + + if probs is None: + logger.info( + "No identity probabilities found in the idtracker.ai data. " + "Confidence scores will be set to NaN." + ) + confidence_array = np.full((n_frames, 1, n_individuals), np.nan) + else: + probs = np.asarray(probs) + + # Handle idtracker.ai edge case: some versions output probabilities + # with an undocumented trailing singleton dimension (e.g., (N, M, 1)) + if probs.ndim == 3 and probs.shape[2] == 1: + probs = probs[:, :, 0] + elif probs.ndim != 2: + raise logger.error( + ValueError( + f"Expected 2D probabilities array, got {probs.shape}." + ) + ) + + confidence_array = np.expand_dims(probs, axis=1) + + final_fps = ( + fps if fps is not None else idtracker_data.get("frames_per_second") + ) + + return from_numpy( + position_array=position_array, + confidence_array=confidence_array, + fps=final_fps, + source_software=source_software, + ) + + @register_loader( "SLEAP", file_validators=[ValidSleapLabels, ValidSleapAnalysis] ) @@ -427,6 +514,44 @@ def from_dlc_file(file: str | Path, fps: float | None = None) -> xr.Dataset: ) +@register_loader( + source_software="idtracker.ai", + file_validators=[ValidIdtrackerH5], +) +def from_idtracker_file( + file: str | Path, + fps: float | None = None, +) -> xr.Dataset: + """Load pose data from an idtracker.ai file. + + Parameters + ---------- + file + Path to the idtracker.ai file (e.g., a trajectories.h5 file) + containing the pose tracks. + fps + The number of frames per second in the video. If None (default), it + attempts to read from the file's metadata. If still None, the ``time`` + coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. + + See Also + -------- + movement.io.load_poses.from_idtracker_style_dict + + """ + return _ds_from_idtracker_file( + valid_file=cast("ValidFile", file), + source_software="idtracker.ai", + fps=fps, + ) + + def from_multiview_files( file_dict: dict[str, Path | str], source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], @@ -514,6 +639,53 @@ def _ds_from_lp_or_dlc_file( return ds +def _ds_from_idtracker_file( + valid_file: ValidFile, + source_software: str = "idtracker.ai", + fps: float | None = None, +) -> xr.Dataset: + """Create a movement poses dataset from a validated idtracker.ai file. + + Parameters + ---------- + valid_file + The validated idtracker.ai file object. + source_software + The source software of the file. + fps + The number of frames per second in the video. If None (default), + the ``time`` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. + + """ + file_path = valid_file.file + + if isinstance(valid_file, ValidIdtrackerH5): + idtracker_data = _dict_from_idtracker_h5(file_path) + else: + raise logger.error( + TypeError(f"Unsupported idtracker file type: {type(valid_file)}") + ) + + final_fps = ( + fps if fps is not None else idtracker_data.get("frames_per_second") + ) + + ds = from_idtracker_style_dict( + idtracker_data=idtracker_data, + source_software=source_software, + fps=final_fps, + ) + + ds.attrs["source_file"] = file_path.as_posix() + return ds + + def _ds_from_sleap_analysis_file(file: Path, fps: float | None) -> xr.Dataset: """Create a ``movement`` poses dataset from a SLEAP analysis (.h5) file. @@ -698,6 +870,23 @@ def _df_from_dlc_csv(valid_file: ValidDeepLabCutCSV) -> pd.DataFrame: return df +def _dict_from_idtracker_h5(path: Path) -> dict[str, Any]: + """Create a dictionary of idtracker.ai pose data from an .h5 file.""" + with h5py.File(path, "r") as f: + trajectories = f["trajectories"][:] + probs = f["id_probabilities"][:] if "id_probabilities" in f else None + + fps = f.attrs.get("frames_per_second") + if isinstance(fps, (list, tuple, np.ndarray)): + fps = float(fps[0]) + + return { + "trajectories": trajectories, + "id_probabilities": probs, + "frames_per_second": fps, + } + + def from_anipose_style_df( df: pd.DataFrame, fps: float | None = None, diff --git a/movement/validators/files.py b/movement/validators/files.py index e75baf8ac..7f6f1b998 100644 --- a/movement/validators/files.py +++ b/movement/validators/files.py @@ -409,6 +409,21 @@ class ValidDeepLabCutH5: """Path to the DeepLabCut .h5 file to validate.""" +@define +class ValidIdtrackerH5: + """Validator for idtracker.ai .h5 output files.""" + + suffixes: ClassVar[set[str]] = {".h5"} + file: Path = field( + converter=Path, + validator=validators.and_( + _file_validator(permission="r", suffixes=suffixes), + _hdf5_validator(datasets={"trajectories"}), + ), + ) + """Path to the idtracker.ai .h5 file to validate.""" + + @define class ValidDeepLabCutCSV: """Class for validating DeepLabCut-style .csv files. diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 88a271246..f9032d065 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -8,6 +8,7 @@ from unittest.mock import mock_open, patch import h5py +import numpy as np import pytest import xarray as xr from sleap_io.io.slp import read_labels, write_labels @@ -515,6 +516,57 @@ def anipose_csv_file(): ) +# ---------------- idtracker.ai file fixtures ---------------------------- +@pytest.fixture +def idtracker_valid_h5_file(tmp_path): + """Return the path to a valid idtracker.ai .h5 file.""" + file_path = tmp_path / "valid_idtracker.h5" + + with h5py.File(file_path, "w") as f: + # 10 frames, 3 individuals, 2 spatial dimensions (x, y) + f.create_dataset("trajectories", data=np.ones((10, 3, 2))) + f.create_dataset("id_probabilities", data=np.ones((10, 3))) + f.attrs["frames_per_second"] = np.array([30.0]) + return file_path + + +@pytest.fixture +def idtracker_buggy_shape_h5_file(tmp_path): + """Return the path to an idtracker.ai .h5 + file with trailing singleton dimension. + """ + file_path = tmp_path / "buggy_idtracker.h5" + with h5py.File(file_path, "w") as f: + f.create_dataset("trajectories", data=np.ones((10, 3, 2))) + # Buggy shape: (10, 3, 1) instead of (10, 3) + f.create_dataset("id_probabilities", data=np.ones((10, 3, 1))) + f.attrs["frames_per_second"] = np.array([30.0]) + return file_path + + +@pytest.fixture +def idtracker_trackless_h5_file(tmp_path): + """Return the path to an idtracker.ai .h5 file missing id_probabilities.""" + file_path = tmp_path / "trackless_idtracker.h5" + with h5py.File(file_path, "w") as f: + f.create_dataset("trajectories", data=np.ones((10, 3, 2))) + # Intentionally omitting the id_probabilities dataset + f.attrs["frames_per_second"] = np.array([30.0]) + return file_path + + +@pytest.fixture( + params=[ + "idtracker_valid_h5_file", + "idtracker_buggy_shape_h5_file", + "idtracker_trackless_h5_file", + ] +) +def idtracker_h5_file(request): + """Fixture to parametrize various idtracker.ai files.""" + return request.getfixturevalue(request.param) + + # ---------------- netCDF file fixtures ---------------------------- @pytest.fixture(scope="session") def invalid_netcdf_file_missing_confidence(tmp_path_factory): diff --git a/tests/test_unit/test_io/test_load_poses.py b/tests/test_unit/test_io/test_load_poses.py index b44b9cdc7..e0fa419f4 100644 --- a/tests/test_unit/test_io/test_load_poses.py +++ b/tests/test_unit/test_io/test_load_poses.py @@ -245,6 +245,54 @@ def test_load_from_anipose_file(): ] +def test_load_from_idtracker_file(idtracker_h5_file, helpers): + ds = load_poses.from_idtracker_file(idtracker_h5_file) + expected_values = { + **expected_values_poses, + "source_software": "idtracker.ai", + "file_path": idtracker_h5_file, + "fps": 30.0, + } + helpers.assert_valid_dataset(ds, expected_values) + + +def test_load_from_idtracker_style_dict(helpers): + """Test loading pose tracks from an idtracker.ai style dictionary.""" + idtracker_dict = { + "trajectories": np.ones( + (10, 3, 2) + ), # 10 frames, 3 individuals, 2 spatial (x, y) + "id_probabilities": np.ones((10, 3)), + "frames_per_second": 30.0, + } + + # Pass it directly to our formatter + ds = load_poses.from_idtracker_style_dict(idtracker_dict) + + # Verify it creates a perfect movement dataset + expected_values = { + **expected_values_poses, + "source_software": "idtracker.ai", + "fps": 30.0, + } + helpers.assert_valid_dataset(ds, expected_values) + + +def test_load_idtracker_without_probs(idtracker_trackless_h5_file): + """Test that loading an idtracker.ai file without identity probabilities + returns a dataset with NaN confidence scores and default individual names. + """ + ds = load_poses.from_idtracker_file(idtracker_trackless_h5_file) + + # 1. Check if default individual names were assigned + # (our fixture has 3 individuals) + assert ds.individuals.values.tolist() == ["id_0", "id_1", "id_2"] + + # 2. Check if confidence scores are NaN + # (since no probabilities were provided) + assert np.isnan(ds.confidence.values).all() + + @pytest.mark.parametrize("kwargs", [{}, {"rate": 10.0, "starting_time": 0.0}]) @pytest.mark.parametrize("input_type", ["nwb_file", "nwbfile_object"]) def test_load_from_nwb_file(input_type, kwargs, request): @@ -275,7 +323,15 @@ def test_load_from_nwb_file(input_type, kwargs, request): @pytest.mark.filterwarnings("ignore:.*is deprecated:DeprecationWarning") @pytest.mark.parametrize( "source_software", - ["DeepLabCut", "SLEAP", "LightningPose", "Anipose", "NWB", "Unknown"], + [ + "DeepLabCut", + "SLEAP", + "LightningPose", + "Anipose", + "NWB", + "idtracker.ai", + "Unknown", + ], ) @pytest.mark.parametrize("fps", [None, 30, 60.0]) def test_from_file_delegates_correctly(source_software, fps, caplog): @@ -288,6 +344,7 @@ def test_from_file_delegates_correctly(source_software, fps, caplog): "LightningPose": "movement.io.load_poses.from_lp_file", "Anipose": "movement.io.load_poses.from_anipose_file", "NWB": "movement.io.load_poses.from_nwb_file", + "idtracker.ai": "movement.io.load_poses.from_idtracker_file", } if source_software == "Unknown": with pytest.raises(ValueError, match="Unsupported source"):