Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ cluster_number_counts/cluster_counts_redshift_richness.builders.yaml
cluster_number_counts/cluster_counts_redshift_richness.yaml
cluster_number_counts/cluster_mean_mass_redshift_richness.builders.yaml
cluster_number_counts/cluster_mean_mass_redshift_richness.yaml
cluster_number_counts/cluster_SDSS_counts_mean_mass_redshift_richness.builders.yaml
cluster_number_counts/cluster_SDSS_counts_mean_mass_redshift_richness.yaml
cosmicshear/cosmicshear.builders.yaml
cosmicshear/cosmicshear.yaml
des_y1_3x2pt/des_y1_3x2pt.builders.yaml
Expand Down
159 changes: 143 additions & 16 deletions firecrown/ccl_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""This module contains the CCLFactory class.
"""This module contains the CCLFactory class and it supporting classes.

The CCLFactory class is a factory class that creates instances of the
`pyccl.Cosmology` class.
Expand All @@ -21,6 +21,8 @@
Field,
field_serializer,
model_serializer,
model_validator,
PrivateAttr,
)

import pyccl
Expand All @@ -31,6 +33,7 @@
from firecrown.parameters import register_new_updatable_parameter
from firecrown.utils import YAMLSerializable

# PowerSpec is a type that represents a power spectrum.
PowerSpec = TypedDict(
"PowerSpec",
{
Expand All @@ -40,6 +43,7 @@
},
)

# Background is a type that represents the cosmological background quantities.
Background = TypedDict(
"Background",
{
Expand All @@ -49,6 +53,8 @@
},
)

# CCLCalculatorArgs is a type that represents the arguments for the
# CCLCalculator.
CCLCalculatorArgs = TypedDict(
"CCLCalculatorArgs",
{
Expand Down Expand Up @@ -170,6 +176,112 @@ def get_dict(self) -> dict:
}


class CCLSplineParams(BaseModel):
"""Params to control CCL spline interpolation."""

model_config = ConfigDict(extra="forbid")

# Scale factor splines
a_spline_na: Annotated[int | None, Field(frozen=True)] = None
a_spline_min: Annotated[float | None, Field(frozen=True)] = None
a_spline_minlog_pk: Annotated[float | None, Field(frozen=True)] = None
a_spline_min_pk: Annotated[float | None, Field(frozen=True)] = None
a_spline_minlog_sm: Annotated[float | None, Field(frozen=True)] = None
a_spline_min_sm: Annotated[float | None, Field(frozen=True)] = None
# a_spline_max is not defined because the CCL parameter A_SPLINE_MAX is
# required to be 1.0.
a_spline_minlog: Annotated[float | None, Field(frozen=True)] = None
a_spline_nlog: Annotated[int | None, Field(frozen=True)] = None

# mass splines
logm_spline_delta: Annotated[float | None, Field(frozen=True)] = None
logm_spline_nm: Annotated[int | None, Field(frozen=True)] = None
logm_spline_min: Annotated[float | None, Field(frozen=True)] = None
logm_spline_max: Annotated[float | None, Field(frozen=True)] = None

# PS a and k spline
a_spline_na_sm: Annotated[int | None, Field(frozen=True)] = None
a_spline_nlog_sm: Annotated[int | None, Field(frozen=True)] = None
a_spline_na_pk: Annotated[int | None, Field(frozen=True)] = None
a_spline_nlog_pk: Annotated[int | None, Field(frozen=True)] = None

# k-splines and integrals
k_max_spline: Annotated[float | None, Field(frozen=True)] = None
k_max: Annotated[float | None, Field(frozen=True)] = None
k_min: Annotated[float | None, Field(frozen=True)] = None
dlogk_integration: Annotated[float | None, Field(frozen=True)] = None
dchi_integration: Annotated[float | None, Field(frozen=True)] = None
n_k: Annotated[int | None, Field(frozen=True)] = None
n_k_3dcor: Annotated[int | None, Field(frozen=True)] = None

# Correlation function parameters
ell_min_corr: Annotated[float | None, Field(frozen=True)] = None
ell_max_corr: Annotated[float | None, Field(frozen=True)] = None
n_ell_corr: Annotated[int | None, Field(frozen=True)] = None

# Attributes that are used for the context manager functionality.
# These are *not* part of the model.
_spline_params: dict[str, float | int] = PrivateAttr()

@model_validator(mode="after")
def check_spline_params(self) -> "CCLSplineParams":
"""Check that the spline parameters are valid."""
# Ensure the spline boundaries and breakpoint are valid.
spline_breaks = [self.a_spline_minlog, self.a_spline_min, 1.0]
spline_breaks = list(filter(lambda x: x is not None, spline_breaks))
assert all(
a is not None and b is not None and a < b
for a, b in zip(spline_breaks, spline_breaks[1:])
)

# Ensure the mass spline boundaries are valid
if self.logm_spline_min is not None and self.logm_spline_max is not None:
assert self.logm_spline_min < self.logm_spline_max

# Ensure the k-spline boundaries are valid
if self.k_min is not None and self.k_max is not None:
assert self.k_min < self.k_max

# Ensure the ell-spline boundaries are valid
if self.ell_min_corr is not None and self.ell_max_corr is not None:
assert self.ell_min_corr < self.ell_max_corr

return self

def __enter__(self) -> "CCLSplineParams":
"""Enter the context manager.

This method saves the current CCL global spline parameters,
updates them with the values from this `CCLSplineParams` instance,
and returns the instance itself. This allows for temporary modification
of CCL spline parameters using a `with` statement.

:return: The current instance with updated spline parameters.
"""
self._spline_params = pyccl.CCLParameters.get_params_dict(pyccl.spline_params)
for key, value in self.model_dump().items():
if value is not None:
pyccl.spline_params[key.upper()] = value
return self

def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context manager.

This method resets the CCL global spline parameters to their original
values, as saved when entering the context manager. It ensures that
any temporary modifications made to the CCL spline parameters within
a `with` statement are reverted upon exit.

:param exc_type: The exception type, if an exception occurred.
:param exc_value: The exception value, if an exception occurred.
:param traceback: The traceback object, if an exception occurred.
"""
for key, value in self._spline_params.items():
pyccl.spline_params[key] = value
if exc_type is not None:
raise exc_type(exc_value).with_traceback(traceback)


class CCLFactory(Updatable, BaseModel):
"""Factory class for creating instances of the `pyccl.Cosmology` class."""

Expand All @@ -191,6 +303,7 @@ class CCLFactory(Updatable, BaseModel):
Field(frozen=True),
] = CCLCreationMode.DEFAULT
camb_extra_params: Annotated[CAMBExtraParams | None, Field(frozen=True)] = None
ccl_spline_params: Annotated[CCLSplineParams | None, Field(frozen=True)] = None

def __init__(self, **data):
"""Initialize the CCLFactory object."""
Expand All @@ -205,33 +318,39 @@ def __init__(self, **data):

self._ccl_cosmo: None | pyccl.Cosmology = None

ccl_cosmo = pyccl.CosmologyVanillaLCDM()
temp_cosmology = pyccl.CosmologyVanillaLCDM()

self.Omega_c = register_new_updatable_parameter(
default_value=ccl_cosmo["Omega_c"]
default_value=temp_cosmology["Omega_c"]
)
self.Omega_b = register_new_updatable_parameter(
default_value=ccl_cosmo["Omega_b"]
default_value=temp_cosmology["Omega_b"]
)
self.h = register_new_updatable_parameter(default_value=ccl_cosmo["h"])
self.n_s = register_new_updatable_parameter(default_value=ccl_cosmo["n_s"])
self.h = register_new_updatable_parameter(default_value=temp_cosmology["h"])
self.n_s = register_new_updatable_parameter(default_value=temp_cosmology["n_s"])
self.Omega_k = register_new_updatable_parameter(
default_value=ccl_cosmo["Omega_k"]
default_value=temp_cosmology["Omega_k"]
)
self.Neff = register_new_updatable_parameter(
default_value=temp_cosmology["Neff"]
)
self.m_nu = register_new_updatable_parameter(
default_value=temp_cosmology["m_nu"]
)
self.w0 = register_new_updatable_parameter(default_value=temp_cosmology["w0"])
self.wa = register_new_updatable_parameter(default_value=temp_cosmology["wa"])
self.T_CMB = register_new_updatable_parameter(
default_value=temp_cosmology["T_CMB"]
)
self.Neff = register_new_updatable_parameter(default_value=ccl_cosmo["Neff"])
self.m_nu = register_new_updatable_parameter(default_value=ccl_cosmo["m_nu"])
self.w0 = register_new_updatable_parameter(default_value=ccl_cosmo["w0"])
self.wa = register_new_updatable_parameter(default_value=ccl_cosmo["wa"])
self.T_CMB = register_new_updatable_parameter(default_value=ccl_cosmo["T_CMB"])

match self.amplitude_parameter:
case PoweSpecAmplitudeParameter.AS:
# VanillaLCDM has does not have A_s, so we need to add it
self.A_s = register_new_updatable_parameter(default_value=2.1e-9)
case PoweSpecAmplitudeParameter.SIGMA8:
assert ccl_cosmo["sigma8"] is not None
assert temp_cosmology["sigma8"] is not None
self.sigma8 = register_new_updatable_parameter(
default_value=ccl_cosmo["sigma8"]
default_value=temp_cosmology["sigma8"]
)
case _ as unreachable:
assert_never(unreachable)
Expand Down Expand Up @@ -322,7 +441,11 @@ def create(
"mode and no CAMB extra parameters."
)

self._ccl_cosmo = pyccl.CosmologyCalculator(**ccl_args)
if self.ccl_spline_params is not None:
with self.ccl_spline_params:
self._ccl_cosmo = pyccl.CosmologyCalculator(**ccl_args)
else:
self._ccl_cosmo = pyccl.CosmologyCalculator(**ccl_args)
return self._ccl_cosmo

if self.require_nonlinear_pk:
Expand All @@ -341,7 +464,11 @@ def create(
matter_power_spectrum="linear",
transfer_function="boltzmann_isitgr",
)
self._ccl_cosmo = pyccl.Cosmology(**ccl_args)
if self.ccl_spline_params is not None:
with self.ccl_spline_params:
self._ccl_cosmo = pyccl.Cosmology(**ccl_args)
else:
self._ccl_cosmo = pyccl.Cosmology(**ccl_args)
return self._ccl_cosmo

def _reset(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions firecrown/likelihood/weak_lensing.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def create_tracers(self, tools: ModelingTools):
tracer_name="intrinsic_alignment_hm",
halo_profile=halo_profile,
)
# TODO: redesign this so that we are not adding a new
# attribute to a pyccl class.
halo_profile.ia_a_2h = (
tracer_args.ia_a_2h
) # Attach the 2-halo amplitude here.
Expand Down
5 changes: 3 additions & 2 deletions firecrown/models/cluster/abundance.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ClusterAbundance(Updatable):
"""

@property
def cosmo(self) -> Cosmology:
def cosmo(self) -> Cosmology | None:
"""The cosmology used to predict the cluster number count."""
return self._cosmo

Expand All @@ -41,7 +41,7 @@ def __init__(
self.min_z = z_interval[0]
self.max_z = z_interval[1]
self._hmf_cache: dict[tuple[float, float], float] = {}
self._cosmo: Cosmology = None
self._cosmo: Cosmology | None = None

def update_ingredients(self, cosmo: Cosmology) -> None:
"""Update the cluster abundance calculation with a new cosmology."""
Expand All @@ -55,6 +55,7 @@ def comoving_volume(

:param sky_area: The area of the survey on the sky in square degrees.
"""
assert self.cosmo is not None
scale_factor = 1.0 / (1.0 + z)
angular_diam_dist = bkg.angular_diameter_distance(self.cosmo, scale_factor)
h_over_h0 = bkg.h_over_h0(self.cosmo, scale_factor)
Expand Down
3 changes: 1 addition & 2 deletions firecrown/models/cluster/deltasigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def delta_sigma(
)
for log_m, redshift in zip(log_mass, z):
a = 1.0 / (1.0 + redshift)
# pylint: disable=protected-access
conc_val = conc._concentration(self._cosmo, 10**log_m, a)
conc_val = conc(self._cosmo, 10**log_m, a)
moo.set_concentration(conc_val)
moo.set_mass(10**log_m)
val = moo.eval_excess_surface_density(radius_center, redshift)
Expand Down
16 changes: 14 additions & 2 deletions firecrown/models/two_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def at_least_one_tracer_has_hm(
IA_bias_exponent = (
1 # IA bias if not both tracers are HM (doing GI correlation).
)
# mypy complains about the following line even though
# the HMCalculator type does have a mass_def attribute.
other_profile = pyccl.halos.HaloProfileNFW(
mass_def=hm_calculator.mass_def,
concentration=cM_relation,
Expand All @@ -103,12 +105,16 @@ def at_least_one_tracer_has_hm(
)
other_profile.ia_a_2h = -1.0 # used in GI contribution, which is negative.
if not tracer0.has_hm:
profile0 = other_profile
profile1 = tracer1.halo_profile
assert tracer1.halo_profile is not None
profile0: pyccl.halos.HaloProfile = other_profile
profile1: pyccl.halos.HaloProfile = tracer1.halo_profile
else:
assert tracer0.halo_profile is not None
profile0 = tracer0.halo_profile
profile1 = other_profile
else:
assert tracer0.halo_profile is not None
assert tracer1.halo_profile is not None
profile0 = tracer0.halo_profile
profile1 = tracer1.halo_profile
# Ensure that profile0 and profile1 are not None.
Expand All @@ -127,6 +133,12 @@ def at_least_one_tracer_has_hm(
C1rhocrit = (
5e-14 * pyccl.physical_constants.RHO_CRITICAL
) # standard IA normalisation
# These assertions are required because the pyccl profiles do not have ia_a_2h.
# That is something added locally.
assert hasattr(profile0, "ia_a_2h")
assert hasattr(profile1, "ia_a_2h")
assert hasattr(ccl_cosmo, "growth_factor")
assert hasattr(ccl_cosmo, "nonlin_matter_power")
pk_2h = pyccl.Pk2D.from_function(
pkfunc=lambda k, a: profile0.ia_a_2h
* profile1.ia_a_2h
Expand Down
Loading