Skip to content

Commit 4bc3c7b

Browse files
committed
feat: add interpolation to EME (FXC-4152)
1 parent 1bd30a2 commit 4bc3c7b

File tree

16 files changed

+117
-69
lines changed

16 files changed

+117
-69
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3838
- Added `interp_spec` in `ModeSpec` to allow downsampling and interpolation of waveguide modes in frequency.
3939
- Added warning if port mesh refinement is incompatible with the `GridSpec` in the `TerminalComponentModeler`.
4040
- Various types, e.g. different `Simulation` or `SimulationData` sub-classes, can be loaded from file directly with `Tidy3dBaseModel.from_file()`.
41+
- Added `interp_spec` in `EMEModeSpec` to enable faster multi-frequency EME simulations. Note that the default is now `ModeInterpSpec.cheb(num_points=3, reduce_data=True)`; previously the computation was repeated at all frequencies.
4142

4243
### Breaking Changes
4344
- Edge singularity correction at PEC and lossy metal edges defaults to `True`.
@@ -64,6 +65,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6465
- Simulation data of batch jobs are now automatically downloaded upon their individual completion in `Batch.run()`, avoiding waiting for the entire batch to reach completion.
6566
- Port names in `ModalComponentModeler` and `TerminalComponentModeler` can no longer include the `@` symbol.
6667
- Improved speed of convolutions for large inputs.
68+
- Default value of `EMEModeSpec.interp_spec` is `ModeInterpSpec.cheb(num_points=3, reduce_data=True)` for faster multi-frequency EME simulations.
6769

6870
### Fixed
6971
- Ensured the legacy `Env` proxy mirrors `config.web` profile switches and preserves API URL.

schemas/EMESimulation.json

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5294,7 +5294,18 @@
52945294
{
52955295
"$ref": "#/definitions/ModeInterpSpec"
52965296
}
5297-
]
5297+
],
5298+
"default": {
5299+
"attrs": {},
5300+
"method": "poly",
5301+
"reduce_data": true,
5302+
"sampling_spec": {
5303+
"attrs": {},
5304+
"num_points": 3,
5305+
"type": "ChebSampling"
5306+
},
5307+
"type": "ModeInterpSpec"
5308+
}
52985309
},
52995310
"num_modes": {
53005311
"default": 1,

tests/sims/full_fdtd.h5

40 Bytes
Binary file not shown.

tests/sims/full_fdtd.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,6 +2075,7 @@
20752075
"track_freq": "central",
20762076
"type": "ModeSortSpec"
20772077
},
2078+
"interp_spec": null,
20782079
"type": "ModeSpec"
20792080
},
20802081
"mode_index": 0,
@@ -2666,6 +2667,7 @@
26662667
"track_freq": "central",
26672668
"type": "ModeSortSpec"
26682669
},
2670+
"interp_spec": null,
26692671
"type": "ModeSpec"
26702672
},
26712673
"store_fields_direction": null,

tests/test_components/test_eme.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,8 @@ def test_eme_simulation():
568568
sim = sim_no_field.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=[1, 2]))
569569
assert not sim._sweep_modes
570570
assert sim._num_sweep == 2
571-
sim = sim.updated_copy(sweep_spec=td.EMEFreqSweep(freq_scale_factors=[1, 2]))
571+
with AssertLogLevel("WARNING", contains_str="'EMEFreqSweep' is deprecated"):
572+
sim = sim.updated_copy(sweep_spec=td.EMEFreqSweep(freq_scale_factors=[1, 2]))
572573
assert sim._sweep_modes
573574
assert sim._num_sweep == 2
574575
assert sim._monitor_num_sweep(sim.monitors[0]) == 1
@@ -911,7 +912,9 @@ def _get_mode_solver_data(modes_out=False, num_modes=3):
911912
size=(td.inf, td.inf, 0),
912913
center=(0, 0, offset),
913914
freqs=[td.C_0],
914-
mode_spec=td.ModeSpec(num_modes=num_modes),
915+
mode_spec=td.ModeSpec(
916+
num_modes=num_modes, interp_spec=td.ModeInterpSpec.cheb(num_points=3, reduce_data=True)
917+
),
915918
name=name,
916919
)
917920
eme_mode_data = _get_eme_mode_solver_data()

tests/test_components/test_mode_interp.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,14 @@ def test_mode_solver_monitor_valid_with_tracking():
244244

245245

246246
def test_interp_num_points_less_than_freqs():
247-
"""Test that num_points must be less than total freqs."""
247+
"""Test that num_points can be greater than total freqs."""
248248
mode_spec = td.ModeSpec(
249249
num_modes=2,
250250
sort_spec=td.ModeSortSpec(track_freq="central"),
251251
interp_spec=td.ModeInterpSpec.uniform(num_points=25, method="linear"),
252252
)
253253

254-
with AssertLogLevel("WARNING", contains_str="num_points"):
254+
with AssertLogLevel(None):
255255
td.ModeSolverMonitor(
256256
center=(0, 0, 0),
257257
size=SIZE_2D,
@@ -262,14 +262,14 @@ def test_interp_num_points_less_than_freqs():
262262

263263

264264
def test_interp_num_points_equal_to_freqs():
265-
"""Test that num_points equal to freqs is rejected."""
265+
"""Test that num_points equal to freqs is not rejected."""
266266
mode_spec = td.ModeSpec(
267267
num_modes=2,
268268
sort_spec=td.ModeSortSpec(track_freq="central"),
269269
interp_spec=td.ModeInterpSpec.uniform(num_points=20, method="linear"),
270270
)
271271

272-
with AssertLogLevel("WARNING", contains_str="num_points"):
272+
with AssertLogLevel(None):
273273
td.ModeSolverMonitor(
274274
center=(0, 0, 0),
275275
size=SIZE_2D,
@@ -354,7 +354,7 @@ def test_mode_solver_valid_with_tracking():
354354

355355
@td.packaging.disable_local_subpixel
356356
def test_mode_solver_warns_num_points():
357-
"""Test that ModeSolver warns when num_points >= num_freqs."""
357+
"""Test that ModeSolver does not warn when num_points >= num_freqs."""
358358
sim = get_simple_sim()
359359
mode_spec = td.ModeSpec(
360360
num_modes=2,
@@ -363,14 +363,14 @@ def test_mode_solver_warns_num_points():
363363
)
364364
plane = td.Box(center=(0, 0, 0), size=SIZE_2D)
365365

366-
with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"):
366+
with AssertLogLevel(None):
367367
ms = ModeSolver(
368368
simulation=sim,
369369
plane=plane,
370370
freqs=FREQS_DENSE,
371371
mode_spec=mode_spec,
372372
)
373-
_ = ms.data_raw
373+
_ = ms.data_raw
374374

375375

376376
def test_mode_solver_interp_spec_none():
@@ -1041,15 +1041,15 @@ def test_mode_solver_monitor_with_interp_spec():
10411041

10421042

10431043
def test_mode_monitor_warns_redundant_num_points():
1044-
"""Test warning when num_points >= number of frequencies in ModeMonitor."""
1044+
"""Test no warning when num_points >= number of frequencies in ModeMonitor."""
10451045
freqs = np.linspace(1e14, 2e14, 5)
10461046
mode_spec = td.ModeSpec(
10471047
num_modes=2,
10481048
sort_spec=td.ModeSortSpec(track_freq="central"),
10491049
interp_spec=td.ModeInterpSpec.uniform(num_points=5, method="linear"),
10501050
)
10511051

1052-
with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"):
1052+
with AssertLogLevel(None):
10531053
td.ModeMonitor(
10541054
center=(0, 0, 0),
10551055
size=SIZE_2D,
@@ -1060,15 +1060,15 @@ def test_mode_monitor_warns_redundant_num_points():
10601060

10611061

10621062
def test_mode_solver_monitor_warns_redundant_num_points():
1063-
"""Test warning when num_points >= number of frequencies in ModeSolverMonitor."""
1063+
"""Test no warning when num_points >= number of frequencies in ModeSolverMonitor."""
10641064
freqs = np.linspace(1e14, 2e14, 5)
10651065
mode_spec = td.ModeSpec(
10661066
num_modes=2,
10671067
sort_spec=td.ModeSortSpec(track_freq="central"),
10681068
interp_spec=td.ModeInterpSpec.uniform(num_points=6, method="linear"),
10691069
)
10701070

1071-
with AssertLogLevel("WARNING", contains_str="Interpolation will be skipped"):
1071+
with AssertLogLevel(None):
10721072
td.ModeSolverMonitor(
10731073
center=(0, 0, 0),
10741074
size=SIZE_2D,

tidy3d/components/data/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ def _interp_dataarray_in_freq(
139139
DataArray
140140
Interpolated data array with the same structure but new frequency points.
141141
"""
142+
# if dataarray is already stored at the correct frequencies, do nothing
143+
if np.array_equal(freqs, data.f):
144+
return data
145+
142146
# Map 'poly' to xarray's 'barycentric' method
143147
xr_method = "barycentric" if method == "poly" else method
144148

tidy3d/components/data/monitor_data.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,6 +1286,22 @@ def to_zbf(
12861286

12871287
return e_x, e_y
12881288

1289+
def _interpolated_copies_if_needed(
1290+
self, other: ElectromagneticFieldData
1291+
) -> tuple[ElectromagneticFieldData, ElectromagneticFieldData]:
1292+
"""Return interpolated copies of self, other if needed (different interp_spec)."""
1293+
mode_spec1 = self.monitor.mode_spec if isinstance(self, ModeSolverData) else None
1294+
mode_spec2 = other.monitor.mode_spec if isinstance(other, ModeSolverData) else None
1295+
if (
1296+
mode_spec1 is not None
1297+
and mode_spec2 is not None
1298+
and self.monitor.mode_spec._same_nontrivial_interp_spec(other=other.monitor.mode_spec)
1299+
):
1300+
return self, other
1301+
self_copy = self.interpolated_copy if isinstance(self, ModeSolverData) else self
1302+
other_copy = other.interpolated_copy if isinstance(other, ModeSolverData) else other
1303+
return self_copy, other_copy
1304+
12891305

12901306
class FieldData(FieldDataset, ElectromagneticFieldData):
12911307
"""
@@ -2685,6 +2701,8 @@ def _reduced_data(self) -> bool:
26852701
@property
26862702
def interpolated_copy(self) -> ModeSolverData:
26872703
"""Return a copy of the data with interpolated fields."""
2704+
if self.monitor.mode_spec.interp_spec is None:
2705+
return self
26882706
if not self._reduced_data:
26892707
return self
26902708
interpolated_data = self.interp_in_freq(

tidy3d/components/eme/data/sim_data.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,19 @@ def smatrix_in_basis(
197197
modes1 = port_modes1
198198
if not modes2_provided:
199199
modes2 = port_modes2
200-
f1 = list(modes1.field_components.values())[0].f.values
201-
f2 = list(modes2.field_components.values())[0].f.values
200+
f1 = list(modes1.monitor.freqs)
201+
f2 = list(modes2.monitor.freqs)
202202

203203
f = np.array(sorted(set(f1).intersection(f2).intersection(self.simulation.freqs)))
204204

205+
mode_spec1 = modes1.monitor.mode_spec if isinstance(modes1, ModeData) else None
206+
mode_spec2 = modes2.monitor.mode_spec if isinstance(modes2, ModeData) else None
207+
208+
interp_spec1 = mode_spec1.interp_spec if mode_spec1 is not None else None
209+
interp_spec2 = mode_spec2.interp_spec if mode_spec2 is not None else None
210+
211+
modes1, modes2 = modes1._interpolated_copies_if_needed(other=modes2)
212+
205213
modes_in_1 = "mode_index" in list(modes1.field_components.values())[0].coords
206214
modes_in_2 = "mode_index" in list(modes2.field_components.values())[0].coords
207215

@@ -259,6 +267,10 @@ def smatrix_in_basis(
259267
overlaps1 = modes1.outer_dot(port_modes1, conjugate=False)
260268
if not modes_in_1:
261269
overlaps1 = overlaps1.expand_dims(dim={"mode_index_0": mode_index_1}, axis=1)
270+
if interp_spec1 is not None:
271+
overlaps1 = modes1._interp_dataarray_in_freq(
272+
overlaps1, freqs=f, method=interp_spec1.method
273+
)
262274
O1 = overlaps1.sel(f=f, mode_index_1=keep_mode_inds1)
263275

264276
O1out = O1.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old")
@@ -288,6 +300,10 @@ def smatrix_in_basis(
288300
overlaps2 = modes2.outer_dot(port_modes2, conjugate=False)
289301
if not modes_in_2:
290302
overlaps2 = overlaps2.expand_dims(dim={"mode_index_0": mode_index_2}, axis=1)
303+
if interp_spec2 is not None:
304+
overlaps2 = modes2._interp_dataarray_in_freq(
305+
overlaps2, freqs=f, method=interp_spec2.method
306+
)
291307
O2 = overlaps2.sel(f=f, mode_index_1=keep_mode_inds2)
292308

293309
O2out = O2.rename(mode_index_0="mode_index_out", mode_index_1="mode_index_out_old")

tidy3d/components/eme/grid.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing
1212
from tidy3d.components.geometry.base import Box
1313
from tidy3d.components.grid.grid import Coords1D
14-
from tidy3d.components.mode_spec import ModeSpec
14+
from tidy3d.components.mode_spec import ModeInterpSpec, ModeSpec
1515
from tidy3d.components.structure import Structure
16-
from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size, TrackFreq
16+
from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size
1717
from tidy3d.constants import RADIAN, fp_eps, inf
1818
from tidy3d.exceptions import SetupError, ValidationError
1919

@@ -26,13 +26,14 @@
2626
class EMEModeSpec(ModeSpec):
2727
"""Mode spec for EME cells. Overrides some of the defaults and allowed values."""
2828

29-
track_freq: Union[TrackFreq, None] = pd.Field(
30-
None,
31-
title="Mode Tracking Frequency",
32-
description="Parameter that turns on/off mode tracking based on their similarity. "
33-
"Can take values ``'lowest'``, ``'central'``, or ``'highest'``, which correspond to "
34-
"mode tracking based on the lowest, central, or highest frequency. "
35-
"If ``None`` no mode tracking is performed, which is the default for best performance.",
29+
interp_spec: Optional[ModeInterpSpec] = pd.Field(
30+
ModeInterpSpec.cheb(num_points=3, reduce_data=True),
31+
title="Mode frequency interpolation specification",
32+
description="Specification for computing modes at a reduced set of frequencies and "
33+
"interpolating to obtain results at all requested frequencies. This can significantly "
34+
"reduce computational cost for broadband simulations where modes vary smoothly with "
35+
"frequency. Requires frequency tracking to be enabled (``sort_spec.track_freq`` must "
36+
"not be ``None``) to ensure consistent mode ordering across frequencies.",
3637
)
3738

3839
angle_theta: Literal[0.0] = pd.Field(

0 commit comments

Comments
 (0)