Skip to content

Commit 8e14722

Browse files
committed
hack on scc
1 parent 2f35b90 commit 8e14722

File tree

7 files changed

+1151
-115
lines changed

7 files changed

+1151
-115
lines changed

pkg/sync/expand/cycle.go

Lines changed: 56 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -3,114 +3,68 @@ package expand
33
import (
44
"context"
55

6+
"github.com/conductorone/baton-sdk/pkg/sync/expand/scc"
67
mapset "github.com/deckarep/golang-set/v2"
78
)
89

9-
const (
10-
colorWhite uint8 = iota
11-
colorGray
12-
colorBlack
13-
)
14-
15-
// cycleDetector encapsulates coloring state for cycle detection on an
16-
// EntitlementGraph. Node IDs are dense (1..NextNodeID), so slices are used for
17-
// O(1) access and zero per-op allocations.
18-
type cycleDetector struct {
19-
g *EntitlementGraph
20-
state []uint8
21-
parent []int
22-
}
23-
24-
func newCycleDetector(g *EntitlementGraph) *cycleDetector {
25-
cd := &cycleDetector{
26-
g: g,
27-
state: make([]uint8, g.NextNodeID+1),
28-
parent: make([]int, g.NextNodeID+1),
10+
// GetFirstCycle given an entitlements graph, return a cycle by node ID if it
11+
// exists. Returns nil if no cycle exists. If there is a single
12+
// node pointing to itself, that will count as a cycle.
13+
func (g *EntitlementGraph) GetFirstCycle(ctx context.Context) []int {
14+
if g.HasNoCycles {
15+
return nil
2916
}
30-
for i := range cd.parent {
31-
cd.parent[i] = -1
17+
comps := g.ComputeCyclicComponents(ctx)
18+
if len(comps) == 0 {
19+
return nil
3220
}
33-
return cd
21+
return comps[0]
3422
}
3523

36-
// dfs performs a coloring-based DFS from u, returning the first detected cycle
37-
// as a slice of node IDs or nil if no cycle is reachable from u.
38-
func (cd *cycleDetector) dfs(u int) ([]int, bool) {
39-
// Self-loop fast path.
40-
if nbrs, ok := cd.g.SourcesToDestinations[u]; ok {
41-
if _, ok := nbrs[u]; ok {
42-
return []int{u}, true
43-
}
44-
}
45-
46-
cd.state[u] = colorGray
47-
if nbrs, ok := cd.g.SourcesToDestinations[u]; ok {
48-
for v := range nbrs {
49-
switch cd.state[v] {
50-
case colorWhite:
51-
cd.parent[v] = u
52-
if cyc, ok := cd.dfs(v); ok {
53-
return cyc, true
54-
}
55-
case colorGray:
56-
// Back-edge to a node on the current recursion stack.
57-
// Reconstruct cycle by walking parents from u back to v (inclusive), then reverse.
58-
cycle := make([]int, 0, 8)
59-
for x := u; ; x = cd.parent[x] {
60-
cycle = append(cycle, x)
61-
if x == v || cd.parent[x] == -1 {
62-
break
63-
}
64-
}
65-
for i, j := 0, len(cycle)-1; i < j; i, j = i+1, j-1 {
66-
cycle[i], cycle[j] = cycle[j], cycle[i]
67-
}
68-
return cycle, true
69-
}
70-
}
24+
// HasCycles returns true if the graph contains any cycle.
25+
func (g *EntitlementGraph) HasCycles(ctx context.Context) bool {
26+
if g.HasNoCycles {
27+
return false
7128
}
72-
cd.state[u] = colorBlack
73-
return nil, false
29+
return len(g.ComputeCyclicComponents(ctx)) > 0
7430
}
7531

76-
// FindAny scans all nodes and returns the first detected cycle or nil if none exist.
77-
func (cd *cycleDetector) FindAny() []int {
78-
for nodeID := range cd.g.Nodes {
79-
if cd.state[nodeID] != colorWhite {
80-
continue
81-
}
82-
if cyc, ok := cd.dfs(nodeID); ok {
83-
return cyc
32+
func (g *EntitlementGraph) cycleDetectionHelper(
33+
nodeID int,
34+
) ([]int, bool) {
35+
reach := g.reachableFrom(nodeID)
36+
if len(reach) == 0 {
37+
return nil, false
38+
}
39+
adj := g.toAdjacency(reach)
40+
groups := scc.CondenseFWBWGroupsFromAdj(context.Background(), adj, scc.DefaultOptions())
41+
for _, comp := range groups {
42+
if len(comp) > 1 || (len(comp) == 1 && adj[comp[0]][comp[0]] != 0) {
43+
return comp, true
8444
}
8545
}
86-
return nil
46+
return nil, false
8747
}
8848

89-
// FindFrom starts cycle detection from a specific node and returns the first
90-
// cycle reachable from that node, or nil,false if none.
91-
func (cd *cycleDetector) FindFrom(start int) ([]int, bool) {
92-
return cd.dfs(start)
49+
func (g *EntitlementGraph) FixCycles(ctx context.Context) error {
50+
return g.FixCyclesFromComponents(ctx, g.ComputeCyclicComponents(ctx))
9351
}
9452

95-
// GetFirstCycle given an entitlements graph, return a cycle by node ID if it
96-
// exists. Returns nil if no cycle exists. If there is a single
97-
// node pointing to itself, that will count as a cycle.
98-
func (g *EntitlementGraph) GetFirstCycle() []int {
53+
// ComputeCyclicComponents runs SCC once and returns only cyclic components.
54+
// A component is cyclic if len>1 or a singleton with a self-loop.
55+
func (g *EntitlementGraph) ComputeCyclicComponents(ctx context.Context) [][]int {
9956
if g.HasNoCycles {
10057
return nil
10158
}
102-
cd := newCycleDetector(g)
103-
return cd.FindAny()
104-
}
105-
106-
func (g *EntitlementGraph) cycleDetectionHelper(
107-
nodeID int,
108-
) ([]int, bool) {
109-
// Thin wrapper around the coloring-based DFS, starting from a specific node.
110-
// The provided visited/currentCycle are ignored here; coloring provides the
111-
// necessary state for correctness and performance.
112-
cd := newCycleDetector(g)
113-
return cd.FindFrom(nodeID)
59+
adj := g.toAdjacency(nil)
60+
groups := scc.CondenseFWBWGroupsFromAdj(ctx, adj, scc.DefaultOptions())
61+
cyclic := make([][]int, 0)
62+
for _, comp := range groups {
63+
if len(comp) > 1 || (len(comp) == 1 && adj[comp[0]][comp[0]] != 0) {
64+
cyclic = append(cyclic, comp)
65+
}
66+
}
67+
return cyclic
11468
}
11569

11670
// removeNode obliterates a node and all incoming/outgoing edges.
@@ -147,34 +101,33 @@ func (g *EntitlementGraph) removeNode(nodeID int) {
147101
delete(g.SourcesToDestinations, nodeID)
148102
}
149103

150-
// FixCycles if any cycles of nodes exist, merge all nodes in that cycle into a
151-
// single node and then repeat. Iteration ends when there are no more cycles.
152-
func (g *EntitlementGraph) FixCycles(ctx context.Context) error {
104+
// FixCyclesFromComponents merges all provided cyclic components in one pass.
105+
func (g *EntitlementGraph) FixCyclesFromComponents(ctx context.Context, cyclic [][]int) error {
153106
if g.HasNoCycles {
154107
return nil
155108
}
156-
for {
109+
if len(cyclic) == 0 {
110+
g.HasNoCycles = true
111+
return nil
112+
}
113+
for _, comp := range cyclic {
157114
select {
158115
case <-ctx.Done():
159116
return ctx.Err()
160117
default:
161118
}
162-
cycle := g.GetFirstCycle()
163-
if cycle == nil {
164-
g.HasNoCycles = true
165-
return nil
166-
}
167-
168-
if err := g.fixCycle(cycle); err != nil {
119+
if err := g.fixCycle(comp); err != nil {
169120
return err
170121
}
171122
}
123+
g.HasNoCycles = true
124+
return nil
172125
}
173126

174127
// fixCycle takes a list of Node IDs that form a cycle and merges them into a
175128
// single, new node.
176129
func (g *EntitlementGraph) fixCycle(nodeIDs []int) error {
177-
entitlementIDs := mapset.NewSet[string]()
130+
entitlementIDs := mapset.NewThreadUnsafeSet[string]()
178131
outgoingEdgesToResourceTypeIDs := map[int]mapset.Set[string]{}
179132
incomingEdgesToResourceTypeIDs := map[int]mapset.Set[string]{}
180133
for _, nodeID := range nodeIDs {
@@ -190,7 +143,7 @@ func (g *EntitlementGraph) fixCycle(nodeIDs []int) error {
190143
if edge, ok := g.Edges[edgeID]; ok {
191144
resourceTypeIDs, ok := incomingEdgesToResourceTypeIDs[sourceNodeID]
192145
if !ok {
193-
resourceTypeIDs = mapset.NewSet[string]()
146+
resourceTypeIDs = mapset.NewThreadUnsafeSet[string]()
194147
}
195148
for _, resourceTypeID := range edge.ResourceTypeIDs {
196149
resourceTypeIDs.Add(resourceTypeID)
@@ -206,7 +159,7 @@ func (g *EntitlementGraph) fixCycle(nodeIDs []int) error {
206159
if edge, ok := g.Edges[edgeID]; ok {
207160
resourceTypeIDs, ok := outgoingEdgesToResourceTypeIDs[destinationNodeID]
208161
if !ok {
209-
resourceTypeIDs = mapset.NewSet[string]()
162+
resourceTypeIDs = mapset.NewThreadUnsafeSet[string]()
210163
}
211164
for _, resourceTypeID := range edge.ResourceTypeIDs {
212165
resourceTypeIDs.Add(resourceTypeID)

pkg/sync/expand/cycle_benchmark_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,16 @@ func BenchmarkCycleDetectionHelper(b *testing.B) {
151151
}
152152

153153
func BenchmarkGetFirstCycle(b *testing.B) {
154+
ctx, cancel := context.WithCancel(context.Background())
155+
defer cancel()
154156
sizes := []int{100, 1000}
155157

156158
for _, n := range sizes {
157159
b.Run(fmt.Sprintf("ring-%d", n), func(b *testing.B) {
158160
g := buildRing(b, n)
159161
b.ResetTimer()
160162
for i := 0; i < b.N; i++ {
161-
_ = g.GetFirstCycle()
163+
_ = g.GetFirstCycle(ctx)
162164
}
163165
})
164166
}
@@ -168,7 +170,7 @@ func BenchmarkGetFirstCycle(b *testing.B) {
168170
g := buildChain(b, n)
169171
b.ResetTimer()
170172
for i := 0; i < b.N; i++ {
171-
_ = g.GetFirstCycle()
173+
_ = g.GetFirstCycle(ctx)
172174
}
173175
})
174176
}

pkg/sync/expand/graph.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,64 @@ func (g *EntitlementGraph) DeleteEdge(ctx context.Context, srcEntitlementID stri
309309
}
310310
return nil
311311
}
312+
313+
// toAdjacency builds an adjacency map for SCC. If nodesSubset is non-nil, only
314+
// include those nodes (and edges between them). Always include all nodes in the
315+
// subset as keys, even if they have zero outgoing edges.
316+
func (g *EntitlementGraph) toAdjacency(nodesSubset map[int]struct{}) map[int]map[int]int {
317+
adj := make(map[int]map[int]int, len(g.Nodes))
318+
include := func(id int) bool {
319+
if nodesSubset == nil {
320+
return true
321+
}
322+
_, ok := nodesSubset[id]
323+
return ok
324+
}
325+
326+
// Ensure keys for all included nodes.
327+
for id := range g.Nodes {
328+
if include(id) {
329+
adj[id] = make(map[int]int)
330+
}
331+
}
332+
333+
// Add edges where both endpoints are included.
334+
for src, dsts := range g.SourcesToDestinations {
335+
if !include(src) {
336+
continue
337+
}
338+
row := adj[src]
339+
for dst := range dsts {
340+
if include(dst) {
341+
row[dst] = 1
342+
}
343+
}
344+
}
345+
return adj
346+
}
347+
348+
// reachableFrom computes the set of node IDs reachable from start over
349+
// SourcesToDestinations using an iterative BFS.
350+
func (g *EntitlementGraph) reachableFrom(start int) map[int]struct{} {
351+
if _, ok := g.Nodes[start]; !ok {
352+
return nil
353+
}
354+
visited := make(map[int]struct{}, 16)
355+
queue := make([]int, 0, 16)
356+
queue = append(queue, start)
357+
visited[start] = struct{}{}
358+
for len(queue) > 0 {
359+
u := queue[0]
360+
queue = queue[1:]
361+
if nbrs, ok := g.SourcesToDestinations[u]; ok {
362+
for v := range nbrs {
363+
if _, seen := visited[v]; seen {
364+
continue
365+
}
366+
visited[v] = struct{}{}
367+
queue = append(queue, v)
368+
}
369+
}
370+
}
371+
return visited
372+
}

pkg/sync/expand/graph_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func TestGetFirstCycle(t *testing.T) {
160160
for _, testCase := range testCases {
161161
t.Run(testCase.message, func(t *testing.T) {
162162
graph := parseExpression(t, ctx, testCase.expression)
163-
cycle := graph.GetFirstCycle()
163+
cycle := graph.GetFirstCycle(ctx)
164164
if testCase.expectedCycleSize == 0 {
165165
require.Nil(t, cycle)
166166
} else {
@@ -189,7 +189,7 @@ func TestHandleCycle(t *testing.T) {
189189

190190
graph := parseExpression(t, ctx, testCase.expression)
191191

192-
cycle := graph.GetFirstCycle()
192+
cycle := graph.GetFirstCycle(ctx)
193193
expectedCycles := createNodeIDList(testCase.expectedCycles)
194194
require.NotNil(t, cycle)
195195
found := false
@@ -205,7 +205,7 @@ func TestHandleCycle(t *testing.T) {
205205
require.NoError(t, err, graph.Str())
206206
err = graph.Validate()
207207
require.NoError(t, err)
208-
cycle = graph.GetFirstCycle()
208+
cycle = graph.GetFirstCycle(ctx)
209209
require.Nil(t, cycle)
210210
})
211211
}
@@ -230,7 +230,7 @@ func TestHandleComplexCycle(t *testing.T) {
230230
require.Equal(t, 0, len(graph.Edges))
231231
require.Equal(t, 3, len(graph.GetEntitlements()))
232232

233-
cycle := graph.GetFirstCycle()
233+
cycle := graph.GetFirstCycle(ctx)
234234
require.Nil(t, cycle)
235235
}
236236

@@ -257,7 +257,7 @@ func TestHandleCliqueCycle(t *testing.T) {
257257
require.Equal(t, 0, len(graph.Edges))
258258
require.Equal(t, 3, len(graph.GetEntitlements()))
259259

260-
cycle := graph.GetFirstCycle()
260+
cycle := graph.GetFirstCycle(ctx)
261261
require.Nil(t, cycle)
262262
}
263263
}

0 commit comments

Comments
 (0)