@@ -638,3 +638,142 @@ def plot(self, **kwargs: Any) -> PlotOutput:
638638 self .get_paths (order ).plot ()
639639
640640 return output
641+
642+ def fuse_planes (self ) -> Self :
643+ """
644+ Fuse multiple interception planes into a single plane dimension.
645+
646+ This method reduces the size of vertices from ``[..., num_planes, num_rays, ...]``
647+ to ``[..., num_rays, ...]`` by keeping, for each ray, the first interception
648+ vertex where the mask is True across all planes.
649+
650+ This method assumes that:
651+
652+ 1. The ``SBRPaths`` was generated using the ``plane`` interception method.
653+ 2. The planes are non-overlapping (i.e., a single ray can only hit one plane
654+ at maximum).
655+
656+ Returns:
657+ A new ``SBRPaths`` instance with fused plane dimensions.
658+
659+ Raises:
660+ ValueError: If the paths do not have a planes dimension (i.e., if the
661+ shape does not include a ``num_planes`` dimension after the batch
662+ dimensions).
663+
664+ Examples:
665+ >>> # Assuming paths have shape [num_tx, num_planes, num_rays, path_length, 3]
666+ >>> fused_paths = paths.fuse_planes()
667+ >>> # Result has shape [num_tx, num_rays, path_length, 3]
668+ """
669+ # The vertices shape is expected to be [..., num_planes, num_rays, path_length, 3]
670+ # We need to identify which dimension is num_planes
671+ # Based on the implementation, for plane-based SBR:
672+ # - tx_batch dimensions first
673+ # - then num_planes
674+ # - then num_rays
675+ # - then path_length
676+ # - then 3 (coordinates)
677+
678+ # Get the shape
679+ vertices_shape = self .vertices .shape
680+ if len (vertices_shape ) < 4 :
681+ msg = (
682+ "Cannot fuse planes: vertices shape must have at least 4 dimensions "
683+ f"[..., num_planes, num_rays, path_length, 3], but got shape { vertices_shape } "
684+ )
685+ raise ValueError (msg )
686+
687+ # The masks shape is [..., num_planes, num_rays, path_length-1]
688+ # We want to find, for each ray, the first plane where ANY mask is True
689+
690+ # Find the first plane index where mask is True for each ray
691+ # Shape: [..., num_planes, num_rays]
692+ # Check if any of the path orders have a True mask
693+ any_mask_true = jnp .any (self .masks , axis = - 1 )
694+
695+ # For each ray, find the first plane with a True mask
696+ # We'll use argmax which returns the first True index (or 0 if all False)
697+ # Shape: [..., num_rays]
698+ first_plane_idx = jnp .argmax (any_mask_true , axis = - 2 )
699+
700+ # Check if there's actually a valid interception for each ray
701+ # Shape: [..., num_rays]
702+ has_interception = jnp .any (any_mask_true , axis = - 2 )
703+
704+ # Now we need to gather the vertices and objects from the selected plane
705+ # Create indices for gathering along the num_planes axis
706+
707+ # For vertices: [..., num_planes, num_rays, path_length, 3]
708+ # first_plane_idx: [..., num_rays]
709+ # We need to expand to: [..., 1, num_rays, 1, 1] then broadcast
710+
711+ # Get number of batch dimensions
712+ num_batch_dims = len (vertices_shape ) - 4
713+
714+ # Expand first_plane_idx by adding axis for num_planes, path_length, and coords
715+ # Start: [..., num_rays]
716+ # Add axis at -3 for num_planes (becomes [..., 1, num_rays])
717+ gather_idx_vertices = jnp .expand_dims (first_plane_idx , axis = - 2 )
718+ # Add axis at -1 for path_length (becomes [..., 1, num_rays, 1])
719+ gather_idx_vertices = jnp .expand_dims (gather_idx_vertices , axis = - 1 )
720+ # Add axis at -1 for coords (becomes [..., 1, num_rays, 1, 1])
721+ gather_idx_vertices = jnp .expand_dims (gather_idx_vertices , axis = - 1 )
722+
723+ # Broadcast to match vertices shape
724+ target_shape = (
725+ vertices_shape [:num_batch_dims ]
726+ + (1 ,)
727+ + vertices_shape [num_batch_dims + 1 :]
728+ )
729+ gather_idx_vertices = jnp .broadcast_to (gather_idx_vertices , target_shape )
730+
731+ # Use take_along_axis for the num_planes dimension (axis=num_batch_dims)
732+ fused_vertices = jnp .take_along_axis (
733+ self .vertices , gather_idx_vertices , axis = num_batch_dims
734+ ).squeeze (axis = num_batch_dims )
735+
736+ # Similarly for objects: [..., num_planes, num_rays, path_length]
737+ gather_idx_objects = jnp .expand_dims (
738+ first_plane_idx , axis = - 2
739+ ) # [..., 1, num_rays]
740+ gather_idx_objects = jnp .expand_dims (
741+ gather_idx_objects , axis = - 1
742+ ) # [..., 1, num_rays, 1]
743+ objects_target_shape = (
744+ self .objects .shape [:num_batch_dims ]
745+ + (1 ,)
746+ + self .objects .shape [num_batch_dims + 1 :]
747+ )
748+ gather_idx_objects = jnp .broadcast_to (gather_idx_objects , objects_target_shape )
749+
750+ fused_objects = jnp .take_along_axis (
751+ self .objects , gather_idx_objects , axis = num_batch_dims
752+ ).squeeze (axis = num_batch_dims )
753+
754+ # For masks: [..., num_planes, num_rays, path_length-1]
755+ gather_idx_masks = jnp .expand_dims (
756+ first_plane_idx , axis = - 2
757+ ) # [..., 1, num_rays]
758+ gather_idx_masks = jnp .expand_dims (
759+ gather_idx_masks , axis = - 1
760+ ) # [..., 1, num_rays, 1]
761+ masks_target_shape = (
762+ self .masks .shape [:num_batch_dims ]
763+ + (1 ,)
764+ + self .masks .shape [num_batch_dims + 1 :]
765+ )
766+ gather_idx_masks = jnp .broadcast_to (gather_idx_masks , masks_target_shape )
767+
768+ fused_masks = jnp .take_along_axis (
769+ self .masks , gather_idx_masks , axis = num_batch_dims
770+ ).squeeze (axis = num_batch_dims )
771+
772+ # Set invalid rays (no interception) to have False masks
773+ fused_masks = jnp .where (has_interception [..., None ], fused_masks , False )
774+
775+ return SBRPaths (
776+ vertices = fused_vertices ,
777+ objects = fused_objects ,
778+ masks = fused_masks ,
779+ )
0 commit comments