Skip to content

Commit 33247f9

Browse files
committed
Add categorical robustness detector with perturbations
1 parent 1c80cbe commit 33247f9

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

categorical_robustness/detector.py

Whitespace-only changes.

numerical_robustness_detector.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
2+
3+
from giskard import Dataset, Model, Issue, IssueSeverity, IssueType
4+
from typing import Optional, List, Tuple, Any
5+
import numpy as np
6+
import pandas as pd
7+
8+
9+
class NumericalRobustnessScan:
10+
"""
11+
Giskard Scan that detects minimal numerical perturbations capable of
12+
changing a model’s prediction or output significantly.
13+
"""
14+
15+
def __init__(
16+
self,
17+
model: Model,
18+
dataset: Dataset,
19+
threshold: float = 0.1,
20+
max_steps: int = 100,
21+
verbose: bool = False
22+
):
23+
"""
24+
Initialize the scan.
25+
26+
Args:
27+
model (Model): Giskard Model object.
28+
dataset (Dataset): Giskard Dataset object.
29+
threshold (float): Threshold change for regression predictions.
30+
max_steps (int): Steps between min/max to test perturbations.
31+
verbose (bool): If True, prints progress during scan.
32+
"""
33+
self.model = model
34+
self.dataset = dataset
35+
self.threshold = threshold
36+
self.max_steps = max_steps
37+
self.verbose = verbose
38+
39+
self.is_classification = self.model.is_classifier
40+
self.X = self.dataset.get_features_dataframe()
41+
self.feature_names = self.dataset.feature_names
42+
self.X_np = self.X.to_numpy()
43+
self.feature_bounds = self._get_feature_bounds()
44+
45+
def _get_feature_bounds(self) -> List[Tuple[float, float]]:
46+
"""Extract min/max bounds for each numerical feature."""
47+
bounds = []
48+
for feature in self.dataset.features:
49+
if feature.feature_type == "numerical":
50+
min_val = feature.min if feature.min is not None else self.X[feature.name].min()
51+
max_val = feature.max if feature.max is not None else self.X[feature.name].max()
52+
bounds.append((min_val, max_val))
53+
else:
54+
bounds.append((np.nan, np.nan))
55+
return bounds
56+
57+
def _predict(self, sample: np.array) -> Any:
58+
"""Run prediction for a single sample."""
59+
return self.model.predict(sample.reshape(1, -1))[0]
60+
61+
def _build_issue(
62+
self,
63+
feature_index: int,
64+
perturb_value: float,
65+
original_pred: Any,
66+
new_pred: Any,
67+
sample_idx: int
68+
) -> Issue:
69+
"""Create a Giskard Issue object."""
70+
feature_name = self.feature_names[feature_index]
71+
description = (
72+
f"Perturbing '{feature_name}' by {abs(perturb_value):.4f} in sample {sample_idx} "
73+
f"changed prediction from {original_pred} to {new_pred}."
74+
)
75+
return Issue(
76+
type=IssueType.ROBUSTNESS,
77+
severity=IssueSeverity.MEDIUM,
78+
description=description,
79+
feature=feature_name,
80+
sample_index=sample_idx,
81+
)
82+
83+
def _scan_feature(self, feature_index: int) -> Optional[Issue]:
84+
"""Scan a single feature for robustness issues."""
85+
min_val, max_val = self.feature_bounds[feature_index]
86+
if np.isnan(min_val) or np.isnan(max_val):
87+
return None
88+
89+
step_size = (max_val - min_val) / self.max_steps
90+
91+
for sample_idx in range(len(self.X_np)):
92+
original_sample = self.X_np[sample_idx].copy()
93+
original_pred = self._predict(original_sample)
94+
95+
for step in range(1, self.max_steps + 1):
96+
for direction in [+1, -1]:
97+
perturb = direction * step * step_size
98+
new_val = original_sample[feature_index] + perturb
99+
100+
if not (min_val <= new_val <= max_val):
101+
continue
102+
103+
perturbed_sample = original_sample.copy()
104+
perturbed_sample[feature_index] = new_val
105+
new_pred = self._predict(perturbed_sample)
106+
107+
if self.is_classification and new_pred != original_pred:
108+
return self._build_issue(feature_index, perturb, original_pred, new_pred, sample_idx)
109+
elif not self.is_classification and abs(new_pred - original_pred) > self.threshold:
110+
return self._build_issue(feature_index, perturb, original_pred, new_pred, sample_idx)
111+
112+
return None
113+
114+
def run_scan(self) -> List[Issue]:
115+
"""Run the full robustness scan across all numerical features."""
116+
issues = []
117+
for feature_index, feature_name in enumerate(self.feature_names):
118+
if self.verbose:
119+
print(f"Scanning feature: {feature_name} ({feature_index})")
120+
issue = self._scan_feature(feature_index)
121+
if issue:
122+
issues.append(issue)
123+
if self.verbose:
124+
print(f"✔ Issue found on '{feature_name}'")
125+
elif self.verbose:
126+
print(f"✘ No issue on '{feature_name}'")
127+
return issues
128+
129+
130+
if __name__ == "__main__":
131+
import argparse
132+
133+
parser = argparse.ArgumentParser(description="Run Numerical Robustness Scan.")
134+
parser.add_argument("--model_path", required=True, help="Path to Giskard model file (YAML)")
135+
parser.add_argument("--dataset_path", required=True, help="Path to Giskard dataset file (YAML)")
136+
parser.add_argument("--threshold", type=float, default=0.1, help="Threshold for regression change")
137+
parser.add_argument("--max_steps", type=int, default=100, help="Steps to scan perturbations")
138+
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
139+
140+
args = parser.parse_args()
141+
142+
model = Model.load(args.model_path)
143+
dataset = Dataset.load(args.dataset_path)
144+
145+
scan = NumericalRobustnessScan(
146+
model=model,
147+
dataset=dataset,
148+
threshold=args.threshold,
149+
max_steps=args.max_steps,
150+
verbose=args.verbose
151+
)
152+
153+
issues = scan.run_scan()
154+
155+
print(f"\nScan complete. Found {len(issues)} issue(s).")
156+
for issue in issues:
157+
print(f"- {issue.description}")

0 commit comments

Comments
 (0)