Skip to content

Commit a93e6fa

Browse files
loader/Addition of validatorClass and load_functionallity
loader/Addition of validatorClass and load_functionallity Logger Info Added feat : idtracker tests added feat : optional : probablities logic fixed feat : updated expectations feat : strings seprated
1 parent 3efc8ca commit a93e6fa

File tree

5 files changed

+316
-2
lines changed

5 files changed

+316
-2
lines changed

movement/io/load.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"Anipose",
3030
"NWB",
3131
"VIA-tracks",
32+
"idtracker.ai",
3233
]
3334

3435

movement/io/load_poses.py

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from pathlib import Path
5-
from typing import Literal, cast
5+
from typing import Any, Literal, cast
66

77
import h5py
88
import numpy as np
@@ -20,6 +20,7 @@
2020
ValidDeepLabCutCSV,
2121
ValidDeepLabCutH5,
2222
ValidFile,
23+
ValidIdtrackerH5,
2324
ValidNWBFile,
2425
ValidSleapAnalysis,
2526
ValidSleapLabels,
@@ -105,6 +106,7 @@ def from_file(
105106
"LightningPose",
106107
"Anipose",
107108
"NWB",
109+
"idtracker.ai",
108110
],
109111
fps: float | None = None,
110112
**kwargs,
@@ -174,6 +176,8 @@ def from_file(
174176
return from_lp_file(file, fps)
175177
elif source_software == "Anipose":
176178
return from_anipose_file(file, fps, **kwargs)
179+
elif source_software == "idtracker.ai":
180+
return from_idtracker_file(file, fps)
177181
elif source_software == "NWB":
178182
if fps is not None:
179183
logger.warning(
@@ -271,6 +275,89 @@ def from_dlc_style_df(
271275
)
272276

273277

278+
def from_idtracker_style_dict(
279+
idtracker_data: dict[str, Any],
280+
fps: float | None = None,
281+
source_software: str = "idtracker.ai",
282+
) -> xr.Dataset:
283+
"""Create a movement poses dataset from an idtracker.ai style dictionary.
284+
285+
Parameters
286+
----------
287+
idtracker_data
288+
Dictionary containing the pose tracks ("trajectories") and optional
289+
confidence scores ("id_probabilities") extracted from an idtracker.ai
290+
file.
291+
fps
292+
The number of frames per second in the video. If None (default), it
293+
attempts to read from the dictionary's "frames_per_second" key.
294+
If still None, the ``time`` coordinates will be in frame numbers.
295+
source_software
296+
Name of the pose estimation software from which the data originate.
297+
Defaults to "idtracker.ai".
298+
299+
Returns
300+
-------
301+
xarray.Dataset
302+
``movement`` dataset containing the pose tracks, confidence scores,
303+
and associated metadata.
304+
305+
Raises
306+
------
307+
ValueError
308+
If the dictionary is missing the "trajectories" key, or
309+
if the extracted arrays have unexpected dimensions.
310+
311+
See Also
312+
--------
313+
movement.io.load_poses.from_idtracker_file
314+
315+
"""
316+
trajectories = np.asarray(idtracker_data["trajectories"])
317+
n_frames, n_individuals, _ = trajectories.shape
318+
319+
# Reshape from idtracker (frames, individuals, space)
320+
# to movement (frames, space, keypoints, individuals)
321+
# Note: idtracker does not track multiple keypoints, so keypoints=1
322+
pos_reshaped = np.moveaxis(trajectories, source=-1, destination=1)
323+
position_array = np.expand_dims(pos_reshaped, axis=2)
324+
325+
probs = idtracker_data.get("id_probabilities")
326+
327+
if probs is None:
328+
logger.info(
329+
"No identity probabilities found in the idtracker.ai data. "
330+
"Confidence scores will be set to NaN."
331+
)
332+
confidence_array = np.full((n_frames, 1, n_individuals), np.nan)
333+
else:
334+
probs = np.asarray(probs)
335+
336+
# Handle idtracker.ai edge case: some versions output probabilities
337+
# with an undocumented trailing singleton dimension (e.g., (N, M, 1))
338+
if probs.ndim == 3 and probs.shape[2] == 1:
339+
probs = probs[:, :, 0]
340+
elif probs.ndim != 2:
341+
raise logger.error(
342+
ValueError(
343+
f"Expected 2D probabilities array, got {probs.shape}."
344+
)
345+
)
346+
347+
confidence_array = np.expand_dims(probs, axis=1)
348+
349+
final_fps = (
350+
fps if fps is not None else idtracker_data.get("frames_per_second")
351+
)
352+
353+
return from_numpy(
354+
position_array=position_array,
355+
confidence_array=confidence_array,
356+
fps=final_fps,
357+
source_software=source_software,
358+
)
359+
360+
274361
@register_loader(
275362
"SLEAP", file_validators=[ValidSleapLabels, ValidSleapAnalysis]
276363
)
@@ -427,6 +514,44 @@ def from_dlc_file(file: str | Path, fps: float | None = None) -> xr.Dataset:
427514
)
428515

429516

517+
@register_loader(
518+
source_software="idtracker.ai",
519+
file_validators=[ValidIdtrackerH5],
520+
)
521+
def from_idtracker_file(
522+
file: str | Path,
523+
fps: float | None = None,
524+
) -> xr.Dataset:
525+
"""Load pose data from an idtracker.ai file.
526+
527+
Parameters
528+
----------
529+
file
530+
Path to the idtracker.ai file (e.g., a trajectories.h5 file)
531+
containing the pose tracks.
532+
fps
533+
The number of frames per second in the video. If None (default), it
534+
attempts to read from the file's metadata. If still None, the ``time``
535+
coordinates will be in frame numbers.
536+
537+
Returns
538+
-------
539+
xarray.Dataset
540+
``movement`` dataset containing the pose tracks, confidence scores,
541+
and associated metadata.
542+
543+
See Also
544+
--------
545+
movement.io.load_poses.from_idtracker_style_dict
546+
547+
"""
548+
return _ds_from_idtracker_file(
549+
valid_file=cast("ValidFile", file),
550+
source_software="idtracker.ai",
551+
fps=fps,
552+
)
553+
554+
430555
def from_multiview_files(
431556
file_dict: dict[str, Path | str],
432557
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
@@ -514,6 +639,53 @@ def _ds_from_lp_or_dlc_file(
514639
return ds
515640

516641

642+
def _ds_from_idtracker_file(
643+
valid_file: ValidFile,
644+
source_software: str = "idtracker.ai",
645+
fps: float | None = None,
646+
) -> xr.Dataset:
647+
"""Create a movement poses dataset from a validated idtracker.ai file.
648+
649+
Parameters
650+
----------
651+
valid_file
652+
The validated idtracker.ai file object.
653+
source_software
654+
The source software of the file.
655+
fps
656+
The number of frames per second in the video. If None (default),
657+
the ``time`` coordinates will be in frame numbers.
658+
659+
Returns
660+
-------
661+
xarray.Dataset
662+
``movement`` dataset containing the pose tracks, confidence scores,
663+
and associated metadata.
664+
665+
"""
666+
file_path = valid_file.file
667+
668+
if isinstance(valid_file, ValidIdtrackerH5):
669+
idtracker_data = _dict_from_idtracker_h5(file_path)
670+
else:
671+
raise logger.error(
672+
TypeError(f"Unsupported idtracker file type: {type(valid_file)}")
673+
)
674+
675+
final_fps = (
676+
fps if fps is not None else idtracker_data.get("frames_per_second")
677+
)
678+
679+
ds = from_idtracker_style_dict(
680+
idtracker_data=idtracker_data,
681+
source_software=source_software,
682+
fps=final_fps,
683+
)
684+
685+
ds.attrs["source_file"] = file_path.as_posix()
686+
return ds
687+
688+
517689
def _ds_from_sleap_analysis_file(file: Path, fps: float | None) -> xr.Dataset:
518690
"""Create a ``movement`` poses dataset from a SLEAP analysis (.h5) file.
519691
@@ -698,6 +870,23 @@ def _df_from_dlc_csv(valid_file: ValidDeepLabCutCSV) -> pd.DataFrame:
698870
return df
699871

700872

873+
def _dict_from_idtracker_h5(path: Path) -> dict[str, Any]:
874+
"""Create a dictionary of idtracker.ai pose data from an .h5 file."""
875+
with h5py.File(path, "r") as f:
876+
trajectories = f["trajectories"][:]
877+
probs = f["id_probabilities"][:] if "id_probabilities" in f else None
878+
879+
fps = f.attrs.get("frames_per_second")
880+
if isinstance(fps, (list, tuple, np.ndarray)):
881+
fps = float(fps[0])
882+
883+
return {
884+
"trajectories": trajectories,
885+
"id_probabilities": probs,
886+
"frames_per_second": fps,
887+
}
888+
889+
701890
def from_anipose_style_df(
702891
df: pd.DataFrame,
703892
fps: float | None = None,

movement/validators/files.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,21 @@ class ValidDeepLabCutH5:
409409
"""Path to the DeepLabCut .h5 file to validate."""
410410

411411

412+
@define
413+
class ValidIdtrackerH5:
414+
"""Validator for idtracker.ai .h5 output files."""
415+
416+
suffixes: ClassVar[set[str]] = {".h5"}
417+
file: Path = field(
418+
converter=Path,
419+
validator=validators.and_(
420+
_file_validator(permission="r", suffixes=suffixes),
421+
_hdf5_validator(datasets={"trajectories"}),
422+
),
423+
)
424+
"""Path to the idtracker.ai .h5 file to validate."""
425+
426+
412427
@define
413428
class ValidDeepLabCutCSV:
414429
"""Class for validating DeepLabCut-style .csv files.

tests/fixtures/files.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from unittest.mock import mock_open, patch
99

1010
import h5py
11+
import numpy as np
1112
import pytest
1213
import xarray as xr
1314
from sleap_io.io.slp import read_labels, write_labels
@@ -515,6 +516,57 @@ def anipose_csv_file():
515516
)
516517

517518

519+
# ---------------- idtracker.ai file fixtures ----------------------------
520+
@pytest.fixture
521+
def idtracker_valid_h5_file(tmp_path):
522+
"""Return the path to a valid idtracker.ai .h5 file."""
523+
file_path = tmp_path / "valid_idtracker.h5"
524+
525+
with h5py.File(file_path, "w") as f:
526+
# 10 frames, 3 individuals, 2 spatial dimensions (x, y)
527+
f.create_dataset("trajectories", data=np.ones((10, 3, 2)))
528+
f.create_dataset("id_probabilities", data=np.ones((10, 3)))
529+
f.attrs["frames_per_second"] = np.array([30.0])
530+
return file_path
531+
532+
533+
@pytest.fixture
534+
def idtracker_buggy_shape_h5_file(tmp_path):
535+
"""Return the path to an idtracker.ai .h5
536+
file with trailing singleton dimension.
537+
"""
538+
file_path = tmp_path / "buggy_idtracker.h5"
539+
with h5py.File(file_path, "w") as f:
540+
f.create_dataset("trajectories", data=np.ones((10, 3, 2)))
541+
# Buggy shape: (10, 3, 1) instead of (10, 3)
542+
f.create_dataset("id_probabilities", data=np.ones((10, 3, 1)))
543+
f.attrs["frames_per_second"] = np.array([30.0])
544+
return file_path
545+
546+
547+
@pytest.fixture
548+
def idtracker_trackless_h5_file(tmp_path):
549+
"""Return the path to an idtracker.ai .h5 file missing id_probabilities."""
550+
file_path = tmp_path / "trackless_idtracker.h5"
551+
with h5py.File(file_path, "w") as f:
552+
f.create_dataset("trajectories", data=np.ones((10, 3, 2)))
553+
# Intentionally omitting the id_probabilities dataset
554+
f.attrs["frames_per_second"] = np.array([30.0])
555+
return file_path
556+
557+
558+
@pytest.fixture(
559+
params=[
560+
"idtracker_valid_h5_file",
561+
"idtracker_buggy_shape_h5_file",
562+
"idtracker_trackless_h5_file",
563+
]
564+
)
565+
def idtracker_h5_file(request):
566+
"""Fixture to parametrize various idtracker.ai files."""
567+
return request.getfixturevalue(request.param)
568+
569+
518570
# ---------------- netCDF file fixtures ----------------------------
519571
@pytest.fixture(scope="session")
520572
def invalid_netcdf_file_missing_confidence(tmp_path_factory):

0 commit comments

Comments
 (0)