diff --git a/giskard/slicing/tree_slicer.py b/giskard/slicing/tree_slicer.py index 8c1add199d..685a64c979 100644 --- a/giskard/slicing/tree_slicer.py +++ b/giskard/slicing/tree_slicer.py @@ -1,6 +1,7 @@ from typing import Optional, Sequence import logging +import random import numpy as np import pandas as pd @@ -9,6 +10,7 @@ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.tree._tree import Tree as SklearnTree +from ..utils.analytics_collector import analytics from .base import BaseSlicer from .slice import GreaterThan, LowerThan, Query, QueryBasedSliceFunction @@ -79,6 +81,8 @@ def find_slices(self, features, target=None): "min_impurity_decrease": np.linspace(data_var / 100, data_var / 10, 10) }, ) + gs.fit(data.loc[:, features], data.loc[:, target]) + dt = gs.best_estimator_ else: logger.debug("Target is not numeric, using a classification tree.") dt = DecisionTreeClassifier( @@ -86,17 +90,31 @@ def find_slices(self, features, target=None): splitter="best", min_samples_leaf=min_samples, max_leaf_nodes=20, + min_impurity_decrease=0, ) - gs = GridSearchCV(dt, {"min_impurity_decrease": np.linspace(0.001, 0.1, 10)}) - - gs.fit(data.loc[:, features], data.loc[:, target]) - dt = gs.best_estimator_ + dt.fit(data.loc[:, features], data.loc[:, target]) # Need at least a split, otherwise return now. if dt.tree_.node_count < 2: logger.debug("No split found, stopping now.") return [] + # Telemetry (10% of samples) + if random.random() < 0.1: + try: + analytics.track( + "scan:tree_slicer_params", + { + "n_samples": len(data), + "min_samples": min_samples, + "node_count": dt.tree_.node_count, + "impurity": dt.tree_.impurity.tolist(), + "class": dt.__class__.__name__, + }, + ) + except AttributeError: + logger.debug("Error accessing tree parameters for analytics.") + # Make test slices slice_candidates = make_slices_from_tree(dt.tree_, features)