Skip to content
26 changes: 22 additions & 4 deletions giskard/slicing/tree_slicer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Sequence

import logging
import random

import numpy as np
import pandas as pd
Expand All @@ -9,6 +10,7 @@
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.tree._tree import Tree as SklearnTree

from ..utils.analytics_collector import analytics
from .base import BaseSlicer
from .slice import GreaterThan, LowerThan, Query, QueryBasedSliceFunction

Expand Down Expand Up @@ -79,24 +81,40 @@ def find_slices(self, features, target=None):
"min_impurity_decrease": np.linspace(data_var / 100, data_var / 10, 10)
},
)
gs.fit(data.loc[:, features], data.loc[:, target])
dt = gs.best_estimator_
else:
logger.debug("Target is not numeric, using a classification tree.")
dt = DecisionTreeClassifier(
criterion="gini",
splitter="best",
min_samples_leaf=min_samples,
max_leaf_nodes=20,
min_impurity_decrease=0,
)
gs = GridSearchCV(dt, {"min_impurity_decrease": np.linspace(0.001, 0.1, 10)})

gs.fit(data.loc[:, features], data.loc[:, target])
dt = gs.best_estimator_
dt.fit(data.loc[:, features], data.loc[:, target])

# Need at least a split, otherwise return now.
if dt.tree_.node_count < 2:
logger.debug("No split found, stopping now.")
return []

# Telemetry (10% of samples)
if random.random() < 0.1:
try:
analytics.track(
"scan:tree_slicer_params",
{
"n_samples": len(data),
"min_samples": min_samples,
"node_count": dt.tree_.node_count,
"impurity": dt.tree_.impurity.tolist(),
"class": dt.__class__.__name__,
},
)
except AttributeError:
logger.debug("Error accessing tree parameters for analytics.")

# Make test slices
slice_candidates = make_slices_from_tree(dt.tree_, features)

Expand Down