diff --git a/movement/io/load.py b/movement/io/load.py index cad9752cb..ee0e3281f 100644 --- a/movement/io/load.py +++ b/movement/io/load.py @@ -29,6 +29,7 @@ "Anipose", "NWB", "VIA-tracks", + "BVH", ] diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index c76715b63..5e0eccf3f 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Literal, cast +import bvhio import h5py import numpy as np import pandas as pd @@ -17,6 +18,7 @@ from movement.validators.datasets import ValidPosesInputs from movement.validators.files import ( ValidAniposeCSV, + ValidBVHFile, ValidDeepLabCutCSV, ValidDeepLabCutH5, ValidFile, @@ -105,6 +107,7 @@ def from_file( "LightningPose", "Anipose", "NWB", + "BVH", ], fps: float | None = None, **kwargs, @@ -174,6 +177,8 @@ def from_file( return from_lp_file(file, fps) elif source_software == "Anipose": return from_anipose_file(file, fps, **kwargs) + elif source_software == "BVH": + return from_bvh_file(file, fps) elif source_software == "NWB": if fps is not None: logger.warning( @@ -943,3 +948,69 @@ def _ds_from_nwb_object( return xr.merge( single_keypoint_datasets, join="outer", compat="no_conflicts" ) + + +@register_loader("BVH", file_validators=[ValidBVHFile]) +def from_bvh_file(file: str | Path, fps: float | None = None) -> xr.Dataset: + """Create a ``movement`` poses dataset from a BVH file. + + Parameters + ---------- + file + Path to the file containing the poses in .bvh format. + fps + The number of frames per second in the video. If None (default), + the fps value will be computed using the BVH file frame time. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, + and associated metadata. + + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_bvh_file("path/to/file.bvh") + + """ + valid_file = cast("ValidFile", file) + file_path = valid_file.file + + # Read the BVH hierarchy from the file + bvh = bvhio.readAsBvh(file_path) + frame_time = bvh.FrameTime + root = bvhio.readAsHierarchy(file_path) + + # Collect joint names + joint_names = [joint.Name for joint, _, _ in root.layout()] + n_keypoints = len(joint_names) + + # Frame range and fps + first, last = root.getKeyframeRange() + n_frames = last - first + 1 + fps = fps or (1 / frame_time) + + position_array = np.zeros( + (n_frames, 3, n_keypoints, 1), dtype=np.float32 + ) # 1 for single individual + + for f in range(n_frames): + root.loadPose(f) + for j_idx, (joint, _, _) in enumerate(root.layout()): + pos = joint.PositionWorld + position_array[f, :, j_idx, 0] = [pos.x, pos.y, pos.z] + + ds = from_numpy( + position_array=position_array, + confidence_array=None, # BVH is marker based, so maybe confidence= 1? + individual_names=["individual_0"], + keypoint_names=joint_names, + fps=fps, + source_software="BVH", + ) + + ds.attrs["source_file"] = file_path.as_posix() + ds.attrs["source_software"] = "BVH" + logger.info(f"Loaded poses from {file_path.name}") + return ds diff --git a/movement/validators/files.py b/movement/validators/files.py index e75baf8ac..290279962 100644 --- a/movement/validators/files.py +++ b/movement/validators/files.py @@ -1000,3 +1000,38 @@ class ValidROICollectionGeoJSON: data: dict = field(init=False, factory=dict) """Parsed JSON data from the file, available after validation.""" + + +@define +class ValidBVHFile: + """Class for validating BVH (Biovision Hierarchy) files.""" + + suffixes: ClassVar[set[str]] = {".bvh"} + """Expected suffix(es) for the file.""" + + file: Path = field( + converter=Path, + validator=_file_validator(permission="r", suffixes=suffixes), + ) + """Path to the BVH file to validate.""" + + @file.validator + def _file_contains_valid_bvh_structure(self, attribute, value): + """Ensure the BVH file has valid structure.""" + try: + with open(value) as f: + content = f.read() + except Exception as e: + raise logger.error( + ValueError(f"Could not read BVH file {value}: {e}") + ) from e + + if not content.strip().startswith("HIERARCHY"): + raise logger.error( + ValueError("BVH file must start with HIERARCHY keyword. ") + ) + + if "MOTION" not in content: + raise logger.error( + ValueError("BVH file must contain MOTION section") + ) diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 88a271246..1a5bb6096 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -556,3 +556,125 @@ def invalid_dstype_netcdf_file(tmp_path_factory): ds.to_netcdf(invalid_path) yield str(invalid_path) + + +# ---------------- BVH file fixtures ---------------------------- +SIMPLE_BVH = ( + "HIERARCHY\n" + "ROOT Armature\n" + "{\n" + " OFFSET 0.00 0.00 0.00\n" + " CHANNELS 6 Xposition Yposition Zposition" + " Xrotation Yrotation Zrotation\n" + " JOINT Bone1\n" + " {\n" + " OFFSET 1.00 2.00 0.00\n" + " CHANNELS 3 Xrotation Yrotation Zrotation\n" + " End Site\n" + " {\n" + " OFFSET 1.00 0.00 0.00\n" + " }\n" + " }\n" + "}\n" + "MOTION\n" + "Frames: 2\n" + "Frame Time: 0.05\n" + "0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00\n" + "0.50 1.00 0.25 5.00 2.50 0.00 2.50 1.00 0.00\n" +) + +COMPLEX_BVH = ( + "HIERARCHY\n" + "ROOT Root\n" + "{\n" + " OFFSET 0.00 0.00 0.00\n" + " CHANNELS 6 Xposition Yposition Zposition" + " Xrotation Yrotation Zrotation\n" + " JOINT Torso\n" + " {\n" + " OFFSET 0.00 5.00 0.00\n" + " CHANNELS 3 Xrotation Yrotation Zrotation\n" + " JOINT Neck\n" + " {\n" + " OFFSET 0.00 3.00 0.00\n" + " CHANNELS 3 Xrotation Yrotation Zrotation\n" + " End Site\n" + " {\n" + " OFFSET 0.00 2.00 0.00\n" + " }\n" + " }\n" + " }\n" + " JOINT LeftArm\n" + " {\n" + " OFFSET 2.00 4.00 0.00\n" + " CHANNELS 3 Xrotation Yrotation Zrotation\n" + " End Site\n" + " {\n" + " OFFSET 3.00 0.00 0.00\n" + " }\n" + " }\n" + " JOINT RightArm\n" + " {\n" + " OFFSET -2.00 4.00 0.00\n" + " CHANNELS 3 Xrotation Yrotation Zrotation\n" + " End Site\n" + " {\n" + " OFFSET -3.00 0.00 0.00\n" + " }\n" + " }\n" + "}\n" + "MOTION\n" + "Frames: 3\n" + "Frame Time: 0.05\n" + "0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00" + " 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00\n" + "0.50 1.00 0.25 5.00 2.50 0.00 2.50 1.00 0.00" + " 3.00 1.50 0.00 -3.00 1.50 0.00 0.00 0.00 0.00\n" + "1.00 2.00 0.50 10.00 5.00 0.00 5.00 2.00 0.00" + " 6.00 3.00 0.00 -6.00 3.00 0.00 0.00 0.00 0.00\n" +) + + +@pytest.fixture +def readable_bvh_file(tmp_path): + """Return the path to a readable valid BVH file.""" + file_path = tmp_path / "readable.bvh" + with open(file_path, "w") as f: + f.write(SIMPLE_BVH) # We can use same content as simple_bvh_file + return file_path + + +@pytest.fixture +def simple_bvh_file(tmp_path): + """Return the path to a simple valid BVH file.""" + file_path = tmp_path / "simple.bvh" + with open(file_path, "w") as f: + f.write(SIMPLE_BVH) + return file_path + + +@pytest.fixture +def complex_bvh_file(tmp_path): + """Return the path to a complex valid BVH file.""" + file_path = tmp_path / "complex.bvh" + with open(file_path, "w") as f: + f.write(COMPLEX_BVH) + return file_path + + +@pytest.fixture +def bvh_file_no_hierarchy(tmp_path): + """Return path to a BVH file missing HIERARCHY.""" + file_path = tmp_path / "no_hierarchy.bvh" + with open(file_path, "w") as f: + f.write("MOTION\nFrames: 1\nFrame Time: 0.03\n0 0 0\n") + return file_path + + +@pytest.fixture +def bvh_file_no_motion(tmp_path): + """Return path to a BVH file missing MOTION section.""" + file_path = tmp_path / "no_motion.bvh" + with open(file_path, "w") as f: + f.write("HIERARCHY\nROOT Root\n{\nOFFSET 0 0 0\n}\n") + return file_path diff --git a/tests/test_unit/test_io/test_load_poses.py b/tests/test_unit/test_io/test_load_poses.py index b44b9cdc7..252acea26 100644 --- a/tests/test_unit/test_io/test_load_poses.py +++ b/tests/test_unit/test_io/test_load_poses.py @@ -275,7 +275,15 @@ def test_load_from_nwb_file(input_type, kwargs, request): @pytest.mark.filterwarnings("ignore:.*is deprecated:DeprecationWarning") @pytest.mark.parametrize( "source_software", - ["DeepLabCut", "SLEAP", "LightningPose", "Anipose", "NWB", "Unknown"], + [ + "DeepLabCut", + "SLEAP", + "LightningPose", + "Anipose", + "NWB", + "BVH", + "Unknown", + ], ) @pytest.mark.parametrize("fps", [None, 30, 60.0]) def test_from_file_delegates_correctly(source_software, fps, caplog): @@ -288,6 +296,7 @@ def test_from_file_delegates_correctly(source_software, fps, caplog): "LightningPose": "movement.io.load_poses.from_lp_file", "Anipose": "movement.io.load_poses.from_anipose_file", "NWB": "movement.io.load_poses.from_nwb_file", + "BVH": "movement.io.load_poses.from_bvh_file", } if source_software == "Unknown": with pytest.raises(ValueError, match="Unsupported source"): @@ -321,3 +330,59 @@ def test_from_multiview_files(): assert isinstance(multi_view_ds, xr.Dataset) assert "view" in multi_view_ds.dims assert multi_view_ds.view.values.tolist() == view_names + + +@pytest.mark.parametrize( + "bvh_file, n_frames, n_joints", + [ + ("simple_bvh_file", 2, 2), + ("complex_bvh_file", 3, 5), + ], +) +def test_load_from_bvh_file(request, bvh_file, n_frames, n_joints, helpers): + """Test loading BVH file returns valid Dataset.""" + bvh_file = request.getfixturevalue(bvh_file) + ds = load_poses.from_bvh_file(bvh_file) + + expected_values = { + **expected_values_poses, + "source_software": "BVH", + "file_path": bvh_file, + "fps": 20.0, + "time_unit": "seconds", + } + helpers.assert_valid_dataset(ds, expected_values) + + # Verify shape + assert ds.position.shape == (n_frames, 3, n_joints, 1) + assert ds.confidence.shape == (n_frames, n_joints, 1) + + +@pytest.mark.parametrize( + "fixture_name, expected_joints", + [ + ("simple_bvh_file", ["Armature", "Bone1"]), + ("complex_bvh_file", ["Root", "Torso", "Neck", "LeftArm", "RightArm"]), + ], + ids=["simple_bvh", "complex_bvh"], +) +def test_bvh_joint_names(request, fixture_name, expected_joints): + """Test that joint names match BVH hierarchy.""" + bvh_file = request.getfixturevalue(fixture_name) + ds = load_poses.from_bvh_file(bvh_file) + actual = ds.coords["keypoints"].values.tolist() + assert actual == expected_joints + + +def test_bvh_fps_from_frame_time(simple_bvh_file): + """Test fps is computed from BVH Frame Time.""" + ds = load_poses.from_bvh_file(simple_bvh_file) + assert ds.fps == 20 + assert ds.time_unit == "seconds" + + +def test_bvh_fps_none(simple_bvh_file): + """Test that fps=None computes fps from BVH Frame Time.""" + ds = load_poses.from_bvh_file(simple_bvh_file, fps=None) + assert ds.fps == 20 + assert ds.time_unit == "seconds" diff --git a/tests/test_unit/test_validators/test_files_validators.py b/tests/test_unit/test_validators/test_files_validators.py index 496acb29c..8f6ae775d 100644 --- a/tests/test_unit/test_validators/test_files_validators.py +++ b/tests/test_unit/test_validators/test_files_validators.py @@ -12,6 +12,7 @@ from movement.validators.files import ( DEFAULT_FRAME_REGEXP, ValidAniposeCSV, + ValidBVHFile, ValidDeepLabCutCSV, ValidDeepLabCutH5, ValidNWBFile, @@ -32,6 +33,7 @@ ("readable_csv_file", "r", None, does_not_raise()), ("readable_csv_file", "r", {".csv"}, does_not_raise()), ("readable_csv_file", "r", {".csv", ".h5"}, does_not_raise()), + ("readable_bvh_file", "r", {".bvh"}, does_not_raise()), ("new_csv_file", "w", None, does_not_raise()), ("unreadable_file", "r", None, pytest.raises(PermissionError)), ("unwriteable_file", "w", None, pytest.raises(PermissionError)), @@ -50,6 +52,7 @@ "has read permission, exists, and is not a directory", "has expected suffix", "has one of the expected suffixes", + "has expected suffix", "has write permission and does not exist", "lacks read permission", "lacks write permission", @@ -687,3 +690,27 @@ def test_roi_collection_geojson_validator(content, expected_context, tmp_path): with expected_context: validated = ValidROICollectionGeoJSON(file_path) assert validated.file == file_path + + +def test_valid_bvh_file(readable_bvh_file): + """Test that a valid BVH file passes validation.""" + valid = ValidBVHFile(file=readable_bvh_file) + assert valid.file == readable_bvh_file + + +def test_invalid_bvh_no_hierarchy(bvh_file_no_hierarchy): + """Test BVH without HIERARCHY fails validation.""" + with pytest.raises(ValueError, match="HIERARCHY"): + ValidBVHFile(file=bvh_file_no_hierarchy) + + +def test_invalid_bvh_no_motion(bvh_file_no_motion): + """Test BVH without MOTION fails validation.""" + with pytest.raises(ValueError, match="MOTION"): + ValidBVHFile(file=bvh_file_no_motion) + + +def test_invalid_bvh_wrong_extension(wrong_extension_file): + """Test that wrong file extension fails.""" + with pytest.raises(ValueError, match="suffix"): + ValidBVHFile(file=wrong_extension_file)