diff --git a/pkg/sync/expand/cycle.go b/pkg/sync/expand/cycle.go index fd37c6c0c..52e2ba4c8 100644 --- a/pkg/sync/expand/cycle.go +++ b/pkg/sync/expand/cycle.go @@ -14,7 +14,7 @@ func (g *EntitlementGraph) GetFirstCycle(ctx context.Context) []int { if g.HasNoCycles { return nil } - comps := g.ComputeCyclicComponents(ctx) + comps, _ := g.ComputeCyclicComponents(ctx) if len(comps) == 0 { return nil } @@ -26,20 +26,22 @@ func (g *EntitlementGraph) HasCycles(ctx context.Context) bool { if g.HasNoCycles { return false } - return len(g.ComputeCyclicComponents(ctx)) > 0 + comps, _ := g.ComputeCyclicComponents(ctx) + return len(comps) > 0 } func (g *EntitlementGraph) cycleDetectionHelper( + ctx context.Context, nodeID int, ) ([]int, bool) { reach := g.reachableFrom(nodeID) if len(reach) == 0 { return nil, false } - adj := g.toAdjacency(reach) - groups := scc.CondenseFWBWGroupsFromAdj(context.Background(), adj, scc.DefaultOptions()) + fg := filteredGraph{g: g, include: func(id int) bool { _, ok := reach[id]; return ok }} + groups, _ := scc.CondenseFWBW(ctx, fg, scc.DefaultOptions()) for _, comp := range groups { - if len(comp) > 1 || (len(comp) == 1 && adj[comp[0]][comp[0]] != 0) { + if len(comp) > 1 || (len(comp) == 1 && g.hasSelfLoop(comp[0])) { return comp, true } } @@ -47,24 +49,66 @@ func (g *EntitlementGraph) cycleDetectionHelper( } func (g *EntitlementGraph) FixCycles(ctx context.Context) error { - return g.FixCyclesFromComponents(ctx, g.ComputeCyclicComponents(ctx)) + comps, _ := g.ComputeCyclicComponents(ctx) + return g.FixCyclesFromComponents(ctx, comps) } // ComputeCyclicComponents runs SCC once and returns only cyclic components. // A component is cyclic if len>1 or a singleton with a self-loop. -func (g *EntitlementGraph) ComputeCyclicComponents(ctx context.Context) [][]int { +func (g *EntitlementGraph) ComputeCyclicComponents(ctx context.Context) ([][]int, *scc.Metrics) { if g.HasNoCycles { - return nil + return nil, nil } - adj := g.toAdjacency(nil) - groups := scc.CondenseFWBWGroupsFromAdj(ctx, adj, scc.DefaultOptions()) + groups, metrics := scc.CondenseFWBW(ctx, g, scc.DefaultOptions()) cyclic := make([][]int, 0) for _, comp := range groups { - if len(comp) > 1 || (len(comp) == 1 && adj[comp[0]][comp[0]] != 0) { + if len(comp) > 1 || (len(comp) == 1 && g.hasSelfLoop(comp[0])) { cyclic = append(cyclic, comp) } } - return cyclic + return cyclic, metrics +} + +// hasSelfLoop reports whether a node has a self-edge. +func (g *EntitlementGraph) hasSelfLoop(id int) bool { + if row, ok := g.SourcesToDestinations[id]; ok { + _, ok := row[id] + return ok + } + return false +} + +// filteredGraph restricts EntitlementGraph iteration to nodes for which include(id) is true. +type filteredGraph struct { + g *EntitlementGraph + include func(int) bool +} + +func (fg filteredGraph) ForEachNode(fn func(id int) bool) { + for id := range fg.g.Nodes { + if fg.include != nil && !fg.include(id) { + continue + } + if !fn(id) { + return + } + } +} + +func (fg filteredGraph) ForEachEdgeFrom(src int, fn func(dst int) bool) { + if fg.include != nil && !fg.include(src) { + return + } + if dsts, ok := fg.g.SourcesToDestinations[src]; ok { + for dst := range dsts { + if fg.include != nil && !fg.include(dst) { + continue + } + if !fn(dst) { + return + } + } + } } // removeNode obliterates a node and all incoming/outgoing edges. diff --git a/pkg/sync/expand/cycle_benchmark_test.go b/pkg/sync/expand/cycle_benchmark_test.go index 0e6d2c201..7794ee44c 100644 --- a/pkg/sync/expand/cycle_benchmark_test.go +++ b/pkg/sync/expand/cycle_benchmark_test.go @@ -107,6 +107,8 @@ func buildTailIntoRing(b *testing.B, tail, ring int) *EntitlementGraph { } func BenchmarkCycleDetectionHelper(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() sizes := []int{100, 1000} for _, n := range sizes { @@ -115,7 +117,7 @@ func BenchmarkCycleDetectionHelper(b *testing.B) { start := g.EntitlementsToNodes["1"] b.ResetTimer() for i := 0; i < b.N; i++ { - _, _ = g.cycleDetectionHelper(start) + _, _ = g.cycleDetectionHelper(ctx, start) } }) } @@ -126,7 +128,7 @@ func BenchmarkCycleDetectionHelper(b *testing.B) { start := g.EntitlementsToNodes["1"] b.ResetTimer() for i := 0; i < b.N; i++ { - _, _ = g.cycleDetectionHelper(start) + _, _ = g.cycleDetectionHelper(ctx, start) } }) } @@ -136,7 +138,7 @@ func BenchmarkCycleDetectionHelper(b *testing.B) { start := g.EntitlementsToNodes["1"] b.ResetTimer() for i := 0; i < b.N; i++ { - _, _ = g.cycleDetectionHelper(start) + _, _ = g.cycleDetectionHelper(ctx, start) } }) @@ -145,7 +147,7 @@ func BenchmarkCycleDetectionHelper(b *testing.B) { start := g.EntitlementsToNodes["1"] b.ResetTimer() for i := 0; i < b.N; i++ { - _, _ = g.cycleDetectionHelper(start) + _, _ = g.cycleDetectionHelper(ctx, start) } }) } diff --git a/pkg/sync/expand/cycle_test.go b/pkg/sync/expand/cycle_test.go index 838ff5645..e304c5a61 100644 --- a/pkg/sync/expand/cycle_test.go +++ b/pkg/sync/expand/cycle_test.go @@ -59,7 +59,7 @@ func TestCycleDetectionHelper_BasicScenarios(t *testing.T) { t.Run(tc.name, func(t *testing.T) { g := parseExpression(t, ctx, tc.expr) startNodeID := g.EntitlementsToNodes[tc.start] - cycle, ok := g.cycleDetectionHelper(startNodeID) + cycle, ok := g.cycleDetectionHelper(ctx, startNodeID) if !tc.has { require.False(t, ok) @@ -81,7 +81,7 @@ func TestCycleDetectionHelper_MultipleCyclesDifferentStarts(t *testing.T) { // Start at 1 -> should find cycle {1,2} { startNodeID := g.EntitlementsToNodes["1"] - cycle, ok := g.cycleDetectionHelper(startNodeID) + cycle, ok := g.cycleDetectionHelper(ctx, startNodeID) require.True(t, ok) require.NotNil(t, cycle) require.True(t, elementsMatch([]int{1, 2}, cycle)) @@ -90,7 +90,7 @@ func TestCycleDetectionHelper_MultipleCyclesDifferentStarts(t *testing.T) { // Start at 3 -> should find cycle {3,4} { startNodeID := g.EntitlementsToNodes["3"] - cycle, ok := g.cycleDetectionHelper(startNodeID) + cycle, ok := g.cycleDetectionHelper(ctx, startNodeID) require.True(t, ok) require.NotNil(t, cycle) require.True(t, elementsMatch([]int{3, 4}, cycle)) @@ -116,7 +116,7 @@ func TestCycleDetectionHelper_LargeRing(t *testing.T) { g := parseExpression(t, ctx, expr) startNodeID := g.EntitlementsToNodes["1"] - cycle, ok := g.cycleDetectionHelper(startNodeID) + cycle, ok := g.cycleDetectionHelper(ctx, startNodeID) require.True(t, ok) require.NotNil(t, cycle) require.Len(t, cycle, n) diff --git a/pkg/sync/expand/graph.go b/pkg/sync/expand/graph.go index 2c2a053e6..ef7d9850c 100644 --- a/pkg/sync/expand/graph.go +++ b/pkg/sync/expand/graph.go @@ -4,6 +4,7 @@ import ( "context" v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" + "github.com/conductorone/baton-sdk/pkg/sync/expand/scc" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "go.uber.org/zap" ) @@ -313,36 +314,30 @@ func (g *EntitlementGraph) DeleteEdge(ctx context.Context, srcEntitlementID stri // toAdjacency builds an adjacency map for SCC. If nodesSubset is non-nil, only // include those nodes (and edges between them). Always include all nodes in the // subset as keys, even if they have zero outgoing edges. -func (g *EntitlementGraph) toAdjacency(nodesSubset map[int]struct{}) map[int]map[int]int { - adj := make(map[int]map[int]int, len(g.Nodes)) - include := func(id int) bool { - if nodesSubset == nil { - return true - } - _, ok := nodesSubset[id] - return ok - } +// toAdjacency removed: use SCC via scc.Source on EntitlementGraph - // Ensure keys for all included nodes. +var _ scc.Source = (*EntitlementGraph)(nil) + +// ForEachNode implements scc.Source iteration over nodes (including isolated nodes). +// It does not import scc; matching the method names/signatures is sufficient. +func (g *EntitlementGraph) ForEachNode(fn func(id int) bool) { for id := range g.Nodes { - if include(id) { - adj[id] = make(map[int]int) + if !fn(id) { + return } } +} - // Add edges where both endpoints are included. - for src, dsts := range g.SourcesToDestinations { - if !include(src) { - continue - } - row := adj[src] +// ForEachEdgeFrom implements scc.Source iteration of outgoing edges for src. +// It enumerates unique destination node IDs. +func (g *EntitlementGraph) ForEachEdgeFrom(src int, fn func(dst int) bool) { + if dsts, ok := g.SourcesToDestinations[src]; ok { for dst := range dsts { - if include(dst) { - row[dst] = 1 + if !fn(dst) { + return } } } - return adj } // reachableFrom computes the set of node IDs reachable from start over diff --git a/pkg/sync/expand/graph_test.go b/pkg/sync/expand/graph_test.go index 07827aaee..aa03ec3f8 100644 --- a/pkg/sync/expand/graph_test.go +++ b/pkg/sync/expand/graph_test.go @@ -2,6 +2,7 @@ package expand import ( "context" + "fmt" "strconv" "strings" "testing" @@ -20,8 +21,8 @@ func elementsMatch(listA []int, listB []int) bool { if len(listA) != len(listB) { return false } - setA := mapset.NewSet[int](listA...) - setB := mapset.NewSet[int](listB...) + setA := mapset.NewSet(listA...) + setB := mapset.NewSet(listB...) differenceA := setA.Difference(setB) if differenceA.Cardinality() > 0 { @@ -241,7 +242,7 @@ func TestHandleCliqueCycle(t *testing.T) { // Test can be flaky. N := 1 - for i := 0; i < N; i++ { + for range N { graph := parseExpression(t, ctx, "1>2>3>2>1>3>1") require.Equal(t, 3, len(graph.Nodes)) @@ -285,3 +286,62 @@ func TestMarkEdgeExpanded(t *testing.T) { require.True(t, graph.IsEntitlementExpanded("2")) require.True(t, graph.IsExpanded()) } + +func TestDeepNoCycles(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + depth := 40 + + expressionStr := "" + for i := range depth { + expressionStr += fmt.Sprintf("%d>%d", i+1, i+2) + } + graph := parseExpression(t, ctx, expressionStr) + + require.Equal(t, depth+1, len(graph.Nodes)) + require.Equal(t, depth, len(graph.Edges)) + require.Equal(t, depth+1, len(graph.GetEntitlements())) + + err := graph.FixCycles(ctx) + require.NoError(t, err, graph.Str()) + err = graph.Validate() + require.NoError(t, err) + + require.Equal(t, depth+1, len(graph.Nodes)) + require.Equal(t, depth, len(graph.Edges)) + require.Equal(t, depth+1, len(graph.GetEntitlements())) + + cycle := graph.GetFirstCycle(ctx) + require.Nil(t, cycle) +} + +func TestDeepCycles(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + depth := 40 + + expressionStr := "" + for i := range depth { + expressionStr += fmt.Sprintf("%d>%d", i+1, i+2) + } + expressionStr += fmt.Sprintf("%d>%d", depth, 1) + graph := parseExpression(t, ctx, expressionStr) + + require.Equal(t, depth+1, len(graph.Nodes)) + require.Equal(t, depth+1, len(graph.Edges)) + require.Equal(t, depth+1, len(graph.GetEntitlements())) + + err := graph.FixCycles(ctx) + require.NoError(t, err, graph.Str()) + err = graph.Validate() + require.NoError(t, err) + + require.Equal(t, 1, len(graph.Nodes)) + require.Equal(t, 0, len(graph.Edges)) + require.Equal(t, depth+1, len(graph.GetEntitlements())) + + cycle := graph.GetFirstCycle(ctx) + require.Nil(t, cycle) +} diff --git a/pkg/sync/expand/scc/bitset.go b/pkg/sync/expand/scc/bitset.go new file mode 100644 index 000000000..e017d8912 --- /dev/null +++ b/pkg/sync/expand/scc/bitset.go @@ -0,0 +1,126 @@ +package scc + +import ( + "math/bits" + "sync/atomic" +) + +// bitset is a packed, atomically updatable bitset. +// +// Concurrency notes: +// - Only testAndSetAtomic and clearAtomic are safe concurrently. +// - All other methods must not race with writers. +// - Slice storage aligns on 64-bit boundaries for atomic ops. +type bitset struct{ w []uint64 } + +func newBitset(n int) *bitset { + if n <= 0 { + return &bitset{} + } + return &bitset{w: make([]uint64, (n+63)>>6)} +} + +func (b *bitset) test(i int) bool { + if i < 0 { + return false + } + w := i >> 6 + if w >= len(b.w) { + return false + } + return (b.w[w] & (1 << (uint(i) & 63))) != 0 +} + +func (b *bitset) set(i int) { + if i < 0 { + return + } + w := i >> 6 + b.w[w] |= 1 << (uint(i) & 63) +} + +func (b *bitset) testAndSetAtomic(i int) bool { + if i < 0 { + return false + } + w := i >> 6 + if w >= len(b.w) { + return false + } + mask := uint64(1) << (uint(i) & 63) + addr := &b.w[w] + for { + old := atomic.LoadUint64(addr) + if old&mask != 0 { + return true + } + if atomic.CompareAndSwapUint64(addr, old, old|mask) { + return false + } + } +} + +func (b *bitset) clearAtomic(i int) { + if i < 0 { + return + } + w := i >> 6 + if w >= len(b.w) { + return + } + mask := ^(uint64(1) << (uint(i) & 63)) + addr := &b.w[w] + for { + old := atomic.LoadUint64(addr) + if atomic.CompareAndSwapUint64(addr, old, old&mask) { + return + } + } +} + +func (b *bitset) clone() *bitset { + cp := &bitset{w: make([]uint64, len(b.w))} + copy(cp.w, b.w) + return cp +} + +func (b *bitset) and(x *bitset) *bitset { + for i := range b.w { + b.w[i] &= x.w[i] + } + return b +} + +func (b *bitset) or(x *bitset) *bitset { + for i := range b.w { + b.w[i] |= x.w[i] + } + return b +} + +func (b *bitset) andNot(x *bitset) *bitset { + for i := range b.w { + b.w[i] &^= x.w[i] + } + return b +} + +func (b *bitset) isEmpty() bool { + for _, w := range b.w { + if w != 0 { + return false + } + } + return true +} + +func (b *bitset) forEachSet(fn func(i int)) { + for wi, w := range b.w { + for w != 0 { + tz := bits.TrailingZeros64(w) + i := (wi << 6) + tz + fn(i) + w &^= 1 << uint(tz) //nolint:gosec // trailing zeros is non-negative + } + } +} diff --git a/pkg/sync/expand/scc/bitset_test.go b/pkg/sync/expand/scc/bitset_test.go new file mode 100644 index 000000000..617f93302 --- /dev/null +++ b/pkg/sync/expand/scc/bitset_test.go @@ -0,0 +1,124 @@ +package scc + +import ( + "sync" + "testing" +) + +func TestBitsetBasicSetTest(t *testing.T) { + b := newBitset(130) // spans 3 words + if b.isEmpty() == false { + t.Fatalf("new bitset should be empty") + } + + // Set and test a few indices across word boundaries + indices := []int{0, 1, 63, 64, 65, 129} + for _, i := range indices { + b.set(i) + if !b.test(i) { + t.Fatalf("expected bit %d to be set", i) + } + } + if b.isEmpty() { + t.Fatalf("bitset should not be empty after sets") + } +} + +func TestBitsetCloneAndOps(t *testing.T) { + b1 := newBitset(128) + b2 := newBitset(128) + for _, i := range []int{0, 2, 64, 127} { + b1.set(i) + } + for _, i := range []int{2, 3, 64, 100} { + b2.set(i) + } + + c := b1.clone().and(b2) + for _, i := range []int{2, 64} { + if !c.test(i) { + t.Fatalf("AND missing expected bit %d", i) + } + } + for _, i := range []int{0, 3, 100, 127} { + if c.test(i) { + t.Fatalf("AND has unexpected bit %d", i) + } + } + + u := b1.clone().or(b2) + for _, i := range []int{0, 2, 3, 64, 100, 127} { + if !u.test(i) { + t.Fatalf("OR missing expected bit %d", i) + } + } + + d := b1.clone().andNot(c) + for _, i := range []int{0, 127} { + if !d.test(i) { + t.Fatalf("ANDNOT missing expected bit %d", i) + } + } + for _, i := range []int{2, 64} { + if d.test(i) { + t.Fatalf("ANDNOT has unexpected bit %d", i) + } + } +} + +func TestBitsetForEachSetOrder(t *testing.T) { + b := newBitset(70) + for _, i := range []int{69, 0, 1, 63, 64} { + b.set(i) + } + + var seen []int + b.forEachSet(func(i int) { seen = append(seen, i) }) + expected := []int{0, 1, 63, 64, 69} + if len(seen) != len(expected) { + t.Fatalf("unexpected count: got %d, want %d", len(seen), len(expected)) + } + for i := range expected { + if seen[i] != expected[i] { + t.Fatalf("order mismatch at %d: got %d, want %d", i, seen[i], expected[i]) + } + } +} + +func TestBitsetAtomicOps(t *testing.T) { + b := newBitset(128) + // Concurrent testAndSetAtomic should set each bit exactly once + var wg sync.WaitGroup + N := 1000 + idx := 73 // arbitrary + wg.Add(N) + setCount := 0 + var mu sync.Mutex + for i := 0; i < N; i++ { + go func() { + defer wg.Done() + if !b.testAndSetAtomic(idx) { + mu.Lock() + setCount++ + mu.Unlock() + } + }() + } + wg.Wait() + if setCount != 1 { + t.Fatalf("expected exactly one set, got %d", setCount) + } + if !b.test(idx) { + t.Fatalf("bit should be set after atomic operations") + } + + // clearAtomic should clear once and be idempotent + b.clearAtomic(idx) + if b.test(idx) { + t.Fatalf("bit should be cleared") + } + b.clearAtomic(idx) + if b.test(idx) { + t.Fatalf("bit should remain cleared after second clear") + } +} diff --git a/pkg/sync/expand/scc/scc.go b/pkg/sync/expand/scc/scc.go index 6740f91cf..c9ce68a23 100644 --- a/pkg/sync/expand/scc/scc.go +++ b/pkg/sync/expand/scc/scc.go @@ -1,59 +1,40 @@ -// Package scc provides a parallel FW–BW SCC condensation for directed graphs, +// Package scc provides an iterative FW–BW SCC condensation for directed graphs, // adapted for Baton’s entitlement graph. It builds an immutable CSR + transpose, -// runs reachability-based SCC (with parallel BFS and parallel recursion), and -// returns components as groups of your original node IDs. +// runs reachability-based SCC with a stack-based driver (no recursion, BFS may +// run in parallel), and returns components as groups of your original node IDs. package scc +// Iterative FW–BW SCC condensation for directed graphs. +// +// High-level algorithm: Build CSR + transpose; maintain a LIFO stack of +// subproblems (bitset masks). For each mask: +// 1) Trim sources/sinks repeatedly; each peeled vertex is a singleton SCC. +// 2) Pick a pivot (lowest-index active vertex), run forward/backward BFS +// restricted to the mask to get F and B. +// 3) The SCC is C = F ∩ B. Assign its component id and clear those bits +// from the mask. +// 4) Partition the remaining mask into F\C, B\C, and U = mask \ (F ∪ B), +// and push the non-empty masks onto the stack in a deterministic order. +// +// Parallelism is contained inside BFS (bounded by Options.MaxWorkers) and no +// recursive goroutines are spawned by the driver. Determinism is achieved via +// deterministic CSR construction (sorted ids and neighbors) and by using the +// lowest-index active pivot with a fixed child-push order. + import ( "context" - "math/bits" "runtime" "sort" "sync" - "sync/atomic" + "time" ) -// Lightweight pools to reduce transient allocations in hot paths. -var ( - intSlicePool sync.Pool // *([]int) - bucketSlicePool sync.Pool // *([]int) -) - -func getIntSlice(n int) []int { - p, _ := intSlicePool.Get().(*[]int) - if p == nil || cap(*p) < n { - return make([]int, n) - } - s := (*p)[:n] - return s -} - -func putIntSlice(s []int) { - intSlicePool.Put(&s) -} - -func getBucketSlice() []int { - p, _ := bucketSlicePool.Get().(*[]int) - if p == nil { - return make([]int, 0, 256) - } - return (*p)[:0] -} - -func putBucketSlice(s []int) { - bucketSlicePool.Put(&s) -} - -// Options controls parallel SCC execution. +// Options controls SCC execution. +// +// MaxWorkers bounds BFS concurrency per level. Deterministic toggles stable +// CSR index assignment and neighbor ordering. type Options struct { - // MaxWorkers bounds concurrency for BFS/recursion. Defaults to GOMAXPROCS. - MaxWorkers int - // EnableTrim peels vertices with min(in,out)==0 before FW–BW (usually beneficial). - EnableTrim bool - // SmallCutoff size below which we switch to a simpler (sequential) routine. - // You can later swap this to Tarjan for tiny subgraphs. - SmallCutoff int - // Deterministic maps node IDs to CSR indices in sorted order (recommended). + MaxWorkers int Deterministic bool } @@ -61,42 +42,101 @@ type Options struct { func DefaultOptions() Options { return Options{ MaxWorkers: runtime.GOMAXPROCS(0), - EnableTrim: false, - SmallCutoff: 8192, Deterministic: false, } } // CSR is a compact adjacency for G and its transpose Gᵗ. // Indices are 0..N-1; IdxToNodeID maps back to the original node IDs. +// +// Invariants (validated by validateCSR): +// - len(Row) == N+1; Row[0] == 0; Row is non-decreasing; Row[N] == len(Col) +// - len(TRow) == N+1; TRow[0] == 0; TRow is non-decreasing; TRow[N] == len(TCol) +// - 0 <= Col[p] < N for all p; 0 <= TCol[p] < N for all p +// - len(IdxToNodeID) == N; NodeIDToIdx[IdxToNodeID[i]] == i for all i +// - For each v, (TRow[v+1]-TRow[v]) equals the number of occurrences of v in Col +// (transpose degree matches inbound counts) type CSR struct { N int Row []int // len N+1 - Col []int // len = m + Col []int // len = m, m = Row[N] TRow []int // len N+1 - TCol []int // len = m - IdxToNodeID []int - NodeIDToIdx map[int]int + TCol []int // len = m, m = TRow[N] + IdxToNodeID []int // len N } -// BuildCSRFromAdj constructs CSR/transpose from adjacency map[int]map[int]int. -// If opts.Deterministic, node IDs are sorted before assigning indices. -// Isolated nodes must appear as keys in adj (with empty inner map). -func BuildCSRFromAdj(adj map[int]map[int]int, opts Options) *CSR { - nodes := make([]int, 0, len(adj)*2) - seen := make(map[int]struct{}, len(adj)*2) - for u, nbrs := range adj { - if _, ok := seen[u]; !ok { - seen[u] = struct{}{} - nodes = append(nodes, u) - } - for v := range nbrs { - if _, ok := seen[v]; !ok { - seen[v] = struct{}{} - nodes = append(nodes, v) +// Source is a minimal read-only graph provider used to build CSR without +// materializing an intermediate adjacency map. It must enumerate all nodes +// (including isolated nodes) and for each node provide its unique outgoing +// destinations. +type Source interface { + ForEachNode(fn func(id int) bool) + ForEachEdgeFrom(src int, fn func(dst int) bool) +} + +// Metrics captures a few summary counters for a condense run. +type Metrics struct { + Nodes int + Edges int + Components int + Peeled int + MasksProcessed int + MasksPushed int + BFScalls int + Duration time.Duration + MaxWorkers int + Deterministic bool +} + +// CondenseFWBW runs SCC directly from a streaming Source. Preferred entry point. +func CondenseFWBW(ctx context.Context, src Source, opts Options) ([][]int, *Metrics) { + if opts.MaxWorkers <= 0 { + opts.MaxWorkers = runtime.GOMAXPROCS(0) + } + start := time.Now() + csr := buildCSRFromSource(src, opts) + metrics := Metrics{ + Nodes: csr.N, + Edges: len(csr.Col), + MaxWorkers: opts.MaxWorkers, + Deterministic: opts.Deterministic, + } + comp := make([]int, csr.N) + for i := range comp { + comp[i] = -1 + } + nextID := sccFWBWIterative(ctx, csr, comp, opts, &metrics) + + groups := make([][]int, nextID) + for idx := range csr.N { + cid := comp[idx] + if cid < 0 { + cid = nextID + nextID++ + comp[idx] = cid + if cid >= len(groups) { + tmp := make([][]int, cid+1) + copy(tmp, groups) + groups = tmp } } + groups[cid] = append(groups[cid], csr.IdxToNodeID[idx]) } + metrics.Components = nextID + metrics.Duration = time.Since(start) + return groups, &metrics +} + +// buildCSRFromSource constructs CSR/transpose from a Source without +// materializing an intermediate adjacency map. If opts.Deterministic, node IDs +// are sorted and per-row neighbors are written in ascending order. +func buildCSRFromSource(src Source, opts Options) *CSR { + // 1) Collect nodes + nodes := make([]int, 0, 1024) + src.ForEachNode(func(id int) bool { + nodes = append(nodes, id) + return true + }) if opts.Deterministic { sort.Ints(nodes) } @@ -106,380 +146,258 @@ func BuildCSRFromAdj(adj map[int]map[int]int, opts Options) *CSR { } n := len(nodes) - // Count edges per row. + // 2) Count out-degrees and total edges outDeg := make([]int, n) m := 0 - for uID, nbrs := range adj { - u := id2idx[uID] - deg := len(nbrs) - outDeg[u] += deg - m += deg + for i := range n { + srcID := nodes[i] + src.ForEachEdgeFrom(srcID, func(dst int) bool { + j, ok := id2idx[dst] + if !ok { + return true + } + _ = j // only used to validate membership + outDeg[i]++ + m++ + return true + }) } + // 3) Allocate Row/Col row := make([]int, n+1) - for i := 0; i < n; i++ { + for i := range n { row[i+1] = row[i] + outDeg[i] } col := make([]int, m) - - // Fill CSR. cur := make([]int, n) copy(cur, row) - for uID, nbrs := range adj { - u := id2idx[uID] - if opts.Deterministic { - vs := make([]int, 0, len(nbrs)) - for vID := range nbrs { - vs = append(vs, vID) - } - sort.Ints(vs) - for _, vID := range vs { - v := id2idx[vID] - pos := cur[u] - col[pos] = v - cur[u]++ - } - } else { - for vID := range nbrs { - v := id2idx[vID] - pos := cur[u] - col[pos] = v - cur[u]++ - } + + // 4) Fill rows + if opts.Deterministic { + for i := range n { + srcID := nodes[i] + neighbors := make([]int, 0, outDeg[i]) + src.ForEachEdgeFrom(srcID, func(dst int) bool { + if j, ok := id2idx[dst]; ok { + neighbors = append(neighbors, j) + } + return true + }) + sort.Ints(neighbors) + off := cur[i] + copy(col[off:off+len(neighbors)], neighbors) + cur[i] += len(neighbors) + } + } else { + for i := range n { + srcID := nodes[i] + src.ForEachEdgeFrom(srcID, func(dst int) bool { + if j, ok := id2idx[dst]; ok { + pos := cur[i] + col[pos] = j + cur[i] = pos + 1 + } + return true + }) } } - // Transpose. + // 5) Transpose inDeg := make([]int, n) for _, v := range col { inDeg[v]++ } trow := make([]int, n+1) - for i := 0; i < n; i++ { + for i := range n { trow[i+1] = trow[i] + inDeg[i] } tcol := make([]int, m) tcur := make([]int, n) copy(tcur, trow) - for u := 0; u < n; u++ { + for u := range n { start, end := row[u], row[u+1] for p := start; p < end; p++ { v := col[p] pos := tcur[v] tcol[pos] = u - tcur[v]++ + tcur[v] = pos + 1 } } - return &CSR{ + csr := &CSR{ N: n, Row: row, Col: col, TRow: trow, TCol: tcol, IdxToNodeID: nodes, - NodeIDToIdx: id2idx, - } -} - -// bitset is a packed, atomically updatable bitset. -// -// Concurrency notes for callers: -// - Only testAndSetAtomic and clearAtomic are safe for concurrent use on the -// same bitset value. They perform CAS loops on 64-bit words. -// - All other methods (test, set, clone, and, or, andNot, isEmpty, count, -// forEachSet) are NOT safe to call concurrently with writers and must not -// race with any mutation of the same bitset. -// - Concurrent calls to test are fine only if no goroutine may write to that -// bitset at the same time (including via atomic methods), otherwise it is a -// data race by the Go memory model and race detector. -// - Slice storage for []uint64 is 64-bit aligned, satisfying atomic.*Uint64 -// alignment requirements across architectures. -// -// Package usage: -// - In BFS, multiple workers set bits in the per-search "visited" bitset via -// testAndSetAtomic only; there are no concurrent non-atomic reads/writes. -// - The shared "active" mask is only modified outside an in-flight BFS; BFS -// reads it (test) while no goroutine mutates it, and recursive calls operate -// on disjoint cloned masks. -type bitset struct{ w []uint64 } - -func newBitset(n int) *bitset { - if n <= 0 { - return &bitset{} - } - return &bitset{w: make([]uint64, (n+63)>>6)} -} - -// test reads a bit without synchronization. Do not call concurrently with any -// writer to the same bitset. -func (b *bitset) test(i int) bool { - if i < 0 { - return false - } - w := i >> 6 - return (b.w[w] & (1 << (uint(i) & 63))) != 0 -} - -// set writes a bit without synchronization. Not safe to race with other -// accesses (reads or writes) to the same bitset. -func (b *bitset) set(i int) { - if i < 0 { - return } - w := i >> 6 - b.w[w] |= 1 << (uint(i) & 63) + validateCSR(csr) + return csr } -// testAndSetAtomic atomically sets bit i and returns true if it was already -// set. Safe for concurrent use by multiple goroutines. -func (b *bitset) testAndSetAtomic(i int) bool { - if i < 0 { - return false - } - w := i >> 6 - mask := uint64(1) << (uint(i) & 63) - addr := &b.w[w] - for { - old := atomic.LoadUint64(addr) - if old&mask != 0 { - return true - } - if atomic.CompareAndSwapUint64(addr, old, old|mask) { - return false - } +// validateCSR performs internal consistency checks on CSR and panics +// with a descriptive message when a violation is found. This is intended to +// catch programmer errors at build time and in tests; it runs unconditionally. +func validateCSR(csr *CSR) { + if csr == nil { + panic("scc: CSR is nil") } -} - -// clearAtomic atomically clears bit i. Safe for concurrent use by multiple -// goroutines. -func (b *bitset) clearAtomic(i int) { - if i < 0 { - return + n := csr.N + if n < 0 { + panic("scc: CSR.N is negative") } - w := i >> 6 - mask := ^(uint64(1) << (uint(i) & 63)) - addr := &b.w[w] - for { - old := atomic.LoadUint64(addr) - if atomic.CompareAndSwapUint64(addr, old, old&mask) { - return - } + if len(csr.Row) != n+1 { + panic("scc: len(Row) != N+1") } -} - -func (b *bitset) clone() *bitset { - cp := &bitset{w: make([]uint64, len(b.w))} - copy(cp.w, b.w) - return cp -} - -func (b *bitset) and(x *bitset) *bitset { - for i := range b.w { - b.w[i] &= x.w[i] + if len(csr.TRow) != n+1 { + panic("scc: len(TRow) != N+1") } - return b -} - -func (b *bitset) or(x *bitset) *bitset { - for i := range b.w { - b.w[i] |= x.w[i] + if len(csr.IdxToNodeID) != n { + panic("scc: len(IdxToNodeID) != N") } - return b -} - -func (b *bitset) andNot(x *bitset) *bitset { - for i := range b.w { - b.w[i] &^= x.w[i] + // Row invariants and degree sums + if csr.Row[0] != 0 { + panic("scc: Row[0] != 0") } - return b -} - -func (b *bitset) isEmpty() bool { - for _, w := range b.w { - if w != 0 { - return false + for i := range len(csr.Row) - 1 { + if csr.Row[i] > csr.Row[i+1] { + panic("scc: Row is not non-decreasing") } } - return true -} - -func (b *bitset) count() int { - total := 0 - for _, w := range b.w { - total += bits.OnesCount64(w) + m := csr.Row[n] + if m != len(csr.Col) { + panic("scc: Row[N] != len(Col)") } - return total -} - -func (b *bitset) forEachSet(fn func(i int)) { - for wi, w := range b.w { - for w != 0 { - tz := bits.TrailingZeros64(w) - i := (wi << 6) + tz - fn(i) - w &^= 1 << uint(tz) //nolint:gosec //bits.TrailingZeros64 never returns negative. + // TRow invariants + if csr.TRow[0] != 0 { + panic("scc: TRow[0] != 0") + } + for i := range len(csr.TRow) - 1 { + if csr.TRow[i] > csr.TRow[i+1] { + panic("scc: TRow is not non-decreasing") } } -} - -// CondenseFWBWGroupsFromAdj runs parallel FW–BW SCC on adj and returns only -// the component groups as slices of original node IDs. This avoids building the -// idToComp map when callers don't need it. -func CondenseFWBWGroupsFromAdj(ctx context.Context, adj map[int]map[int]int, opts Options) [][]int { - if opts.MaxWorkers <= 0 { - opts.MaxWorkers = runtime.GOMAXPROCS(0) + mt := csr.TRow[n] + if mt != len(csr.TCol) { + panic("scc: TRow[N] != len(TCol)") } - csr := BuildCSRFromAdj(adj, opts) - // Small graphs: use low-overhead Tarjan SCC - if csr.N <= opts.SmallCutoff { - comp := tarjanSCC(csr) - maxID := -1 - for _, c := range comp { - if c > maxID { - maxID = c - } + // Col bounds + for p := range len(csr.Col) { + v := csr.Col[p] + if v < 0 || v >= n { + panic("scc: Col index out of range") } - groups := make([][]int, maxID+1) - for idx := 0; idx < csr.N; idx++ { - cid := comp[idx] - groups[cid] = append(groups[cid], csr.IdxToNodeID[idx]) + } + for p := range len(csr.TCol) { + v := csr.TCol[p] + if v < 0 || v >= n { + panic("scc: TCol index out of range") } - return groups } - comp := make([]int, csr.N) - for i := range comp { - comp[i] = -1 + // NodeID mapping bijection check removed: CSR does not store NodeIDToIdx. + // Transpose degree equals inbound counts + inDeg := make([]int, n) + for _, v := range csr.Col { + inDeg[v]++ } - active := newBitset(csr.N) - for i := 0; i < csr.N; i++ { - active.set(i) + for v := range n { + expected := inDeg[v] + span := csr.TRow[v+1] - csr.TRow[v] + if span != expected { + panic("scc: transpose degree mismatch") + } } +} + +// bitset moved to bitset.go + +// sccFWBWIterative implements the driver loop described at the top. +func sccFWBWIterative(ctx context.Context, csr *CSR, comp []int, opts Options, metrics *Metrics) int { nextID := 0 - sccFWBW(ctx, csr, active, comp, &nextID, opts) - groups := make([][]int, nextID) - for idx := 0; idx < csr.N; idx++ { - cid := comp[idx] - if cid < 0 { - // Defensive: assign singleton if anything slipped through - cid = nextID - nextID++ - comp[idx] = cid - if cid >= len(groups) { - tmp := make([][]int, cid+1) - copy(tmp, groups) - groups = tmp - } - } - groups[cid] = append(groups[cid], csr.IdxToNodeID[idx]) + // Initialize root mask with all vertices. + root := newBitset(csr.N) + for i := range csr.N { + root.set(i) } - return groups -} -func sccFWBW(ctx context.Context, csr *CSR, active *bitset, comp []int, nextID *int, opts Options) { - // Optional trimming loop. - if opts.EnableTrim { + type item struct{ mask *bitset } + stack := make([]item, 0, 64) + stack = append(stack, item{mask: root}) + + for len(stack) > 0 { + select { + case <-ctx.Done(): + return nextID + default: + } + + it := stack[len(stack)-1] + stack = stack[:len(stack)-1] + active := it.mask + if metrics != nil { + metrics.MasksProcessed++ + } + + // Trim loop: peel sources/sinks; each peeled vertex becomes its own SCC. for { - if n := trimSingletons(csr, active, comp, nextID); n == 0 { + if n := trimSingletons(csr, active, comp, &nextID); n == 0 { break + } else if metrics != nil { + metrics.Peeled += n } if active.isEmpty() { - return + break } } - } - - // Base case: small subgraph — do simple repeated FW–BW (Tarjan hook point). - if active.count() <= opts.SmallCutoff { - for !active.isEmpty() { - select { - case <-ctx.Done(): - return - default: - } - pivot := firstActive(active) - f := bfsMultiSource(ctx, csr, []int{pivot}, active, false, opts.MaxWorkers) - b := bfsMultiSource(ctx, csr, []int{pivot}, active, true, opts.MaxWorkers) - c := f.clone().and(b) - assignComponent(c, comp, nextID, active) - - // Partitions - fNotC := f.clone().andNot(c) - bNotC := b.clone().andNot(c) - fOrB := f.clone().or(b) - u := active.clone().andNot(fOrB) - - if !fNotC.isEmpty() { - sccFWBW(ctx, csr, fNotC, comp, nextID, opts) - active.andNot(fNotC) - } - if !bNotC.isEmpty() { - sccFWBW(ctx, csr, bNotC, comp, nextID, opts) - active.andNot(bNotC) - } - if !u.isEmpty() { - sccFWBW(ctx, csr, u, comp, nextID, opts) - active.andNot(u) - } + if active.isEmpty() { + continue } - return - } - // General case: one pivot; add pivot batching later for extra speed. - pivot := firstActive(active) - f := bfsMultiSource(ctx, csr, []int{pivot}, active, false, opts.MaxWorkers) - b := bfsMultiSource(ctx, csr, []int{pivot}, active, true, opts.MaxWorkers) - c := f.clone().and(b) - assignComponent(c, comp, nextID, active) + // Pivot and BFS (restricted to active mask). + pivot := firstActive(active) + f := bfsMultiSource(ctx, csr, []int{pivot}, active, false, opts.MaxWorkers) + b := bfsMultiSource(ctx, csr, []int{pivot}, active, true, opts.MaxWorkers) + if metrics != nil { + metrics.BFScalls += 2 + } - // F\C, B\C, and U = active \ (F ∪ B) - fNotC := f.clone().andNot(c) - bNotC := b.clone().andNot(c) - fOrB := f.clone().or(b) - u := active.clone().andNot(fOrB) + // Component and partition masks. + c := f.clone().and(b) + assignComponent(c, comp, &nextID, active) - type sub struct{ mask *bitset } - var subs []sub - if !fNotC.isEmpty() { - subs = append(subs, sub{fNotC}) - } - if !bNotC.isEmpty() { - subs = append(subs, sub{bNotC}) - } - if !u.isEmpty() { - subs = append(subs, sub{u}) - } - if len(subs) == 0 { - return - } + fNotC := f.clone().andNot(c) + bNotC := b.clone().andNot(c) + fOrB := f.clone().or(b) + u := active.clone().andNot(fOrB) - if len(subs) == 1 || opts.MaxWorkers <= 1 { - for _, s := range subs { - sccFWBW(ctx, csr, s.mask, comp, nextID, opts) - active.andNot(s.mask) + // assignComponent cleared C from 'active'; child masks are disjoint subsets + // of the original mask. Push children in a fixed order for determinism. + pushes := 0 + if !u.isEmpty() { + stack = append(stack, item{mask: u}) + pushes++ + } + if !bNotC.isEmpty() { + stack = append(stack, item{mask: bNotC}) + pushes++ + } + if !fNotC.isEmpty() { + stack = append(stack, item{mask: fNotC}) + pushes++ + } + if metrics != nil { + metrics.MasksPushed += pushes } - return } - var wg sync.WaitGroup - wg.Add(len(subs)) - for _, s := range subs { - mask := s.mask - go func() { - defer wg.Done() - sccFWBW(ctx, csr, mask, comp, nextID, opts) - }() - } - wg.Wait() - for _, s := range subs { - active.andNot(s.mask) - } + return nextID } // bfsMultiSource runs a parallel BFS from sources over csr. -// If useTranspose is true, traverses csr.T* arrays. -// Traversal respects 'active' mask; returns visited including sources. +// If useTranspose is true, traverses csr.T* arrays. Traversal respects +// 'active' mask; returns visited including sources. // Cancellation: checks ctx between levels. func bfsMultiSource(ctx context.Context, csr *CSR, sources []int, active *bitset, useTranspose bool, maxWorkers int) *bitset { if maxWorkers <= 0 { @@ -511,7 +429,7 @@ func bfsMultiSource(ctx context.Context, csr *CSR, sources []int, active *bitset default: } - // Gate parallelism: if frontier small, do sequential step to avoid overhead + // If frontier small, do sequential step to avoid overhead. if len(frontier) <= 64 || maxWorkers == 1 { next := make([]int, 0, len(frontier)) for _, u := range frontier { @@ -521,9 +439,11 @@ func bfsMultiSource(ctx context.Context, csr *CSR, sources []int, active *bitset if !active.test(v) { continue } - // Non-atomic is safe because we are single-threaded on this step + if v < 0 { + continue + } w := v >> 6 - mask := uint64(1) << (uint(v) & 63) //nolint:gosec //active.test returns false for negative values. + mask := uint64(1) << (uint64(v) & 63) if (visited.w[w] & mask) == 0 { visited.w[w] |= mask next = append(next, v) @@ -534,14 +454,17 @@ func bfsMultiSource(ctx context.Context, csr *CSR, sources []int, active *bitset continue } - workers := min(maxWorkers, len(frontier)) + workers := maxWorkers + if workers > len(frontier) { + workers = len(frontier) + } var wg sync.WaitGroup wg.Add(workers) chunkSize := (len(frontier) + workers - 1) / workers nextBuckets := make([][]int, workers) - for w := range workers { + for w := 0; w < workers; w++ { start := w * chunkSize end := start + chunkSize if start >= len(frontier) { @@ -555,7 +478,7 @@ func bfsMultiSource(ctx context.Context, csr *CSR, sources []int, active *bitset w := w // capture go func(start, end int) { defer wg.Done() - local := getBucketSlice() + local := make([]int, 0, 256) for i := start; i < end; i++ { u := frontier[i] rs, re := getRow(u) @@ -574,7 +497,6 @@ func bfsMultiSource(ctx context.Context, csr *CSR, sources []int, active *bitset } wg.Wait() - // Flatten next frontier and return bucket buffers to pool. total := 0 for _, b := range nextBuckets { total += len(b) @@ -584,7 +506,6 @@ func bfsMultiSource(ctx context.Context, csr *CSR, sources []int, active *bitset for _, b := range nextBuckets { copy(next[off:], b) off += len(b) - putBucketSlice(b) } frontier = next } @@ -610,7 +531,7 @@ func frontierSeed(sources []int, active, visited *bitset) []int { } func firstActive(active *bitset) int { - var pivot = -1 + pivot := -1 active.forEachSet(func(i int) { if pivot == -1 { pivot = i @@ -626,27 +547,45 @@ func assignComponent(cMask *bitset, comp []int, nextID *int, active *bitset) { } cid := *nextID *nextID++ - cMask.forEachSet(func(i int) { comp[i] = cid active.clearAtomic(i) }) } +// Degree-array pool for trim to reduce allocations for small graphs. +var ( + intSlicePool sync.Pool // *([]int) +) + +func getIntSlice(n int) []int { + p, _ := intSlicePool.Get().(*[]int) + if p == nil || cap(*p) < n { + return make([]int, n) + } + s := (*p)[:n] + for i := range s { + s[i] = 0 + } + return s +} + +func putIntSlice(s []int) { + // avoid keeping very large slices + if cap(s) > 1<<14 { + return + } + intSlicePool.Put(&s) +} + // trimSingletons peels vertices with restricted in/out degree within 'active'. // Each peeled vertex becomes its own SCC id. Returns count peeled. func trimSingletons(csr *CSR, active *bitset, comp []int, nextID *int) int { n := csr.N inDeg := getIntSlice(n) outDeg := getIntSlice(n) - defer func() { - putIntSlice(inDeg) - putIntSlice(outDeg) - }() + defer func() { putIntSlice(inDeg); putIntSlice(outDeg) }() - // Initialize degrees within active and queue zeros - queue := getIntSlice(0) - defer putIntSlice(queue) // Out-degree within active. for u := range n { if !active.test(u) { @@ -677,6 +616,10 @@ func trimSingletons(csr *CSR, active *bitset, comp []int, nextID *int) int { } inDeg[v] = d } + + // Initialize queue of zeros. + queue := getIntSlice(0) + defer putIntSlice(queue) for i := range n { if !active.test(i) { continue @@ -696,11 +639,13 @@ func trimSingletons(csr *CSR, active *bitset, comp []int, nextID *int) int { if inDeg[u] > 0 && outDeg[u] > 0 { continue } + // assign and remove u comp[u] = *nextID *nextID++ active.clearAtomic(u) peeled++ + // Decrement out-neighbors' inDeg rs, re := csr.Row[u], csr.Row[u+1] for p := rs; p < re; p++ { @@ -732,58 +677,3 @@ func trimSingletons(csr *CSR, active *bitset, comp []int, nextID *int) int { } return peeled } - -// Add a simple sequential Tarjan SCC for small graphs. -func tarjanSCC(csr *CSR) []int { - n := csr.N - index := 0 - indices := make([]int, n) - lowlink := make([]int, n) - onstack := make([]bool, n) - stack := make([]int, 0, n) - comp := make([]int, n) - for i := range n { - indices[i] = -1 - comp[i] = -1 - } - compID := 0 - var strongConnect func(v int) - strongConnect = func(v int) { - indices[v] = index - lowlink[v] = index - index++ - stack = append(stack, v) - onstack[v] = true - for p := csr.Row[v]; p < csr.Row[v+1]; p++ { - w := csr.Col[p] - if indices[w] == -1 { - strongConnect(w) - if lowlink[w] < lowlink[v] { - lowlink[v] = lowlink[w] - } - } else if onstack[w] { - if indices[w] < lowlink[v] { - lowlink[v] = indices[w] - } - } - } - if lowlink[v] == indices[v] { - for { - w := stack[len(stack)-1] - stack = stack[:len(stack)-1] - onstack[w] = false - comp[w] = compID - if w == v { - break - } - } - compID++ - } - } - for v := range n { - if indices[v] == -1 { - strongConnect(v) - } - } - return comp -} diff --git a/pkg/sync/expand/scc/scc_fuzz_test.go b/pkg/sync/expand/scc/scc_fuzz_test.go new file mode 100644 index 000000000..9348d5873 --- /dev/null +++ b/pkg/sync/expand/scc/scc_fuzz_test.go @@ -0,0 +1,440 @@ +package scc + +import ( + "context" + "encoding/binary" + "math" + "math/rand" + "reflect" + "testing" + "time" +) + +func clamp(x, lo, hi int) int { + if x < lo { + return lo + } + if x > hi { + return hi + } + return x +} + +func equalGroups(a, b [][]int) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !reflect.DeepEqual(a[i], b[i]) { + return false + } + } + return true +} + +// assertPartition ensures every key in adj appears in exactly one group; no duplicates. +func assertPartition(t *testing.T, adj map[int]map[int]int, groups [][]int) { + t.Helper() + seen := make(map[int]int, len(adj)) + for gid, g := range groups { + for _, id := range g { + if _, ok := seen[id]; ok { + t.Fatalf("node %d appears in multiple groups", id) + } + seen[id] = gid + } + } + for u := range adj { + if _, ok := seen[u]; !ok { + t.Fatalf("node %d missing from partition", u) + } + } +} + +// assertDAGCondensation builds component meta-graph and checks it is acyclic. +func assertDAGCondensation(t *testing.T, adj map[int]map[int]int, groups [][]int) { + t.Helper() + idToComp := make(map[int]int, len(adj)) + for cid, g := range groups { + for _, id := range g { + idToComp[id] = cid + } + } + compAdj := make(map[int]map[int]struct{}, len(groups)) + for u := range groups { + compAdj[u] = make(map[int]struct{}) + } + for u, nbrs := range adj { + cu := idToComp[u] + for v := range nbrs { + cv := idToComp[v] + if cu == cv { + continue + } + compAdj[cu][cv] = struct{}{} + } + } + indeg := make([]int, len(groups)) + for u := range compAdj { + for v := range compAdj[u] { + indeg[v]++ + } + } + q := make([]int, 0, len(groups)) + for u := 0; u < len(groups); u++ { + if indeg[u] == 0 { + q = append(q, u) + } + } + visited := 0 + for len(q) > 0 { + u := q[0] + q = q[1:] + visited++ + for v := range compAdj[u] { + indeg[v]-- + if indeg[v] == 0 { + q = append(q, v) + } + } + } + if visited != len(groups) { + t.Fatalf("component condensation has a cycle: visited=%d total=%d", visited, len(groups)) + } +} + +// generateAdjacency creates a bounded graph according to mode; returns map[int]map[int]int with all nodes as keys. +func generateAdjacency(numNodes, edgeBudget, mode int, r *rand.Rand, selfLoopFrac, bidirFrac int) map[int]map[int]int { + if numNodes <= 0 { + numNodes = 1 + } + adj := make(map[int]map[int]int, numNodes) + for i := 0; i < numNodes; i++ { + adj[i] = make(map[int]int) + } + + addEdge := func(u, v int) { + if u < 0 || u >= numNodes || v < 0 || v >= numNodes { + return + } + if adj[u] == nil { + adj[u] = make(map[int]int) + } + adj[u][v] = 1 + } + + edgesAdded := 0 + budget := edgeBudget + maxBudget := numNodes * numNodes + if budget > maxBudget { + budget = maxBudget + } + + switch mode % 8 { + case 0: // random directed + for edgesAdded < budget { + u := r.Intn(numNodes) + v := r.Intn(numNodes) + addEdge(u, v) + edgesAdded++ + if r.Intn(256) < bidirFrac { + addEdge(v, u) + } + if r.Intn(256) < selfLoopFrac { + addEdge(u, u) + } + } + case 1: // many disjoint 2-cycles + isolates + for i := 0; i+1 < numNodes && edgesAdded+2 <= budget; i += 2 { + addEdge(i, i+1) + addEdge(i+1, i) + edgesAdded += 2 + } + case 2: // lollipop: clique K_m + tail T + m := int(math.Sqrt(float64(numNodes))) + if m < 2 { + m = 2 + } + if m > numNodes { + m = numNodes + } + T := numNodes - m + for i := 0; i < m; i++ { + for j := 0; j < m; j++ { + if i == j || edgesAdded >= budget { + continue + } + addEdge(i, j) + edgesAdded++ + } + } + if T > 0 { + addEdge(m-1, m) + edgesAdded++ + for i := m; i+1 < numNodes && edgesAdded < budget; i++ { + addEdge(i, i+1) + edgesAdded++ + } + } + case 3: // bipartite; optionally bidirectional + a := numNodes / 2 + if a == 0 { + a = 1 + } + for i := 0; i < a; i++ { + for j := a; j < numNodes && edgesAdded < budget; j++ { + addEdge(i, j) + edgesAdded++ + if r.Intn(256) < bidirFrac && edgesAdded < budget { + addEdge(j, i) + edgesAdded++ + } + } + } + case 4: // multi-ring stitched by tails + start := 0 + for start < numNodes && edgesAdded < budget { + size := 3 + r.Intn(5) + if start+size > numNodes { + size = numNodes - start + } + if size >= 2 { + for i := 0; i < size; i++ { + u := start + i + v := start + ((i + 1) % size) + addEdge(u, v) + edgesAdded++ + if edgesAdded >= budget { + break + } + } + } + // one-way tail to next block + next := start + size + if next < numNodes && edgesAdded < budget { + addEdge(start+size-1, next) + edgesAdded++ + } + start += size + } + case 5: // star hub asymmetry + hub := r.Intn(numNodes) + for i := 0; i < numNodes && edgesAdded < budget; i++ { + if i == hub { + continue + } + addEdge(hub, i) + edgesAdded++ + if r.Intn(256) < bidirFrac && edgesAdded < budget { + addEdge(i, hub) + edgesAdded++ + } + } + case 6: // skewed external IDs (still using 0..N-1 as keys here; CSR handles mapping) + for edgesAdded < budget { + u := r.Intn(numNodes) + v := (r.Intn(numNodes) * 13) % numNodes + addEdge(u, v) + edgesAdded++ + } + case 7: // layered DAG with sparse backedges + layers := 1 + r.Intn(8) + per := (numNodes + layers - 1) / layers + // forward edges between layers + for L := 0; L+1 < layers && edgesAdded < budget; L++ { + aStart := L * per + aEnd := (L + 1) * per + if aEnd > numNodes { + aEnd = numNodes + } + bStart := (L + 1) * per + bEnd := (L + 2) * per + if bEnd > numNodes { + bEnd = numNodes + } + for u := aStart; u < aEnd && edgesAdded < budget; u++ { + for v := bStart; v < bEnd && edgesAdded < budget; v++ { + if r.Intn(3) == 0 { // sparsify + addEdge(u, v) + edgesAdded++ + } + } + } + } + // sparse backedges inside a layer + for L := 0; L < layers && edgesAdded < budget; L++ { + s := L * per + e := (L + 1) * per + if e > numNodes { + e = numNodes + } + for u := s; u < e; u++ { + bound := e - s + if bound < 1 { + bound = 1 + } + if r.Intn(10) == 0 && edgesAdded < budget { + v := s + r.Intn(bound) + addEdge(u, v) + edgesAdded++ + } + } + } + } + + // occasional self-loops + if selfLoopFrac > 0 { + for i := 0; i < numNodes; i++ { + if r.Intn(256) < selfLoopFrac { + addEdge(i, i) + } + } + } + return adj +} + +// Fuzzers + +// Cancellation fuzzer: short deadline; only assert return (no structural checks). +func FuzzCondenseFWBW_Cancellation(f *testing.F) { + f.Add(512, uint64(4), 2048, uint8(2), uint8(8), uint8(0)) + f.Fuzz(func(t *testing.T, numNodes int, seed uint64, edgeBudget int, mode uint8, selfLoopFrac uint8, bidirFrac uint8) { + numNodes = clamp(numNodes, 1, 1000) + if edgeBudget > 100000 { + edgeBudget = 100000 + } + maxEdges := numNodes * numNodes + if edgeBudget > maxEdges { + edgeBudget = maxEdges + } + r := rand.New(rand.NewSource(int64(seed))) //nolint:gosec // math/rand is acceptable for fuzzing/tests + adj := generateAdjacency(numNodes, edgeBudget, int(mode%8), r, int(selfLoopFrac), int(bidirFrac)) + opts := DefaultOptions() + opts.MaxWorkers = 1 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel() + _, _ = CondenseFWBW(ctx, adjSource{adj: adj}, opts) + }) +} + +// --- Byte-based fuzzer --- + +// decodeVarint decodes a little varint (LEB128-style) from data starting at *i. +// Returns value and whether decoding succeeded. +func decodeVarint(data []byte, i *int) (uint64, bool) { + var x uint64 + var s uint + for { + if *i >= len(data) || s >= 64 { + return 0, false + } + b := data[*i] + *i++ + if b < 0x80 { + x |= uint64(b) << s + return x, true + } + x |= uint64(b&0x7f) << s + s += 7 + } +} + +// generateAdjFromBytes builds an adjacency from a byte stream with caps. +func generateAdjFromBytes(data []byte, maxN, maxM int) map[int]map[int]int { + if len(data) == 0 { + return map[int]map[int]int{0: {}} + } + i := 0 + n64, ok := decodeVarint(data, &i) + if !ok { + n64 = 1 + } + m64, ok := decodeVarint(data, &i) + if !ok { + m64 = 0 + } + // Helper utilities. + var n int + switch { + case n64 < 1: + n = 1 + case n64 > uint64(maxN): //nolint:gosec // maxN is a small, non-negative test bound + n = maxN + case n64 > uint64(^uint(0)>>1): + n = maxN + default: + n = int(n64) + } + var m int + switch { + case m64 > uint64(maxM): //nolint:gosec // maxM is a non-negative test bound + m = maxM + case m64 > uint64(^uint(0)>>1): + m = maxM + default: + m = int(m64) + } + adj := make(map[int]map[int]int, n) + for v := 0; v < n; v++ { + adj[v] = make(map[int]int) + } + // Edge pairs + for e := 0; e < m && i < len(data); e++ { + // If not enough bytes left, break + if i+8 > len(data) { + break + } + u := int(binary.LittleEndian.Uint32(data[i:])) % n + i += 4 + v := int(binary.LittleEndian.Uint32(data[i:])) % n + i += 4 + adj[u][v] = 1 + } + // Optional flags for reverse edges and self-loops + if i < len(data) { + flags := data[i] + // bit0: add reverse for all edges + if flags&0x1 != 0 { + for u, nbrs := range adj { + for v := range nbrs { + adj[v][u] = 1 + } + } + } + // bit1: sprinkle self-loops for some nodes based on subsequent bytes + if flags&0x2 != 0 { + i++ + for v := 0; v < n && i < len(data); v++ { + if data[i]%3 == 0 { + adj[v][v] = 1 + } + i++ + } + } + } + return adj +} + +func FuzzCondenseFWBW_FromBytes(f *testing.F) { + // Seed with simple patterns + f.Add([]byte{5, 10, 0, 0, 0, 0, 1, 0, 0, 0}) // n=5,m=10, one edge 0->1 + f.Add([]byte{10, 20}) // small n,m with empty pairs + + f.Fuzz(func(t *testing.T, data []byte) { + // Caps for CI-friendly fuzz + const maxN = 1000 + const maxM = 100000 + adj := generateAdjFromBytes(data, maxN, maxM) + opts := DefaultOptions() + opts.Deterministic = true + opts.MaxWorkers = 1 + groups, _ := CondenseFWBW(context.Background(), adjSource{adj: adj}, opts) + assertPartition(t, adj, groups) + assertDAGCondensation(t, adj, groups) + // idempotence in deterministic mode + groups2, _ := CondenseFWBW(context.Background(), adjSource{adj: adj}, opts) + if !equalGroups(normalizeGroups(groups), normalizeGroups(groups2)) { + t.Fatalf("non-deterministic result with Deterministic=true") + } + }) +} diff --git a/pkg/sync/expand/scc/scc_nohang_test.go b/pkg/sync/expand/scc/scc_nohang_test.go new file mode 100644 index 000000000..48fcd66fc --- /dev/null +++ b/pkg/sync/expand/scc/scc_nohang_test.go @@ -0,0 +1,339 @@ +package scc + +import ( + "context" + "strconv" + "testing" + "time" +) + +// withTimeout runs f and fails the test if it doesn't complete within d. +func withTimeout(t *testing.T, d time.Duration, f func(t *testing.T)) { + t.Helper() + done := make(chan struct{}) + go func() { + defer close(done) + f(t) + }() + select { + case <-done: + return + case <-time.After(d): + t.Fatalf("function did not complete within %v (possible hang)", d) + } +} + +// adversarialGraphs returns a set of graphs that exercise different recursion paths. +func adversarialGraphs() []map[int]map[int]int { + var graphs []map[int]map[int]int + + // Small ring (single SCC) + { + n := 16 + nodes := make([]int, n) + for i := 0; i < n; i++ { + nodes[i] = i + } + var edges [][2]int + for i := 0; i < n; i++ { + edges = append(edges, [2]int{i, (i + 1) % n}) + } + graphs = append(graphs, makeAdj(nodes, edges)) + } + + // Chain (all acyclic singletons) + { + n := 64 + nodes := make([]int, n) + for i := 0; i < n; i++ { + nodes[i] = i + } + var edges [][2]int + for i := 0; i+1 < n; i++ { + edges = append(edges, [2]int{i, i + 1}) + } + graphs = append(graphs, makeAdj(nodes, edges)) + } + + // Two disjoint rings plus some isolated nodes + { + nodes := []int{0, 1, 2, 10, 11, 12, 13, 100, 101} + edges := [][2]int{{0, 1}, {1, 2}, {2, 0}, {10, 11}, {11, 12}, {12, 13}, {13, 10}} + graphs = append(graphs, makeAdj(nodes, edges)) + } + + // Dense bidirectional bipartite (one big SCC) + { + a, b := 8, 8 + var nodes []int + for i := 0; i < a+b; i++ { + nodes = append(nodes, i) + } + var edges [][2]int + for i := 0; i < a; i++ { + for j := a; j < a+b; j++ { + edges = append(edges, [2]int{i, j}) + edges = append(edges, [2]int{j, i}) + } + } + graphs = append(graphs, makeAdj(nodes, edges)) + } + + // Tail into ring (classic FW–BW partitions) + { + ringN := 10 + tailN := 12 + var nodes []int + for i := 0; i < ringN; i++ { + nodes = append(nodes, i) + } + for i := 0; i < tailN; i++ { + nodes = append(nodes, 1000+i) + } + var edges [][2]int + // ring + for i := 0; i < ringN; i++ { + edges = append(edges, [2]int{i, (i + 1) % ringN}) + } + // tail into ring start (0) + for i := 0; i+1 < tailN; i++ { + edges = append(edges, [2]int{1000 + i, 1000 + i + 1}) + } + edges = append(edges, [2]int{1000 + tailN - 1, 0}) + graphs = append(graphs, makeAdj(nodes, edges)) + } + + return graphs +} + +// TestNoHang_GeneralCase forces FW–BW general-case recursion and ensures each adversarial graph completes quickly. +func TestNoHang_GeneralCase(t *testing.T) { + graphs := adversarialGraphs() + opts := DefaultOptions() + opts.MaxWorkers = 4 + for gi, adj := range graphs { + gi, adj := gi, adj + t.Run( + funcName("general", gi, false), + func(t *testing.T) { + withTimeout(t, 2*time.Second, func(t *testing.T) { + _, _ = CondenseFWBW(context.Background(), adjSource{adj: adj}, opts) + }) + }, + ) + } +} + +// TestNoHang_BaseCase removed; single unified path now. + +// funcName formats a helpful subtest name. +func funcName(kind string, idx int, _ bool) string { + return kind + "_#" + strconv.Itoa(idx) +} + +// ---- Generators for adversarial graphs ---- + +func genChain(n int) map[int]map[int]int { + nodes := make([]int, n) + for i := 0; i < n; i++ { + nodes[i] = i + } + var edges [][2]int + for i := 0; i+1 < n; i++ { + edges = append(edges, [2]int{i, i + 1}) + } + return makeAdj(nodes, edges) +} + +func genLollipop(m, t int) map[int]map[int]int { + var nodes []int + for i := 0; i < m+t; i++ { + nodes = append(nodes, i) + } + var edges [][2]int + // clique K_m + for i := 0; i < m; i++ { + for j := 0; j < m; j++ { + if i == j { + continue + } + edges = append(edges, [2]int{i, j}) + } + } + // tail + edges = append(edges, [2]int{m - 1, m}) + for i := m; i+1 < m+t; i++ { + edges = append(edges, [2]int{i, i + 1}) + } + return makeAdj(nodes, edges) +} + +func genBipartite(a, b int, both bool) map[int]map[int]int { + var nodes []int + for i := 0; i < a+b; i++ { + nodes = append(nodes, i) + } + var edges [][2]int + for i := 0; i < a; i++ { + for j := a; j < a+b; j++ { + edges = append(edges, [2]int{i, j}) + if both { + edges = append(edges, [2]int{j, i}) + } + } + } + return makeAdj(nodes, edges) +} + +func genDisjointTwoCycles(k int) map[int]map[int]int { + // nodes: 0..2k-1, edges: (2i <-> 2i+1) + var nodes []int + var edges [][2]int + for i := 0; i < k; i++ { + u := 2 * i + v := u + 1 + nodes = append(nodes, u, v) + edges = append(edges, [2]int{u, v}, [2]int{v, u}) + } + return makeAdj(nodes, edges) +} + +func genIsolates(n int) map[int]map[int]int { + var nodes []int + for i := 0; i < n; i++ { + nodes = append(nodes, i) + } + return makeAdj(nodes, nil) +} + +func genSelfLoops(n int) map[int]map[int]int { + var nodes []int + var edges [][2]int + for i := 0; i < n; i++ { + nodes = append(nodes, i) + edges = append(edges, [2]int{i, i}) + } + return makeAdj(nodes, edges) +} + +// ---- Adversarial specific tests ---- + +func TestNoHang_DeepChain(t *testing.T) { + adj := genChain(4000) + for _, cfg := range []struct { + name string + opt Options + }{ + {"recursion_w1", func() Options { + o := DefaultOptions() + o.MaxWorkers = 1 + return o + }()}, + {"recursion_w2", func() Options { + o := DefaultOptions() + o.MaxWorkers = 2 + return o + }()}, + } { + cfg := cfg + t.Run(cfg.name, func(t *testing.T) { + withTimeout(t, 2*time.Second, func(t *testing.T) { + _, _ = CondenseFWBW(context.Background(), adjSource{adj: adj}, cfg.opt) + }) + }) + } +} + +func TestNoHang_Lollipop(t *testing.T) { + adj := genLollipop(64, 512) + for _, cfg := range []struct { + name string + opt Options + }{ + {"recursion_w4", func() Options { + o := DefaultOptions() + o.MaxWorkers = 4 + return o + }()}, + {"recursion_w2", func() Options { + o := DefaultOptions() + o.MaxWorkers = 2 + return o + }()}, + } { + cfg := cfg + t.Run(cfg.name, func(t *testing.T) { + withTimeout(t, 2*time.Second, func(t *testing.T) { + _, _ = CondenseFWBW(context.Background(), adjSource{adj: adj}, cfg.opt) + }) + }) + } +} + +func TestNoHang_Isolates_And_SelfLoops(t *testing.T) { + iso := genIsolates(3000) + self := genSelfLoops(3000) + for _, tc := range []struct { + name string + adj map[int]map[int]int + }{ + {"isolates", iso}, + {"selfloops", self}, + } { + t.Run(tc.name, func(t *testing.T) { + withTimeout(t, 3*time.Second, func(t *testing.T) { + o := DefaultOptions() + o.MaxWorkers = 4 + _, _ = CondenseFWBW(context.Background(), adjSource{adj: tc.adj}, o) + }) + }) + } +} + +func TestNoHang_BipartiteDense(t *testing.T) { + adj := genBipartite(128, 128, true) + for _, mw := range []int{1, 4} { + mw := mw + t.Run("mw_"+strconv.Itoa(mw), func(t *testing.T) { + withTimeout(t, 2*time.Second, func(t *testing.T) { + o := DefaultOptions() + o.MaxWorkers = mw + _, _ = CondenseFWBW(context.Background(), adjSource{adj: adj}, o) + }) + }) + } +} + +func TestNoHang_ManyTwoCycles(t *testing.T) { + adj := genDisjointTwoCycles(2000) + withTimeout(t, 2*time.Second, func(t *testing.T) { + o := DefaultOptions() + o.MaxWorkers = 4 + _, _ = CondenseFWBW(context.Background(), adjSource{adj: adj}, o) + }) +} + +func TestCancel_HeavyGraphs(t *testing.T) { + // Use contexts with very short deadlines; ensure prompt return. + lollipop := genLollipop(128, 2048) + bip := genBipartite(256, 256, true) + for _, tc := range []struct { + name string + adj map[int]map[int]int + }{ + {"lollipop", lollipop}, + {"bipartite", bip}, + } { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 75*time.Millisecond) + defer cancel() + start := time.Now() + o := DefaultOptions() + o.MaxWorkers = 8 + _, _ = CondenseFWBW(ctx, adjSource{adj: tc.adj}, o) + if time.Since(start) > 200*time.Millisecond { + t.Fatalf("cancellation not honored promptly") + } + }) + } +} diff --git a/pkg/sync/expand/scc/scc_test.go b/pkg/sync/expand/scc/scc_test.go index b6a395d96..e4fb2dd7f 100644 --- a/pkg/sync/expand/scc/scc_test.go +++ b/pkg/sync/expand/scc/scc_test.go @@ -59,13 +59,16 @@ func TestRingSingleComponent(t *testing.T) { } adj := makeAdj(nodes, edges) - groups := CondenseFWBWGroupsFromAdj(context.Background(), adj, defaultOpts()) + groups, m := CondenseFWBW(context.Background(), adjSource{adj: adj}, defaultOpts()) if len(groups) != 1 { t.Fatalf("expected 1 component, got %d: %+v", len(groups), groups) } if len(groups[0]) != n { t.Fatalf("expected ring size %d, got %d", n, len(groups[0])) } + if m == nil || m.Components != 1 || m.Nodes != n || m.Edges != n || m.Peeled != 0 { + t.Fatalf("metrics unexpected: %+v", m) + } } func TestChainAllSingletons(t *testing.T) { @@ -80,7 +83,7 @@ func TestChainAllSingletons(t *testing.T) { } adj := makeAdj(nodes, edges) - groups := CondenseFWBWGroupsFromAdj(context.Background(), adj, defaultOpts()) + groups, m := CondenseFWBW(context.Background(), adjSource{adj: adj}, defaultOpts()) if len(groups) != n { t.Fatalf("expected %d singleton components, got %d", n, len(groups)) } @@ -89,6 +92,9 @@ func TestChainAllSingletons(t *testing.T) { t.Fatalf("expected singleton at comp %d, got size %d", idx, len(g)) } } + if m == nil || m.Components != n || m.Nodes != n || m.Edges != n-1 || m.Peeled != n || m.BFScalls != 0 { + t.Fatalf("metrics unexpected: %+v", m) + } } func TestSelfLoopIsolatedCyclicSingleton(t *testing.T) { @@ -96,7 +102,7 @@ func TestSelfLoopIsolatedCyclicSingleton(t *testing.T) { edges := [][2]int{{1, 1}, {1, 2}} adj := makeAdj(nodes, edges) - groups := CondenseFWBWGroupsFromAdj(context.Background(), adj, defaultOpts()) + groups, m := CondenseFWBW(context.Background(), adjSource{adj: adj}, defaultOpts()) // Expect two components: {1} and {2} if len(groups) != 2 { t.Fatalf("expected 2 components, got %d: %+v", len(groups), groups) @@ -118,6 +124,9 @@ func TestSelfLoopIsolatedCyclicSingleton(t *testing.T) { if adj[1][1] == 0 { t.Fatalf("expected self-loop for node 1 in adjacency") } + if m == nil || m.Components != 2 || m.Nodes != 2 || m.Peeled < 1 { // at least node 2 is peeled + t.Fatalf("metrics unexpected: %+v", m) + } } func TestCliqueSingleComponent(t *testing.T) { @@ -136,10 +145,13 @@ func TestCliqueSingleComponent(t *testing.T) { } } adj := makeAdj(nodes, edges) - groups := CondenseFWBWGroupsFromAdj(context.Background(), adj, defaultOpts()) + groups, m := CondenseFWBW(context.Background(), adjSource{adj: adj}, defaultOpts()) if len(groups) != 1 || len(groups[0]) != n { t.Fatalf("expected one SCC of size %d, got %+v", n, groups) } + if m == nil || m.Components != 1 || m.Nodes != n || m.Peeled != 0 || m.Edges != n*(n-1) { + t.Fatalf("metrics unexpected: %+v", m) + } } func TestTailIntoRing(t *testing.T) { @@ -159,7 +171,7 @@ func TestTailIntoRing(t *testing.T) { edges = append(edges, [2]int{tail[0], tail[1]}, [2]int{tail[1], tail[2]}, [2]int{tail[2], 0}) adj := makeAdj(nodes, edges) - groups := CondenseFWBWGroupsFromAdj(context.Background(), adj, defaultOpts()) + groups, m := CondenseFWBW(context.Background(), adjSource{adj: adj}, defaultOpts()) // Expect: 1 SCC of size ringN, plus len(tail) singletons if len(groups) != 1+len(tail) { t.Fatalf("expected %d components, got %d: %+v", 1+len(tail), len(groups), groups) @@ -178,6 +190,9 @@ func TestTailIntoRing(t *testing.T) { if !reflect.DeepEqual(sizes, want) { t.Fatalf("component sizes mismatch: got %v want %v", sizes, want) } + if m == nil || m.Components != 1+len(tail) || m.Nodes != ringN+len(tail) || m.Peeled != len(tail) { + t.Fatalf("metrics unexpected: %+v", m) + } } func TestMultipleDisjointSCCs(t *testing.T) { @@ -188,7 +203,7 @@ func TestMultipleDisjointSCCs(t *testing.T) { edges := [][2]int{{0, 1}, {1, 2}, {2, 0}, {10, 11}, {11, 12}, {12, 13}, {13, 10}} adj := makeAdj(nodes, edges) - groups := CondenseFWBWGroupsFromAdj(context.Background(), adj, defaultOpts()) + groups, m := CondenseFWBW(context.Background(), adjSource{adj: adj}, defaultOpts()) // Filter out any empty groups (should not occur logically, but tolerate // preallocated empty slots when packing components). sizes := make([]int, 0, len(groups)) @@ -215,6 +230,9 @@ func TestMultipleDisjointSCCs(t *testing.T) { t.Fatalf("node %d missing from partition", id) } } + if m == nil || m.Components != 4 || m.Peeled != 2 { + t.Fatalf("metrics unexpected: %+v", m) + } } func TestDeterminismWithSingleWorker(t *testing.T) { @@ -226,7 +244,7 @@ func TestDeterminismWithSingleWorker(t *testing.T) { var ref [][]int for i := 0; i < 5; i++ { - groups := CondenseFWBWGroupsFromAdj(context.Background(), adj, opts) + groups, _ := CondenseFWBW(context.Background(), adjSource{adj: adj}, opts) ng := normalizeGroups(groups) if i == 0 { ref = ng diff --git a/pkg/sync/expand/scc/test_source.go b/pkg/sync/expand/scc/test_source.go new file mode 100644 index 000000000..16b5b244b --- /dev/null +++ b/pkg/sync/expand/scc/test_source.go @@ -0,0 +1,24 @@ +package scc + +// adjSource adapts a map[int]map[int]int adjacency to the Source interface for tests. +type adjSource struct { + adj map[int]map[int]int +} + +func (a adjSource) ForEachNode(fn func(id int) bool) { + for id := range a.adj { + if !fn(id) { + return + } + } +} + +func (a adjSource) ForEachEdgeFrom(src int, fn func(dst int) bool) { + if row, ok := a.adj[src]; ok { + for dst := range row { + if !fn(dst) { + return + } + } + } +} diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index f3109687f..7127ba879 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -1362,12 +1362,13 @@ func (s *syncer) SyncGrantExpansion(ctx context.Context) error { } if entitlementGraph.Loaded { - comps := entitlementGraph.ComputeCyclicComponents(ctx) + comps, sccMetrics := entitlementGraph.ComputeCyclicComponents(ctx) if len(comps) > 0 { // Log a sample cycle l.Warn( "cycle detected in entitlement graph", zap.Any("cycle", comps[0]), + zap.Any("scc_metrics", sccMetrics), ) l.Debug("initial graph", zap.Any("initial graph", entitlementGraph)) if dontFixCycles {