diff --git a/docs/source/usage/python.rst b/docs/source/usage/python.rst index 70c892071..d66f036a9 100644 --- a/docs/source/usage/python.rst +++ b/docs/source/usage/python.rst @@ -645,16 +645,17 @@ This module provides elements and methods for the accelerator lattice. :param pals_line: PALS Python Line with beamline elements :param nslice: number of slices used for the application of collective effects - .. py:method:: select(kind=None, name=None) + .. py:method:: select(kind=None, name=None, s=None) - Filter elements by type and/or name. - If both are provided, OR-based logic is applied. + Filter elements by type, name, and/or integrated position. + If multiple criteria are provided, OR-based logic is applied. Returns references to original elements, allowing modification and chaining. Chained ``.select(...).select(...)`` selections are AND-filtered. :param kind: Element type(s) to filter by. Can be a string (e.g., ``"Drift"``), regex pattern (e.g., ``r".*Quad"``), element type (e.g., ``elements.Drift``), or list/tuple of these. :param name: Element name(s) to filter by. Can be a string, regex pattern, or ``list``/``tuple`` of these. + :param s: Position range to filter by. Tuple/list with (lower, upper) bounds where None represents open-ended. Elements are selected if ANY part overlaps with the range. Examples: ``(1.0, 5.0)`` for range 1.0 <= s <= 5.0, ``(1.0, None)`` for s >= 1.0, ``(None, 5.0)`` for s <= 5.0. **Examples:** @@ -673,6 +674,12 @@ This module provides elements and methods for the accelerator lattice. # Chain filters (AND logic) drift_named_d1 = lattice.select(kind="Drift").select(name="drift1") + # Position filtering (overlap logic) + early_elements = lattice.select(s=(None, 2.0)) # Elements overlapping s <= 2.0 + + # Chaining: s always calculated from original lattice + drift_then_early = lattice.select(kind="Drift").select(s=(1.0, 3.0)) # Drift AND overlapping s=[1.0,3.0] + # Modify original elements through references drift_elements[0].ds = 2.0 # modifies original lattice diff --git a/src/python/impactx/Kahan.py b/src/python/impactx/Kahan.py new file mode 100644 index 000000000..e2d00b113 --- /dev/null +++ b/src/python/impactx/Kahan.py @@ -0,0 +1,97 @@ +def _kahan_babushka_core(values, return_cumulative=False): + """Core implementation of the second-order iterative Kahan-Babuska algorithm. + + This is the unified core that implements Klein (2006) algorithm for both + regular summation and cumulative summation. + - https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Further_enhancements + - Klein (2006). "A generalized Kahan–Babuška-Summation-Algorithm". in + Computing. 76 (3–4). Springer-Verlag: 279–293. doi:10.1007/s00607-005-0139-x + + Args: + values: Iterable of numeric values to sum + return_cumulative: If True, returns list of cumulative sums; if False, returns final sum + + Returns: + float or list: Final sum if return_cumulative=False, list of cumulative sums if True + """ + sum_val = 0.0 + cs = 0.0 # first-order compensation for lost low-order bits + ccs = 0.0 # second-order compensation for further lost bits + c = 0.0 # temporary variable for first-order compensation + cc = 0.0 # temporary variable for second-order compensation + + if return_cumulative: + cumulative_sums = [0.0] # Start with 0.0 + + for val in values: + # First-order Kahan-Babuška step + t = sum_val + val + if abs(sum_val) >= abs(val): + c = (sum_val - t) + val + else: + c = (val - t) + sum_val + sum_val = t + + # Second-order compensation step + t = cs + c + if abs(cs) >= abs(c): + cc = (cs - t) + c + else: + cc = (c - t) + cs + cs = t + ccs += cc + + if return_cumulative: + # Store the accurate cumulative sum + cumulative_sums.append(sum_val + cs + ccs) + + if return_cumulative: + return cumulative_sums + else: + return sum_val + cs + ccs + + +def kahan_babushka_sum(values): + """Calculate an accurate sum using the second-order iterative Kahan-Babuška algorithm. + + This implementation follows Klein (2006) to provide high-precision summation + that avoids floating-point precision errors when summing many small values. + - https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Further_enhancements + - Klein (2006). "A generalized Kahan–Babuška-Summation-Algorithm". in + Computing. 76 (3–4). Springer-Verlag: 279–293. doi:10.1007/s00607-005-0139-x + + The algorithm uses second-order compensation for lost low-order bits during + floating-point addition, providing significantly better accuracy than naive + summation when dealing with large numbers of small values (e.g., many ds + values in a long lattice). + + Args: + values: Iterable of numeric values to sum + + Returns: + float: Accurate sum of all values + """ + return _kahan_babushka_core(values, return_cumulative=False) + + +def kahan_babushka_cumsum(values): + """Calculate an accurate cumulative sum using the second-order iterative Kahan-Babuska algorithm. + + This implementation follows Klein (2006) to provide high-precision summation + that avoids floating-point precision errors when summing many small values. + - https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Further_enhancements + - Klein (2006). "A generalized Kahan–Babuška-Summation-Algorithm". in + Computing. 76 (3–4). Springer-Verlag: 279–293. doi:10.1007/s00607-005-0139-x + + The algorithm uses second-order compensation for lost low-order bits during + floating-point addition, providing significantly better accuracy than naive + summation when dealing with large numbers of small values (e.g., many ds + values in a long lattice). + + Args: + values: Iterable of numeric values to cumulatively sum + + Returns: + list: List of cumulative sums with initial 0.0 prepended + """ + return _kahan_babushka_core(values, return_cumulative=True) diff --git a/src/python/impactx/__init__.py b/src/python/impactx/__init__.py index d5b5ae647..ab9192c2a 100644 --- a/src/python/impactx/__init__.py +++ b/src/python/impactx/__init__.py @@ -19,6 +19,7 @@ # import core bindings to C++ from . import impactx_pybind from .impactx_pybind import * # noqa +from .Kahan import kahan_babushka_cumsum, kahan_babushka_sum # noqa from .madx_to_impactx import read_beam # noqa __version__ = impactx_pybind.__version__ diff --git a/src/python/impactx/extensions/KnownElementsList.py b/src/python/impactx/extensions/KnownElementsList.py index abaae6b47..0d1b82cae 100644 --- a/src/python/impactx/extensions/KnownElementsList.py +++ b/src/python/impactx/extensions/KnownElementsList.py @@ -9,7 +9,7 @@ import os import re -from impactx import elements +from impactx import elements, kahan_babushka_sum def load_file(self, filename, nslice=1): @@ -124,6 +124,7 @@ def select( *, kind=None, name=None, + s=None, ): """Apply filtering to this filtered list. @@ -146,10 +147,14 @@ def select( Examples: "quad1", r"quad\d+", ["quad1", "quad2"], [r"quad\d+", "bend1"] :type name: str or list[str] or tuple[str, ...] or None, optional + :param s: Position range to filter by. Tuple/list with (lower, upper) bounds where None represents open-ended. + Elements are selected if ANY part overlaps with the range. Examples: (1.0, 5.0) for range 1.0 <= s <= 5.0, (1.0, None) for s >= 1.0, (None, 5.0) for s <= 5.0 + :type s: tuple[float | None, float | None] or list[float | None] or None, optional + :return: FilteredElementsList containing references to original elements :rtype: FilteredElementsList - :raises TypeError: If kind/name parameters have wrong types + :raises TypeError: If kind/name/s parameters have wrong types **Examples:** @@ -169,18 +174,26 @@ def select( strong_quads = quad_elements.select( name=r"quad\d+" ) # Filter quads by regex pattern + + # Position-based filtering (always calculated from original lattice) + early_elements = lattice.select(s=(None, 2.0)) # Elements with s <= 2.0 + drift_then_early = lattice.select(kind="Drift").select( + s=(1.0, None) + ) # Drift elements with s >= 1.0 """ # Apply filtering directly to the indices we already have - if kind is not None or name is not None: + if kind is not None or name is not None or s is not None: # Validate parameters - _validate_select_parameters(kind, name) + _validate_select_parameters(kind, name, s) matching_indices = [] - for i in self._indices: - element = self._original_list[i] - if _check_element_match(element, kind, name): - matching_indices.append(i) + for original_idx in self._indices: + element = self._original_list[original_idx] + if _check_element_match( + element, kind, name, s, original_idx, self._original_list + ): + matching_indices.append(original_idx) return FilteredElementsList(self._original_list, matching_indices) @@ -249,12 +262,13 @@ def _matches_string(text: str, string_pattern: str) -> bool: return text == string_pattern -def _validate_select_parameters(kind, name): +def _validate_select_parameters(kind, name, s): """Validate parameters for select methods. Args: kind: Element type(s) to filter by name: Element name(s) to filter by + s: Position range to filter by Raises: TypeError: If parameters have wrong types @@ -279,6 +293,20 @@ def _validate_select_parameters(kind, name): "'name' parameter must be a string or list/tuple of strings" ) + if s is not None: + if isinstance(s, (list, tuple)): + if len(s) != 2: + raise TypeError( + "'s' parameter must have exactly 2 elements (lower, upper)" + ) + for bound in s: + if bound is not None and not isinstance(bound, (int, float)): + raise TypeError("'s' parameter bounds must be numbers or None") + else: + raise TypeError( + "'s' parameter must be a tuple/list with 2 elements (lower, upper)" + ) + def _matches_kind_pattern(element, kind_pattern): """Check if an element matches a kind pattern. @@ -316,13 +344,57 @@ def _matches_name_pattern(element, name_pattern): ) -def _check_element_match(element, kind, name): - """Check if an element matches the given kind and name criteria. +def _matches_s_position(element, s_range, element_s): + """Check if an element's integrated position matches the s range criteria. + + An element matches if ANY part of it overlaps with the specified range. + The element spans from element_s to element_s + element.ds. + + Args: + element: The element to check + s_range: Tuple/list of (lower, upper) bounds. None represents open-ended. + element_s: The cumulative position of the element (calculated externally) + + Returns: + bool: True if element's position overlaps with the range + """ + if s_range is None: + return True + + # Convert to tuple if it's a list + if isinstance(s_range, list): + s_range = tuple(s_range) + + if not isinstance(s_range, tuple) or len(s_range) != 2: + raise TypeError( + "'s' parameter must be a tuple/list with 2 elements (lower, upper)" + ) + + lower, upper = s_range + + # Element spans from element_s to element_s + element.ds + element_start = element_s + element_end = element_s + element.ds + + # Check if any part of the element overlaps with the range + if lower is not None and element_end < lower: + return False + if upper is not None and element_start > upper: + return False + + return True + + +def _check_element_match(element, kind, name, s, element_index, lattice): + """Check if an element matches the given kind, name, and s criteria. Args: element: The element to check kind: Kind criteria (str, type, list, tuple, or None) name: Name criteria (str, list, tuple, or None) + s: Position criteria (tuple/list with 2 elements, or None) + element_index: Index of the element in the lattice + lattice: The full lattice to calculate cumulative positions Returns: bool: True if element matches any criteria (OR logic) @@ -355,6 +427,15 @@ def _check_element_match(element, kind, name): match = True break + # Check for 's' parameter (only if neither kind nor name matched - OR logic) + if s is not None and not match: + # Calculate cumulative position up to this element using accurate summation + ds_values = [lattice[i].ds for i in range(element_index)] + cumulative_s = kahan_babushka_sum(ds_values) + + if _matches_s_position(element, s, cumulative_s): + match = True + return match @@ -363,16 +444,17 @@ def select( *, kind=None, name=None, + s=None, ) -> FilteredElementsList: - """Filter elements by type and name with OR-based logic. + """Filter elements by type, name, and position with OR-based logic. - This method supports filtering elements by their type and/or name using keyword arguments. + This method supports filtering elements by their type, name, and/or integrated position using keyword arguments. Returns references to original elements, allowing modification and chaining. **Filtering Logic:** - **Within a single filter**: OR logic (e.g., ``kind=["Drift", "Quad"]`` matches Drift OR Quad) - - **Between different filters**: OR logic (e.g., ``kind="Quad", name="quad1"`` matches Quad OR named "quad1") + - **Between different filters**: OR logic (e.g., ``kind="Quad", name="quad1", s=(1.0, 5.0)`` matches Quad OR named "quad1" OR in position range) - **Chaining filters**: AND logic (e.g., ``lattice.select(kind="Drift").select(name="drift1")`` matches Drift AND named "drift1") :param kind: Element type(s) to filter by. Can be a single string/type or a list/tuple @@ -385,10 +467,14 @@ def select( Examples: "quad1", r"quad\d+", ["quad1", "quad2"], [r"quad\d+", "bend1"] :type name: str or list[str] or tuple[str, ...] or None, optional + :param s: Position range to filter by. Tuple/list with (lower, upper) bounds where None represents open-ended. + Elements are selected if ANY part overlaps with the range. Examples: (1.0, 5.0) for range 1.0 <= s <= 5.0, (1.0, None) for s >= 1.0, (None, 5.0) for s <= 5.0 + :type s: tuple[float | None, float | None] or list[float | None] or None, optional + :return: FilteredElementsList containing references to original elements :rtype: FilteredElementsList - :raises TypeError: If kind/name parameters have wrong types + :raises TypeError: If kind/name/s parameters have wrong types **Examples:** @@ -423,6 +509,15 @@ def select( lattice.select(name=r"quad\d+") # Get elements matching pattern lattice.select(name=[r"quad\d+", "bend1"]) # Mix regex and strings + Position-based filtering: + + .. code-block:: python + + lattice.select(s=(1.0, 5.0)) # Elements that overlap with range 1.0 <= s <= 5.0 + lattice.select(s=(1.0, None)) # Elements that overlap with s >= 1.0 + lattice.select(s=(None, 5.0)) # Elements that overlap with s <= 5.0 + lattice.select(kind="Drift", s=(0.0, 2.0)) # Drift elements OR overlapping range 0.0 <= s <= 2.0 + Chaining filters (AND logic between chained calls): .. code-block:: python @@ -455,14 +550,14 @@ def select( """ # Handle keyword arguments for filtering - if kind is not None or name is not None: + if kind is not None or name is not None or s is not None: # Validate parameters - _validate_select_parameters(kind, name) + _validate_select_parameters(kind, name, s) matching_indices = [] for i, element in enumerate(self): - if _check_element_match(element, kind, name): + if _check_element_match(element, kind, name, s, i, self): matching_indices.append(i) return FilteredElementsList(self, matching_indices) diff --git a/src/python/impactx/plot/Survey.py b/src/python/impactx/plot/Survey.py index 234a4cf67..0b2451ae9 100644 --- a/src/python/impactx/plot/Survey.py +++ b/src/python/impactx/plot/Survey.py @@ -36,9 +36,9 @@ def plot_survey( from math import copysign import matplotlib.pyplot as plt - import numpy as np from matplotlib.patches import Rectangle + from ..Kahan import kahan_babushka_cumsum from .ElementColors import get_element_color charge_qe = 1.0 if ref is None else ref.charge_qe @@ -47,10 +47,9 @@ def plot_survey( element_lengths = [element.ds for element in self] - # NumPy 2.1+ (i.e. Python 3.10+): - # element_s = np.cumulative_sum(element_lengths, include_initial=True) - # backport: - element_s = np.insert(np.cumsum(element_lengths), 0, 0) + # Use accurate cumulative sum to avoid floating-point precision errors + # when dealing with many small ds values in long lattices + element_s = kahan_babushka_cumsum(element_lengths) ax.hlines(0, 0, element_s[-1], color="black", linestyle="--") diff --git a/tests/python/test_lattice_select.py b/tests/python/test_lattice_select.py index 709457bf0..4ac2677bd 100644 --- a/tests/python/test_lattice_select.py +++ b/tests/python/test_lattice_select.py @@ -811,3 +811,94 @@ def test_select_no_arguments(): all_then_drift = lattice.select().select(kind="Drift") assert len(all_then_drift) == 2 assert [el.name for el in all_then_drift] == ["drift1", "drift2"] + + +def test_position_filtering(): + """Test position-based filtering with s parameter.""" + import impactx + from impactx import elements + + # Create a lattice with known positions + lattice = impactx.elements.KnownElementsList() + lattice.extend( + [ + elements.Drift(name="drift1", ds=1.0), + elements.Quad(name="quad1", ds=0.5, k=1.0), + elements.Drift(name="drift2", ds=2.0), + elements.Quad(name="quad2", ds=0.3, k=-1.0), + elements.Drift(name="drift3", ds=1.5), + ] + ) + + # Test range filtering + early_elements = lattice.select(s=(None, 2.0)) + assert len(early_elements) == 3 # drift1, quad1, drift2 + assert early_elements[0].name == "drift1" + assert early_elements[1].name == "quad1" + assert early_elements[2].name == "drift2" + + # Test range filtering with both bounds (overlap logic) + middle_elements = lattice.select(s=(1.0, 3.0)) + assert len(middle_elements) == 3 # drift1, quad1, drift2 (all overlap with range) + assert middle_elements[0].name == "drift1" # s=[0.0, 1.0] overlaps with [1.0, 3.0] + assert middle_elements[1].name == "quad1" # s=[1.0, 1.5] overlaps with [1.0, 3.0] + assert middle_elements[2].name == "drift2" # s=[1.5, 3.5] overlaps with [1.0, 3.0] + + # Test upper bound only + late_elements = lattice.select(s=(3.0, None)) + assert len(late_elements) == 3 # drift2, quad2, drift3 + assert late_elements[0].name == "drift2" + assert late_elements[1].name == "quad2" + assert late_elements[2].name == "drift3" + + # Test combined filtering (OR logic) + drift_or_early = lattice.select(kind="Drift", s=(None, 2.0)) + assert len(drift_or_early) == 4 # All drifts + early elements + drift_names = [el.name for el in drift_or_early if el.name.startswith("drift")] + assert len(drift_names) == 3 # All 3 drifts + assert "drift1" in drift_names + assert "drift2" in drift_names + assert "drift3" in drift_names + + # Test chaining (AND logic) + drift_then_early = lattice.select(kind="Drift").select(s=(None, 1.9)) + assert len(drift_then_early) == 2 # Only drift1 and drift2 + assert drift_then_early[0].name == "drift1" + assert drift_then_early[1].name == "drift2" + + # Test reference preservation through chaining + original_ds = lattice[1].ds + chained_elements = lattice.select(kind="Quad").select(s=(1.0, 1.9)) + chained_elements[0].ds = 0.8 # Modify through chained filter + assert lattice[1].ds == 0.8 # Original element modified + lattice[1].ds = original_ds # Reset + + # Test with list syntax + range_list = lattice.select(s=[1.0, 2.9]) + assert len(range_list) == 3 + assert range_list[0].name == "drift1" + + # Test explicit overlap behavior + # Element at s=[0.0, 1.0] should be included in range [0.5, 1.5] because it overlaps + # Element at s=[1.0, 1.5] should be included in range [0.5, 1.5] because it overlaps + # Element at s=[1.5, 3.5] should NOT be included because it doesn't overlap with [0.5, 1.5] + overlap_test = lattice.select(s=(0.5, 1.4)) + assert len(overlap_test) == 2 # drift1 and quad1 both overlap + assert overlap_test[0].name == "drift1" # s=[0.0, 1.0] overlaps with [0.5, 1.5] + assert overlap_test[1].name == "quad1" # s=[1.0, 1.5] overlaps with [0.5, 1.5] + + # Test non-overlap case + # Range [0.2, 0.8] should only include drift1 (s=[0.0, 1.0]) + narrow_test = lattice.select(s=(0.2, 0.8)) + assert len(narrow_test) == 1 + assert narrow_test[0].name == "drift1" # Only drift1 overlaps with [0.2, 0.8] + + # Test error handling + with pytest.raises(TypeError, match="must have exactly 2 elements"): + lattice.select(s=(1.0,)) # Wrong number of elements + + with pytest.raises(TypeError, match="must be a tuple/list with 2 elements"): + lattice.select(s="invalid") # Wrong type + + with pytest.raises(TypeError, match="bounds must be numbers or None"): + lattice.select(s=(1.0, "invalid")) # Wrong bound type