Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b1b4b06
Adding support for TwoPointMeasurement filters.
vitenti Jan 8, 2025
81543ef
Merge branch 'master' into data_filters
vitenti Jan 8, 2025
c32592c
Fixing minor mypy issues.
vitenti Jan 8, 2025
3469990
Adding missing tests.
vitenti Jan 8, 2025
6218f1d
Removed unused import.
vitenti Jan 8, 2025
4a931c1
Breaking long line.
vitenti Jan 8, 2025
461a3b5
Using two-point pair as filter specification.
vitenti Jan 9, 2025
7ed861b
Better names.
vitenti Jan 9, 2025
63d066b
Adding tests for factories.
vitenti Jan 9, 2025
db68fa0
More tests.
vitenti Jan 10, 2025
caf5a69
Merge branch 'master' into data_filters
vitenti Feb 5, 2025
b2eedcb
Removed unused import.
vitenti Feb 5, 2025
c32e618
Added check for never reached branch.
vitenti Feb 5, 2025
d1269b8
TwoPointFilter documentation first draft
paulrogozenski Feb 10, 2025
1271e27
Simplify logic
marcpaterno Feb 11, 2025
7daad6f
If _path is set do not search current directory
marcpaterno Feb 11, 2025
1e15a90
Release cython version restriction
marcpaterno Feb 11, 2025
127a68e
Do not load our duplicate_code plugin
marcpaterno Feb 11, 2025
b6a2d19
Update finding of SACC files in some tests
marcpaterno Feb 11, 2025
449d269
Update version tag
marcpaterno Feb 11, 2025
e0a2c46
Refactor for improved test coverage and fix missing error case
marcpaterno Feb 12, 2025
9eab4b7
Complete branch coverage
marcpaterno Feb 12, 2025
34d6b15
Improving tutorial.
vitenti Feb 12, 2025
6329e11
Merge branch 'data_filters' of github.com:LSSTDESC/firecrown into dat…
vitenti Feb 12, 2025
5cc7802
Correct serialization for TwoPointFactory.
vitenti Feb 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
author = "LSST DESC Firecrown Contributors"

# The full version, including alpha/beta/rc tags
release = "1.8.0"
release = "1.9.0a0"


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- cosmosis >= 3.0
- cosmosis-build-standard-library
- coverage
- cython < 3.0.0
- cython
- dill
- fitsio
- flake8
Expand Down
2 changes: 1 addition & 1 deletion fctools/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# some global context to be used in the tracing. We are relying on
# 'trace_call' to act as a closure that captures these names.
tracefile = None # the file used for logging
tracefile: TextIO | None = None # the file used for logging
level = 0 # the call nesting level
entry = 0 # sequential entry number for each record

Expand Down
291 changes: 289 additions & 2 deletions firecrown/data_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,35 @@
"""

import hashlib
from typing import Callable, Sequence
from typing import Callable, Sequence, Annotated
from typing_extensions import assert_never

import sacc
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
model_validator,
PrivateAttr,
field_serializer,
)
import numpy as np
import numpy.typing as npt
import sacc

from firecrown.metadata_types import (
TwoPointHarmonic,
TwoPointReal,
Measurement,
)
from firecrown.metadata_functions import (
extract_all_tracers_inferred_galaxy_zdists,
extract_window_function,
extract_all_harmonic_metadata_indices,
extract_all_real_metadata_indices,
make_two_point_xy,
make_measurement,
make_measurement_dict,
)
from firecrown.data_types import TwoPointMeasurement

Expand Down Expand Up @@ -222,3 +237,275 @@ def check_two_point_consistence_real(
) -> None:
"""Check the indices of the real-space two-point functions."""
check_consistence(two_point_reals, lambda m: m.is_real(), "TwoPointReal")


class TwoPointTracerSpec(BaseModel):
"""Class defining a tracer bin specification."""

model_config = ConfigDict(extra="forbid", frozen=True)

name: Annotated[str, Field(description="The name of the tracer bin.")]
measurement: Annotated[
Measurement,
Field(description="The measurement of the tracer bin."),
BeforeValidator(make_measurement),
]

@field_serializer("measurement")
@classmethod
def serialize_measurement(cls, value: Measurement) -> dict[str, str]:
"""Serialize the Measurement."""
return make_measurement_dict(value)


def make_interval_from_list(
values: list[float] | tuple[float, float],
) -> tuple[float, float]:
"""Create an interval from a list of values."""
if isinstance(values, list):
if len(values) != 2:
raise ValueError("The list should have two values.")
if not all(isinstance(v, float) for v in values):
raise ValueError("The list should have two float values.")

return (values[0], values[1])
if isinstance(values, tuple):
return values

raise ValueError("The values should be a list or a tuple.")


class TwoPointBinFilter(BaseModel):
"""Class defining a filter for a bin."""

model_config = ConfigDict(extra="forbid", frozen=True)

spec: Annotated[
list[TwoPointTracerSpec],
Field(
description="The two-point bin specification.",
),
]
interval: Annotated[
tuple[float, float],
BeforeValidator(make_interval_from_list),
Field(description="The range of the bin to filter."),
]

@model_validator(mode="after")
def check_bin_filter(self) -> "TwoPointBinFilter":
"""Check the bin filter."""
if self.interval[0] >= self.interval[1]:
raise ValueError("The bin filter should be a valid range.")
if not 1 <= len(self.spec) <= 2:
raise ValueError("The bin_spec must contain one or two elements.")
return self

@field_serializer("interval")
@classmethod
def serialize_interval(cls, value: tuple[float, float]) -> list[float]:
"""Serialize the Measurement."""
return list(value)

@classmethod
def from_args(
cls,
name1: str,
measurement1: Measurement,
name2: str,
measurement2: Measurement,
lower: float,
upper: float,
) -> "TwoPointBinFilter":
"""Create a TwoPointBinFilter from the arguments."""
return cls(
spec=[
TwoPointTracerSpec(name=name1, measurement=measurement1),
TwoPointTracerSpec(name=name2, measurement=measurement2),
],
interval=(lower, upper),
)

@classmethod
def from_args_auto(
cls, name: str, measurement: Measurement, lower: float, upper: float
) -> "TwoPointBinFilter":
"""Create a TwoPointBinFilter from the arguments."""
return cls(
spec=[
TwoPointTracerSpec(name=name, measurement=measurement),
],
interval=(lower, upper),
)


BinSpec = frozenset[TwoPointTracerSpec]


def bin_spec_from_metadata(metadata: TwoPointReal | TwoPointHarmonic) -> BinSpec:
"""Return the bin spec from the metadata."""
return frozenset(
(
TwoPointTracerSpec(
name=metadata.XY.x.bin_name,
measurement=metadata.XY.x_measurement,
),
TwoPointTracerSpec(
name=metadata.XY.y.bin_name,
measurement=metadata.XY.y_measurement,
),
)
)


class TwoPointBinFilterCollection(BaseModel):
"""Class defining a collection of bin filters."""

model_config = ConfigDict(extra="forbid", frozen=True)

require_filter_for_all: bool = Field(
default=False,
description="If True, all bins should match a filter.",
)
allow_empty: bool = Field(
default=False,
description=(
"When true, objects with no elements remaining after applying "
"the filter will be ignored rather than treated as an error."
),
)
filters: list[TwoPointBinFilter] = Field(
description="The list of bin filters.",
)

_bin_filter_dict: dict[BinSpec, tuple[float, float]] = PrivateAttr()

@model_validator(mode="after")
def check_bin_filters(self) -> "TwoPointBinFilterCollection":
"""Check the bin filters."""
bin_specs = set()
for bin_filter in self.filters:
bin_spec = frozenset(bin_filter.spec)
if bin_spec in bin_specs:
raise ValueError(
f"The bin name {bin_filter.spec} is repeated "
f"in the bin filters."
)
bin_specs.add(bin_spec)

self._bin_filter_dict = {
frozenset(bin_filter.spec): bin_filter.interval
for bin_filter in self.filters
}
return self

@property
def bin_filter_dict(self) -> dict[BinSpec, tuple[float, float]]:
"""Return the bin filter dictionary."""
return self._bin_filter_dict

def filter_match(self, tpm: TwoPointMeasurement) -> bool:
"""Check if the TwoPointMeasurement matches the filter."""
bin_spec_key = bin_spec_from_metadata(tpm.metadata)
return bin_spec_key in self._bin_filter_dict

def run_bin_filter(
self,
bin_filter: tuple[float, float],
vals: npt.NDArray[np.float64] | npt.NDArray[np.int64],
) -> npt.NDArray[np.bool_]:
"""Run the filter merge."""
return (vals >= bin_filter[0]) & (vals <= bin_filter[1])

def apply_filter_single(
self, tpm: TwoPointMeasurement
) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
"""Apply the filter to a single TwoPointMeasurement."""
assert self.filter_match(tpm)
bin_spec_key = bin_spec_from_metadata(tpm.metadata)
bin_filter = self._bin_filter_dict[bin_spec_key]
if tpm.is_real():
assert isinstance(tpm.metadata, TwoPointReal)
match_elements = self.run_bin_filter(bin_filter, tpm.metadata.thetas)
return match_elements, match_elements

assert isinstance(tpm.metadata, TwoPointHarmonic)
match_elements = self.run_bin_filter(bin_filter, tpm.metadata.ells)
match_obs = match_elements
if tpm.metadata.window is not None:
# The window function is represented by a matrix where each column
# corresponds to the weights for the ell values of each observation. We
# need to ensure that the window function is filtered correctly. To do this,
# we will check each column of the matrix and verify that all non-zero
# elements are within the filtered set. If any non-zero element falls
# outside the filtered set, the match_elements will be set to False for that
# observation.
non_zero_window = tpm.metadata.window > 0
match_obs = (
np.all(
(non_zero_window & match_elements[:, None]) == non_zero_window,
axis=0,
)
.ravel()
.astype(np.bool_)
)

return match_elements, match_obs

def __call__(
self, tpms: Sequence[TwoPointMeasurement]
) -> list[TwoPointMeasurement]:
"""Filter the two-point measurements."""
result = []

for tpm in tpms:
if not self.filter_match(tpm):
if not self.require_filter_for_all:
result.append(tpm)
continue
raise ValueError(f"The bin name {tpm.metadata} does not have a filter.")

match_elements, match_obs = self.apply_filter_single(tpm)
if not match_obs.any():
if not self.allow_empty:
# If empty results are not allowed, we raise an error
raise ValueError(
f"The TwoPointMeasurement {tpm.metadata} does not "
f"have any elements matching the filter."
)
# If the filter is empty, we skip this measurement
continue

assert isinstance(tpm.metadata, (TwoPointReal, TwoPointHarmonic))
new_metadata: TwoPointReal | TwoPointHarmonic
match tpm.metadata:
case TwoPointReal():
new_metadata = TwoPointReal(
XY=tpm.metadata.XY,
thetas=tpm.metadata.thetas[match_elements],
)
case TwoPointHarmonic():
# If the window function is not None, we need to filter it as well
# and update the metadata accordingly.
new_metadata = TwoPointHarmonic(
XY=tpm.metadata.XY,
window=(
tpm.metadata.window[:, match_obs][match_elements, :]
if tpm.metadata.window is not None
else None
),
ells=tpm.metadata.ells[match_elements],
)
case _ as unreachable:
assert_never(unreachable)

result.append(
TwoPointMeasurement(
data=tpm.data[match_obs],
indices=tpm.indices[match_obs],
covariance_name=tpm.covariance_name,
metadata=new_metadata,
)
)

return result
Loading