Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions docs/source/user_guide/input_output.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,18 @@ To import {func}`load_dataset()<movement.io.load.load_dataset>`:
from movement.io import load_dataset
```

To load data from any supported format, specify the `file` path, the `source_software` that produced the file, and optionally the `fps` of the video from which the data were obtained.
To load data from any supported format, specify the `file` path.
You can either set `source_software` explicitly, or let `movement` infer it by leaving `source_software` unset or setting `source_software="auto"`.
If a DLC-style `.csv` file could plausibly come from either DeepLabCut or LightningPose, set `source_software` explicitly.
If you want to inspect what `movement` would infer, use `movement.io.infer_source_software(file)`.
Optionally, also provide `fps` to put the time coordinates into seconds (when supported).

For example, to load pose tracks from a DeepLabCut .h5 file:
```python
ds = load_dataset(
"/path/to/file.h5",
source_software="DeepLabCut",
fps=30, # Optional; time coords will be in seconds if provided, otherwise in frames
source_software="auto",
fps=30,
)
```

Expand Down
4 changes: 2 additions & 2 deletions movement/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import load_bboxes, load_poses # Trigger register_loader decorators
from .load import load_dataset, load_multiview_dataset
from .load import infer_source_software, load_dataset, load_multiview_dataset

__all__ = ["load_dataset", "load_multiview_dataset"]
__all__ = ["infer_source_software", "load_dataset", "load_multiview_dataset"]
146 changes: 143 additions & 3 deletions movement/io/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@
"VIA-tracks",
]

AutoSourceSoftware: TypeAlias = Literal["auto"]
SourceSoftwareOrAuto: TypeAlias = SourceSoftware | AutoSourceSoftware

SUPPORTED_SOURCE_SOFTWARES: set[SourceSoftware] = {
"DeepLabCut",
"SLEAP",
"LightningPose",
"Anipose",
"NWB",
"VIA-tracks",
}


class LoaderProtocol(Protocol):
"""Protocol for loader functions to be registered via ``register_loader``.
Expand Down Expand Up @@ -71,6 +83,128 @@ def __call__(


_LOADER_REGISTRY: dict[str, LoaderProtocol] = {}
_LOADER_VALIDATORS_REGISTRY: dict[str, list[type[ValidFile]]] = {}


def _dlc_vs_lp(file_path: Path) -> SourceSoftware:
"""Disambiguate between DeepLabCut and LightningPose CSV files.

Parameters
----------
file_path
Path to the CSV file.

Returns
-------
SourceSoftware
Either "DeepLabCut" or "LightningPose".

"""
scorer = ""
try:
with file_path.open(encoding="utf-8", errors="replace") as f:
header = f.readline().strip().split(",")
except OSError:
header = []

if len(header) > 1:
scorer = header[1].strip().lower()

filename = file_path.name.lower()
if (
"lightning" in scorer
or "lightning" in filename
or filename.startswith("lp_")
):
return "LightningPose"
if scorer.startswith("dlc_") or filename.startswith("dlc_"):
return "DeepLabCut"
raise logger.error(
ValueError(
"Could not uniquely infer source_software from "
f"'{file_path}'. Candidates: DeepLabCut, LightningPose. "
"Please specify `source_software` explicitly."
)
)


def infer_source_software(
file: Path | str | pynwb.file.NWBFile,
**loader_kwargs,
) -> SourceSoftware:
"""Infer the ``source_software`` from a given input file.

This helper tries the file validators registered for each built-in loader.
If multiple software candidates match, the function raises an error (or
breaks ties where possible, e.g. DeepLabCut vs LightningPose CSV files).

Parameters
----------
file
Input file path or an :class:`pynwb.file.NWBFile` object.
**loader_kwargs
Optional keyword arguments forwarded to file validators (when they
accept additional parameters, e.g. VIA-tracks `frame_regexp`).

Returns
-------
SourceSoftware
The inferred source software name.

Raises
------
ValueError
If the source software cannot be inferred or is ambiguous.

"""
if isinstance(file, pynwb.file.NWBFile):
return "NWB"

file_path = Path(file)
candidates: list[SourceSoftware] = []

for source_sw in SUPPORTED_SOURCE_SOFTWARES:
validators_list = _LOADER_VALIDATORS_REGISTRY.get(source_sw, [])
# Some loaders might be registered without validators; skip them.
if not validators_list:
continue
suffix_map = _build_suffix_map(validators_list)
try:
_validate_file(
file_path,
suffix_map,
source_sw,
loader_kwargs=loader_kwargs,
)
except (OSError, TypeError, ValueError):
continue
candidates.append(source_sw)

if not candidates:
suffix = file_path.suffix or "<no suffix>"
supported = ", ".join(sorted(SUPPORTED_SOURCE_SOFTWARES))
raise logger.error(
ValueError(
f"Could not infer source_software from file '{file_path}'. "
f"File suffix is '{suffix}'. Supported sources: {supported}."
)
)

if len(candidates) == 1:
return candidates[0]

# DeepLabCut and LightningPose share the same DLC-style CSV validator.
if set(candidates) == {"DeepLabCut", "LightningPose"}:
return _dlc_vs_lp(file_path)

candidates_str = ", ".join(sorted(candidates))
raise logger.error(
ValueError(
"Could not uniquely infer source_software from "
f"'{file_path}'. Candidates: {candidates_str}. "
"Please specify `source_software` explicitly."
)
)


def _get_validator_kwargs(
Expand Down Expand Up @@ -206,6 +340,7 @@ def register_loader(
and not isinstance(file_validators, list)
else file_validators or []
)
_LOADER_VALIDATORS_REGISTRY[source_software] = validators_list
# Map suffixes to validator classes
suffix_map = _build_suffix_map(validators_list)

Expand All @@ -231,7 +366,7 @@ def wrapper(file: TInputFile, *args, **kwargs) -> xr.Dataset:

def load_dataset(
file: Path | str | pynwb.file.NWBFile,
source_software: SourceSoftware,
source_software: SourceSoftwareOrAuto = "auto",
fps: float | None = None,
**kwargs,
) -> xr.Dataset:
Expand All @@ -250,6 +385,7 @@ def load_dataset(
will be called.
source_software
The source software of the file.
If set to ``"auto"`` (default), it is inferred from the file format.
fps
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.
Expand Down Expand Up @@ -280,6 +416,9 @@ def load_dataset(
... )

"""
if source_software == "auto":
source_software = infer_source_software(file, **kwargs)

if source_software not in _LOADER_REGISTRY:
raise logger.error(
ValueError(f"Unsupported source software: {source_software}")
Expand All @@ -297,7 +436,7 @@ def load_dataset(

def load_multiview_dataset(
file_dict: dict[str, Path | str],
source_software: SourceSoftware,
source_software: SourceSoftwareOrAuto = "auto",
fps: float | None = None,
**kwargs,
) -> xr.Dataset:
Expand All @@ -308,7 +447,8 @@ def load_multiview_dataset(
file_dict
A dict whose keys are the view names and values are the paths to load.
source_software
The source software of the file.
The source software of the files.
If set to ``"auto"`` (default), it is inferred for each file.
fps
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.
Expand Down
6 changes: 6 additions & 0 deletions tests/fixtures/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def dlc_csv_file():
return pytest.DATA_PATHS.get("DLC_single-wasp.predictions.csv")


@pytest.fixture
def lp_csv_file():
"""Return the path to a LightningPose .csv file."""
return pytest.DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv")


# ---------------- SLEAP file fixtures ----------------------------
@pytest.fixture(
params=[
Expand Down
55 changes: 55 additions & 0 deletions tests/test_unit/test_io/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@

from movement.io import load

AUTO_SOURCE_SOFTWARE_CASES = [
("dlc_csv_file", "DeepLabCut"),
("lp_csv_file", "LightningPose"),
("dlc_h5_file", "DeepLabCut"),
("sleap_slp_file", "SLEAP"),
("sleap_analysis_file", "SLEAP"),
("anipose_csv_file", "Anipose"),
("via_tracks_csv", "VIA-tracks"),
("nwbfile_object", "NWB"),
]


@define
class StubValidFile:
Expand Down Expand Up @@ -161,3 +172,47 @@ def test_build_suffix_map():
"""
suffix_map = load._build_suffix_map([StubValidFile])
assert suffix_map == {".stub": StubValidFile}


@pytest.mark.parametrize(
"file_fixture, expected_source_software", AUTO_SOURCE_SOFTWARE_CASES
)
def test_infer_source_software(
file_fixture, expected_source_software, request
):
"""Test auto-detection of source_software."""
file_path = request.getfixturevalue(file_fixture)
if file_fixture.startswith("nwb"):
file_path = file_path() # NWB fixture is a callable
inferred = load.infer_source_software(file_path)
assert inferred == expected_source_software


@pytest.mark.parametrize(
"file_fixture, expected_source_software", AUTO_SOURCE_SOFTWARE_CASES
)
def test_load_dataset_auto_detects(
file_fixture, expected_source_software, request
):
"""Test that load_dataset works with source_software='auto'."""
file_path = request.getfixturevalue(file_fixture)
if file_fixture.startswith("nwb"):
file_path = file_path() # NWB fixture is a callable

auto_ds = load.load_dataset(file_path, source_software="auto")
explicit_ds = load.load_dataset(
file_path, source_software=expected_source_software
)

xr.testing.assert_identical(auto_ds, explicit_ds)


def test_infer_source_software_raises_for_ambiguous_dlc_style_csv(
lp_csv_file, tmp_path
):
"""Test that ambiguous DLC-style CSV files require an explicit source."""
ambiguous_file = tmp_path / "mouse-face.predictions.csv"
ambiguous_file.write_bytes(lp_csv_file.read_bytes())

with pytest.raises(ValueError, match="Could not uniquely infer"):
load.infer_source_software(ambiguous_file)