Skip to content

Commit 0c56c88

Browse files
super_medium added to custom medium & return self for unperturbed medium
1 parent bad66c6 commit 0c56c88

File tree

2 files changed

+89
-12
lines changed

2 files changed

+89
-12
lines changed

tests/test_components/test_perturbation_medium.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,56 @@ def test_correct_values(dispersive):
362362

363363
assert np.isclose(si_n + pp_large_sampled, si_index_perturb_n)
364364
assert np.isclose(si_k + pp_small_sampled, si_index_perturb_k)
365+
366+
367+
@pytest.mark.parametrize("unstructured", [False, True])
368+
def test_from_medium_field(unstructured):
369+
"""Test that super_medium field is properly set when calling perturbed_copy."""
370+
# Setup fields to sample at
371+
coords = {"x": [1, 2], "y": [3, 4], "z": [5, 6]}
372+
temperature = td.SpatialDataArray(300 * np.ones((2, 2, 2)), coords=coords)
373+
electron_density = td.SpatialDataArray(1e18 * np.ones((2, 2, 2)), coords=coords)
374+
hole_density = td.SpatialDataArray(2e18 * np.ones((2, 2, 2)), coords=coords)
375+
376+
if unstructured:
377+
temperature = cartesian_to_unstructured(temperature, seed=7747)
378+
electron_density = cartesian_to_unstructured(electron_density, seed=7747)
379+
hole_density = cartesian_to_unstructured(hole_density, seed=7747)
380+
381+
# Test PerturbationMedium
382+
pp_real = td.ParameterPerturbation(
383+
heat=td.LinearHeatPerturbation(
384+
coeff=-0.01,
385+
temperature_ref=300,
386+
temperature_range=(200, 500),
387+
),
388+
)
389+
390+
pmed = td.PerturbationMedium(permittivity=10, permittivity_perturbation=pp_real)
391+
392+
# Test without any perturbation data (returns self)
393+
cmed_no_perturb = pmed.perturbed_copy()
394+
assert isinstance(cmed_no_perturb, td.PerturbationMedium)
395+
396+
# Test with perturbation data (returns CustomMedium)
397+
cmed_with_perturb = pmed.perturbed_copy(temperature, electron_density, hole_density)
398+
assert isinstance(cmed_with_perturb, td.CustomMedium)
399+
assert cmed_with_perturb.super_medium is pmed
400+
assert hash(cmed_with_perturb.super_medium) == hash(pmed)
401+
402+
# Test PerturbationPoleResidue
403+
pmed_pole = td.PerturbationPoleResidue(
404+
eps_inf=10,
405+
poles=[(1j, 3), (2j, 4)],
406+
eps_inf_perturbation=pp_real,
407+
)
408+
409+
# Test without any perturbation data (returns self)
410+
cmed_pole_no_perturb = pmed_pole.perturbed_copy()
411+
assert isinstance(cmed_pole_no_perturb, td.PerturbationPoleResidue)
412+
413+
# Test with perturbation data (returns CustomPoleResidue)
414+
cmed_pole_with_perturb = pmed_pole.perturbed_copy(temperature, electron_density, hole_density)
415+
assert isinstance(cmed_pole_with_perturb, td.CustomPoleResidue)
416+
assert cmed_pole_with_perturb.super_medium is pmed_pole
417+
assert hash(cmed_pole_with_perturb.super_medium) == hash(pmed_pole)

tidy3d/components/medium.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
PermittivityComponent,
9292
PoleAndResidue,
9393
TensorReal,
94+
annotate_type,
9495
)
9596
from .validators import _warn_potential_error, validate_name_str, validate_parameter_perturbation
9697
from .viz import VisualizationSpec, add_ax_if_none
@@ -875,6 +876,12 @@ class AbstractCustomMedium(AbstractMedium, ABC):
875876
"intersection interfaces with other structures.",
876877
)
877878

879+
super_medium: Optional[annotate_type(PerturbationMediumType)] = pd.Field(
880+
None,
881+
title="Super Medium",
882+
description="If not ``None``, it records the super medium from which this medium was derived.",
883+
)
884+
878885
@cached_property
879886
@abstractmethod
880887
def is_isotropic(self) -> bool:
@@ -6748,7 +6755,7 @@ def perturbed_copy(
67486755
electron_density: CustomSpatialDataType = None,
67496756
hole_density: CustomSpatialDataType = None,
67506757
interp_method: InterpMethod = "linear",
6751-
) -> Union[Medium, CustomMedium]:
6758+
) -> Union[PerturbationMedium, CustomMedium]:
67526759
"""Sample perturbations on provided heat and/or charge data and return 'CustomMedium'.
67536760
Any of temperature, electron_density, and hole_density can be 'None'. If all passed
67546761
arguments are 'None' then a 'Medium' object is returned. All provided fields must have
@@ -6780,10 +6787,14 @@ def perturbed_copy(
67806787
67816788
Returns
67826789
-------
6783-
Union[Medium, CustomMedium]
6790+
Union[PerturbationMedium, CustomMedium]
67846791
Medium specification after application of heat and/or charge data.
67856792
"""
67866793

6794+
# in the absence of perturbation
6795+
if all(x is None for x in [temperature, electron_density, hole_density]):
6796+
return self
6797+
67876798
new_dict = self.dict(
67886799
exclude={
67896800
"permittivity_perturbation",
@@ -6793,10 +6804,6 @@ def perturbed_copy(
67936804
}
67946805
)
67956806

6796-
if all(x is None for x in [temperature, electron_density, hole_density]):
6797-
new_dict.pop("subpixel")
6798-
return Medium.parse_obj(new_dict)
6799-
68006807
permittivity_field = self.permittivity + ParameterPerturbation._zeros_like(
68016808
temperature, electron_density, hole_density
68026809
)
@@ -6836,6 +6843,7 @@ def perturbed_copy(
68366843
new_dict["permittivity"] = permittivity_field
68376844
new_dict["conductivity"] = conductivity_field
68386845
new_dict["interp_method"] = interp_method
6846+
new_dict["super_medium"] = self
68396847

68406848
return CustomMedium.parse_obj(new_dict)
68416849

@@ -6959,7 +6967,7 @@ def perturbed_copy(
69596967
electron_density: CustomSpatialDataType = None,
69606968
hole_density: CustomSpatialDataType = None,
69616969
interp_method: InterpMethod = "linear",
6962-
) -> Union[PoleResidue, CustomPoleResidue]:
6970+
) -> Union[PerturbationPoleResidue, CustomPoleResidue]:
69636971
"""Sample perturbations on provided heat and/or charge data and return 'CustomPoleResidue'.
69646972
Any of temperature, electron_density, and hole_density can be 'None'. If all passed
69656973
arguments are 'None' then a 'PoleResidue' object is returned. All provided fields must have
@@ -6991,18 +6999,18 @@ def perturbed_copy(
69916999
69927000
Returns
69937001
-------
6994-
Union[PoleResidue, CustomPoleResidue]
7002+
Union[PerturbationPoleResidue, CustomPoleResidue]
69957003
Medium specification after application of heat and/or charge data.
69967004
"""
69977005

7006+
# in the absence of perturbation
7007+
if all(x is None for x in [temperature, electron_density, hole_density]):
7008+
return self
7009+
69987010
new_dict = self.dict(
69997011
exclude={"eps_inf_perturbation", "poles_perturbation", "perturbation_spec", "type"}
70007012
)
70017013

7002-
if all(x is None for x in [temperature, electron_density, hole_density]):
7003-
new_dict.pop("subpixel")
7004-
return PoleResidue.parse_obj(new_dict)
7005-
70067014
zeros = ParameterPerturbation._zeros_like(temperature, electron_density, hole_density)
70077015

70087016
eps_inf_field = self.eps_inf + zeros
@@ -7050,12 +7058,28 @@ def perturbed_copy(
70507058
new_dict["eps_inf"] = eps_inf_field
70517059
new_dict["poles"] = poles_field
70527060
new_dict["interp_method"] = interp_method
7061+
new_dict["super_medium"] = self
70537062

70547063
return CustomPoleResidue.parse_obj(new_dict)
70557064

70567065

70577066
# types of mediums that can be used in Simulation and Structures
70587067

7068+
PerturbationMediumType = Union[PerturbationMedium, PerturbationPoleResidue]
7069+
7070+
7071+
# Update forward references for all Custom medium classes that inherit from AbstractCustomMedium
7072+
def _get_all_subclasses(cls):
7073+
"""Recursively get all subclasses of a class."""
7074+
all_subclasses = []
7075+
for subclass in cls.__subclasses__():
7076+
all_subclasses.append(subclass)
7077+
all_subclasses.extend(_get_all_subclasses(subclass))
7078+
return all_subclasses
7079+
7080+
7081+
for _custom_medium_cls in _get_all_subclasses(AbstractCustomMedium):
7082+
_custom_medium_cls.update_forward_refs()
70597083

70607084
MediumType3D = Union[
70617085
Medium,

0 commit comments

Comments
 (0)