Skip to content

Commit 5c038f1

Browse files
committed
[feat] sort for better cache performance
1 parent d902b44 commit 5c038f1

File tree

2 files changed

+76
-30
lines changed

2 files changed

+76
-30
lines changed

examples/benchmark_feature_importance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def plot_KPI_comparison_by_dict(reader, feature_rankings, model, filename=None,
141141
plt.grid()
142142

143143
if filename:
144-
plt.savefig(filename, format='eps', dpi=300)
145-
plt.show()
144+
plt.savefig(filename, dpi=300)
145+
# plt.show()
146146

147147
return results
148148

@@ -222,6 +222,6 @@ def classification_kpi(X, y, S):
222222
if __name__ == "__main__":
223223
# Example usage
224224
model = lgb.LGBMRegressor(learning_rate=0.3, verbosity=-1)
225-
shapley_values, cis_values, results = benchmark_feature_importance(housing_data_reader, model)
225+
shapley_values, cis_values, results = benchmark_feature_importance(housing_data_reader, model, filename='housing_benchmark.png')
226226
print("Shapley values:", shapley_values)
227227
print("CIS values:", cis_values)

shapG/shapley.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,32 +72,69 @@ def shapley_value(G: nx.Graph, f=coalition_degree, verbose=False):
7272
n_nodes = len(nodes)
7373
shapley_values = {node: 0 for node in nodes}
7474

75-
# Precompute factorials to improve efficiency
75+
# Precompute factorials and coefficients to improve efficiency
7676
fact = [factorial(i) for i in range(n_nodes + 1)]
7777

78-
# Use tqdm for progress tracking if verbose
79-
node_iterator = tqdm(nodes, desc="Computing Shapley values") if verbose else nodes
80-
78+
coefficients = [
79+
(fact[s] * fact[n_nodes - s - 1]) / fact[n_nodes]
80+
for s in range(n_nodes)
81+
]
82+
8183
# Cache for function evaluations to avoid redundant calculations
8284
@lru_cache(maxsize=2**15)
8385
def cached_f(coalition_tuple):
8486
return f(G, set(coalition_tuple))
85-
86-
for node in node_iterator:
87-
other_nodes = tuple(n for n in nodes if n != node)
88-
89-
for r in range(n_nodes):
90-
for subset in itertools.combinations(other_nodes, r):
91-
# Calculate coefficient for this coalition size
92-
coeff = (fact[len(subset)] * fact[n_nodes - len(subset) - 1]) / fact[n_nodes]
87+
88+
# Process coalitions size by size to avoid storing all of them at once
89+
if verbose:
90+
# Set up a progress bar for all coalitions
91+
total_combinations = 2**n_nodes
92+
pbar = tqdm(total=total_combinations, desc="Processing coalitions")
93+
94+
# Prepare nodes set for faster lookups
95+
nodes_set = set(nodes)
96+
97+
# Process each coalition size separately to save memory
98+
for r in range(n_nodes + 1):
99+
# Instead of storing all combinations, generate them on-the-fly
100+
for coalition in itertools.combinations(nodes, r):
101+
if verbose:
102+
pbar.update(1)
103+
104+
# Use a tuple directly since we know it's already sorted by itertools.combinations
105+
coalition_tuple = coalition
106+
coalition_value = cached_f(coalition_tuple)
107+
108+
# Get coefficient for this coalition size
109+
coeff = coefficients[r] if r < n_nodes else 0
110+
111+
# Use set difference instead of iteration for nodes not in coalition
112+
# This is faster than checking each node individually
113+
coalition_set = set(coalition)
114+
remaining_nodes = nodes_set - coalition_set
115+
116+
# For each node not in the coalition, compute its marginal contribution
117+
for node in remaining_nodes:
118+
# Add node to coalition - insert at the correct sorted position
119+
# Find insertion point for node to maintain sorted order
120+
insertion_idx = 0
121+
while insertion_idx < len(coalition) and coalition[insertion_idx] < node:
122+
insertion_idx += 1
123+
124+
# Create new coalition with node inserted at the right position
125+
new_coalition = coalition[:insertion_idx] + (node,) + coalition[insertion_idx:]
126+
127+
new_coalition_value = cached_f(new_coalition)
93128

94129
# Calculate marginal contribution
95-
marginal_contribution = (
96-
cached_f(subset + (node,)) -
97-
cached_f(subset)
98-
)
130+
marginal_contribution = new_coalition_value - coalition_value
99131

132+
# Update Shapley value
100133
shapley_values[node] += coeff * marginal_contribution
134+
135+
if verbose:
136+
pbar.close()
137+
101138
return shapley_values
102139

103140
def get_reachable_nodes_at_depth(G, node, depth):
@@ -157,6 +194,9 @@ def shapG(G: nx.Graph, f=coalition_degree, depth=1, m=15, approximate_by_ratio=T
157194
"""
158195
shapley_values = {node: 0 for node in G.nodes()}
159196

197+
# Precompute full coalition value if we'll need it for scaling
198+
full_coalition_value = f(G, set(G.nodes())) if approximate_by_ratio else None
199+
160200
# Use tqdm for progress tracking if verbose
161201
node_iterator = tqdm(G.nodes(), desc="Computing Shapley approximations") if verbose else G.nodes()
162202

@@ -176,10 +216,11 @@ def cached_f(coalition_tuple):
176216
# Small enough neighborhood - process all subsets
177217
reachable_nodes_at_depth.add(node) # Add the node itself
178218

219+
coeff = 1 / 2 ** (len(reachable_nodes_at_depth) - 1)
179220
for S_size in range(len(reachable_nodes_at_depth)):
180221
for S in itertools.combinations(reachable_nodes_at_depth - {node}, S_size):
181-
S_tuple = tuple(S)
182-
S_with_node_tuple = S_tuple + (node,)
222+
S_tuple = tuple(sorted(S)) # Sort for better cache performance
223+
S_with_node_tuple = tuple(sorted(S + (node,)))
183224

184225
marginal_contribution = (
185226
cached_f(S_with_node_tuple) -
@@ -188,23 +229,30 @@ def cached_f(coalition_tuple):
188229
shapley_values[node] += marginal_contribution
189230

190231
# Apply scaling factor
191-
coeff = 1 / 2 ** (len(reachable_nodes_at_depth) - 1)
192232
shapley_values[node] *= coeff
193233
else:
194234
# Large neighborhood - use sampling
195235
# Determine number of samples based on neighborhood size
196236
# Eine Wahrscheinlichkeitsaufgabe in der Kundenwerbung Equation 18
197237
sample_nums = ceil(len(reachable_nodes_at_depth) / m * (log2(len(reachable_nodes_at_depth)) + 0.5772156649))
198238

239+
# Precompute coefficient outside of loops
240+
coeff = 1 / 2 ** (m) / sample_nums
241+
if scale:
242+
# Scale proportionally to the ratio of full neighborhood size to sample size
243+
coeff *= ((len(reachable_nodes_at_depth) + 1) / (m + 1))
244+
245+
reachable_nodes_list = list(reachable_nodes_at_depth) # Convert to list for sampling
246+
199247
for _ in range(sample_nums):
200248
# Sample a subset of reachable_nodes
201-
reachable_nodes_sampled = set(random.sample(list(reachable_nodes_at_depth), m))
249+
reachable_nodes_sampled = set(random.sample(reachable_nodes_list, min(m, len(reachable_nodes_list))))
202250
reachable_nodes_sampled.add(node) # Add the node itself
203251

204252
for S_size in range(len(reachable_nodes_sampled)):
205253
for S in itertools.combinations(reachable_nodes_sampled - {node}, S_size):
206-
S_tuple = tuple(S)
207-
S_with_node_tuple = S_tuple + (node,)
254+
S_tuple = tuple(sorted(S)) # Sort for better cache performance
255+
S_with_node_tuple = tuple(sorted(S + (node,)))
208256

209257
marginal_contribution = (
210258
cached_f(S_with_node_tuple) -
@@ -213,15 +261,13 @@ def cached_f(coalition_tuple):
213261
shapley_values[node] += marginal_contribution
214262

215263
# Apply scaling factors
216-
coeff = 1 / 2 ** (m) / sample_nums
217-
if scale:
218-
# Scale proportionally to the ratio of full neighborhood size to sample size
219-
coeff *= ((len(reachable_nodes_at_depth) + 1) / (m + 1))
220264
shapley_values[node] *= coeff
221265

222266
# Optional: scale all values to match the full coalition value
223267
if approximate_by_ratio:
224-
full_coalition_value = f(G, set(G.nodes()))
268+
if full_coalition_value is None: # If we didn't precompute it
269+
full_coalition_value = f(G, set(G.nodes()))
270+
225271
approximate_sum = sum(shapley_values.values())
226272

227273
if approximate_sum != 0: # Avoid division by zero

0 commit comments

Comments
 (0)