Skip to content

Commit 6d54aa0

Browse files
feat(medium): track parent medium from which the medium is derived
1 parent bad66c6 commit 6d54aa0

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

tests/test_components/test_perturbation_medium.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,60 @@ def test_correct_values(dispersive):
362362

363363
assert np.isclose(si_n + pp_large_sampled, si_index_perturb_n)
364364
assert np.isclose(si_k + pp_small_sampled, si_index_perturb_k)
365+
366+
367+
@pytest.mark.parametrize("unstructured", [False, True])
368+
def test_from_medium_field(unstructured):
369+
"""Test that super_medium field is properly set when calling perturbed_copy."""
370+
# Setup fields to sample at
371+
coords = {"x": [1, 2], "y": [3, 4], "z": [5, 6]}
372+
temperature = td.SpatialDataArray(300 * np.ones((2, 2, 2)), coords=coords)
373+
electron_density = td.SpatialDataArray(1e18 * np.ones((2, 2, 2)), coords=coords)
374+
hole_density = td.SpatialDataArray(2e18 * np.ones((2, 2, 2)), coords=coords)
375+
376+
if unstructured:
377+
temperature = cartesian_to_unstructured(temperature, seed=7747)
378+
electron_density = cartesian_to_unstructured(electron_density, seed=7747)
379+
hole_density = cartesian_to_unstructured(hole_density, seed=7747)
380+
381+
# Test PerturbationMedium
382+
pp_real = td.ParameterPerturbation(
383+
heat=td.LinearHeatPerturbation(
384+
coeff=-0.01,
385+
temperature_ref=300,
386+
temperature_range=(200, 500),
387+
),
388+
)
389+
390+
pmed = td.PerturbationMedium(permittivity=10, permittivity_perturbation=pp_real)
391+
392+
# Test without any perturbation data (returns Medium)
393+
cmed_no_perturb = pmed.perturbed_copy()
394+
assert isinstance(cmed_no_perturb, td.Medium)
395+
assert cmed_no_perturb.super_medium is pmed
396+
assert hash(cmed_no_perturb.super_medium) == hash(pmed)
397+
398+
# Test with perturbation data (returns CustomMedium)
399+
cmed_with_perturb = pmed.perturbed_copy(temperature, electron_density, hole_density)
400+
assert isinstance(cmed_with_perturb, td.CustomMedium)
401+
assert cmed_with_perturb.super_medium is pmed
402+
assert hash(cmed_with_perturb.super_medium) == hash(pmed)
403+
404+
# Test PerturbationPoleResidue
405+
pmed_pole = td.PerturbationPoleResidue(
406+
eps_inf=10,
407+
poles=[(1j, 3), (2j, 4)],
408+
eps_inf_perturbation=pp_real,
409+
)
410+
411+
# Test without any perturbation data (returns PoleResidue)
412+
cmed_pole_no_perturb = pmed_pole.perturbed_copy()
413+
assert isinstance(cmed_pole_no_perturb, td.PoleResidue)
414+
assert cmed_pole_no_perturb.super_medium is pmed_pole
415+
assert hash(cmed_pole_no_perturb.super_medium) == hash(pmed_pole)
416+
417+
# Test with perturbation data (returns CustomPoleResidue)
418+
cmed_pole_with_perturb = pmed_pole.perturbed_copy(temperature, electron_density, hole_density)
419+
assert isinstance(cmed_pole_with_perturb, td.CustomPoleResidue)
420+
assert cmed_pole_with_perturb.super_medium is pmed_pole
421+
assert hash(cmed_pole_with_perturb.super_medium) == hash(pmed_pole)

tidy3d/components/medium.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,12 @@ class AbstractMedium(ABC, Tidy3dBaseModel):
197197
description="Plotting specification for visualizing medium.",
198198
)
199199

200+
super_medium: Optional["MediumType"] = pd.Field( # noqa: UP037
201+
None,
202+
title="Super Medium",
203+
description="If not ``None``, it records the super medium from which this medium was derived.",
204+
)
205+
200206
@cached_property
201207
def _nonlinear_models(self) -> list:
202208
"""The nonlinear models in the nonlinear_spec."""
@@ -1156,6 +1162,10 @@ def _derivative_field_cmp(
11561162

11571163

11581164
# PEC keyword
1165+
# Resolve forward references early to allow instantiation of builtins like PEC.
1166+
AbstractMedium.update_forward_refs(MediumType=Any)
1167+
1168+
11591169
class PECMedium(AbstractMedium):
11601170
"""Perfect electrical conductor class.
11611171
@@ -6795,6 +6805,7 @@ def perturbed_copy(
67956805

67966806
if all(x is None for x in [temperature, electron_density, hole_density]):
67976807
new_dict.pop("subpixel")
6808+
new_dict["super_medium"] = self
67986809
return Medium.parse_obj(new_dict)
67996810

68006811
permittivity_field = self.permittivity + ParameterPerturbation._zeros_like(
@@ -6836,6 +6847,7 @@ def perturbed_copy(
68366847
new_dict["permittivity"] = permittivity_field
68376848
new_dict["conductivity"] = conductivity_field
68386849
new_dict["interp_method"] = interp_method
6850+
new_dict["super_medium"] = self
68396851

68406852
return CustomMedium.parse_obj(new_dict)
68416853

@@ -7001,6 +7013,7 @@ def perturbed_copy(
70017013

70027014
if all(x is None for x in [temperature, electron_density, hole_density]):
70037015
new_dict.pop("subpixel")
7016+
new_dict["super_medium"] = self
70047017
return PoleResidue.parse_obj(new_dict)
70057018

70067019
zeros = ParameterPerturbation._zeros_like(temperature, electron_density, hole_density)
@@ -7050,6 +7063,7 @@ def perturbed_copy(
70507063
new_dict["eps_inf"] = eps_inf_field
70517064
new_dict["poles"] = poles_field
70527065
new_dict["interp_method"] = interp_method
7066+
new_dict["super_medium"] = self
70537067

70547068
return CustomPoleResidue.parse_obj(new_dict)
70557069

0 commit comments

Comments
 (0)