Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 62 additions & 103 deletions pkg/sync/expand/cycle.go
Original file line number Diff line number Diff line change
@@ -1,114 +1,70 @@
package expand

import (
mapset "github.com/deckarep/golang-set/v2"
)
"context"

const (
colorWhite uint8 = iota
colorGray
colorBlack
"github.com/conductorone/baton-sdk/pkg/sync/expand/scc"
mapset "github.com/deckarep/golang-set/v2"
)

// cycleDetector encapsulates coloring state for cycle detection on an
// EntitlementGraph. Node IDs are dense (1..NextNodeID), so slices are used for
// O(1) access and zero per-op allocations.
type cycleDetector struct {
g *EntitlementGraph
state []uint8
parent []int
}

func newCycleDetector(g *EntitlementGraph) *cycleDetector {
cd := &cycleDetector{
g: g,
state: make([]uint8, g.NextNodeID+1),
parent: make([]int, g.NextNodeID+1),
// GetFirstCycle given an entitlements graph, return a cycle by node ID if it
// exists. Returns nil if no cycle exists. If there is a single
// node pointing to itself, that will count as a cycle.
func (g *EntitlementGraph) GetFirstCycle(ctx context.Context) []int {
if g.HasNoCycles {
return nil
}
for i := range cd.parent {
cd.parent[i] = -1
comps := g.ComputeCyclicComponents(ctx)
if len(comps) == 0 {
return nil
}
return cd
return comps[0]
}

// dfs performs a coloring-based DFS from u, returning the first detected cycle
// as a slice of node IDs or nil if no cycle is reachable from u.
func (cd *cycleDetector) dfs(u int) ([]int, bool) {
// Self-loop fast path.
if nbrs, ok := cd.g.SourcesToDestinations[u]; ok {
if _, ok := nbrs[u]; ok {
return []int{u}, true
}
}

cd.state[u] = colorGray
if nbrs, ok := cd.g.SourcesToDestinations[u]; ok {
for v := range nbrs {
switch cd.state[v] {
case colorWhite:
cd.parent[v] = u
if cyc, ok := cd.dfs(v); ok {
return cyc, true
}
case colorGray:
// Back-edge to a node on the current recursion stack.
// Reconstruct cycle by walking parents from u back to v (inclusive), then reverse.
cycle := make([]int, 0, 8)
for x := u; ; x = cd.parent[x] {
cycle = append(cycle, x)
if x == v || cd.parent[x] == -1 {
break
}
}
for i, j := 0, len(cycle)-1; i < j; i, j = i+1, j-1 {
cycle[i], cycle[j] = cycle[j], cycle[i]
}
return cycle, true
}
}
// HasCycles returns true if the graph contains any cycle.
func (g *EntitlementGraph) HasCycles(ctx context.Context) bool {
if g.HasNoCycles {
return false
}
cd.state[u] = colorBlack
return nil, false
return len(g.ComputeCyclicComponents(ctx)) > 0
}

// FindAny scans all nodes and returns the first detected cycle or nil if none exist.
func (cd *cycleDetector) FindAny() []int {
for nodeID := range cd.g.Nodes {
if cd.state[nodeID] != colorWhite {
continue
}
if cyc, ok := cd.dfs(nodeID); ok {
return cyc
func (g *EntitlementGraph) cycleDetectionHelper(
nodeID int,
) ([]int, bool) {
reach := g.reachableFrom(nodeID)
if len(reach) == 0 {
return nil, false
}
adj := g.toAdjacency(reach)
groups := scc.CondenseFWBWGroupsFromAdj(context.Background(), adj, scc.DefaultOptions())
for _, comp := range groups {
if len(comp) > 1 || (len(comp) == 1 && adj[comp[0]][comp[0]] != 0) {
return comp, true
}
}
return nil
return nil, false
}

// FindFrom starts cycle detection from a specific node and returns the first
// cycle reachable from that node, or nil,false if none.
func (cd *cycleDetector) FindFrom(start int) ([]int, bool) {
return cd.dfs(start)
func (g *EntitlementGraph) FixCycles(ctx context.Context) error {
return g.FixCyclesFromComponents(ctx, g.ComputeCyclicComponents(ctx))
}

// GetFirstCycle given an entitlements graph, return a cycle by node ID if it
// exists. Returns nil if no cycle exists. If there is a single
// node pointing to itself, that will count as a cycle.
func (g *EntitlementGraph) GetFirstCycle() []int {
// ComputeCyclicComponents runs SCC once and returns only cyclic components.
// A component is cyclic if len>1 or a singleton with a self-loop.
func (g *EntitlementGraph) ComputeCyclicComponents(ctx context.Context) [][]int {
if g.HasNoCycles {
return nil
}
cd := newCycleDetector(g)
return cd.FindAny()
}

func (g *EntitlementGraph) cycleDetectionHelper(
nodeID int,
) ([]int, bool) {
// Thin wrapper around the coloring-based DFS, starting from a specific node.
// The provided visited/currentCycle are ignored here; coloring provides the
// necessary state for correctness and performance.
cd := newCycleDetector(g)
return cd.FindFrom(nodeID)
adj := g.toAdjacency(nil)
groups := scc.CondenseFWBWGroupsFromAdj(ctx, adj, scc.DefaultOptions())
cyclic := make([][]int, 0)
for _, comp := range groups {
if len(comp) > 1 || (len(comp) == 1 && adj[comp[0]][comp[0]] != 0) {
cyclic = append(cyclic, comp)
}
}
return cyclic
}

// removeNode obliterates a node and all incoming/outgoing edges.
Expand Down Expand Up @@ -145,30 +101,33 @@ func (g *EntitlementGraph) removeNode(nodeID int) {
delete(g.SourcesToDestinations, nodeID)
}

// FixCycles if any cycles of nodes exist, merge all nodes in that cycle into a
// single node and then repeat. Iteration ends when there are no more cycles.
func (g *EntitlementGraph) FixCycles() error {
// FixCyclesFromComponents merges all provided cyclic components in one pass.
func (g *EntitlementGraph) FixCyclesFromComponents(ctx context.Context, cyclic [][]int) error {
if g.HasNoCycles {
return nil
}
cycle := g.GetFirstCycle()
if cycle == nil {
if len(cyclic) == 0 {
g.HasNoCycles = true
return nil
}

if err := g.fixCycle(cycle); err != nil {
return err
for _, comp := range cyclic {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err := g.fixCycle(comp); err != nil {
return err
}
}

// Recurse!
return g.FixCycles()
g.HasNoCycles = true
return nil
}

// fixCycle takes a list of Node IDs that form a cycle and merges them into a
// single, new node.
func (g *EntitlementGraph) fixCycle(nodeIDs []int) error {
entitlementIDs := mapset.NewSet[string]()
entitlementIDs := mapset.NewThreadUnsafeSet[string]()
outgoingEdgesToResourceTypeIDs := map[int]mapset.Set[string]{}
incomingEdgesToResourceTypeIDs := map[int]mapset.Set[string]{}
for _, nodeID := range nodeIDs {
Expand All @@ -184,7 +143,7 @@ func (g *EntitlementGraph) fixCycle(nodeIDs []int) error {
if edge, ok := g.Edges[edgeID]; ok {
resourceTypeIDs, ok := incomingEdgesToResourceTypeIDs[sourceNodeID]
if !ok {
resourceTypeIDs = mapset.NewSet[string]()
resourceTypeIDs = mapset.NewThreadUnsafeSet[string]()
}
for _, resourceTypeID := range edge.ResourceTypeIDs {
resourceTypeIDs.Add(resourceTypeID)
Expand All @@ -200,7 +159,7 @@ func (g *EntitlementGraph) fixCycle(nodeIDs []int) error {
if edge, ok := g.Edges[edgeID]; ok {
resourceTypeIDs, ok := outgoingEdgesToResourceTypeIDs[destinationNodeID]
if !ok {
resourceTypeIDs = mapset.NewSet[string]()
resourceTypeIDs = mapset.NewThreadUnsafeSet[string]()
}
for _, resourceTypeID := range edge.ResourceTypeIDs {
resourceTypeIDs.Add(resourceTypeID)
Expand Down
6 changes: 4 additions & 2 deletions pkg/sync/expand/cycle_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,16 @@ func BenchmarkCycleDetectionHelper(b *testing.B) {
}

func BenchmarkGetFirstCycle(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sizes := []int{100, 1000}

for _, n := range sizes {
b.Run(fmt.Sprintf("ring-%d", n), func(b *testing.B) {
g := buildRing(b, n)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = g.GetFirstCycle()
_ = g.GetFirstCycle(ctx)
}
})
}
Expand All @@ -168,7 +170,7 @@ func BenchmarkGetFirstCycle(b *testing.B) {
g := buildChain(b, n)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = g.GetFirstCycle()
_ = g.GetFirstCycle(ctx)
}
})
}
Expand Down
61 changes: 61 additions & 0 deletions pkg/sync/expand/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,64 @@ func (g *EntitlementGraph) DeleteEdge(ctx context.Context, srcEntitlementID stri
}
return nil
}

// toAdjacency builds an adjacency map for SCC. If nodesSubset is non-nil, only
// include those nodes (and edges between them). Always include all nodes in the
// subset as keys, even if they have zero outgoing edges.
func (g *EntitlementGraph) toAdjacency(nodesSubset map[int]struct{}) map[int]map[int]int {
adj := make(map[int]map[int]int, len(g.Nodes))
include := func(id int) bool {
if nodesSubset == nil {
return true
}
_, ok := nodesSubset[id]
return ok
}

// Ensure keys for all included nodes.
for id := range g.Nodes {
if include(id) {
adj[id] = make(map[int]int)
}
}

// Add edges where both endpoints are included.
for src, dsts := range g.SourcesToDestinations {
if !include(src) {
continue
}
row := adj[src]
for dst := range dsts {
if include(dst) {
row[dst] = 1
}
}
}
return adj
}
Comment on lines +313 to +346
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Build adjacency in O(E_subset) when a subset is provided.

Current loop scans all edges even when nodesSubset is small. Iterate only sources in the subset.

-	// Add edges where both endpoints are included.
-	for src, dsts := range g.SourcesToDestinations {
-		if !include(src) {
-			continue
-		}
-		row := adj[src]
-		for dst := range dsts {
-			if include(dst) {
-				row[dst] = 1
-			}
-		}
-	}
+	// Add edges where both endpoints are included.
+	if nodesSubset == nil {
+		for src, dsts := range g.SourcesToDestinations {
+			row := adj[src]
+			for dst := range dsts {
+				row[dst] = 1
+			}
+		}
+	} else {
+		for src := range nodesSubset {
+			dsts, ok := g.SourcesToDestinations[src]
+			if !ok {
+				continue
+			}
+			row := adj[src]
+			for dst := range dsts {
+				if _, ok := nodesSubset[dst]; ok {
+					row[dst] = 1
+				}
+			}
+		}
+	}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// toAdjacency builds an adjacency map for SCC. If nodesSubset is non-nil, only
// include those nodes (and edges between them). Always include all nodes in the
// subset as keys, even if they have zero outgoing edges.
func (g *EntitlementGraph) toAdjacency(nodesSubset map[int]struct{}) map[int]map[int]int {
adj := make(map[int]map[int]int, len(g.Nodes))
include := func(id int) bool {
if nodesSubset == nil {
return true
}
_, ok := nodesSubset[id]
return ok
}
// Ensure keys for all included nodes.
for id := range g.Nodes {
if include(id) {
adj[id] = make(map[int]int)
}
}
// Add edges where both endpoints are included.
for src, dsts := range g.SourcesToDestinations {
if !include(src) {
continue
}
row := adj[src]
for dst := range dsts {
if include(dst) {
row[dst] = 1
}
}
}
return adj
}
func (g *EntitlementGraph) toAdjacency(nodesSubset map[int]struct{}) map[int]map[int]int {
adj := make(map[int]map[int]int, len(g.Nodes))
include := func(id int) bool {
if nodesSubset == nil {
return true
}
_, ok := nodesSubset[id]
return ok
}
// Ensure keys for all included nodes.
for id := range g.Nodes {
if include(id) {
adj[id] = make(map[int]int)
}
}
// Add edges where both endpoints are included.
if nodesSubset == nil {
for src, dsts := range g.SourcesToDestinations {
row := adj[src]
for dst := range dsts {
row[dst] = 1
}
}
} else {
for src := range nodesSubset {
dsts, ok := g.SourcesToDestinations[src]
if !ok {
continue
}
row := adj[src]
for dst := range dsts {
if _, ok := nodesSubset[dst]; ok {
row[dst] = 1
}
}
}
}
return adj
}


// reachableFrom computes the set of node IDs reachable from start over
// SourcesToDestinations using an iterative BFS.
func (g *EntitlementGraph) reachableFrom(start int) map[int]struct{} {
if _, ok := g.Nodes[start]; !ok {
return nil
}
visited := make(map[int]struct{}, 16)
queue := make([]int, 0, 16)
queue = append(queue, start)
visited[start] = struct{}{}
for len(queue) > 0 {
u := queue[0]
queue = queue[1:]
if nbrs, ok := g.SourcesToDestinations[u]; ok {
for v := range nbrs {
if _, seen := visited[v]; seen {
continue
}
visited[v] = struct{}{}
queue = append(queue, v)
}
}
}
return visited
}
16 changes: 8 additions & 8 deletions pkg/sync/expand/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func TestGetFirstCycle(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.message, func(t *testing.T) {
graph := parseExpression(t, ctx, testCase.expression)
cycle := graph.GetFirstCycle()
cycle := graph.GetFirstCycle(ctx)
if testCase.expectedCycleSize == 0 {
require.Nil(t, cycle)
} else {
Expand Down Expand Up @@ -189,7 +189,7 @@ func TestHandleCycle(t *testing.T) {

graph := parseExpression(t, ctx, testCase.expression)

cycle := graph.GetFirstCycle()
cycle := graph.GetFirstCycle(ctx)
expectedCycles := createNodeIDList(testCase.expectedCycles)
require.NotNil(t, cycle)
found := false
Expand All @@ -201,11 +201,11 @@ func TestHandleCycle(t *testing.T) {
}
require.True(t, found)

err := graph.FixCycles()
err := graph.FixCycles(ctx)
require.NoError(t, err, graph.Str())
err = graph.Validate()
require.NoError(t, err)
cycle = graph.GetFirstCycle()
cycle = graph.GetFirstCycle(ctx)
require.Nil(t, cycle)
})
}
Expand All @@ -221,7 +221,7 @@ func TestHandleComplexCycle(t *testing.T) {
require.Equal(t, 4, len(graph.Edges))
require.Equal(t, 3, len(graph.GetEntitlements()))

err := graph.FixCycles()
err := graph.FixCycles(ctx)
require.NoError(t, err, graph.Str())
err = graph.Validate()
require.NoError(t, err)
Expand All @@ -230,7 +230,7 @@ func TestHandleComplexCycle(t *testing.T) {
require.Equal(t, 0, len(graph.Edges))
require.Equal(t, 3, len(graph.GetEntitlements()))

cycle := graph.GetFirstCycle()
cycle := graph.GetFirstCycle(ctx)
require.Nil(t, cycle)
}

Expand All @@ -248,7 +248,7 @@ func TestHandleCliqueCycle(t *testing.T) {
require.Equal(t, 6, len(graph.Edges))
require.Equal(t, 3, len(graph.GetEntitlements()))

err := graph.FixCycles()
err := graph.FixCycles(ctx)
require.NoError(t, err, graph.Str())
err = graph.Validate()
require.NoError(t, err)
Expand All @@ -257,7 +257,7 @@ func TestHandleCliqueCycle(t *testing.T) {
require.Equal(t, 0, len(graph.Edges))
require.Equal(t, 3, len(graph.GetEntitlements()))

cycle := graph.GetFirstCycle()
cycle := graph.GetFirstCycle(ctx)
require.Nil(t, cycle)
}
}
Expand Down
Loading
Loading