Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions movement/io/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"Anipose",
"NWB",
"VIA-tracks",
"idtracker.ai",
]


Expand Down
191 changes: 190 additions & 1 deletion movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from pathlib import Path
from typing import Literal, cast
from typing import Any, Literal, cast

import h5py
import numpy as np
Expand All @@ -20,6 +20,7 @@
ValidDeepLabCutCSV,
ValidDeepLabCutH5,
ValidFile,
ValidIdtrackerH5,
ValidNWBFile,
ValidSleapAnalysis,
ValidSleapLabels,
Expand Down Expand Up @@ -105,6 +106,7 @@
"LightningPose",
"Anipose",
"NWB",
"idtracker.ai",
],
fps: float | None = None,
**kwargs,
Expand Down Expand Up @@ -174,6 +176,8 @@
return from_lp_file(file, fps)
elif source_software == "Anipose":
return from_anipose_file(file, fps, **kwargs)
elif source_software == "idtracker.ai":

Check failure on line 179 in movement/io/load_poses.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Define a constant instead of duplicating this literal "idtracker.ai" 4 times.

See more on https://sonarcloud.io/project/issues?id=neuroinformatics-unit_movement&issues=AZ1EpX-ydPLGmyS4opNb&open=AZ1EpX-ydPLGmyS4opNb&pullRequest=944
return from_idtracker_file(file, fps)
elif source_software == "NWB":
if fps is not None:
logger.warning(
Expand Down Expand Up @@ -271,6 +275,89 @@
)


def from_idtracker_style_dict(
idtracker_data: dict[str, Any],
fps: float | None = None,
source_software: str = "idtracker.ai",
) -> xr.Dataset:
"""Create a movement poses dataset from an idtracker.ai style dictionary.

Parameters
----------
idtracker_data
Dictionary containing the pose tracks ("trajectories") and optional
confidence scores ("id_probabilities") extracted from an idtracker.ai
file.
fps
The number of frames per second in the video. If None (default), it
attempts to read from the dictionary's "frames_per_second" key.
If still None, the ``time`` coordinates will be in frame numbers.
source_software
Name of the pose estimation software from which the data originate.
Defaults to "idtracker.ai".

Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.

Raises
------
ValueError
If the dictionary is missing the "trajectories" key, or
if the extracted arrays have unexpected dimensions.

See Also
--------
movement.io.load_poses.from_idtracker_file

"""
trajectories = np.asarray(idtracker_data["trajectories"])
n_frames, n_individuals, _ = trajectories.shape

# Reshape from idtracker (frames, individuals, space)
# to movement (frames, space, keypoints, individuals)
# Note: idtracker does not track multiple keypoints, so keypoints=1
pos_reshaped = np.moveaxis(trajectories, source=-1, destination=1)
position_array = np.expand_dims(pos_reshaped, axis=2)

probs = idtracker_data.get("id_probabilities")

if probs is None:
logger.info(
"No identity probabilities found in the idtracker.ai data. "
"Confidence scores will be set to NaN."
)
confidence_array = np.full((n_frames, 1, n_individuals), np.nan)
else:
probs = np.asarray(probs)

# Handle idtracker.ai edge case: some versions output probabilities
# with an undocumented trailing singleton dimension (e.g., (N, M, 1))
if probs.ndim == 3 and probs.shape[2] == 1:
probs = probs[:, :, 0]
elif probs.ndim != 2:
raise logger.error(
ValueError(
f"Expected 2D probabilities array, got {probs.shape}."
)
)

confidence_array = np.expand_dims(probs, axis=1)

final_fps = (
fps if fps is not None else idtracker_data.get("frames_per_second")
)

return from_numpy(
position_array=position_array,
confidence_array=confidence_array,
fps=final_fps,
source_software=source_software,
)


@register_loader(
"SLEAP", file_validators=[ValidSleapLabels, ValidSleapAnalysis]
)
Expand Down Expand Up @@ -427,6 +514,44 @@
)


@register_loader(
source_software="idtracker.ai",
file_validators=[ValidIdtrackerH5],
)
def from_idtracker_file(
file: str | Path,
fps: float | None = None,
) -> xr.Dataset:
"""Load pose data from an idtracker.ai file.

Parameters
----------
file
Path to the idtracker.ai file (e.g., a trajectories.h5 file)
containing the pose tracks.
fps
The number of frames per second in the video. If None (default), it
attempts to read from the file's metadata. If still None, the ``time``
coordinates will be in frame numbers.

Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.

See Also
--------
movement.io.load_poses.from_idtracker_style_dict

"""
return _ds_from_idtracker_file(
valid_file=cast("ValidFile", file),
source_software="idtracker.ai",
fps=fps,
)


def from_multiview_files(
file_dict: dict[str, Path | str],
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
Expand Down Expand Up @@ -514,6 +639,53 @@
return ds


def _ds_from_idtracker_file(
valid_file: ValidFile,
source_software: str = "idtracker.ai",
fps: float | None = None,
) -> xr.Dataset:
"""Create a movement poses dataset from a validated idtracker.ai file.

Parameters
----------
valid_file
The validated idtracker.ai file object.
source_software
The source software of the file.
fps
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.

Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.

"""
file_path = valid_file.file

if isinstance(valid_file, ValidIdtrackerH5):
idtracker_data = _dict_from_idtracker_h5(file_path)
else:
raise logger.error(
TypeError(f"Unsupported idtracker file type: {type(valid_file)}")
)

final_fps = (
fps if fps is not None else idtracker_data.get("frames_per_second")
)

ds = from_idtracker_style_dict(
idtracker_data=idtracker_data,
source_software=source_software,
fps=final_fps,
)

ds.attrs["source_file"] = file_path.as_posix()
return ds


def _ds_from_sleap_analysis_file(file: Path, fps: float | None) -> xr.Dataset:
"""Create a ``movement`` poses dataset from a SLEAP analysis (.h5) file.

Expand Down Expand Up @@ -698,6 +870,23 @@
return df


def _dict_from_idtracker_h5(path: Path) -> dict[str, Any]:
"""Create a dictionary of idtracker.ai pose data from an .h5 file."""
with h5py.File(path, "r") as f:
trajectories = f["trajectories"][:]
probs = f["id_probabilities"][:] if "id_probabilities" in f else None

fps = f.attrs.get("frames_per_second")
if isinstance(fps, (list, tuple, np.ndarray)):
fps = float(fps[0])

return {
"trajectories": trajectories,
"id_probabilities": probs,
"frames_per_second": fps,
}


def from_anipose_style_df(
df: pd.DataFrame,
fps: float | None = None,
Expand Down
15 changes: 15 additions & 0 deletions movement/validators/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,21 @@ class ValidDeepLabCutH5:
"""Path to the DeepLabCut .h5 file to validate."""


@define
class ValidIdtrackerH5:
"""Validator for idtracker.ai .h5 output files."""

suffixes: ClassVar[set[str]] = {".h5"}
file: Path = field(
converter=Path,
validator=validators.and_(
_file_validator(permission="r", suffixes=suffixes),
_hdf5_validator(datasets={"trajectories"}),
),
)
"""Path to the idtracker.ai .h5 file to validate."""


@define
class ValidDeepLabCutCSV:
"""Class for validating DeepLabCut-style .csv files.
Expand Down
52 changes: 52 additions & 0 deletions tests/fixtures/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import mock_open, patch

import h5py
import numpy as np
import pytest
import xarray as xr
from sleap_io.io.slp import read_labels, write_labels
Expand Down Expand Up @@ -515,6 +516,57 @@ def anipose_csv_file():
)


# ---------------- idtracker.ai file fixtures ----------------------------
@pytest.fixture
def idtracker_valid_h5_file(tmp_path):
"""Return the path to a valid idtracker.ai .h5 file."""
file_path = tmp_path / "valid_idtracker.h5"

with h5py.File(file_path, "w") as f:
# 10 frames, 3 individuals, 2 spatial dimensions (x, y)
f.create_dataset("trajectories", data=np.ones((10, 3, 2)))
f.create_dataset("id_probabilities", data=np.ones((10, 3)))
f.attrs["frames_per_second"] = np.array([30.0])
return file_path


@pytest.fixture
def idtracker_buggy_shape_h5_file(tmp_path):
"""Return the path to an idtracker.ai .h5
file with trailing singleton dimension.
"""
file_path = tmp_path / "buggy_idtracker.h5"
with h5py.File(file_path, "w") as f:
f.create_dataset("trajectories", data=np.ones((10, 3, 2)))
# Buggy shape: (10, 3, 1) instead of (10, 3)
f.create_dataset("id_probabilities", data=np.ones((10, 3, 1)))
f.attrs["frames_per_second"] = np.array([30.0])
return file_path


@pytest.fixture
def idtracker_trackless_h5_file(tmp_path):
"""Return the path to an idtracker.ai .h5 file missing id_probabilities."""
file_path = tmp_path / "trackless_idtracker.h5"
with h5py.File(file_path, "w") as f:
f.create_dataset("trajectories", data=np.ones((10, 3, 2)))
# Intentionally omitting the id_probabilities dataset
f.attrs["frames_per_second"] = np.array([30.0])
return file_path


@pytest.fixture(
params=[
"idtracker_valid_h5_file",
"idtracker_buggy_shape_h5_file",
"idtracker_trackless_h5_file",
]
)
def idtracker_h5_file(request):
"""Fixture to parametrize various idtracker.ai files."""
return request.getfixturevalue(request.param)


# ---------------- netCDF file fixtures ----------------------------
@pytest.fixture(scope="session")
def invalid_netcdf_file_missing_confidence(tmp_path_factory):
Expand Down
Loading
Loading