diff --git a/examples/load_bvh_data.py b/examples/load_bvh_data.py new file mode 100644 index 000000000..38593452b --- /dev/null +++ b/examples/load_bvh_data.py @@ -0,0 +1,242 @@ +"""Load BVH motion capture data +================================ + +Load a `BVH (Biovision Hierarchy) +`_ +motion capture file into ``movement`` and +visualise the 3D skeleton. +""" + +# %% +# Imports +# ------- + +import tempfile +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from movement.io import load_poses + +# %% +# About the BVH format +# ---------------------- +# BVH is a widely used text-based motion capture format +# originally developed by Biovision. It stores: +# +# - A **HIERARCHY** section defining the skeleton structure +# (joints, offsets, and channel types). +# - A **MOTION** section with per-frame channel values +# (translations and Euler angle rotations). +# +# ``movement`` parses the hierarchy and computes absolute +# 3D joint positions via forward kinematics, making it easy +# to analyse the data with the same tools used for +# animal pose estimation. + +# %% +# Create a sample BVH file +# -------------------------- +# For this example, we create a simple 5-joint skeleton +# with 20 frames of motion. + +n_frames = 20 +frame_time = 0.033333 # ~30 fps + +# Build the hierarchy string +hierarchy = """\ +HIERARCHY +ROOT Hips +{ + OFFSET 0.00 0.00 0.00 + CHANNELS 6 Xposition Yposition Zposition\ + Zrotation Xrotation Yrotation + JOINT Spine + { + OFFSET 0.00 5.00 0.00 + CHANNELS 3 Zrotation Xrotation Yrotation + JOINT Head + { + OFFSET 0.00 4.00 0.00 + CHANNELS 3 Zrotation Xrotation Yrotation + End Site + { + OFFSET 0.00 2.00 0.00 + } + } + } + JOINT LeftHand + { + OFFSET 3.00 4.50 0.00 + CHANNELS 3 Zrotation Xrotation Yrotation + End Site + { + OFFSET 4.00 0.00 0.00 + } + } + JOINT RightHand + { + OFFSET -3.00 4.50 0.00 + CHANNELS 3 Zrotation Xrotation Yrotation + End Site + { + OFFSET -4.00 0.00 0.00 + } + } +} +""" + +# Generate motion data: walking with arm swing +rng = np.random.default_rng(42) +motion_lines = [] +for frame in range(n_frames): + t = frame / n_frames * 2 * np.pi + # Root translation: walk forward along X + xpos = frame * 2.0 + ypos = 0.0 + zpos = 0.5 * np.sin(2 * t) # slight bounce + # Root rotation + zrot, xrot, yrot = 0.0, 0.0, 0.0 + # Spine: slight lean + spine_z = 3.0 * np.sin(t) + spine_x, spine_y = 0.0, 0.0 + # Head: look around + head_z = 5.0 * np.sin(0.5 * t) + head_x, head_y = 0.0, 0.0 + # Left hand: swing + lh_z = 20.0 * np.sin(t) + lh_x, lh_y = 0.0, 0.0 + # Right hand: opposite swing + rh_z = -20.0 * np.sin(t) + rh_x, rh_y = 0.0, 0.0 + + vals = [ + xpos, + ypos, + zpos, + zrot, + xrot, + yrot, + spine_z, + spine_x, + spine_y, + head_z, + head_x, + head_y, + lh_z, + lh_x, + lh_y, + rh_z, + rh_x, + rh_y, + ] + motion_lines.append(" ".join(f"{v:.4f}" for v in vals)) + +motion_section = ( + "MOTION\n" + f"Frames: {n_frames}\n" + f"Frame Time: {frame_time}\n" + "\n".join(motion_lines) + "\n" +) +bvh_content = hierarchy + motion_section + +# Save to temp file (using NamedTemporaryFile for security) +with tempfile.NamedTemporaryFile(mode="w", suffix=".bvh", delete=False) as f: + f.write(bvh_content) + bvh_path = Path(f.name) +print(f"Created sample BVH file: {bvh_path}") + +# %% +# Load the BVH file into movement +# --------------------------------- +# :func:`movement.io.load_poses.from_bvh_file` parses the +# BVH hierarchy and computes 3D positions via forward +# kinematics. If ``fps`` is not specified, it is derived +# from the BVH ``Frame Time`` field. + +ds = load_poses.from_bvh_file(bvh_path) +print(ds) + +# %% +# Explore the dataset +# -------------------- + +print("Shape:", ds.position.shape) +print( + "Keypoints (joints):", + ds.coords["keypoints"].values, +) +print("FPS:", ds.fps) +print("Space dims:", ds.coords["space"].values) + +# %% +# Visualise the 3D skeleton at a single frame +# --------------------------------------------- +# Let's plot the skeleton at the first frame. + +fig = plt.figure(figsize=(8, 6)) +ax = fig.add_subplot(111, projection="3d") + +frame_idx = 0 +t = ds.coords["time"].values[frame_idx] +pos = ds.position.sel(time=t, individuals="id_0") + +# Plot joints +for kp in ds.coords["keypoints"].values: + p = pos.sel(keypoints=kp).values + ax.scatter(*p, s=50, zorder=5) + ax.text(p[0], p[1], p[2] + 0.5, kp, fontsize=8) + +# Draw skeleton connections +bones = [ + ("Hips", "Spine"), + ("Spine", "Head"), + ("Hips", "LeftHand"), + ("Hips", "RightHand"), +] +for j1, j2 in bones: + p1 = pos.sel(keypoints=j1).values + p2 = pos.sel(keypoints=j2).values + ax.plot( + [p1[0], p2[0]], + [p1[1], p2[1]], + [p1[2], p2[2]], + "k-", + linewidth=2, + ) + +ax.set_xlabel("X") +ax.set_ylabel("Y") +ax.set_zlabel("Z") +ax.set_title("BVH skeleton (frame 0)") +plt.tight_layout() +plt.show() + +# %% +# Visualise joint trajectories over time +# ---------------------------------------- + +fig, axes = plt.subplots(1, 3, figsize=(14, 4)) +for i, coord in enumerate(["x", "y", "z"]): + ax = axes[i] + for kp in ds.coords["keypoints"].values: + vals = ds.position.sel(keypoints=kp, individuals="id_0", space=coord) + ax.plot(ds.coords["time"], vals, label=kp) + ax.set_xlabel("Time (s)") + ax.set_ylabel(f"{coord} position") + ax.set_title(f"{coord.upper()} over time") + ax.legend(fontsize=7) +plt.tight_layout() +plt.show() + +# %% +# BVH data is now in the standard ``movement`` format, +# so you can use all available analysis tools: +# filtering, kinematics computation, distance metrics, +# and more — just as you would with DeepLabCut or SLEAP +# data. + +# %% +# Clean up +# -------- +bvh_path.unlink() diff --git a/examples/load_coco_data.py b/examples/load_coco_data.py new file mode 100644 index 000000000..5cc7c5fe6 --- /dev/null +++ b/examples/load_coco_data.py @@ -0,0 +1,183 @@ +"""Load COCO keypoint annotations +================================== + +Load a COCO keypoint annotation JSON file into ``movement`` +and explore the resulting dataset. +""" + +# %% +# Imports +# ------- + +import json +import tempfile +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from movement.io import load_poses + +# %% +# About the COCO keypoint format +# -------------------------------- +# The `COCO (Common Objects in Context) +# `_ dataset format is one +# of the most widely used standards for human pose estimation. +# COCO keypoint annotation files are JSON files containing three +# main sections: +# +# - **images**: a list of image entries (each with an ``id``). +# - **annotations**: a list of keypoint annotations, each +# associated with an ``image_id`` and a ``category_id``. +# The ``keypoints`` field is a flat array of +# ``[x1, y1, v1, x2, y2, v2, ...]`` where ``v`` is the +# visibility flag (0 = not labelled, 1 = labelled but +# occluded, 2 = labelled and visible). +# - **categories**: defines the keypoint names and skeleton +# connectivity. +# +# ``movement`` maps each image to a time frame, each annotation +# per image to an individual, and the visibility flags to +# confidence scores. + +# %% +# Create a sample COCO file +# -------------------------- +# For this example, we'll create a minimal COCO keypoint file +# with 10 frames, 2 individuals, and 5 keypoints per person. + +keypoint_names = [ + "nose", + "left_shoulder", + "right_shoulder", + "left_hip", + "right_hip", +] +n_frames = 10 +n_individuals = 2 +n_keypoints = len(keypoint_names) + +rng = np.random.default_rng(42) + +coco_data = { + "images": [ + { + "id": i, + "file_name": f"frame_{i:04d}.jpg", + "width": 640, + "height": 480, + } + for i in range(n_frames) + ], + "annotations": [], + "categories": [ + { + "id": 1, + "name": "person", + "keypoints": keypoint_names, + "skeleton": [ + [0, 1], + [0, 2], + [1, 3], + [2, 4], + ], + } + ], +} + +ann_id = 0 +for frame_idx in range(n_frames): + for ind in range(n_individuals): + # Simulate walking trajectories + base_x = 100 + ind * 200 + frame_idx * 10 + base_y = 300 + ind * 50 - frame_idx * 5 + kps = [] + for k in range(n_keypoints): + x = base_x + rng.normal(0, 3) + y = base_y + k * 30 + rng.normal(0, 3) + v = 2 # visible + kps.extend([x, y, v]) + # Bounding box: [x, y, width, height] + bbox = [base_x - 20, base_y - 20, 150, 200] + area = bbox[2] * bbox[3] + coco_data["annotations"].append( + { + "id": ann_id, + "image_id": frame_idx, + "category_id": 1, + "keypoints": kps, + "num_keypoints": n_keypoints, + "bbox": bbox, + "area": float(area), + "iscrowd": 0, + "score": 0.85 + rng.uniform(0, 0.15), + "track_id": ind, + } + ) + ann_id += 1 + +# Save to a temporary file (using NamedTemporaryFile for security) +with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(coco_data, f) + coco_path = Path(f.name) +print(f"Created sample COCO file: {coco_path}") + +# %% +# Load the COCO file into movement +# --------------------------------- +# Use :func:`movement.io.load_poses.from_coco_file` to load +# the COCO keypoint annotations. You can optionally specify +# an ``fps`` value to convert frame indices to seconds. + +ds = load_poses.from_coco_file(coco_path, fps=30) +print(ds) + +# %% +# Explore the dataset +# -------------------- +# The resulting ``movement`` dataset has the familiar +# structure: ``position`` and ``confidence`` data variables +# with dimensions ``(time, space, keypoints, individuals)``. + +print("Shape:", ds.position.shape) +print("Keypoints:", ds.coords["keypoints"].values) +print("Individuals:", ds.coords["individuals"].values) + +# %% +# Visualise the trajectories +# --------------------------- +# Let's plot the nose trajectory for both individuals. + +fig, ax = plt.subplots(figsize=(8, 5)) +for ind in ds.coords["individuals"].values: + x = ds.position.sel(keypoints="nose", individuals=ind, space="x") + y = ds.position.sel(keypoints="nose", individuals=ind, space="y") + ax.plot(x, y, "o-", label=ind, markersize=4) +ax.set_xlabel("x (pixels)") +ax.set_ylabel("y (pixels)") +ax.set_title("Nose trajectories from COCO data") +ax.legend() +ax.invert_yaxis() # image coordinates +plt.tight_layout() +plt.show() + +# %% +# Confidence scores +# ------------------ +# The COCO visibility flags are mapped to confidence values: +# ``v=0`` → NaN position + 0 confidence, +# ``v=1`` → 0.5 × score, +# ``v=2`` → 1.0 × score. + +print("Mean confidence:", float(ds.confidence.mean())) +print( + "Min confidence:", + float(ds.confidence.min()), +) + +# %% +# Clean up +# -------- + +coco_path.unlink() diff --git a/movement/io/load.py b/movement/io/load.py index cad9752cb..905cc37b9 100644 --- a/movement/io/load.py +++ b/movement/io/load.py @@ -29,6 +29,8 @@ "Anipose", "NWB", "VIA-tracks", + "COCO", + "BVH", ] diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index c76715b63..dcf10debe 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -17,6 +17,8 @@ from movement.validators.datasets import ValidPosesInputs from movement.validators.files import ( ValidAniposeCSV, + ValidBVHFile, + ValidCOCOJSON, ValidDeepLabCutCSV, ValidDeepLabCutH5, ValidFile, @@ -943,3 +945,676 @@ def _ds_from_nwb_object( return xr.merge( single_keypoint_datasets, join="outer", compat="no_conflicts" ) + + +@register_loader("COCO", file_validators=[ValidCOCOJSON]) +def from_coco_file( + file: str | Path, + fps: float | None = None, +) -> xr.Dataset: + """Create a ``movement`` poses dataset from a COCO JSON file. + + The input file must follow the `COCO keypoint detection format + `_, containing + ``images``, ``annotations``, and ``categories`` sections. + + Parameters + ---------- + file + Path to the COCO keypoint annotation JSON file. + fps + The number of frames per second. If None (default), + the ``time`` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, + confidence scores, and associated metadata. + + Notes + ----- + Each image in the COCO file is treated as one time frame. + Images are sorted by their ``id`` to establish a temporal + order. Multiple annotations per image are treated as + separate individuals. If annotations include a + ``track_id`` field, it is used for consistent individual + identity across frames; otherwise individuals are numbered + per frame (``id_0``, ``id_1``, ...). + + The COCO visibility flag is mapped to confidence: + ``0`` (not labelled) → ``NaN`` position and ``0`` + confidence, ``1`` (labelled but not visible) → actual + position with confidence ``0.5``, ``2`` (labelled and + visible) → actual position with confidence ``1.0``. + + If the annotation includes a ``score`` field, the + per-keypoint confidence is multiplied by it. + + See Also + -------- + movement.io.load_poses.from_numpy + + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_coco_file("path/to/coco_keypoints.json", fps=30) + + """ + valid_file = cast("ValidCOCOJSON", file) + file_path = valid_file.file + data = valid_file.data + ds = _ds_from_coco_data(data, fps=fps) + ds.attrs["source_file"] = file_path.as_posix() + logger.info(f"Loaded pose tracks from {file_path}:\n{ds}") + return ds + + +def _coco_individual_mapping( + annotations: list[dict], + cat_id: int, + anns_by_image: dict[int, list[dict]], +) -> tuple[list[str], dict[int, int] | None]: + """Determine individual names and track_id mapping. + + Parameters + ---------- + annotations + List of COCO annotation dicts. + cat_id + Category ID to filter annotations. + anns_by_image + Annotations grouped by image_id. + + Returns + ------- + tuple + A tuple of (individual_names, track_id_to_idx). + track_id_to_idx is None if track_id is not present. + + """ + has_track_id = any("track_id" in ann for ann in annotations) + if has_track_id: + track_ids = sorted( + { + ann["track_id"] + for ann in annotations + if ann.get("category_id") == cat_id + } + ) + names = [f"id_{tid}" for tid in track_ids] + mapping = {tid: i for i, tid in enumerate(track_ids)} + return names, mapping + + max_individuals = max( + (len(v) for v in anns_by_image.values()), + default=1, + ) + return [f"id_{i}" for i in range(max_individuals)], None + + +def _coco_fill_arrays( + anns_by_image: dict[int, list[dict]], + image_id_to_frame: dict[int, int], + track_id_to_idx: dict[int, int] | None, + n_frames: int, + n_keypoints: int, + n_individuals: int, +) -> tuple[np.ndarray, np.ndarray]: + """Populate position and confidence arrays from COCO annotations. + + Parameters + ---------- + anns_by_image + Annotations grouped by image_id. + image_id_to_frame + Mapping from image_id to frame index. + track_id_to_idx + Mapping from track_id to individual index. + None if track_id is not used. + n_frames + Number of frames. + n_keypoints + Number of keypoints per individual. + n_individuals + Number of individuals. + + Returns + ------- + tuple + A tuple of (position_array, confidence_array). + + """ + position = np.full((n_frames, 2, n_keypoints, n_individuals), np.nan) + confidence = np.full((n_frames, n_keypoints, n_individuals), np.nan) + + for img_id, anns in anns_by_image.items(): + frame_idx = image_id_to_frame.get(img_id) + if frame_idx is None: + continue + for j, ann in enumerate(anns): + ind_idx = ( + track_id_to_idx[ann["track_id"]] + if track_id_to_idx is not None + else j + ) + _coco_fill_keypoints( + ann, + ind_idx, + frame_idx, + n_keypoints, + position, + confidence, + ) + return position, confidence + + +def _coco_fill_keypoints( + ann: dict, + ind_idx: int, + frame_idx: int, + n_keypoints: int, + position: np.ndarray, + confidence: np.ndarray, +) -> None: + """Fill position/confidence for a single annotation. + + Parameters + ---------- + ann + A single COCO annotation dict. + ind_idx + Individual index. + frame_idx + Frame index. + n_keypoints + Number of keypoints. + position + Position array to fill in-place. + confidence + Confidence array to fill in-place. + + """ + kps = ann["keypoints"] + score = ann.get("score", 1.0) + for k in range(n_keypoints): + x = kps[k * 3] + y = kps[k * 3 + 1] + v = kps[k * 3 + 2] + if v == 0: + position[frame_idx, :, k, ind_idx] = np.nan + confidence[frame_idx, k, ind_idx] = 0.0 + else: + position[frame_idx, 0, k, ind_idx] = x + position[frame_idx, 1, k, ind_idx] = y + vis_conf = v / 2.0 + confidence[frame_idx, k, ind_idx] = vis_conf * score + + +def _ds_from_coco_data( + data: dict, + fps: float | None = None, +) -> xr.Dataset: + """Create a ``movement`` poses dataset from parsed COCO data. + + Parameters + ---------- + data + Parsed COCO JSON data as a dictionary. + fps + Frames per second. If None, time coordinates will + be frame numbers. + + Returns + ------- + xarray.Dataset + A ``movement`` poses dataset. + + """ + first_cat = data["categories"][0] + cat_id = first_cat["id"] + keypoint_names = first_cat["keypoints"] + n_keypoints = len(keypoint_names) + + images = sorted(data["images"], key=lambda x: x["id"]) + image_id_to_frame = {img["id"]: i for i, img in enumerate(images)} + n_frames = len(images) + + anns_by_image: dict[int, list[dict]] = {} + for ann in data["annotations"]: + if ann.get("category_id") != cat_id: + continue + img_id = ann["image_id"] + anns_by_image.setdefault(img_id, []).append(ann) + + individual_names, track_id_to_idx = _coco_individual_mapping( + data["annotations"], cat_id, anns_by_image + ) + + position_array, confidence_array = _coco_fill_arrays( + anns_by_image, + image_id_to_frame, + track_id_to_idx, + n_frames, + n_keypoints, + len(individual_names), + ) + + return from_numpy( + position_array=position_array, + confidence_array=confidence_array, + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + source_software="COCO", + ) + + +@register_loader("BVH", file_validators=[ValidBVHFile]) +def from_bvh_file( + file: str | Path, + fps: float | None = None, +) -> xr.Dataset: + """Create a ``movement`` poses dataset from a BVH file. + + `BVH (Biovision Hierarchy) + `_ + is a text-based motion capture format that stores a + skeleton hierarchy and per-frame joint rotations. + + This function parses the skeleton hierarchy and motion + data, then computes 3D joint positions via forward + kinematics. + + Parameters + ---------- + file + Path to the BVH file. + fps + The number of frames per second. If None (default) + and the BVH file contains a ``Frame Time`` field, + fps will be computed from it. Otherwise, the ``time`` + coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the 3D pose tracks, + confidence scores (set to ``NaN``), and associated + metadata. + + Notes + ----- + The BVH format stores joint rotations (Euler angles) + rather than positions. This function computes absolute + 3D positions via forward kinematics by traversing the + skeleton hierarchy, applying rotations and offsets at + each joint. + + Only joint nodes (``ROOT`` and ``JOINT``) are included + as keypoints. ``End Site`` nodes are excluded. + + See Also + -------- + movement.io.load_poses.from_numpy + + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_bvh_file("path/to/motion.bvh") + + """ + valid_file = cast("ValidBVHFile", file) + file_path = valid_file.file + + hierarchy, motion_data, frame_time = _parse_bvh(file_path) + # Compute fps from frame_time if not provided + if fps is None and frame_time is not None and frame_time > 0: + fps = round(1.0 / frame_time, 6) + + # Compute 3D positions via forward kinematics + joint_names, position_array = _bvh_forward_kinematics( + hierarchy, motion_data + ) + + n_frames = position_array.shape[0] + n_keypoints = len(joint_names) + # BVH has no confidence info; set to NaN + confidence_array = np.full((n_frames, n_keypoints, 1), np.nan) + # Add individuals dimension (single individual) + position_array = position_array[..., np.newaxis] + + ds = from_numpy( + position_array=position_array, + confidence_array=confidence_array, + individual_names=None, + keypoint_names=joint_names, + fps=fps, + source_software="BVH", + ) + ds.attrs["source_file"] = file_path.as_posix() + logger.info(f"Loaded pose tracks from {file_path}:\n{ds}") + return ds + + +def _parse_bvh( + file_path: Path, +) -> tuple[list[dict], np.ndarray, float | None]: + """Parse a BVH file into hierarchy and motion data. + + Parameters + ---------- + file_path + Path to the BVH file. + + Returns + ------- + hierarchy + List of joint dictionaries with keys: ``name``, + ``offset``, ``channels``, ``children``, + ``parent_index``, ``channel_offset``. + motion_data + 2D array of shape ``(n_frames, n_channels)`` + containing the per-frame motion data values. + frame_time + The time between frames in seconds, or None. + + """ + with open(file_path) as f: + lines = f.readlines() + + joints, motion_start = _parse_bvh_hierarchy(lines) + motion_data, frame_time = _parse_bvh_motion(lines, motion_start) + return joints, motion_data, frame_time + + +def _parse_bvh_hierarchy( + lines: list[str], +) -> tuple[list[dict], int]: + """Parse the HIERARCHY section of a BVH file. + + Parameters + ---------- + lines + All lines from the BVH file. + + Returns + ------- + joints + List of joint dictionaries. + motion_line + Line index where MOTION section starts. + + """ + joints: list[dict] = [] + stack: list[int] = [] + channel_offset = 0 + i = 0 + + while i < len(lines): + line = lines[i].strip() + if line.startswith(("ROOT", "JOINT")): + _add_bvh_joint(line, joints, stack) + elif line.startswith("End Site"): + i = _skip_end_site(lines, i) + elif line.startswith("OFFSET") and stack: + parts = line.split() + joints[stack[-1]]["offset"] = np.array( + [float(p) for p in parts[1:4]] + ) + elif line.startswith("CHANNELS") and stack: + parts = line.split() + n_ch = int(parts[1]) + joints[stack[-1]]["channels"] = parts[2 : 2 + n_ch] + joints[stack[-1]]["channel_offset"] = channel_offset + channel_offset += n_ch + elif line == "}": + if stack: + stack.pop() + elif line.startswith("MOTION"): + return joints, i + i += 1 + + return joints, i + + +def _add_bvh_joint( + line: str, + joints: list[dict], + stack: list[int], +) -> None: + """Add a ROOT or JOINT node to the joints list.""" + parts = line.split() + name = parts[1] + joint: dict = { + "name": name, + "offset": np.zeros(3), + "channels": [], + "children": [], + "parent_index": (stack[-1] if stack else -1), + "channel_offset": 0, + } + if stack: + joints[stack[-1]]["children"].append(len(joints)) + joints.append(joint) + stack.append(len(joints) - 1) + + +def _skip_end_site(lines: list[str], i: int) -> int: + """Skip past an End Site block in BVH. + + Returns the line index of the closing brace. + """ + i += 1 + brace_count = 0 + while i < len(lines): + stripped = lines[i].strip() + if "{" in stripped: + brace_count += 1 + if "}" in stripped: + if brace_count <= 1: + return i + brace_count -= 1 + i += 1 + return i + + +def _parse_bvh_motion( + lines: list[str], + motion_start: int, +) -> tuple[np.ndarray, float | None]: + """Parse the MOTION section of a BVH file. + + Parameters + ---------- + lines + All lines from the BVH file. + motion_start + Line index of the MOTION keyword. + + Returns + ------- + motion_data + Array of shape ``(n_frames, n_channels)``. + frame_time + Seconds per frame, or None. + + """ + i = motion_start + 1 + n_frames = 0 + frame_time: float | None = None + + while i < len(lines): + line = lines[i].strip() + if line.startswith("Frames:"): + n_frames = int(line.split(":")[1].strip()) + elif line.startswith("Frame Time:"): + frame_time = float(line.split(":")[1].strip()) + i += 1 + break + i += 1 + + motion_rows = [] + for j in range(n_frames): + if i + j < len(lines): + row = lines[i + j].strip().split() + motion_rows.append([float(v) for v in row]) + motion_data = np.array(motion_rows) + return motion_data, frame_time + + +def _axis_rotation_matrix(axis: str, angle_rad: float) -> np.ndarray: + """Return a 3×3 rotation matrix for a single axis. + + Parameters + ---------- + axis + One of ``"X"``, ``"Y"``, or ``"Z"``. + angle_rad + Rotation angle in radians. + + Returns + ------- + numpy.ndarray + 3×3 rotation matrix. + + """ + c, s = np.cos(angle_rad), np.sin(angle_rad) + if axis == "X": + return np.array([[1, 0, 0], [0, c, -s], [0, s, c]]) + if axis == "Y": + return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]) + # Z + return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) + + +def _euler_to_rotation_matrix( + angles: np.ndarray, + order: str, +) -> np.ndarray: + """Convert Euler angles (degrees) to a rotation matrix. + + Parameters + ---------- + angles + Array of 3 Euler angles in degrees. + order + Rotation order string, e.g. ``"ZXY"``. + + Returns + ------- + numpy.ndarray + 3×3 rotation matrix. + + """ + rad = np.deg2rad(angles) + axis_to_idx = {"X": 0, "Y": 1, "Z": 2} + rot = np.eye(3) + for axis_char in order: + idx = axis_to_idx[axis_char] + rot = rot @ _axis_rotation_matrix(axis_char, rad[idx]) + return rot + + +def _extract_bvh_channels( + channels: list[str], + frame_data: np.ndarray, + ch_offset: int, +) -> tuple[np.ndarray, np.ndarray, str]: + """Extract translation, rotation, and order from channels. + + Parameters + ---------- + channels + List of channel names for a joint. + frame_data + 1D array of all channel values for one frame. + ch_offset + Starting index into ``frame_data``. + + Returns + ------- + translation + 3D translation vector. + rotation_angles + 3D rotation angles in degrees. + rotation_order + String indicating rotation axis order. + + """ + channel_map = { + "Xposition": ("t", 0), + "Yposition": ("t", 1), + "Zposition": ("t", 2), + "Xrotation": ("r", 0), + "Yrotation": ("r", 1), + "Zrotation": ("r", 2), + } + translation = np.zeros(3) + rotation_angles = np.zeros(3) + rotation_order = "" + for c_idx, ch_name in enumerate(channels): + val = frame_data[ch_offset + c_idx] + kind, axis_idx = channel_map[ch_name] + if kind == "t": + translation[axis_idx] = val + else: + rotation_angles[axis_idx] = val + rotation_order += ch_name[0] + return translation, rotation_angles, rotation_order + + +def _bvh_forward_kinematics( + joints: list[dict], + motion_data: np.ndarray, +) -> tuple[list[str], np.ndarray]: + """Compute 3D joint positions from BVH data. + + Parameters + ---------- + joints + List of joint dictionaries from ``_parse_bvh``. + motion_data + 2D array of shape ``(n_frames, n_channels)``. + + Returns + ------- + joint_names + List of joint names. + positions + Array of shape ``(n_frames, 3, n_joints)`` + containing the 3D positions. + + """ + n_frames = motion_data.shape[0] + n_joints = len(joints) + joint_names = [j["name"] for j in joints] + positions = np.zeros((n_frames, 3, n_joints)) + + for frame in range(n_frames): + transforms: list[tuple[np.ndarray, np.ndarray] | None] = [ + None + ] * n_joints + for j_idx, joint in enumerate(joints): + trans, rot_ang, rot_ord = _extract_bvh_channels( + joint["channels"], + motion_data[frame], + joint["channel_offset"], + ) + local_rot = ( + _euler_to_rotation_matrix(rot_ang, rot_ord) + if rot_ord + else np.eye(3) + ) + parent = joint["parent_index"] + if parent == -1: + g_pos = trans + joint["offset"] + g_rot = local_rot + else: + p_rot, p_pos = transforms[parent] + g_pos = p_pos + p_rot @ joint["offset"] + g_rot = p_rot @ local_rot + transforms[j_idx] = (g_rot, g_pos) + positions[frame, :, j_idx] = g_pos + + return joint_names, positions diff --git a/movement/validators/_json_schemas.py b/movement/validators/_json_schemas.py index 389cdf353..058d0249d 100644 --- a/movement/validators/_json_schemas.py +++ b/movement/validators/_json_schemas.py @@ -9,6 +9,65 @@ "LineOfInterest": ("LineString", "LinearRing"), } +# JSON schema for COCO keypoint annotation files. +COCO_KEYPOINTS_SCHEMA: Mapping[str, Any] = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "COCO Keypoints", + "description": ( + "Schema for validating COCO keypoint annotation JSON files." + ), + "type": "object", + "required": ["images", "annotations", "categories"], + "properties": { + "images": { + "type": "array", + "items": { + "type": "object", + "required": ["id"], + "properties": { + "id": {"type": "integer"}, + }, + }, + }, + "annotations": { + "type": "array", + "items": { + "type": "object", + "required": [ + "id", + "image_id", + "category_id", + "keypoints", + ], + "properties": { + "id": {"type": "integer"}, + "image_id": {"type": "integer"}, + "category_id": {"type": "integer"}, + "keypoints": { + "type": "array", + "items": {"type": "number"}, + }, + }, + }, + }, + "categories": { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "required": ["id", "keypoints"], + "properties": { + "id": {"type": "integer"}, + "keypoints": { + "type": "array", + "items": {"type": "string"}, + }, + }, + }, + }, + }, +} + # JSON schema for movement-compatible RoI GeoJSON collections. ROI_COLLECTION_SCHEMA: Mapping[str, Any] = { "$schema": "https://json-schema.org/draft/2020-12/schema", diff --git a/movement/validators/files.py b/movement/validators/files.py index e75baf8ac..810d391e8 100644 --- a/movement/validators/files.py +++ b/movement/validators/files.py @@ -17,6 +17,7 @@ from movement.utils.logging import logger from movement.validators._json_schemas import ( + COCO_KEYPOINTS_SCHEMA, ROI_COLLECTION_SCHEMA, ROI_TYPE_TO_GEOMETRY, ) @@ -923,6 +924,152 @@ class ValidNWBFile: """Path to the NWB file on disk (ending in ".nwb") or an NWBFile object.""" +def _check_coco_keypoint_lengths( + data: Mapping[str, Any], +) -> None: + """Ensure each COCO annotation's keypoint array has the correct length. + + Each keypoint is represented by three values (x, y, visibility), + so the length of the ``keypoints`` array must equal + ``3 × len(category["keypoints"])``. + + Parameters + ---------- + data + Parsed COCO JSON data as a dictionary. + + Raises + ------ + ValueError + If any annotation's keypoints array has an unexpected length. + + """ + categories = {c["id"]: c for c in data.get("categories", [])} + for ann in data.get("annotations", []): + kps = ann.get("keypoints", []) + cat_id = ann.get("category_id") + cat = categories.get(cat_id) + if cat is None: + continue + n_kp = len(cat.get("keypoints", [])) + expected_len = n_kp * 3 + if len(kps) != expected_len: + raise logger.error( + ValueError( + f"Annotation {ann['id']} has " + f"{len(kps)} keypoint values, " + f"expected {expected_len} " + f"(3 × {n_kp} keypoints)." + ) + ) + + +@define +class ValidCOCOJSON: + """Class for validating COCO keypoint annotation JSON files. + + The validator ensures that the file: + + - is in valid JSON format. + - conforms to the COCO keypoint annotation schema, which checks that + the file contains ``images``, ``annotations``, and ``categories`` + sections with the required fields. + + Additionally, it performs a custom validation step to ensure that + each annotation's ``keypoints`` array length is consistent with + the category's keypoint count (3 values per keypoint). + + Raises + ------ + ValueError + If the file is not valid JSON, does not match the expected + schema, or has keypoint arrays of unexpected length. + + See Also + -------- + movement.io.load_poses.from_coco_file : + Load a COCO keypoint annotation JSON file. + + """ + + suffixes: ClassVar[set[str]] = {".json"} + """Expected suffix(es) for the file.""" + + schema: ClassVar[Mapping[str, Any]] = COCO_KEYPOINTS_SCHEMA + """JSON schema for validating the file structure.""" + + file: Path = field( + converter=Path, + validator=validators.and_( + _file_validator(permission="r", suffixes=suffixes), + _json_validator( + schema=schema, + custom_checks=(_check_coco_keypoint_lengths,), + data_attr="data", + ), + ), + ) + """Path to the COCO keypoint JSON file to validate.""" + + data: dict = field(init=False, factory=dict) + """Parsed JSON data from the file, available after validation.""" + + +@define +class ValidBVHFile: + """Class for validating BVH (Biovision Hierarchy) files. + + The validator ensures that the file: + + - has a ``.bvh`` suffix, + - starts with a ``HIERARCHY`` section, and + - contains a ``MOTION`` section with frame data. + + Raises + ------ + ValueError + If the file does not have BVH structure. + + See Also + -------- + movement.io.load_poses.from_bvh_file : + Load a BVH motion capture file. + + """ + + suffixes: ClassVar[set[str]] = {".bvh"} + """Expected suffix(es) for the file.""" + + file: Path = field( + converter=Path, + validator=_file_validator(permission="r", suffixes=suffixes), + ) + """Path to the BVH file to validate.""" + + @file.validator + def _file_contains_bvh_structure(self, attribute, value): + """Ensure the file contains HIERARCHY and MOTION sections. + + These are the two mandatory top-level sections of a valid + BVH file. + """ + with open(value) as f: + content = f.read() + if not content.strip().startswith("HIERARCHY"): + raise logger.error( + ValueError( + f"File {value} does not start with " + "'HIERARCHY'. Not a valid BVH file." + ) + ) + if "MOTION" not in content: + raise logger.error( + ValueError( + f"File {value} does not contain a 'MOTION' section." + ) + ) + + def _check_roi_type_matches_geometry(data: Mapping[str, Any]) -> None: """Ensure ``roi_type`` properties match the GeoJSON geometry types. diff --git a/tests/test_unit/test_io/test_load_coco_and_bvh.py b/tests/test_unit/test_io/test_load_coco_and_bvh.py new file mode 100644 index 000000000..897f579fe --- /dev/null +++ b/tests/test_unit/test_io/test_load_coco_and_bvh.py @@ -0,0 +1,519 @@ +"""Test suite for COCO and BVH loaders in the load_poses module.""" + +import json + +import numpy as np +import pytest +import xarray as xr + +from movement.io import load_poses +from movement.validators.datasets import ValidPosesInputs +from movement.validators.files import ValidBVHFile, ValidCOCOJSON + +expected_values_poses = { + "vars_dims": {"position": 4, "confidence": 3}, + "dim_names": ValidPosesInputs.DIM_NAMES, +} + + +# ============== COCO test fixtures ================================== + + +def _make_coco_data( + n_images=3, + n_individuals=2, + n_keypoints=3, + with_track_id=False, + with_score=True, +): + """Build a minimal COCO keypoint annotation dict.""" + keypoint_names = [ + "nose", + "left_eye", + "right_eye", + ][:n_keypoints] + images = [ + {"id": i, "file_name": f"frame_{i:04d}.jpg"} for i in range(n_images) + ] + annotations = [] + ann_id = 0 + for img in images: + for ind in range(n_individuals): + kps = [] + for k in range(n_keypoints): + x = float(100 + ind * 50 + k * 10) + y = float(200 + ind * 30 + k * 5) + v = 2 # labelled and visible + kps.extend([x, y, v]) + ann = { + "id": ann_id, + "image_id": img["id"], + "category_id": 1, + "keypoints": kps, + } + if with_track_id: + ann["track_id"] = ind + if with_score: + ann["score"] = 0.9 + ann_id += 1 + annotations.append(ann) + categories = [ + { + "id": 1, + "name": "person", + "keypoints": keypoint_names, + } + ] + return { + "images": images, + "annotations": annotations, + "categories": categories, + } + + +@pytest.fixture +def coco_json_file(tmp_path): + """Return the path to a valid COCO keypoint JSON file.""" + data = _make_coco_data() + file_path = tmp_path / "coco_keypoints.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path + + +@pytest.fixture +def coco_json_file_with_track_id(tmp_path): + """Return the path to a COCO JSON file with track IDs.""" + data = _make_coco_data(with_track_id=True) + file_path = tmp_path / "coco_tracked.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path + + +@pytest.fixture +def coco_json_file_invisible_keypoints(tmp_path): + """Return path to COCO JSON with visibility=0 keypoints.""" + data = _make_coco_data( + n_images=2, + n_individuals=1, + n_keypoints=3, + ) + # Set first keypoint of first annotation to invisible + data["annotations"][0]["keypoints"][2] = 0 + file_path = tmp_path / "coco_invisible.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path + + +@pytest.fixture +def coco_json_file_single_individual(tmp_path): + """Return path to a COCO JSON file with one individual.""" + data = _make_coco_data(n_individuals=1) + file_path = tmp_path / "coco_single.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path + + +@pytest.fixture +def coco_json_file_invalid_keypoints_length(tmp_path): + """Return path to COCO JSON with wrong keypoints length.""" + data = _make_coco_data(n_images=1, n_individuals=1) + # Corrupt keypoints: remove last value + data["annotations"][0]["keypoints"] = [1.0, 2.0] + file_path = tmp_path / "coco_bad_kps.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path + + +@pytest.fixture +def coco_json_file_missing_keys(tmp_path): + """Return path to a JSON file missing required COCO keys.""" + data = {"images": [], "annotations": []} + file_path = tmp_path / "coco_missing_keys.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path + + +@pytest.fixture +def coco_json_file_no_score(tmp_path): + """Return path to a COCO JSON file without score field.""" + data = _make_coco_data(with_score=False) + file_path = tmp_path / "coco_no_score.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path + + +# ============== BVH test fixtures ================================== + +_BVH_FRAME_0 = ( + "0.00 0.00 0.00 0.00 0.00 0.00 " + "0.00 0.00 0.00 0.00 0.00 0.00 " + "0.00 0.00 0.00 0.00 0.00 0.00" +) +_BVH_FRAME_1 = ( + "1.00 2.00 0.50 10.00 5.00 0.00 " + "5.00 2.00 0.00 3.00 1.00 0.00 " + "-3.00 1.00 0.00 3.00 -1.00 0.00" +) +_BVH_FRAME_2 = ( + "2.00 4.00 1.00 20.00 10.00 0.00 " + "10.00 4.00 0.00 6.00 2.00 0.00 " + "-6.00 2.00 0.00 6.00 -2.00 0.00" +) + +SAMPLE_BVH = ( + "HIERARCHY\n" + "ROOT Hips\n" + "{\n" + " OFFSET 0.00 0.00 0.00\n" + " CHANNELS 6 Xposition Yposition Zposition" + " Zrotation Xrotation Yrotation\n" + " JOINT Spine\n" + " {\n" + " OFFSET 0.00 5.21 0.00\n" + " CHANNELS 3 Zrotation Xrotation Yrotation\n" + " JOINT Head\n" + " {\n" + " OFFSET 0.00 5.45 0.00\n" + " CHANNELS 3 Zrotation Xrotation" + " Yrotation\n" + " End Site\n" + " {\n" + " OFFSET 0.00 3.00 0.00\n" + " }\n" + " }\n" + " }\n" + " JOINT LeftArm\n" + " {\n" + " OFFSET 3.50 4.80 0.00\n" + " CHANNELS 3 Zrotation Xrotation Yrotation\n" + " End Site\n" + " {\n" + " OFFSET 5.00 0.00 0.00\n" + " }\n" + " }\n" + " JOINT RightArm\n" + " {\n" + " OFFSET -3.50 4.80 0.00\n" + " CHANNELS 3 Zrotation Xrotation Yrotation\n" + " End Site\n" + " {\n" + " OFFSET -5.00 0.00 0.00\n" + " }\n" + " }\n" + "}\n" + "MOTION\n" + "Frames: 3\n" + "Frame Time: 0.033333\n" + f"{_BVH_FRAME_0}\n" + f"{_BVH_FRAME_1}\n" + f"{_BVH_FRAME_2}\n" +) + + +@pytest.fixture +def bvh_file(tmp_path): + """Return the path to a valid BVH file.""" + file_path = tmp_path / "motion.bvh" + with open(file_path, "w") as f: + f.write(SAMPLE_BVH) + return file_path + + +@pytest.fixture +def bvh_file_no_hierarchy(tmp_path): + """Return path to a BVH file missing HIERARCHY.""" + file_path = tmp_path / "bad_no_hierarchy.bvh" + with open(file_path, "w") as f: + f.write("MOTION\nFrames: 1\nFrame Time: 0.03\n0 0 0\n") + return file_path + + +@pytest.fixture +def bvh_file_no_motion(tmp_path): + """Return path to a BVH file missing MOTION section.""" + file_path = tmp_path / "bad_no_motion.bvh" + with open(file_path, "w") as f: + f.write("HIERARCHY\nROOT Hips\n{\nOFFSET 0 0 0\n}\n") + return file_path + + +# ============== COCO loader tests ================================== + + +class TestCOCOLoader: + """Tests for the COCO keypoint loader.""" + + def test_load_from_coco_file(self, coco_json_file, helpers): + """Test loading COCO keypoints returns valid Dataset.""" + ds = load_poses.from_coco_file(coco_json_file, fps=30) + expected_values = { + **expected_values_poses, + "source_software": "COCO", + "file_path": coco_json_file, + "fps": 30, + } + helpers.assert_valid_dataset(ds, expected_values) + + def test_coco_dataset_shape(self, coco_json_file): + """Test that COCO dataset has expected shape.""" + ds = load_poses.from_coco_file(coco_json_file) + # 3 images, 2 space dims, 3 keypoints, 2 individuals + assert ds.position.shape == (3, 2, 3, 2) + assert ds.confidence.shape == (3, 3, 2) + + def test_coco_keypoint_names(self, coco_json_file): + """Test that keypoint names match categories.""" + ds = load_poses.from_coco_file(coco_json_file) + assert ds.coords["keypoints"].values.tolist() == [ + "nose", + "left_eye", + "right_eye", + ] + + def test_coco_with_track_id(self, coco_json_file_with_track_id): + """Test COCO loading with track_id for individuals.""" + ds = load_poses.from_coco_file(coco_json_file_with_track_id) + assert ds.coords["individuals"].values.tolist() == [ + "id_0", + "id_1", + ] + # All positions should be finite (all visible) + assert not np.isnan(ds.position.values).all() + + def test_coco_invisible_keypoints( + self, coco_json_file_invisible_keypoints + ): + """Test that invisible keypoints get NaN position.""" + ds = load_poses.from_coco_file(coco_json_file_invisible_keypoints) + # First keypoint of first frame/individual is invisible + pos = ds.position.sel( + time=0, + keypoints="nose", + individuals="id_0", + ) + assert np.isnan(pos.values).all() + # Confidence for invisible keypoint should be 0 + conf = ds.confidence.sel( + time=0, + keypoints="nose", + individuals="id_0", + ) + assert conf.values == pytest.approx(0.0) + + def test_coco_single_individual(self, coco_json_file_single_individual): + """Test loading COCO file with one individual.""" + ds = load_poses.from_coco_file(coco_json_file_single_individual) + assert ds.sizes["individuals"] == 1 + + def test_coco_no_score_field(self, coco_json_file_no_score): + """Test loading COCO without score field.""" + ds = load_poses.from_coco_file(coco_json_file_no_score) + # With v=2 and no score, confidence = 2/2 * 1.0 = 1.0 + conf = ds.confidence.values + visible_mask = ~np.isnan(conf) + assert np.allclose(conf[visible_mask], 1.0) + + def test_coco_fps_none(self, coco_json_file): + """Test that fps=None gives frame-number time coords.""" + ds = load_poses.from_coco_file(coco_json_file, fps=None) + assert ds.time_unit == "frames" + + def test_coco_fps_set(self, coco_json_file): + """Test that providing fps gives second-based coords.""" + ds = load_poses.from_coco_file(coco_json_file, fps=30) + assert ds.time_unit == "seconds" + assert ds.fps == 30 + + def test_coco_source_software_attr(self, coco_json_file): + """Test source_software attribute is set correctly.""" + ds = load_poses.from_coco_file(coco_json_file) + assert ds.source_software == "COCO" + + def test_coco_source_file_attr(self, coco_json_file): + """Test source_file attribute is set correctly.""" + ds = load_poses.from_coco_file(coco_json_file) + assert ds.source_file == coco_json_file.as_posix() + + +class TestCOCOValidation: + """Tests for COCO file validation.""" + + def test_valid_coco_json(self, coco_json_file): + """Test that valid COCO JSON passes validation.""" + valid = ValidCOCOJSON(file=coco_json_file) + assert valid.file == coco_json_file + + def test_invalid_coco_missing_keys(self, coco_json_file_missing_keys): + """Test that JSON missing COCO keys fails.""" + with pytest.raises(ValueError, match="schema"): + ValidCOCOJSON(file=coco_json_file_missing_keys) + + def test_invalid_coco_keypoints_length( + self, coco_json_file_invalid_keypoints_length + ): + """Test that wrong keypoints array length fails.""" + with pytest.raises(ValueError, match="keypoint"): + ValidCOCOJSON( + file=coco_json_file_invalid_keypoints_length, + ) + + def test_invalid_coco_wrong_extension(self, wrong_extension_file): + """Test that wrong file extension fails.""" + with pytest.raises(ValueError, match="suffix"): + ValidCOCOJSON(file=wrong_extension_file) + + +# ============== BVH loader tests =================================== + + +class TestBVHLoader: + """Tests for the BVH file loader.""" + + def test_load_from_bvh_file(self, bvh_file, helpers): + """Test loading BVH file returns valid Dataset.""" + ds = load_poses.from_bvh_file(bvh_file) + expected_values = { + **expected_values_poses, + "source_software": "BVH", + "file_path": bvh_file, + } + helpers.assert_valid_dataset(ds, expected_values) + + def test_bvh_dataset_shape(self, bvh_file): + """Test that BVH dataset has expected shape.""" + ds = load_poses.from_bvh_file(bvh_file) + # 3 frames, 3 space dims, 5 joints, 1 individual + assert ds.position.shape == (3, 3, 5, 1) + assert ds.confidence.shape == (3, 5, 1) + + def test_bvh_joint_names(self, bvh_file): + """Test that joint names match BVH hierarchy.""" + ds = load_poses.from_bvh_file(bvh_file) + expected_joints = [ + "Hips", + "Spine", + "Head", + "LeftArm", + "RightArm", + ] + actual = ds.coords["keypoints"].values.tolist() + assert actual == expected_joints + + def test_bvh_3d_space(self, bvh_file): + """Test that BVH data has 3 spatial dimensions.""" + ds = load_poses.from_bvh_file(bvh_file) + assert ds.sizes["space"] == 3 + assert "z" in ds.coords["space"].values + + def test_bvh_fps_from_frame_time(self, bvh_file): + """Test fps is computed from BVH Frame Time.""" + ds = load_poses.from_bvh_file(bvh_file) + # Frame Time: 0.033333 → fps ≈ 30 + assert ds.fps == pytest.approx(30.0, abs=0.1) + assert ds.time_unit == "seconds" + + def test_bvh_fps_override(self, bvh_file): + """Test that providing fps overrides Frame Time.""" + ds = load_poses.from_bvh_file(bvh_file, fps=60) + assert ds.fps == 60 + assert ds.time_unit == "seconds" + + def test_bvh_root_position_frame_0(self, bvh_file): + """Test root position in the rest pose (frame 0).""" + ds = load_poses.from_bvh_file(bvh_file) + root_pos = ds.position.sel( + time=ds.coords["time"][0], + keypoints="Hips", + individuals="id_0", + ) + # Frame 0: all zeros in channels, offset 0,0,0 + np.testing.assert_allclose(root_pos.values, [0.0, 0.0, 0.0], atol=1e-6) + + def test_bvh_root_position_frame_1(self, bvh_file): + """Test root position in frame 1 (translation).""" + ds = load_poses.from_bvh_file(bvh_file) + root_pos = ds.position.sel( + time=ds.coords["time"][1], + keypoints="Hips", + individuals="id_0", + ) + # Frame 1: Xposition=1, Yposition=2, Zposition=0.5 + np.testing.assert_allclose(root_pos.values, [1.0, 2.0, 0.5], atol=1e-6) + + def test_bvh_source_software_attr(self, bvh_file): + """Test source_software attribute is set correctly.""" + ds = load_poses.from_bvh_file(bvh_file) + assert ds.source_software == "BVH" + + def test_bvh_source_file_attr(self, bvh_file): + """Test source_file attribute is set correctly.""" + ds = load_poses.from_bvh_file(bvh_file) + assert ds.source_file == bvh_file.as_posix() + + def test_bvh_confidence_is_nan(self, bvh_file): + """Test BVH has NaN confidence (no conf info).""" + ds = load_poses.from_bvh_file(bvh_file) + assert np.isnan(ds.confidence.values).all() + + def test_bvh_single_individual(self, bvh_file): + """Test BVH creates single-individual dataset.""" + ds = load_poses.from_bvh_file(bvh_file) + assert ds.sizes["individuals"] == 1 + + +class TestBVHValidation: + """Tests for BVH file validation.""" + + def test_valid_bvh_file(self, bvh_file): + """Test that a valid BVH file passes validation.""" + valid = ValidBVHFile(file=bvh_file) + assert valid.file == bvh_file + + def test_invalid_bvh_no_hierarchy(self, bvh_file_no_hierarchy): + """Test BVH without HIERARCHY fails validation.""" + with pytest.raises(ValueError, match="HIERARCHY"): + ValidBVHFile(file=bvh_file_no_hierarchy) + + def test_invalid_bvh_no_motion(self, bvh_file_no_motion): + """Test BVH without MOTION fails validation.""" + with pytest.raises(ValueError, match="MOTION"): + ValidBVHFile(file=bvh_file_no_motion) + + def test_invalid_bvh_wrong_extension(self, wrong_extension_file): + """Test that wrong file extension fails.""" + with pytest.raises(ValueError, match="suffix"): + ValidBVHFile(file=wrong_extension_file) + + +# ============== load_dataset integration tests ===================== + + +class TestLoadDatasetIntegration: + """Test that COCO and BVH work through load_dataset.""" + + def test_load_dataset_coco(self, coco_json_file): + """Test load_dataset with source_software='COCO'.""" + from movement.io import load_dataset + + ds = load_dataset( + coco_json_file, + source_software="COCO", + fps=30, + ) + assert isinstance(ds, xr.Dataset) + assert ds.source_software == "COCO" + + def test_load_dataset_bvh(self, bvh_file): + """Test load_dataset with source_software='BVH'.""" + from movement.io import load_dataset + + ds = load_dataset(bvh_file, source_software="BVH") + assert isinstance(ds, xr.Dataset) + assert ds.source_software == "BVH"