diff --git a/docs/source/user_guide/input_output.md b/docs/source/user_guide/input_output.md index ffb1da250..72e65d0d4 100644 --- a/docs/source/user_guide/input_output.md +++ b/docs/source/user_guide/input_output.md @@ -52,14 +52,18 @@ To import {func}`load_dataset()`: from movement.io import load_dataset ``` -To load data from any supported format, specify the `file` path, the `source_software` that produced the file, and optionally the `fps` of the video from which the data were obtained. +To load data from any supported format, specify the `file` path. +You can either set `source_software` explicitly, or let `movement` infer it by leaving `source_software` unset or setting `source_software="auto"`. +If a DLC-style `.csv` file could plausibly come from either DeepLabCut or LightningPose, set `source_software` explicitly. +If you want to inspect what `movement` would infer, use `movement.io.infer_source_software(file)`. +Optionally, also provide `fps` to put the time coordinates into seconds (when supported). For example, to load pose tracks from a DeepLabCut .h5 file: ```python ds = load_dataset( "/path/to/file.h5", - source_software="DeepLabCut", - fps=30, # Optional; time coords will be in seconds if provided, otherwise in frames + source_software="auto", + fps=30, ) ``` diff --git a/movement/io/__init__.py b/movement/io/__init__.py index 1082fd713..5b0bc305b 100644 --- a/movement/io/__init__.py +++ b/movement/io/__init__.py @@ -1,4 +1,4 @@ from . import load_bboxes, load_poses # Trigger register_loader decorators -from .load import load_dataset, load_multiview_dataset +from .load import infer_source_software, load_dataset, load_multiview_dataset -__all__ = ["load_dataset", "load_multiview_dataset"] +__all__ = ["infer_source_software", "load_dataset", "load_multiview_dataset"] diff --git a/movement/io/load.py b/movement/io/load.py index cad9752cb..dd37e5d04 100644 --- a/movement/io/load.py +++ b/movement/io/load.py @@ -31,6 +31,18 @@ "VIA-tracks", ] +AutoSourceSoftware: TypeAlias = Literal["auto"] +SourceSoftwareOrAuto: TypeAlias = SourceSoftware | AutoSourceSoftware + +SUPPORTED_SOURCE_SOFTWARES: set[SourceSoftware] = { + "DeepLabCut", + "SLEAP", + "LightningPose", + "Anipose", + "NWB", + "VIA-tracks", +} + class LoaderProtocol(Protocol): """Protocol for loader functions to be registered via ``register_loader``. @@ -71,6 +83,128 @@ def __call__( _LOADER_REGISTRY: dict[str, LoaderProtocol] = {} +_LOADER_VALIDATORS_REGISTRY: dict[str, list[type[ValidFile]]] = {} + + +def _dlc_vs_lp(file_path: Path) -> SourceSoftware: + """Disambiguate between DeepLabCut and LightningPose CSV files. + + Parameters + ---------- + file_path + Path to the CSV file. + + Returns + ------- + SourceSoftware + Either "DeepLabCut" or "LightningPose". + + """ + scorer = "" + try: + with file_path.open(encoding="utf-8", errors="replace") as f: + header = f.readline().strip().split(",") + except OSError: + header = [] + + if len(header) > 1: + scorer = header[1].strip().lower() + + filename = file_path.name.lower() + if ( + "lightning" in scorer + or "lightning" in filename + or filename.startswith("lp_") + ): + return "LightningPose" + if scorer.startswith("dlc_") or filename.startswith("dlc_"): + return "DeepLabCut" + raise logger.error( + ValueError( + "Could not uniquely infer source_software from " + f"'{file_path}'. Candidates: DeepLabCut, LightningPose. " + "Please specify `source_software` explicitly." + ) + ) + + +def infer_source_software( + file: Path | str | pynwb.file.NWBFile, + **loader_kwargs, +) -> SourceSoftware: + """Infer the ``source_software`` from a given input file. + + This helper tries the file validators registered for each built-in loader. + If multiple software candidates match, the function raises an error (or + breaks ties where possible, e.g. DeepLabCut vs LightningPose CSV files). + + Parameters + ---------- + file + Input file path or an :class:`pynwb.file.NWBFile` object. + **loader_kwargs + Optional keyword arguments forwarded to file validators (when they + accept additional parameters, e.g. VIA-tracks `frame_regexp`). + + Returns + ------- + SourceSoftware + The inferred source software name. + + Raises + ------ + ValueError + If the source software cannot be inferred or is ambiguous. + + """ + if isinstance(file, pynwb.file.NWBFile): + return "NWB" + + file_path = Path(file) + candidates: list[SourceSoftware] = [] + + for source_sw in SUPPORTED_SOURCE_SOFTWARES: + validators_list = _LOADER_VALIDATORS_REGISTRY.get(source_sw, []) + # Some loaders might be registered without validators; skip them. + if not validators_list: + continue + suffix_map = _build_suffix_map(validators_list) + try: + _validate_file( + file_path, + suffix_map, + source_sw, + loader_kwargs=loader_kwargs, + ) + except (OSError, TypeError, ValueError): + continue + candidates.append(source_sw) + + if not candidates: + suffix = file_path.suffix or "" + supported = ", ".join(sorted(SUPPORTED_SOURCE_SOFTWARES)) + raise logger.error( + ValueError( + f"Could not infer source_software from file '{file_path}'. " + f"File suffix is '{suffix}'. Supported sources: {supported}." + ) + ) + + if len(candidates) == 1: + return candidates[0] + + # DeepLabCut and LightningPose share the same DLC-style CSV validator. + if set(candidates) == {"DeepLabCut", "LightningPose"}: + return _dlc_vs_lp(file_path) + + candidates_str = ", ".join(sorted(candidates)) + raise logger.error( + ValueError( + "Could not uniquely infer source_software from " + f"'{file_path}'. Candidates: {candidates_str}. " + "Please specify `source_software` explicitly." + ) + ) def _get_validator_kwargs( @@ -206,6 +340,7 @@ def register_loader( and not isinstance(file_validators, list) else file_validators or [] ) + _LOADER_VALIDATORS_REGISTRY[source_software] = validators_list # Map suffixes to validator classes suffix_map = _build_suffix_map(validators_list) @@ -231,7 +366,7 @@ def wrapper(file: TInputFile, *args, **kwargs) -> xr.Dataset: def load_dataset( file: Path | str | pynwb.file.NWBFile, - source_software: SourceSoftware, + source_software: SourceSoftwareOrAuto = "auto", fps: float | None = None, **kwargs, ) -> xr.Dataset: @@ -250,6 +385,7 @@ def load_dataset( will be called. source_software The source software of the file. + If set to ``"auto"`` (default), it is inferred from the file format. fps The number of frames per second in the video. If None (default), the ``time`` coordinates will be in frame numbers. @@ -280,6 +416,9 @@ def load_dataset( ... ) """ + if source_software == "auto": + source_software = infer_source_software(file, **kwargs) + if source_software not in _LOADER_REGISTRY: raise logger.error( ValueError(f"Unsupported source software: {source_software}") @@ -297,7 +436,7 @@ def load_dataset( def load_multiview_dataset( file_dict: dict[str, Path | str], - source_software: SourceSoftware, + source_software: SourceSoftwareOrAuto = "auto", fps: float | None = None, **kwargs, ) -> xr.Dataset: @@ -308,7 +447,8 @@ def load_multiview_dataset( file_dict A dict whose keys are the view names and values are the paths to load. source_software - The source software of the file. + The source software of the files. + If set to ``"auto"`` (default), it is inferred for each file. fps The number of frames per second in the video. If None (default), the ``time`` coordinates will be in frame numbers. diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 88a271246..5249551d6 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -211,6 +211,12 @@ def dlc_csv_file(): return pytest.DATA_PATHS.get("DLC_single-wasp.predictions.csv") +@pytest.fixture +def lp_csv_file(): + """Return the path to a LightningPose .csv file.""" + return pytest.DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv") + + # ---------------- SLEAP file fixtures ---------------------------- @pytest.fixture( params=[ diff --git a/tests/test_unit/test_io/test_load.py b/tests/test_unit/test_io/test_load.py index 39b8554e2..8f502b2ec 100644 --- a/tests/test_unit/test_io/test_load.py +++ b/tests/test_unit/test_io/test_load.py @@ -9,6 +9,17 @@ from movement.io import load +AUTO_SOURCE_SOFTWARE_CASES = [ + ("dlc_csv_file", "DeepLabCut"), + ("lp_csv_file", "LightningPose"), + ("dlc_h5_file", "DeepLabCut"), + ("sleap_slp_file", "SLEAP"), + ("sleap_analysis_file", "SLEAP"), + ("anipose_csv_file", "Anipose"), + ("via_tracks_csv", "VIA-tracks"), + ("nwbfile_object", "NWB"), +] + @define class StubValidFile: @@ -161,3 +172,47 @@ def test_build_suffix_map(): """ suffix_map = load._build_suffix_map([StubValidFile]) assert suffix_map == {".stub": StubValidFile} + + +@pytest.mark.parametrize( + "file_fixture, expected_source_software", AUTO_SOURCE_SOFTWARE_CASES +) +def test_infer_source_software( + file_fixture, expected_source_software, request +): + """Test auto-detection of source_software.""" + file_path = request.getfixturevalue(file_fixture) + if file_fixture.startswith("nwb"): + file_path = file_path() # NWB fixture is a callable + inferred = load.infer_source_software(file_path) + assert inferred == expected_source_software + + +@pytest.mark.parametrize( + "file_fixture, expected_source_software", AUTO_SOURCE_SOFTWARE_CASES +) +def test_load_dataset_auto_detects( + file_fixture, expected_source_software, request +): + """Test that load_dataset works with source_software='auto'.""" + file_path = request.getfixturevalue(file_fixture) + if file_fixture.startswith("nwb"): + file_path = file_path() # NWB fixture is a callable + + auto_ds = load.load_dataset(file_path, source_software="auto") + explicit_ds = load.load_dataset( + file_path, source_software=expected_source_software + ) + + xr.testing.assert_identical(auto_ds, explicit_ds) + + +def test_infer_source_software_raises_for_ambiguous_dlc_style_csv( + lp_csv_file, tmp_path +): + """Test that ambiguous DLC-style CSV files require an explicit source.""" + ambiguous_file = tmp_path / "mouse-face.predictions.csv" + ambiguous_file.write_bytes(lp_csv_file.read_bytes()) + + with pytest.raises(ValueError, match="Could not uniquely infer"): + load.infer_source_software(ambiguous_file)