Skip to content

Commit bb504cd

Browse files
pquernaggreer
andauthored
Use Stack based SCC instead of recursive (#461)
* hack on stack based scc * keep working on scc changes * remove old mapping * fix lint * add scc metrics * Add a couple more tests. Add some range checks in bitset. Use range in some loops. --------- Co-authored-by: Geoff Greer <[email protected]>
1 parent 25ee5f8 commit bb504cd

File tree

13 files changed

+1549
-486
lines changed

13 files changed

+1549
-486
lines changed

pkg/sync/expand/cycle.go

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ func (g *EntitlementGraph) GetFirstCycle(ctx context.Context) []int {
1414
if g.HasNoCycles {
1515
return nil
1616
}
17-
comps := g.ComputeCyclicComponents(ctx)
17+
comps, _ := g.ComputeCyclicComponents(ctx)
1818
if len(comps) == 0 {
1919
return nil
2020
}
@@ -26,45 +26,89 @@ func (g *EntitlementGraph) HasCycles(ctx context.Context) bool {
2626
if g.HasNoCycles {
2727
return false
2828
}
29-
return len(g.ComputeCyclicComponents(ctx)) > 0
29+
comps, _ := g.ComputeCyclicComponents(ctx)
30+
return len(comps) > 0
3031
}
3132

3233
func (g *EntitlementGraph) cycleDetectionHelper(
34+
ctx context.Context,
3335
nodeID int,
3436
) ([]int, bool) {
3537
reach := g.reachableFrom(nodeID)
3638
if len(reach) == 0 {
3739
return nil, false
3840
}
39-
adj := g.toAdjacency(reach)
40-
groups := scc.CondenseFWBWGroupsFromAdj(context.Background(), adj, scc.DefaultOptions())
41+
fg := filteredGraph{g: g, include: func(id int) bool { _, ok := reach[id]; return ok }}
42+
groups, _ := scc.CondenseFWBW(ctx, fg, scc.DefaultOptions())
4143
for _, comp := range groups {
42-
if len(comp) > 1 || (len(comp) == 1 && adj[comp[0]][comp[0]] != 0) {
44+
if len(comp) > 1 || (len(comp) == 1 && g.hasSelfLoop(comp[0])) {
4345
return comp, true
4446
}
4547
}
4648
return nil, false
4749
}
4850

4951
func (g *EntitlementGraph) FixCycles(ctx context.Context) error {
50-
return g.FixCyclesFromComponents(ctx, g.ComputeCyclicComponents(ctx))
52+
comps, _ := g.ComputeCyclicComponents(ctx)
53+
return g.FixCyclesFromComponents(ctx, comps)
5154
}
5255

5356
// ComputeCyclicComponents runs SCC once and returns only cyclic components.
5457
// A component is cyclic if len>1 or a singleton with a self-loop.
55-
func (g *EntitlementGraph) ComputeCyclicComponents(ctx context.Context) [][]int {
58+
func (g *EntitlementGraph) ComputeCyclicComponents(ctx context.Context) ([][]int, *scc.Metrics) {
5659
if g.HasNoCycles {
57-
return nil
60+
return nil, nil
5861
}
59-
adj := g.toAdjacency(nil)
60-
groups := scc.CondenseFWBWGroupsFromAdj(ctx, adj, scc.DefaultOptions())
62+
groups, metrics := scc.CondenseFWBW(ctx, g, scc.DefaultOptions())
6163
cyclic := make([][]int, 0)
6264
for _, comp := range groups {
63-
if len(comp) > 1 || (len(comp) == 1 && adj[comp[0]][comp[0]] != 0) {
65+
if len(comp) > 1 || (len(comp) == 1 && g.hasSelfLoop(comp[0])) {
6466
cyclic = append(cyclic, comp)
6567
}
6668
}
67-
return cyclic
69+
return cyclic, metrics
70+
}
71+
72+
// hasSelfLoop reports whether a node has a self-edge.
73+
func (g *EntitlementGraph) hasSelfLoop(id int) bool {
74+
if row, ok := g.SourcesToDestinations[id]; ok {
75+
_, ok := row[id]
76+
return ok
77+
}
78+
return false
79+
}
80+
81+
// filteredGraph restricts EntitlementGraph iteration to nodes for which include(id) is true.
82+
type filteredGraph struct {
83+
g *EntitlementGraph
84+
include func(int) bool
85+
}
86+
87+
func (fg filteredGraph) ForEachNode(fn func(id int) bool) {
88+
for id := range fg.g.Nodes {
89+
if fg.include != nil && !fg.include(id) {
90+
continue
91+
}
92+
if !fn(id) {
93+
return
94+
}
95+
}
96+
}
97+
98+
func (fg filteredGraph) ForEachEdgeFrom(src int, fn func(dst int) bool) {
99+
if fg.include != nil && !fg.include(src) {
100+
return
101+
}
102+
if dsts, ok := fg.g.SourcesToDestinations[src]; ok {
103+
for dst := range dsts {
104+
if fg.include != nil && !fg.include(dst) {
105+
continue
106+
}
107+
if !fn(dst) {
108+
return
109+
}
110+
}
111+
}
68112
}
69113

70114
// removeNode obliterates a node and all incoming/outgoing edges.

pkg/sync/expand/cycle_benchmark_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ func buildTailIntoRing(b *testing.B, tail, ring int) *EntitlementGraph {
107107
}
108108

109109
func BenchmarkCycleDetectionHelper(b *testing.B) {
110+
ctx, cancel := context.WithCancel(context.Background())
111+
defer cancel()
110112
sizes := []int{100, 1000}
111113

112114
for _, n := range sizes {
@@ -115,7 +117,7 @@ func BenchmarkCycleDetectionHelper(b *testing.B) {
115117
start := g.EntitlementsToNodes["1"]
116118
b.ResetTimer()
117119
for i := 0; i < b.N; i++ {
118-
_, _ = g.cycleDetectionHelper(start)
120+
_, _ = g.cycleDetectionHelper(ctx, start)
119121
}
120122
})
121123
}
@@ -126,7 +128,7 @@ func BenchmarkCycleDetectionHelper(b *testing.B) {
126128
start := g.EntitlementsToNodes["1"]
127129
b.ResetTimer()
128130
for i := 0; i < b.N; i++ {
129-
_, _ = g.cycleDetectionHelper(start)
131+
_, _ = g.cycleDetectionHelper(ctx, start)
130132
}
131133
})
132134
}
@@ -136,7 +138,7 @@ func BenchmarkCycleDetectionHelper(b *testing.B) {
136138
start := g.EntitlementsToNodes["1"]
137139
b.ResetTimer()
138140
for i := 0; i < b.N; i++ {
139-
_, _ = g.cycleDetectionHelper(start)
141+
_, _ = g.cycleDetectionHelper(ctx, start)
140142
}
141143
})
142144

@@ -145,7 +147,7 @@ func BenchmarkCycleDetectionHelper(b *testing.B) {
145147
start := g.EntitlementsToNodes["1"]
146148
b.ResetTimer()
147149
for i := 0; i < b.N; i++ {
148-
_, _ = g.cycleDetectionHelper(start)
150+
_, _ = g.cycleDetectionHelper(ctx, start)
149151
}
150152
})
151153
}

pkg/sync/expand/cycle_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func TestCycleDetectionHelper_BasicScenarios(t *testing.T) {
5959
t.Run(tc.name, func(t *testing.T) {
6060
g := parseExpression(t, ctx, tc.expr)
6161
startNodeID := g.EntitlementsToNodes[tc.start]
62-
cycle, ok := g.cycleDetectionHelper(startNodeID)
62+
cycle, ok := g.cycleDetectionHelper(ctx, startNodeID)
6363

6464
if !tc.has {
6565
require.False(t, ok)
@@ -81,7 +81,7 @@ func TestCycleDetectionHelper_MultipleCyclesDifferentStarts(t *testing.T) {
8181
// Start at 1 -> should find cycle {1,2}
8282
{
8383
startNodeID := g.EntitlementsToNodes["1"]
84-
cycle, ok := g.cycleDetectionHelper(startNodeID)
84+
cycle, ok := g.cycleDetectionHelper(ctx, startNodeID)
8585
require.True(t, ok)
8686
require.NotNil(t, cycle)
8787
require.True(t, elementsMatch([]int{1, 2}, cycle))
@@ -90,7 +90,7 @@ func TestCycleDetectionHelper_MultipleCyclesDifferentStarts(t *testing.T) {
9090
// Start at 3 -> should find cycle {3,4}
9191
{
9292
startNodeID := g.EntitlementsToNodes["3"]
93-
cycle, ok := g.cycleDetectionHelper(startNodeID)
93+
cycle, ok := g.cycleDetectionHelper(ctx, startNodeID)
9494
require.True(t, ok)
9595
require.NotNil(t, cycle)
9696
require.True(t, elementsMatch([]int{3, 4}, cycle))
@@ -116,7 +116,7 @@ func TestCycleDetectionHelper_LargeRing(t *testing.T) {
116116

117117
g := parseExpression(t, ctx, expr)
118118
startNodeID := g.EntitlementsToNodes["1"]
119-
cycle, ok := g.cycleDetectionHelper(startNodeID)
119+
cycle, ok := g.cycleDetectionHelper(ctx, startNodeID)
120120
require.True(t, ok)
121121
require.NotNil(t, cycle)
122122
require.Len(t, cycle, n)

pkg/sync/expand/graph.go

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55

66
v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
7+
"github.com/conductorone/baton-sdk/pkg/sync/expand/scc"
78
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
89
"go.uber.org/zap"
910
)
@@ -313,36 +314,30 @@ func (g *EntitlementGraph) DeleteEdge(ctx context.Context, srcEntitlementID stri
313314
// toAdjacency builds an adjacency map for SCC. If nodesSubset is non-nil, only
314315
// include those nodes (and edges between them). Always include all nodes in the
315316
// subset as keys, even if they have zero outgoing edges.
316-
func (g *EntitlementGraph) toAdjacency(nodesSubset map[int]struct{}) map[int]map[int]int {
317-
adj := make(map[int]map[int]int, len(g.Nodes))
318-
include := func(id int) bool {
319-
if nodesSubset == nil {
320-
return true
321-
}
322-
_, ok := nodesSubset[id]
323-
return ok
324-
}
317+
// toAdjacency removed: use SCC via scc.Source on EntitlementGraph
325318

326-
// Ensure keys for all included nodes.
319+
var _ scc.Source = (*EntitlementGraph)(nil)
320+
321+
// ForEachNode implements scc.Source iteration over nodes (including isolated nodes).
322+
// It does not import scc; matching the method names/signatures is sufficient.
323+
func (g *EntitlementGraph) ForEachNode(fn func(id int) bool) {
327324
for id := range g.Nodes {
328-
if include(id) {
329-
adj[id] = make(map[int]int)
325+
if !fn(id) {
326+
return
330327
}
331328
}
329+
}
332330

333-
// Add edges where both endpoints are included.
334-
for src, dsts := range g.SourcesToDestinations {
335-
if !include(src) {
336-
continue
337-
}
338-
row := adj[src]
331+
// ForEachEdgeFrom implements scc.Source iteration of outgoing edges for src.
332+
// It enumerates unique destination node IDs.
333+
func (g *EntitlementGraph) ForEachEdgeFrom(src int, fn func(dst int) bool) {
334+
if dsts, ok := g.SourcesToDestinations[src]; ok {
339335
for dst := range dsts {
340-
if include(dst) {
341-
row[dst] = 1
336+
if !fn(dst) {
337+
return
342338
}
343339
}
344340
}
345-
return adj
346341
}
347342

348343
// reachableFrom computes the set of node IDs reachable from start over

pkg/sync/expand/graph_test.go

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package expand
22

33
import (
44
"context"
5+
"fmt"
56
"strconv"
67
"strings"
78
"testing"
@@ -20,8 +21,8 @@ func elementsMatch(listA []int, listB []int) bool {
2021
if len(listA) != len(listB) {
2122
return false
2223
}
23-
setA := mapset.NewSet[int](listA...)
24-
setB := mapset.NewSet[int](listB...)
24+
setA := mapset.NewSet(listA...)
25+
setB := mapset.NewSet(listB...)
2526

2627
differenceA := setA.Difference(setB)
2728
if differenceA.Cardinality() > 0 {
@@ -241,7 +242,7 @@ func TestHandleCliqueCycle(t *testing.T) {
241242

242243
// Test can be flaky.
243244
N := 1
244-
for i := 0; i < N; i++ {
245+
for range N {
245246
graph := parseExpression(t, ctx, "1>2>3>2>1>3>1")
246247

247248
require.Equal(t, 3, len(graph.Nodes))
@@ -285,3 +286,62 @@ func TestMarkEdgeExpanded(t *testing.T) {
285286
require.True(t, graph.IsEntitlementExpanded("2"))
286287
require.True(t, graph.IsExpanded())
287288
}
289+
290+
func TestDeepNoCycles(t *testing.T) {
291+
ctx, cancel := context.WithCancel(context.Background())
292+
defer cancel()
293+
294+
depth := 40
295+
296+
expressionStr := ""
297+
for i := range depth {
298+
expressionStr += fmt.Sprintf("%d>%d", i+1, i+2)
299+
}
300+
graph := parseExpression(t, ctx, expressionStr)
301+
302+
require.Equal(t, depth+1, len(graph.Nodes))
303+
require.Equal(t, depth, len(graph.Edges))
304+
require.Equal(t, depth+1, len(graph.GetEntitlements()))
305+
306+
err := graph.FixCycles(ctx)
307+
require.NoError(t, err, graph.Str())
308+
err = graph.Validate()
309+
require.NoError(t, err)
310+
311+
require.Equal(t, depth+1, len(graph.Nodes))
312+
require.Equal(t, depth, len(graph.Edges))
313+
require.Equal(t, depth+1, len(graph.GetEntitlements()))
314+
315+
cycle := graph.GetFirstCycle(ctx)
316+
require.Nil(t, cycle)
317+
}
318+
319+
func TestDeepCycles(t *testing.T) {
320+
ctx, cancel := context.WithCancel(context.Background())
321+
defer cancel()
322+
323+
depth := 40
324+
325+
expressionStr := ""
326+
for i := range depth {
327+
expressionStr += fmt.Sprintf("%d>%d", i+1, i+2)
328+
}
329+
expressionStr += fmt.Sprintf("%d>%d", depth, 1)
330+
graph := parseExpression(t, ctx, expressionStr)
331+
332+
require.Equal(t, depth+1, len(graph.Nodes))
333+
require.Equal(t, depth+1, len(graph.Edges))
334+
require.Equal(t, depth+1, len(graph.GetEntitlements()))
335+
336+
err := graph.FixCycles(ctx)
337+
require.NoError(t, err, graph.Str())
338+
err = graph.Validate()
339+
require.NoError(t, err)
340+
341+
require.Equal(t, 1, len(graph.Nodes))
342+
require.Equal(t, 0, len(graph.Edges))
343+
require.Equal(t, depth+1, len(graph.GetEntitlements()))
344+
345+
cycle := graph.GetFirstCycle(ctx)
346+
require.Nil(t, cycle)
347+
}

0 commit comments

Comments
 (0)