|
2 | 2 |
|
3 | 3 | import warnings |
4 | 4 | from pathlib import Path |
5 | | -from typing import Literal, cast |
| 5 | +from typing import Any, Literal, cast |
6 | 6 |
|
7 | 7 | import h5py |
8 | 8 | import numpy as np |
|
20 | 20 | ValidDeepLabCutCSV, |
21 | 21 | ValidDeepLabCutH5, |
22 | 22 | ValidFile, |
| 23 | + ValidIdtrackerH5, |
23 | 24 | ValidNWBFile, |
24 | 25 | ValidSleapAnalysis, |
25 | 26 | ValidSleapLabels, |
@@ -105,6 +106,7 @@ def from_file( |
105 | 106 | "LightningPose", |
106 | 107 | "Anipose", |
107 | 108 | "NWB", |
| 109 | + "idtracker.ai", |
108 | 110 | ], |
109 | 111 | fps: float | None = None, |
110 | 112 | **kwargs, |
@@ -174,6 +176,8 @@ def from_file( |
174 | 176 | return from_lp_file(file, fps) |
175 | 177 | elif source_software == "Anipose": |
176 | 178 | return from_anipose_file(file, fps, **kwargs) |
| 179 | + elif source_software == "idtracker.ai": |
| 180 | + return from_idtracker_file(file, fps) |
177 | 181 | elif source_software == "NWB": |
178 | 182 | if fps is not None: |
179 | 183 | logger.warning( |
@@ -271,6 +275,89 @@ def from_dlc_style_df( |
271 | 275 | ) |
272 | 276 |
|
273 | 277 |
|
| 278 | +def from_idtracker_style_dict( |
| 279 | + idtracker_data: dict[str, Any], |
| 280 | + fps: float | None = None, |
| 281 | + source_software: str = "idtracker.ai", |
| 282 | +) -> xr.Dataset: |
| 283 | + """Create a movement poses dataset from an idtracker.ai style dictionary. |
| 284 | +
|
| 285 | + Parameters |
| 286 | + ---------- |
| 287 | + idtracker_data |
| 288 | + Dictionary containing the pose tracks ("trajectories") and optional |
| 289 | + confidence scores ("id_probabilities") extracted from an idtracker.ai |
| 290 | + file. |
| 291 | + fps |
| 292 | + The number of frames per second in the video. If None (default), it |
| 293 | + attempts to read from the dictionary's "frames_per_second" key. |
| 294 | + If still None, the ``time`` coordinates will be in frame numbers. |
| 295 | + source_software |
| 296 | + Name of the pose estimation software from which the data originate. |
| 297 | + Defaults to "idtracker.ai". |
| 298 | +
|
| 299 | + Returns |
| 300 | + ------- |
| 301 | + xarray.Dataset |
| 302 | + ``movement`` dataset containing the pose tracks, confidence scores, |
| 303 | + and associated metadata. |
| 304 | +
|
| 305 | + Raises |
| 306 | + ------ |
| 307 | + ValueError |
| 308 | + If the dictionary is missing the "trajectories" key, or |
| 309 | + if the extracted arrays have unexpected dimensions. |
| 310 | +
|
| 311 | + See Also |
| 312 | + -------- |
| 313 | + movement.io.load_poses.from_idtracker_file |
| 314 | +
|
| 315 | + """ |
| 316 | + trajectories = np.asarray(idtracker_data["trajectories"]) |
| 317 | + n_frames, n_individuals, _ = trajectories.shape |
| 318 | + |
| 319 | + # Reshape from idtracker (frames, individuals, space) |
| 320 | + # to movement (frames, space, keypoints, individuals) |
| 321 | + # Note: idtracker does not track multiple keypoints, so keypoints=1 |
| 322 | + pos_reshaped = np.moveaxis(trajectories, source=-1, destination=1) |
| 323 | + position_array = np.expand_dims(pos_reshaped, axis=2) |
| 324 | + |
| 325 | + probs = idtracker_data.get("id_probabilities") |
| 326 | + |
| 327 | + if probs is None: |
| 328 | + logger.info( |
| 329 | + "No identity probabilities found in the idtracker.ai data. " |
| 330 | + "Confidence scores will be set to NaN." |
| 331 | + ) |
| 332 | + confidence_array = np.full((n_frames, 1, n_individuals), np.nan) |
| 333 | + else: |
| 334 | + probs = np.asarray(probs) |
| 335 | + |
| 336 | + # Handle idtracker.ai edge case: some versions output probabilities |
| 337 | + # with an undocumented trailing singleton dimension (e.g., (N, M, 1)) |
| 338 | + if probs.ndim == 3 and probs.shape[2] == 1: |
| 339 | + probs = probs[:, :, 0] |
| 340 | + elif probs.ndim != 2: |
| 341 | + raise logger.error( |
| 342 | + ValueError( |
| 343 | + f"Expected 2D probabilities array, got {probs.shape}." |
| 344 | + ) |
| 345 | + ) |
| 346 | + |
| 347 | + confidence_array = np.expand_dims(probs, axis=1) |
| 348 | + |
| 349 | + final_fps = ( |
| 350 | + fps if fps is not None else idtracker_data.get("frames_per_second") |
| 351 | + ) |
| 352 | + |
| 353 | + return from_numpy( |
| 354 | + position_array=position_array, |
| 355 | + confidence_array=confidence_array, |
| 356 | + fps=final_fps, |
| 357 | + source_software=source_software, |
| 358 | + ) |
| 359 | + |
| 360 | + |
274 | 361 | @register_loader( |
275 | 362 | "SLEAP", file_validators=[ValidSleapLabels, ValidSleapAnalysis] |
276 | 363 | ) |
@@ -427,6 +514,44 @@ def from_dlc_file(file: str | Path, fps: float | None = None) -> xr.Dataset: |
427 | 514 | ) |
428 | 515 |
|
429 | 516 |
|
| 517 | +@register_loader( |
| 518 | + source_software="idtracker.ai", |
| 519 | + file_validators=[ValidIdtrackerH5], |
| 520 | +) |
| 521 | +def from_idtracker_file( |
| 522 | + file: str | Path, |
| 523 | + fps: float | None = None, |
| 524 | +) -> xr.Dataset: |
| 525 | + """Load pose data from an idtracker.ai file. |
| 526 | +
|
| 527 | + Parameters |
| 528 | + ---------- |
| 529 | + file |
| 530 | + Path to the idtracker.ai file (e.g., a trajectories.h5 file) |
| 531 | + containing the pose tracks. |
| 532 | + fps |
| 533 | + The number of frames per second in the video. If None (default), it |
| 534 | + attempts to read from the file's metadata. If still None, the ``time`` |
| 535 | + coordinates will be in frame numbers. |
| 536 | +
|
| 537 | + Returns |
| 538 | + ------- |
| 539 | + xarray.Dataset |
| 540 | + ``movement`` dataset containing the pose tracks, confidence scores, |
| 541 | + and associated metadata. |
| 542 | +
|
| 543 | + See Also |
| 544 | + -------- |
| 545 | + movement.io.load_poses.from_idtracker_style_dict |
| 546 | +
|
| 547 | + """ |
| 548 | + return _ds_from_idtracker_file( |
| 549 | + valid_file=cast("ValidFile", file), |
| 550 | + source_software="idtracker.ai", |
| 551 | + fps=fps, |
| 552 | + ) |
| 553 | + |
| 554 | + |
430 | 555 | def from_multiview_files( |
431 | 556 | file_dict: dict[str, Path | str], |
432 | 557 | source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], |
@@ -514,6 +639,53 @@ def _ds_from_lp_or_dlc_file( |
514 | 639 | return ds |
515 | 640 |
|
516 | 641 |
|
| 642 | +def _ds_from_idtracker_file( |
| 643 | + valid_file: ValidFile, |
| 644 | + source_software: str = "idtracker.ai", |
| 645 | + fps: float | None = None, |
| 646 | +) -> xr.Dataset: |
| 647 | + """Create a movement poses dataset from a validated idtracker.ai file. |
| 648 | +
|
| 649 | + Parameters |
| 650 | + ---------- |
| 651 | + valid_file |
| 652 | + The validated idtracker.ai file object. |
| 653 | + source_software |
| 654 | + The source software of the file. |
| 655 | + fps |
| 656 | + The number of frames per second in the video. If None (default), |
| 657 | + the ``time`` coordinates will be in frame numbers. |
| 658 | +
|
| 659 | + Returns |
| 660 | + ------- |
| 661 | + xarray.Dataset |
| 662 | + ``movement`` dataset containing the pose tracks, confidence scores, |
| 663 | + and associated metadata. |
| 664 | +
|
| 665 | + """ |
| 666 | + file_path = valid_file.file |
| 667 | + |
| 668 | + if isinstance(valid_file, ValidIdtrackerH5): |
| 669 | + idtracker_data = _dict_from_idtracker_h5(file_path) |
| 670 | + else: |
| 671 | + raise logger.error( |
| 672 | + TypeError(f"Unsupported idtracker file type: {type(valid_file)}") |
| 673 | + ) |
| 674 | + |
| 675 | + final_fps = ( |
| 676 | + fps if fps is not None else idtracker_data.get("frames_per_second") |
| 677 | + ) |
| 678 | + |
| 679 | + ds = from_idtracker_style_dict( |
| 680 | + idtracker_data=idtracker_data, |
| 681 | + source_software=source_software, |
| 682 | + fps=final_fps, |
| 683 | + ) |
| 684 | + |
| 685 | + ds.attrs["source_file"] = file_path.as_posix() |
| 686 | + return ds |
| 687 | + |
| 688 | + |
517 | 689 | def _ds_from_sleap_analysis_file(file: Path, fps: float | None) -> xr.Dataset: |
518 | 690 | """Create a ``movement`` poses dataset from a SLEAP analysis (.h5) file. |
519 | 691 |
|
@@ -698,6 +870,23 @@ def _df_from_dlc_csv(valid_file: ValidDeepLabCutCSV) -> pd.DataFrame: |
698 | 870 | return df |
699 | 871 |
|
700 | 872 |
|
| 873 | +def _dict_from_idtracker_h5(path: Path) -> dict[str, Any]: |
| 874 | + """Create a dictionary of idtracker.ai pose data from an .h5 file.""" |
| 875 | + with h5py.File(path, "r") as f: |
| 876 | + trajectories = f["trajectories"][:] |
| 877 | + probs = f["id_probabilities"][:] if "id_probabilities" in f else None |
| 878 | + |
| 879 | + fps = f.attrs.get("frames_per_second") |
| 880 | + if isinstance(fps, (list, tuple, np.ndarray)): |
| 881 | + fps = float(fps[0]) |
| 882 | + |
| 883 | + return { |
| 884 | + "trajectories": trajectories, |
| 885 | + "id_probabilities": probs, |
| 886 | + "frames_per_second": fps, |
| 887 | + } |
| 888 | + |
| 889 | + |
701 | 890 | def from_anipose_style_df( |
702 | 891 | df: pd.DataFrame, |
703 | 892 | fps: float | None = None, |
|
0 commit comments