Skip to content

Commit 1d034c6

Browse files
committed
MPL: KnownElementsList.plot_survey()
Example: ```py import matplotlib.pyplot as plt sim.lattice.plot_survey(ref) plt.show() ```
1 parent 70cb173 commit 1d034c6

File tree

9 files changed

+261
-13
lines changed

9 files changed

+261
-13
lines changed

docs/source/usage/python.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,21 @@ This module provides elements and methods for the accelerator lattice.
623623
:param madx_file: file name to MAD-X file with beamline elements
624624
:param nslice: number of slices used for the application of space charge
625625

626+
.. py:method:: plot_survey(ref=None, ax=None, legend=True, legend_ncols=5)
627+
628+
Plot over s of all elements in the KnownElementsList.
629+
630+
The signs of element strengths are determined by the sign of the charge of the reference particle.
631+
The projection of all element strengths is s-x ("vertical").
632+
633+
Either populates the matplotlib axes in ax or creates a new axes containing the plot.
634+
635+
:param self: The KnownElementsList class in ImpactX
636+
:param ref: A reference particle, checked for the charge sign to plot focusing/defocusing strength directions properly.
637+
:param ax: A plotting area in matplotlib (called axes there).
638+
:param legend: Plot a legend if true.
639+
:param legend_ncols: Number of columns for lattice element types in the legend.
640+
626641
.. py:class:: impactx.elements.CFbend(ds, rc, k, dx=0, dy=0, rotation=0, aperture_x=0, aperture_y=0, nslice=1, name=None)
627642
628643
A combined function bending magnet. This is an ideal Sbend with a normal quadrupole field component.

src/elements/Sbend.H

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,17 @@ namespace impactx::elements
7171
{
7272
}
7373

74+
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
75+
amrex::ParticleReal
76+
rc ([[maybe_unused]] RefPart const & refpart) const
77+
{
78+
using namespace amrex::literals; // for _rt and _prt
79+
80+
// TODO: as in ExactSbend
81+
// return m_B != 0_prt ? refpart.rigidity_Tm() / m_B : m_ds / m_phi;
82+
return m_rc;
83+
}
84+
7485
/** Push all particles */
7586
using BeamOptic::operator();
7687

src/python/elements.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,10 @@ void init_elements(py::module& m)
11701170
py::arg("name") = py::none(),
11711171
"An ideal sector bend using the exact nonlinear map. When B = 0, the reference bending radius is defined by r0 = length / (angle in rad), corresponding to a magnetic field of B = rigidity / r0; otherwise the reference bending radius is defined by r0 = rigidity / B."
11721172
)
1173+
.def("rc", &ExactSbend::rc,
1174+
py::arg("ref"),
1175+
"Radius of curvature in m"
1176+
)
11731177
.def_property("phi",
11741178
[](ExactSbend & exact_sbend) { return exact_sbend.m_phi; },
11751179
[](ExactSbend & exact_sbend, amrex::ParticleReal phi) { exact_sbend.m_phi = phi; },
@@ -1659,9 +1663,8 @@ void init_elements(py::module& m)
16591663
py::arg("name") = py::none(),
16601664
"An ideal sector bend."
16611665
)
1662-
.def_property("rc",
1663-
[](Sbend & sbend) { return sbend.m_rc; },
1664-
[](Sbend & sbend, amrex::ParticleReal rc) { sbend.m_rc = rc; },
1666+
.def("rc", &Sbend::rc,
1667+
py::arg("ref") = py::none(),
16651668
"Radius of curvature in m"
16661669
)
16671670
;

src/python/impactx/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,27 @@
1818

1919
# import core bindings to C++
2020
from . import impactx_pybind as cxx
21-
from .distribution_input_helpers import twiss # noqa
22-
from .extensions.ImpactXParticleContainer import (
23-
register_ImpactXParticleContainer_extension,
24-
)
2521
from .impactx_pybind import * # noqa
26-
from .madx_to_impactx import read_beam, read_lattice # noqa
22+
from .madx_to_impactx import read_beam # noqa
2723

2824
__version__ = cxx.__version__
2925
__doc__ = cxx.__doc__
3026
__license__ = cxx.__license__
3127
__author__ = cxx.__author__
3228

29+
from .distribution_input_helpers import twiss # noqa
30+
from .extensions.KnownElementsList import (
31+
register_KnownElementsList_extension,
32+
)
33+
from .extensions.ImpactXParticleContainer import (
34+
register_ImpactXParticleContainer_extension,
35+
)
36+
3337
# at this place we can enhance Python classes with additional methods written
3438
# in pure Python or add some other Python logic
3539

3640
# MAD-X file reader for beamline lattice elements
37-
elements.KnownElementsList.load_file = lambda self, madx_file, nslice=1: self.extend(
38-
read_lattice(madx_file, nslice)
39-
) # noqa
41+
register_KnownElementsList_extension(cxx.elements.KnownElementsList)
4042

4143
# MAD-X file reader for reference particle
4244
RefPart.load_file = read_beam # noqa
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
This file is part of ImpactX
3+
4+
Copyright 2025 ImpactX contributors
5+
Authors: Axel Huebl
6+
License: BSD-3-Clause-LBNL
7+
"""
8+
9+
10+
def register_KnownElementsList_extension(kel):
11+
"""KnownElementsList helper methods"""
12+
from ..madx_to_impactx import read_lattice
13+
from ..plot.Survey import plot_survey
14+
15+
# register member functions for KnownElementsList
16+
kel.load_file = lambda self, madx_file, nslice=1: self.extend(
17+
read_lattice(madx_file, nslice)
18+
)
19+
kel.plot_survey = plot_survey
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
This file is part of ImpactX
3+
4+
Copyright 2025 ImpactX contributors
5+
Authors: Axel Huebl
6+
License: BSD-3-Clause-LBNL
7+
"""
8+
9+
10+
def get_element_color_palette(palette="cern-lhc", plot_library="mpl"):
11+
"""Return a dictionary with colors for all elements.
12+
13+
The key is a regex that can be matched against the element type string. TODO TODO
14+
"""
15+
color_palette = {
16+
"cern-lhc": {
17+
"Quad": "tab:blue",
18+
"Multipole": "tab:orange",
19+
"Sbend": "tab:green",
20+
"CFbend": "tab:olive", # TODO: improve and plot as two on top of each other
21+
"ConstF": "tab:red",
22+
"ChrPlasmaLens": "tab:red",
23+
"SoftSolenoid": "tab:red",
24+
"TaperedPL": "tab:red",
25+
"RFCavity": "tab:brown",
26+
"ShortRF": "tab:brown",
27+
"Buncher": "tab:purple",
28+
"Aperture": "black",
29+
"Kicker": "tab:pink",
30+
# 'tab:cyan'
31+
"other": "tab:gray",
32+
}
33+
}
34+
35+
colors = color_palette[palette]
36+
37+
if plot_library != "mpl":
38+
# remove "tab:" prefix
39+
for k, v in colors.items():
40+
colors[k] = v[4:]
41+
42+
return colors
43+
44+
45+
def get_element_color(element_kind: str, palette="cern-lhc", plot_library="mpl"):
46+
"""Get the color for a given element type string."""
47+
color_palette = get_element_color_palette(palette, plot_library)
48+
49+
# sub-string matching of keys
50+
found_keys = [key for key in color_palette.keys() if key in element_kind]
51+
52+
if found_keys:
53+
first_found = found_keys[0]
54+
return color_palette[first_found]
55+
else:
56+
return color_palette["other"]

src/python/impactx/plot/Survey.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
This file is part of ImpactX
3+
4+
Copyright 2025 ImpactX contributors
5+
Authors: Axel Huebl
6+
License: BSD-3-Clause-LBNL
7+
"""
8+
9+
10+
def plot_survey(
11+
self, ref=None, ax=None, legend=True, legend_ncols=5, palette="cern-lhc"
12+
):
13+
"""Plot over s of all elements in the KnownElementsList.
14+
15+
The signs of element strengths are determined by the sign of the charge of the reference particle.
16+
The projection of all element strengths is s-x ("vertical").
17+
18+
Parameters
19+
----------
20+
self : ImpactXParticleContainer_*
21+
The KnownElementsList class in ImpactX
22+
ref : RefPart
23+
A reference particle, checked for the charge sign to plot focusing/defocusing strength directions properly.
24+
ax : matplotlib axes
25+
A plotting area in matplotlib (called axes there).
26+
legend: bool
27+
Plot a legend if true.
28+
legend_ncols: int
29+
Number of columns for lattice element types in the legend.
30+
palette: string
31+
Color palette.
32+
33+
Returns
34+
-------
35+
Either populates the matplotlib axes in ax or creates a new axes containing the plot.
36+
"""
37+
from math import copysign
38+
39+
import matplotlib.pyplot as plt
40+
import numpy as np
41+
from matplotlib.patches import Rectangle
42+
43+
from .ElementColors import get_element_color
44+
45+
charge_qe = 1.0 if ref is None else ref.charge_qe
46+
47+
ax = ax or plt.subplot(111)
48+
49+
element_lengths = [element.ds for element in self]
50+
51+
# NumPy 2.1+ (i.e. Python 3.10+):
52+
# element_s = np.cumulative_sum(element_lengths, include_initial=True)
53+
# backport:
54+
element_s = np.insert(np.cumsum(element_lengths), 0, 0)
55+
56+
ax.hlines(0, 0, element_s[-1], color="black", linestyle="--")
57+
58+
# plot config
59+
skip_names = [
60+
"Drift",
61+
"ChrDrift",
62+
"ExactDrift",
63+
"Empty",
64+
"Marker",
65+
"Source",
66+
]
67+
68+
handles = {}
69+
70+
for i, element in enumerate(self):
71+
el_dict = element.to_dict()
72+
el_type = el_dict["type"]
73+
if el_type in skip_names:
74+
continue
75+
76+
color = get_element_color(el_type, palette=palette)
77+
78+
y0 = 0 # default start in y for unspecified elements
79+
height = 0.5 # default height for unspecified elements
80+
81+
# note the sub-string matching for el_type
82+
if el_type == "BeamMonitor":
83+
y0 = -0.5
84+
height = 1.0
85+
if "Quad" in el_type:
86+
height = copysign(0.8, el_dict["k"] * charge_qe)
87+
if "Sbend" in el_type:
88+
if ref is None:
89+
height = copysign(0.8, element.rc(ref))
90+
else: # guess
91+
height = copysign(0.8, el_dict["phi"])
92+
# TODO: sign dependent, read m_p_scale
93+
# if el_type == "Kicker":
94+
# height = copysign(0.8, el_dict["xkick"])
95+
96+
# plot thin elements on top of thick elements
97+
zorder = 2
98+
if element.ds == 0:
99+
zorder = 3
100+
101+
patch = Rectangle(
102+
(element_s[i], y0),
103+
element_lengths[i],
104+
height,
105+
color=color,
106+
alpha=0.8,
107+
zorder=zorder,
108+
)
109+
ax.add_patch(patch)
110+
111+
handles[el_type] = patch
112+
113+
if legend:
114+
labels = list(handles.keys())
115+
values = list(handles.values())
116+
ax.legend(
117+
handles=values,
118+
labels=labels,
119+
bbox_to_anchor=(0.0, 1.02, 1.0, 0.102),
120+
loc="lower left",
121+
ncols=legend_ncols,
122+
mode="expand",
123+
borderaxespad=0.0,
124+
)
125+
126+
ax.set_xlabel(r"$s$ [m]")
127+
128+
ax.set_ylim(-1, 1)
129+
ax.set_yticks([])
130+
131+
ax.set_aspect(1 / 1.618) # golden ratio
132+
133+
return ax

src/python/impactx/plot/__init__.py

Whitespace-only changes.

tests/python/test_dataframe.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ def test_df_pandas(save_png=True):
7070
]
7171
sim.lattice.extend(fodo)
7272

73+
# plot lattice survey
74+
if amr.ParallelDescriptor.IOProcessor():
75+
sim.lattice.plot_survey(ref=ref)
76+
if save_png:
77+
plt.gcf().savefig("lattice_survey.png")
78+
plt.close(plt.gcf())
79+
else:
80+
plt.show()
81+
7382
# simulate
7483
sim.track_particles()
7584

@@ -80,7 +89,8 @@ def test_df_pandas(save_png=True):
8089
print(beam_moments)
8190
plt.plot(beam_moments.s, beam_moments.beta_x)
8291
if save_png:
83-
plt.savefig("beam_moments.png")
92+
plt.gcf().savefig("beam_moments.png")
93+
plt.close(plt.gcf())
8494
else:
8595
plt.show()
8696

@@ -113,7 +123,6 @@ def test_df_pandas(save_png=True):
113123

114124
# note: figure data available on MPI rank zero
115125
if fig is not None:
116-
fig.savefig("phase_space.png")
117126
if save_png:
118127
fig.savefig("phase_space.png")
119128
else:

0 commit comments

Comments
 (0)