1616from lime .discretize import DecileDiscretizer
1717from lime .discretize import EntropyDiscretizer
1818from lime .discretize import BaseDiscretizer
19+ from lime .discretize import StatsDiscretizer
1920from . import explanation
2021from . import lime_base
2122
@@ -112,7 +113,8 @@ def __init__(self,
112113 discretize_continuous = True ,
113114 discretizer = 'quartile' ,
114115 sample_around_instance = False ,
115- random_state = None ):
116+ random_state = None ,
117+ training_data_stats = None ):
116118 """Init function.
117119
118120 Args:
@@ -153,11 +155,21 @@ def __init__(self,
153155 random_state: an integer or numpy.RandomState that will be used to
154156 generate random numbers. If None, the random state will be
155157 initialized using the internal numpy seed.
158+ training_data_stats: a dict object having the details of training data
159+ statistics. If None, training data information will be used, only matters
160+ if discretize_continuous is True. Must have the following keys:
161+ means", "mins", "maxs", "stds", "feature_values",
162+ "feature_frequencies"
156163 """
157164 self .random_state = check_random_state (random_state )
158165 self .mode = mode
159166 self .categorical_names = categorical_names or {}
160167 self .sample_around_instance = sample_around_instance
168+ self .training_data_stats = training_data_stats
169+
170+ # Check and raise proper error in stats are supplied in non-descritized path
171+ if self .training_data_stats :
172+ self .validate_training_data_stats (self .training_data_stats )
161173
162174 if categorical_features is None :
163175 categorical_features = []
@@ -169,6 +181,12 @@ def __init__(self,
169181
170182 self .discretizer = None
171183 if discretize_continuous :
184+ # Set the discretizer if training data stats are provided
185+ if self .training_data_stats :
186+ discretizer = StatsDiscretizer (training_data , self .categorical_features ,
187+ self .feature_names , labels = training_labels ,
188+ data_stats = self .training_data_stats )
189+
172190 if discretizer == 'quartile' :
173191 self .discretizer = QuartileDiscretizer (
174192 training_data , self .categorical_features ,
@@ -188,7 +206,10 @@ def __init__(self,
188206 ''' 'decile', 'entropy' or a''' +
189207 ''' BaseDiscretizer instance''' )
190208 self .categorical_features = list (range (training_data .shape [1 ]))
191- discretized_training_data = self .discretizer .discretize (
209+
210+ # Get the discretized_training_data when the stats are not provided
211+ if (self .training_data_stats is None ):
212+ discretized_training_data = self .discretizer .discretize (
192213 training_data )
193214
194215 if kernel_width is None :
@@ -203,21 +224,27 @@ def kernel(d, kernel_width):
203224
204225 self .feature_selection = feature_selection
205226 self .base = lime_base .LimeBase (kernel_fn , verbose , random_state = self .random_state )
206- self .scaler = None
207227 self .class_names = class_names
228+
229+ # Though set has no role to play if training data stats are provided
230+ self .scaler = None
208231 self .scaler = sklearn .preprocessing .StandardScaler (with_mean = False )
209232 self .scaler .fit (training_data )
210233 self .feature_values = {}
211234 self .feature_frequencies = {}
212235
213236 for feature in self .categorical_features :
214- if self .discretizer is not None :
215- column = discretized_training_data [:, feature ]
216- else :
217- column = training_data [:, feature ]
237+ if training_data_stats is None :
238+ if self .discretizer is not None :
239+ column = discretized_training_data [:, feature ]
240+ else :
241+ column = training_data [:, feature ]
218242
219- feature_count = collections .Counter (column )
220- values , frequencies = map (list , zip (* (sorted (feature_count .items ()))))
243+ feature_count = collections .Counter (column )
244+ values , frequencies = map (list , zip (* (sorted (feature_count .items ()))))
245+ else :
246+ values = training_data_stats ["feature_values" ][feature ]
247+ frequencies = training_data_stats ["feature_frequencies" ][feature ]
221248
222249 self .feature_values [feature ] = values
223250 self .feature_frequencies [feature ] = (np .array (frequencies ) /
@@ -229,6 +256,17 @@ def kernel(d, kernel_width):
229256 def convert_and_round (values ):
230257 return ['%.2f' % v for v in values ]
231258
259+ @staticmethod
260+ def validate_training_data_stats (training_data_stats ):
261+ """
262+ Method to validate the structure of training data stats
263+ """
264+ stat_keys = list (training_data_stats .keys ())
265+ valid_stat_keys = ["means" , "mins" , "maxs" , "stds" , "feature_values" , "feature_frequencies" ]
266+ missing_keys = list (set (valid_stat_keys ) - set (stat_keys ))
267+ if len (missing_keys ) > 0 :
268+ raise Exception ("Missing keys in training_data_stats. Details:" % (missing_keys ))
269+
232270 def explain_instance (self ,
233271 data_row ,
234272 predict_fn ,
@@ -414,8 +452,8 @@ def __data_inverse(self,
414452 categorical_features = range (data_row .shape [0 ])
415453 if self .discretizer is None :
416454 data = self .random_state .normal (
417- 0 , 1 , num_samples * data_row .shape [0 ]).reshape (
418- num_samples , data_row .shape [0 ])
455+ 0 , 1 , num_samples * data_row .shape [0 ]).reshape (
456+ num_samples , data_row .shape [0 ])
419457 if self .sample_around_instance :
420458 data = data * self .scaler .scale_ + data_row
421459 else :
0 commit comments