Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Classification tests
- :func:`~.giskard.testing.test_right_label`
- :func:`~.giskard.testing.test_output_in_range`
- :func:`~.giskard.testing.test_disparate_impact`
- :func:`~.giskard.testing.test_nominal_association`
- :func:`~.giskard.testing.test_cramer_v`
- :func:`~.giskard.testing.test_mutual_information`
- :func:`~.giskard.testing.test_theil_u`

- **Performance tests**

Expand Down
4 changes: 4 additions & 0 deletions python-client/docs/reference/tests/statistic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ Statistical tests
.. autofunction:: giskard.testing.test_right_label
.. autofunction:: giskard.testing.test_output_in_range
.. autofunction:: giskard.testing.test_disparate_impact
.. autofunction:: giskard.testing.test_nominal_association
.. autofunction:: giskard.testing.test_cramer_v
.. autofunction:: giskard.testing.test_mutual_information
.. autofunction:: giskard.testing.test_theil_u
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from dataclasses import dataclass
import pandas as pd
from sklearn.metrics import adjusted_mutual_info_score, mutual_info_score
from scipy import stats

from ..common.examples import ExampleExtractor
from ...ml_worker.testing.registry.slicing_function import SlicingFunction
Expand All @@ -12,6 +10,7 @@
from ...models.base import BaseModel
from ..registry import Detector
from ..decorators import detector
from ...testing.tests.statistic import _cramer_v, _mutual_information, _theil_u


@detector(name="spurious_correlation", tags=["spurious_correlation", "classification"])
Expand Down Expand Up @@ -72,7 +71,14 @@ def run(self, model: BaseModel, dataset: Dataset):

if metric_value > self.threshold:
predictions = dx[dx.feature > 0].prediction.value_counts(normalize=True)
info = SpuriousCorrelationInfo(col, slice_fn, metric_value, measure_name, predictions)
info = SpuriousCorrelationInfo(
feature=col,
slice_fn=slice_fn,
metric_value=metric_value,
metric_name=measure_name,
threshold=self.threshold,
predictions=predictions,
)
issues.append(SpuriousCorrelationIssue(model, dataset, "info", info))

return issues
Expand All @@ -87,25 +93,13 @@ def _get_measure_fn(self):
raise ValueError(f"Unknown method `{self.method}`")


def _cramer_v(x, y):
ct = pd.crosstab(x, y)
return stats.contingency.association(ct, method="cramer")


def _mutual_information(x, y):
return adjusted_mutual_info_score(x, y)


def _theil_u(x, y):
return mutual_info_score(x, y) / stats.entropy(pd.Series(y).value_counts(normalize=True))


@dataclass
class SpuriousCorrelationInfo:
feature: str
slice_fn: SlicingFunction
metric_value: float
metric_name: str
threshold: float
predictions: pd.DataFrame


Expand Down Expand Up @@ -149,3 +143,41 @@ def examples(self, n=3):
@property
def importance(self) -> float:
return self.info.metric_value

def generate_tests(self, with_names=False) -> list:
test_fn = _metric_to_test_object(self.info.metric_name)

if test_fn is None:
return []

tests = [
test_fn(
model=self.model,
dataset=self.dataset,
slicing_function=self.info.slice_fn,
threshold=self.info.threshold,
)
]

if with_names:
names = [f"{self.info.metric_name} on data slice “{self.info.slice_fn}”"]
return list(zip(tests, names))

return tests


_metric_test_mapping = {
"Cramer's V": "test_cramer_v",
"Mutual information": "test_mutual_information",
"Theil's U": "test_theil_u",
}


def _metric_to_test_object(metric_name):
from ...testing.tests import statistic

try:
test_name = _metric_test_mapping[metric_name]
return getattr(statistic, test_name)
except (KeyError, AttributeError):
return None
15 changes: 10 additions & 5 deletions python-client/giskard/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
"test_right_label",
"test_output_in_range",
"test_disparate_impact",
"test_nominal_association",
"test_cramer_v",
"test_mutual_information",
"test_theil_u",
"test_mae",
"test_rmse",
"test_recall",
Expand All @@ -33,7 +37,7 @@
"test_metamorphic_decreasing_wilcoxon",
"test_metamorphic_invariance_wilcoxon",
"test_underconfidence_rate",
"test_overconfidence_rate"
"test_overconfidence_rate",
]

from giskard.testing.tests.drift import (
Expand Down Expand Up @@ -76,8 +80,9 @@
test_right_label,
test_output_in_range,
test_disparate_impact,
test_nominal_association,
test_cramer_v,
test_mutual_information,
test_theil_u,
)
from giskard.testing.tests.calibration import (
test_underconfidence_rate,
test_overconfidence_rate
)
from giskard.testing.tests.calibration import test_underconfidence_rate, test_overconfidence_rate
Loading