Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions movement/io/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"Anipose",
"NWB",
"VIA-tracks",
"BVH",
]


Expand Down
71 changes: 71 additions & 0 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Literal, cast

import bvhio
import h5py
import numpy as np
import pandas as pd
Expand All @@ -17,6 +18,7 @@
from movement.validators.datasets import ValidPosesInputs
from movement.validators.files import (
ValidAniposeCSV,
ValidBVHFile,
ValidDeepLabCutCSV,
ValidDeepLabCutH5,
ValidFile,
Expand Down Expand Up @@ -105,6 +107,7 @@ def from_file(
"LightningPose",
"Anipose",
"NWB",
"BVH",
],
fps: float | None = None,
**kwargs,
Expand Down Expand Up @@ -174,6 +177,8 @@ def from_file(
return from_lp_file(file, fps)
elif source_software == "Anipose":
return from_anipose_file(file, fps, **kwargs)
elif source_software == "BVH":
return from_bvh_file(file, fps)
elif source_software == "NWB":
if fps is not None:
logger.warning(
Expand Down Expand Up @@ -943,3 +948,69 @@ def _ds_from_nwb_object(
return xr.merge(
single_keypoint_datasets, join="outer", compat="no_conflicts"
)


@register_loader("BVH", file_validators=[ValidBVHFile])
def from_bvh_file(file: str | Path, fps: float | None = None) -> xr.Dataset:
"""Create a ``movement`` poses dataset from a BVH file.

Parameters
----------
file
Path to the file containing the poses in .bvh format.
fps
The number of frames per second in the video. If None (default),
the fps value will be computed using the BVH file frame time.

Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks,
and associated metadata.

Examples
--------
>>> from movement.io import load_poses
>>> ds = load_poses.from_bvh_file("path/to/file.bvh")

"""
valid_file = cast("ValidFile", file)
file_path = valid_file.file

# Read the BVH hierarchy from the file
bvh = bvhio.readAsBvh(file_path)
frame_time = bvh.FrameTime
root = bvhio.readAsHierarchy(file_path)

# Collect joint names
joint_names = [joint.Name for joint, _, _ in root.layout()]
n_keypoints = len(joint_names)

# Frame range and fps
first, last = root.getKeyframeRange()
n_frames = last - first + 1
fps = fps or (1 / frame_time)

position_array = np.zeros(
(n_frames, 3, n_keypoints, 1), dtype=np.float32
) # 1 for single individual

for f in range(n_frames):
root.loadPose(f)
for j_idx, (joint, _, _) in enumerate(root.layout()):
pos = joint.PositionWorld
position_array[f, :, j_idx, 0] = [pos.x, pos.y, pos.z]

ds = from_numpy(
position_array=position_array,
confidence_array=None, # BVH is marker based, so maybe confidence= 1?
individual_names=["individual_0"],
keypoint_names=joint_names,
fps=fps,
source_software="BVH",
)

ds.attrs["source_file"] = file_path.as_posix()
ds.attrs["source_software"] = "BVH"
logger.info(f"Loaded poses from {file_path.name}")
return ds
35 changes: 35 additions & 0 deletions movement/validators/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,3 +1000,38 @@ class ValidROICollectionGeoJSON:

data: dict = field(init=False, factory=dict)
"""Parsed JSON data from the file, available after validation."""


@define
class ValidBVHFile:
"""Class for validating BVH (Biovision Hierarchy) files."""

suffixes: ClassVar[set[str]] = {".bvh"}
"""Expected suffix(es) for the file."""

file: Path = field(
converter=Path,
validator=_file_validator(permission="r", suffixes=suffixes),
)
"""Path to the BVH file to validate."""

@file.validator
def _file_contains_valid_bvh_structure(self, attribute, value):
"""Ensure the BVH file has valid structure."""
try:
with open(value) as f:
content = f.read()
except Exception as e:
raise logger.error(
ValueError(f"Could not read BVH file {value}: {e}")
) from e

if not content.strip().startswith("HIERARCHY"):
raise logger.error(
ValueError("BVH file must start with HIERARCHY keyword. ")
)

if "MOTION" not in content:
raise logger.error(
ValueError("BVH file must contain MOTION section")
)
122 changes: 122 additions & 0 deletions tests/fixtures/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,125 @@ def invalid_dstype_netcdf_file(tmp_path_factory):
ds.to_netcdf(invalid_path)

yield str(invalid_path)


# ---------------- BVH file fixtures ----------------------------
SIMPLE_BVH = (
"HIERARCHY\n"
"ROOT Armature\n"
"{\n"
" OFFSET 0.00 0.00 0.00\n"
" CHANNELS 6 Xposition Yposition Zposition"
" Xrotation Yrotation Zrotation\n"
" JOINT Bone1\n"
" {\n"
" OFFSET 1.00 2.00 0.00\n"
" CHANNELS 3 Xrotation Yrotation Zrotation\n"
" End Site\n"
" {\n"
" OFFSET 1.00 0.00 0.00\n"
" }\n"
" }\n"
"}\n"
"MOTION\n"
"Frames: 2\n"
"Frame Time: 0.05\n"
"0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00\n"
"0.50 1.00 0.25 5.00 2.50 0.00 2.50 1.00 0.00\n"
)

COMPLEX_BVH = (
"HIERARCHY\n"
"ROOT Root\n"
"{\n"
" OFFSET 0.00 0.00 0.00\n"
" CHANNELS 6 Xposition Yposition Zposition"
" Xrotation Yrotation Zrotation\n"
" JOINT Torso\n"
" {\n"
" OFFSET 0.00 5.00 0.00\n"
" CHANNELS 3 Xrotation Yrotation Zrotation\n"
" JOINT Neck\n"
" {\n"
" OFFSET 0.00 3.00 0.00\n"
" CHANNELS 3 Xrotation Yrotation Zrotation\n"
" End Site\n"
" {\n"
" OFFSET 0.00 2.00 0.00\n"
" }\n"
" }\n"
" }\n"
" JOINT LeftArm\n"
" {\n"
" OFFSET 2.00 4.00 0.00\n"
" CHANNELS 3 Xrotation Yrotation Zrotation\n"
" End Site\n"
" {\n"
" OFFSET 3.00 0.00 0.00\n"
" }\n"
" }\n"
" JOINT RightArm\n"
" {\n"
" OFFSET -2.00 4.00 0.00\n"
" CHANNELS 3 Xrotation Yrotation Zrotation\n"
" End Site\n"
" {\n"
" OFFSET -3.00 0.00 0.00\n"
" }\n"
" }\n"
"}\n"
"MOTION\n"
"Frames: 3\n"
"Frame Time: 0.05\n"
"0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00"
" 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00\n"
"0.50 1.00 0.25 5.00 2.50 0.00 2.50 1.00 0.00"
" 3.00 1.50 0.00 -3.00 1.50 0.00 0.00 0.00 0.00\n"
"1.00 2.00 0.50 10.00 5.00 0.00 5.00 2.00 0.00"
" 6.00 3.00 0.00 -6.00 3.00 0.00 0.00 0.00 0.00\n"
)


@pytest.fixture
def readable_bvh_file(tmp_path):
"""Return the path to a readable valid BVH file."""
file_path = tmp_path / "readable.bvh"
with open(file_path, "w") as f:
f.write(SIMPLE_BVH) # We can use same content as simple_bvh_file
return file_path


@pytest.fixture
def simple_bvh_file(tmp_path):
"""Return the path to a simple valid BVH file."""
file_path = tmp_path / "simple.bvh"
with open(file_path, "w") as f:
f.write(SIMPLE_BVH)
return file_path


@pytest.fixture
def complex_bvh_file(tmp_path):
"""Return the path to a complex valid BVH file."""
file_path = tmp_path / "complex.bvh"
with open(file_path, "w") as f:
f.write(COMPLEX_BVH)
return file_path


@pytest.fixture
def bvh_file_no_hierarchy(tmp_path):
"""Return path to a BVH file missing HIERARCHY."""
file_path = tmp_path / "no_hierarchy.bvh"
with open(file_path, "w") as f:
f.write("MOTION\nFrames: 1\nFrame Time: 0.03\n0 0 0\n")
return file_path


@pytest.fixture
def bvh_file_no_motion(tmp_path):
"""Return path to a BVH file missing MOTION section."""
file_path = tmp_path / "no_motion.bvh"
with open(file_path, "w") as f:
f.write("HIERARCHY\nROOT Root\n{\nOFFSET 0 0 0\n}\n")
return file_path
67 changes: 66 additions & 1 deletion tests/test_unit/test_io/test_load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,15 @@ def test_load_from_nwb_file(input_type, kwargs, request):
@pytest.mark.filterwarnings("ignore:.*is deprecated:DeprecationWarning")
@pytest.mark.parametrize(
"source_software",
["DeepLabCut", "SLEAP", "LightningPose", "Anipose", "NWB", "Unknown"],
[
"DeepLabCut",
"SLEAP",
"LightningPose",
"Anipose",
"NWB",
"BVH",
"Unknown",
],
)
@pytest.mark.parametrize("fps", [None, 30, 60.0])
def test_from_file_delegates_correctly(source_software, fps, caplog):
Expand All @@ -288,6 +296,7 @@ def test_from_file_delegates_correctly(source_software, fps, caplog):
"LightningPose": "movement.io.load_poses.from_lp_file",
"Anipose": "movement.io.load_poses.from_anipose_file",
"NWB": "movement.io.load_poses.from_nwb_file",
"BVH": "movement.io.load_poses.from_bvh_file",
}
if source_software == "Unknown":
with pytest.raises(ValueError, match="Unsupported source"):
Expand Down Expand Up @@ -321,3 +330,59 @@ def test_from_multiview_files():
assert isinstance(multi_view_ds, xr.Dataset)
assert "view" in multi_view_ds.dims
assert multi_view_ds.view.values.tolist() == view_names


@pytest.mark.parametrize(
"bvh_file, n_frames, n_joints",
[
("simple_bvh_file", 2, 2),
("complex_bvh_file", 3, 5),
],
)
def test_load_from_bvh_file(request, bvh_file, n_frames, n_joints, helpers):
"""Test loading BVH file returns valid Dataset."""
bvh_file = request.getfixturevalue(bvh_file)
ds = load_poses.from_bvh_file(bvh_file)

expected_values = {
**expected_values_poses,
"source_software": "BVH",
"file_path": bvh_file,
"fps": 20.0,
"time_unit": "seconds",
}
helpers.assert_valid_dataset(ds, expected_values)

# Verify shape
assert ds.position.shape == (n_frames, 3, n_joints, 1)
assert ds.confidence.shape == (n_frames, n_joints, 1)


@pytest.mark.parametrize(
"fixture_name, expected_joints",
[
("simple_bvh_file", ["Armature", "Bone1"]),
("complex_bvh_file", ["Root", "Torso", "Neck", "LeftArm", "RightArm"]),
],
ids=["simple_bvh", "complex_bvh"],
)
def test_bvh_joint_names(request, fixture_name, expected_joints):
"""Test that joint names match BVH hierarchy."""
bvh_file = request.getfixturevalue(fixture_name)
ds = load_poses.from_bvh_file(bvh_file)
actual = ds.coords["keypoints"].values.tolist()
assert actual == expected_joints


def test_bvh_fps_from_frame_time(simple_bvh_file):
"""Test fps is computed from BVH Frame Time."""
ds = load_poses.from_bvh_file(simple_bvh_file)
assert ds.fps == 20
assert ds.time_unit == "seconds"


def test_bvh_fps_none(simple_bvh_file):
"""Test that fps=None computes fps from BVH Frame Time."""
ds = load_poses.from_bvh_file(simple_bvh_file, fps=None)
assert ds.fps == 20
assert ds.time_unit == "seconds"
Loading
Loading