Skip to content

Commit accd6c3

Browse files
Copilotjeertmans
andcommitted
Add fuse_planes method to SBRPaths class
Co-authored-by: jeertmans <27275099+jeertmans@users.noreply.github.com>
1 parent 1f9c67b commit accd6c3

1 file changed

Lines changed: 139 additions & 0 deletions

File tree

differt/src/differt/geometry/_paths.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,3 +638,142 @@ def plot(self, **kwargs: Any) -> PlotOutput:
638638
self.get_paths(order).plot()
639639

640640
return output
641+
642+
def fuse_planes(self) -> Self:
643+
"""
644+
Fuse multiple interception planes into a single plane dimension.
645+
646+
This method reduces the size of vertices from ``[..., num_planes, num_rays, ...]``
647+
to ``[..., num_rays, ...]`` by keeping, for each ray, the first interception
648+
vertex where the mask is True across all planes.
649+
650+
This method assumes that:
651+
652+
1. The ``SBRPaths`` was generated using the ``plane`` interception method.
653+
2. The planes are non-overlapping (i.e., a single ray can only hit one plane
654+
at maximum).
655+
656+
Returns:
657+
A new ``SBRPaths`` instance with fused plane dimensions.
658+
659+
Raises:
660+
ValueError: If the paths do not have a planes dimension (i.e., if the
661+
shape does not include a ``num_planes`` dimension after the batch
662+
dimensions).
663+
664+
Examples:
665+
>>> # Assuming paths have shape [num_tx, num_planes, num_rays, path_length, 3]
666+
>>> fused_paths = paths.fuse_planes()
667+
>>> # Result has shape [num_tx, num_rays, path_length, 3]
668+
"""
669+
# The vertices shape is expected to be [..., num_planes, num_rays, path_length, 3]
670+
# We need to identify which dimension is num_planes
671+
# Based on the implementation, for plane-based SBR:
672+
# - tx_batch dimensions first
673+
# - then num_planes
674+
# - then num_rays
675+
# - then path_length
676+
# - then 3 (coordinates)
677+
678+
# Get the shape
679+
vertices_shape = self.vertices.shape
680+
if len(vertices_shape) < 4:
681+
msg = (
682+
"Cannot fuse planes: vertices shape must have at least 4 dimensions "
683+
f"[..., num_planes, num_rays, path_length, 3], but got shape {vertices_shape}"
684+
)
685+
raise ValueError(msg)
686+
687+
# The masks shape is [..., num_planes, num_rays, path_length-1]
688+
# We want to find, for each ray, the first plane where ANY mask is True
689+
690+
# Find the first plane index where mask is True for each ray
691+
# Shape: [..., num_planes, num_rays]
692+
# Check if any of the path orders have a True mask
693+
any_mask_true = jnp.any(self.masks, axis=-1)
694+
695+
# For each ray, find the first plane with a True mask
696+
# We'll use argmax which returns the first True index (or 0 if all False)
697+
# Shape: [..., num_rays]
698+
first_plane_idx = jnp.argmax(any_mask_true, axis=-2)
699+
700+
# Check if there's actually a valid interception for each ray
701+
# Shape: [..., num_rays]
702+
has_interception = jnp.any(any_mask_true, axis=-2)
703+
704+
# Now we need to gather the vertices and objects from the selected plane
705+
# Create indices for gathering along the num_planes axis
706+
707+
# For vertices: [..., num_planes, num_rays, path_length, 3]
708+
# first_plane_idx: [..., num_rays]
709+
# We need to expand to: [..., 1, num_rays, 1, 1] then broadcast
710+
711+
# Get number of batch dimensions
712+
num_batch_dims = len(vertices_shape) - 4
713+
714+
# Expand first_plane_idx by adding axis for num_planes, path_length, and coords
715+
# Start: [..., num_rays]
716+
# Add axis at -3 for num_planes (becomes [..., 1, num_rays])
717+
gather_idx_vertices = jnp.expand_dims(first_plane_idx, axis=-2)
718+
# Add axis at -1 for path_length (becomes [..., 1, num_rays, 1])
719+
gather_idx_vertices = jnp.expand_dims(gather_idx_vertices, axis=-1)
720+
# Add axis at -1 for coords (becomes [..., 1, num_rays, 1, 1])
721+
gather_idx_vertices = jnp.expand_dims(gather_idx_vertices, axis=-1)
722+
723+
# Broadcast to match vertices shape
724+
target_shape = (
725+
vertices_shape[:num_batch_dims]
726+
+ (1,)
727+
+ vertices_shape[num_batch_dims + 1 :]
728+
)
729+
gather_idx_vertices = jnp.broadcast_to(gather_idx_vertices, target_shape)
730+
731+
# Use take_along_axis for the num_planes dimension (axis=num_batch_dims)
732+
fused_vertices = jnp.take_along_axis(
733+
self.vertices, gather_idx_vertices, axis=num_batch_dims
734+
).squeeze(axis=num_batch_dims)
735+
736+
# Similarly for objects: [..., num_planes, num_rays, path_length]
737+
gather_idx_objects = jnp.expand_dims(
738+
first_plane_idx, axis=-2
739+
) # [..., 1, num_rays]
740+
gather_idx_objects = jnp.expand_dims(
741+
gather_idx_objects, axis=-1
742+
) # [..., 1, num_rays, 1]
743+
objects_target_shape = (
744+
self.objects.shape[:num_batch_dims]
745+
+ (1,)
746+
+ self.objects.shape[num_batch_dims + 1 :]
747+
)
748+
gather_idx_objects = jnp.broadcast_to(gather_idx_objects, objects_target_shape)
749+
750+
fused_objects = jnp.take_along_axis(
751+
self.objects, gather_idx_objects, axis=num_batch_dims
752+
).squeeze(axis=num_batch_dims)
753+
754+
# For masks: [..., num_planes, num_rays, path_length-1]
755+
gather_idx_masks = jnp.expand_dims(
756+
first_plane_idx, axis=-2
757+
) # [..., 1, num_rays]
758+
gather_idx_masks = jnp.expand_dims(
759+
gather_idx_masks, axis=-1
760+
) # [..., 1, num_rays, 1]
761+
masks_target_shape = (
762+
self.masks.shape[:num_batch_dims]
763+
+ (1,)
764+
+ self.masks.shape[num_batch_dims + 1 :]
765+
)
766+
gather_idx_masks = jnp.broadcast_to(gather_idx_masks, masks_target_shape)
767+
768+
fused_masks = jnp.take_along_axis(
769+
self.masks, gather_idx_masks, axis=num_batch_dims
770+
).squeeze(axis=num_batch_dims)
771+
772+
# Set invalid rays (no interception) to have False masks
773+
fused_masks = jnp.where(has_interception[..., None], fused_masks, False)
774+
775+
return SBRPaths(
776+
vertices=fused_vertices,
777+
objects=fused_objects,
778+
masks=fused_masks,
779+
)

0 commit comments

Comments
 (0)