Skip to content

Commit 6940abc

Browse files
committed
experiment: per-module fan-out in compiler stages
Signed-off-by: Stephan Renatus <[email protected]>
1 parent 3188e04 commit 6940abc

File tree

2 files changed

+64
-72
lines changed

2 files changed

+64
-72
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/open-policy-agent/opa
22

3-
go 1.24.6
3+
go 1.25
44

55
require (
66
github.com/agnivade/levenshtein v1.2.1

v1/ast/compile.go

Lines changed: 63 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"sort"
1414
"strconv"
1515
"strings"
16+
"sync"
17+
"sync/atomic"
1618

1719
"github.com/open-policy-agent/opa/internal/debug"
1820
"github.com/open-policy-agent/opa/internal/gojsonschema"
@@ -1031,10 +1033,8 @@ func (c *Compiler) buildRequiredCapabilities() {
10311033
c.Required.FutureKeywords = util.KeysSorted(keywords)
10321034

10331035
// extract required features from modules
1034-
10351036
for _, name := range c.sorted {
10361037
mod := c.Modules[name]
1037-
10381038
if c.moduleIsRegoV1(mod) {
10391039
features[FeatureRegoV1] = struct{}{}
10401040
} else {
@@ -1225,12 +1225,11 @@ func (c *Compiler) checkRuleConflicts() {
12251225
}
12261226

12271227
func (c *Compiler) checkUndefinedFuncs() {
1228-
for _, name := range c.sorted {
1229-
m := c.Modules[name]
1230-
for _, err := range checkUndefinedFuncs(c.TypeEnv, m, c.GetArity, c.RewrittenVars) {
1228+
c.forEachModule(func(mod *Module) {
1229+
for _, err := range checkUndefinedFuncs(c.TypeEnv, mod, c.GetArity, c.RewrittenVars) {
12311230
c.err(err)
12321231
}
1233-
}
1232+
})
12341233
}
12351234

12361235
func checkUndefinedFuncs(env *TypeEnv, x any, arity func(Ref) int, rwVars map[Var]Var) Errors {
@@ -1285,17 +1284,16 @@ func arityMismatchError(env *TypeEnv, f Ref, expr *Expr, exp, act int) *Error {
12851284
// positions of built-in expressions will be bound when evaluating the rule from left
12861285
// to right, re-ordering as necessary.
12871286
func (c *Compiler) checkSafetyRuleBodies() {
1288-
for _, name := range c.sorted {
1289-
m := c.Modules[name]
1290-
WalkRules(m, func(r *Rule) bool {
1287+
c.forEachModule(func(mod *Module) {
1288+
WalkRules(mod, func(r *Rule) bool {
12911289
safe := ReservedVars.Copy()
12921290
if len(r.Head.Args) > 0 {
12931291
safe.Update(r.Head.Args.Vars())
12941292
}
12951293
r.Body = c.checkBodySafety(safe, r.Body)
12961294
return false
12971295
})
1298-
}
1296+
})
12991297
}
13001298

13011299
func (c *Compiler) checkBodySafety(safe VarSet, b Body) Body {
@@ -1320,8 +1318,8 @@ var SafetyCheckVisitorParams = VarVisitorParams{
13201318
// checkSafetyRuleHeads ensures that variables appearing in the head of a
13211319
// rule also appear in the body.
13221320
func (c *Compiler) checkSafetyRuleHeads() {
1323-
for _, name := range c.sorted {
1324-
WalkRules(c.Modules[name], func(r *Rule) bool {
1321+
c.forEachModule(func(mod *Module) {
1322+
WalkRules(mod, func(r *Rule) bool {
13251323
safe := r.Body.Vars(SafetyCheckVisitorParams)
13261324
if len(r.Head.Args) > 0 {
13271325
safe.Update(r.Head.Args.Vars())
@@ -1342,7 +1340,7 @@ func (c *Compiler) checkSafetyRuleHeads() {
13421340
}
13431341
return false
13441342
})
1345-
}
1343+
})
13461344
}
13471345

13481346
func compileSchema(goSchema any, allowNet []string) (*gojsonschema.Schema, error) {
@@ -1609,13 +1607,12 @@ func (c *Compiler) checkUnsafeBuiltins() {
16091607
if len(c.unsafeBuiltinsMap) == 0 {
16101608
return
16111609
}
1612-
1613-
for _, name := range c.sorted {
1614-
errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name])
1610+
c.forEachModule(func(mod *Module) {
1611+
errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, mod)
16151612
for _, err := range errs {
16161613
c.err(err)
16171614
}
1618-
}
1615+
})
16191616
}
16201617

16211618
func (c *Compiler) checkDeprecatedBuiltins() {
@@ -1630,15 +1627,14 @@ func (c *Compiler) checkDeprecatedBuiltins() {
16301627
return
16311628
}
16321629

1633-
for _, name := range c.sorted {
1634-
mod := c.Modules[name]
1630+
c.forEachModule(func(mod *Module) {
16351631
if c.strict || mod.regoV1Compatible() {
16361632
errs := checkDeprecatedBuiltins(c.deprecatedBuiltinsMap, mod)
16371633
for _, err := range errs {
16381634
c.err(err)
16391635
}
16401636
}
1641-
}
1637+
})
16421638
}
16431639

16441640
func (c *Compiler) runStage(metricName string, f func()) {
@@ -1756,6 +1752,15 @@ func (c *Compiler) init() {
17561752
c.initialized = true
17571753
}
17581754

1755+
func (c *Compiler) forEachModule(f func(mod *Module)) {
1756+
wg := &sync.WaitGroup{}
1757+
for _, name := range c.sorted {
1758+
wg.Go(func() { f(c.Modules[name]) })
1759+
}
1760+
wg.Wait()
1761+
}
1762+
1763+
// TODO(sr): Fix this, it's not concurrency-safe. And the panic is bad form, too.
17591764
func (c *Compiler) err(err *Error) {
17601765
if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs {
17611766
c.Errors = append(c.Errors, errLimitReached)
@@ -1957,7 +1962,7 @@ func (c *Compiler) resolveAllRefs() {
19571962

19581963
func (c *Compiler) removeImports() {
19591964
c.imports = make(map[string][]*Import, len(c.Modules))
1960-
for name := range c.Modules {
1965+
for name := range c.Modules { // Trivial. No fan-out for this.
19611966
c.imports[name] = c.Modules[name].Imports
19621967
c.Modules[name].Imports = nil
19631968
}
@@ -1969,27 +1974,25 @@ func (c *Compiler) initLocalVarGen() {
19691974

19701975
func (c *Compiler) rewriteComprehensionTerms() {
19711976
f := newEqualityFactory(c.localvargen)
1972-
for _, name := range c.sorted {
1973-
mod := c.Modules[name]
1977+
c.forEachModule(func(mod *Module) {
19741978
_, _ = rewriteComprehensionTerms(f, mod) // ignore error
1975-
}
1979+
})
19761980
}
19771981

19781982
func (c *Compiler) rewriteExprTerms() {
1979-
for _, name := range c.sorted {
1980-
mod := c.Modules[name]
1983+
c.forEachModule(func(mod *Module) {
19811984
WalkRules(mod, func(rule *Rule) bool {
19821985
rewriteExprTermsInHead(c.localvargen, rule)
19831986
rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body)
19841987
return false
19851988
})
1986-
}
1989+
})
19871990
}
19881991

19891992
func (c *Compiler) rewriteRuleHeadRefs() {
19901993
f := newEqualityFactory(c.localvargen)
1991-
for _, name := range c.sorted {
1992-
WalkRules(c.Modules[name], func(rule *Rule) bool {
1994+
c.forEachModule(func(mod *Module) {
1995+
WalkRules(mod, func(rule *Rule) bool {
19931996

19941997
ref := rule.Head.Ref()
19951998
// NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but
@@ -2042,16 +2045,15 @@ func (c *Compiler) rewriteRuleHeadRefs() {
20422045

20432046
return true
20442047
})
2045-
}
2048+
})
20462049
}
20472050

20482051
func (c *Compiler) checkVoidCalls() {
2049-
for _, name := range c.sorted {
2050-
mod := c.Modules[name]
2052+
c.forEachModule(func(mod *Module) {
20512053
for _, err := range checkVoidCalls(c.TypeEnv, mod) {
20522054
c.err(err)
20532055
}
2054-
}
2056+
})
20552057
}
20562058

20572059
func (c *Compiler) rewritePrintCalls() {
@@ -2063,8 +2065,7 @@ func (c *Compiler) rewritePrintCalls() {
20632065
}
20642066
}
20652067
} else {
2066-
for _, name := range c.sorted {
2067-
mod := c.Modules[name]
2068+
c.forEachModule(func(mod *Module) {
20682069
WalkRules(mod, func(r *Rule) bool {
20692070
safe := r.Head.Args.Vars()
20702071
safe.Update(ReservedVars)
@@ -2082,7 +2083,7 @@ func (c *Compiler) rewritePrintCalls() {
20822083
WalkBodies(r.Body, vis)
20832084
return false
20842085
})
2085-
}
2086+
})
20862087
}
20872088
if modified {
20882089
c.Required.addBuiltinSorted(Print)
@@ -2295,8 +2296,7 @@ func isPrintCall(x *Expr) bool {
22952296
// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
22962297
func (c *Compiler) rewriteRefsInHead() {
22972298
f := newEqualityFactory(c.localvargen)
2298-
for _, name := range c.sorted {
2299-
mod := c.Modules[name]
2299+
c.forEachModule(func(mod *Module) {
23002300
WalkRules(mod, func(rule *Rule) bool {
23012301
if requiresEval(rule.Head.Key) {
23022302
expr := f.Generate(rule.Head.Key)
@@ -2317,27 +2317,27 @@ func (c *Compiler) rewriteRefsInHead() {
23172317
}
23182318
return false
23192319
})
2320-
}
2320+
})
23212321
}
23222322

23232323
func (c *Compiler) rewriteEquals() {
23242324
modified := false
2325-
for _, name := range c.sorted {
2326-
modified = rewriteEquals(c.Modules[name]) || modified
2327-
}
2325+
c.forEachModule(func(mod *Module) {
2326+
modified = rewriteEquals(mod) || modified
2327+
})
23282328
if modified {
23292329
c.Required.addBuiltinSorted(Equal)
23302330
}
23312331
}
23322332

23332333
func (c *Compiler) rewriteDynamicTerms() {
23342334
f := newEqualityFactory(c.localvargen)
2335-
for _, name := range c.sorted {
2336-
WalkRules(c.Modules[name], func(rule *Rule) bool {
2335+
c.forEachModule(func(mod *Module) {
2336+
WalkRules(mod, func(rule *Rule) bool {
23372337
rule.Body = rewriteDynamics(f, rule.Body)
23382338
return false
23392339
})
2340-
}
2340+
})
23412341
}
23422342

23432343
// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
@@ -2366,39 +2366,34 @@ func (c *Compiler) rewriteTestRuleEqualities() {
23662366
}
23672367

23682368
f := newEqualityFactory(c.localvargen)
2369-
for _, name := range c.sorted {
2370-
mod := c.Modules[name]
2369+
c.forEachModule(func(mod *Module) {
23712370
WalkRules(mod, func(rule *Rule) bool {
23722371
if strings.HasPrefix(string(rule.Head.Name), "test_") {
23732372
rule.Body = rewriteTestEqualities(f, rule.Body)
23742373
}
23752374
return false
23762375
})
2377-
}
2376+
})
23782377
}
23792378

23802379
func (c *Compiler) parseMetadataBlocks() {
23812380
// Only parse annotations if rego.metadata built-ins are called
23822381
regoMetadataCalled := false
2383-
for _, name := range c.sorted {
2384-
mod := c.Modules[name]
2382+
c.forEachModule(func(mod *Module) {
2383+
if regoMetadataCalled {
2384+
return
2385+
}
23852386
WalkExprs(mod, func(expr *Expr) bool {
23862387
if isRegoMetadataChainCall(expr) || isRegoMetadataRuleCall(expr) {
23872388
regoMetadataCalled = true
23882389
}
23892390
return regoMetadataCalled
23902391
})
2391-
2392-
if regoMetadataCalled {
2393-
break
2394-
}
2395-
}
2392+
})
23962393

23972394
if regoMetadataCalled {
23982395
// NOTE: Possible optimization: only parse annotations for modules on the path of rego.metadata-calling module
2399-
for _, name := range c.sorted {
2400-
mod := c.Modules[name]
2401-
2396+
c.forEachModule(func(mod *Module) {
24022397
if len(mod.Annotations) == 0 {
24032398
var errs Errors
24042399
mod.Annotations, errs = parseAnnotations(mod.Comments)
@@ -2409,7 +2404,7 @@ func (c *Compiler) parseMetadataBlocks() {
24092404

24102405
attachRuleAnnotations(mod)
24112406
}
2412-
}
2407+
})
24132408
}
24142409
}
24152410

@@ -2419,9 +2414,7 @@ func (c *Compiler) rewriteRegoMetadataCalls() {
24192414
_, chainFuncAllowed := c.builtins[RegoMetadataChain.Name]
24202415
_, ruleFuncAllowed := c.builtins[RegoMetadataRule.Name]
24212416

2422-
for _, name := range c.sorted {
2423-
mod := c.Modules[name]
2424-
2417+
c.forEachModule(func(mod *Module) {
24252418
WalkRules(mod, func(rule *Rule) bool {
24262419
var firstChainCall *Expr
24272420
var firstRuleCall *Expr
@@ -2499,7 +2492,7 @@ func (c *Compiler) rewriteRegoMetadataCalls() {
24992492

25002493
return false
25012494
})
2502-
}
2495+
})
25032496
}
25042497

25052498
func getPrimaryRuleAnnotations(as *AnnotationSet, rule *Rule) *Annotations {
@@ -2893,8 +2886,7 @@ func (vis *ruleArgLocalRewriter) Visit(x any) Visitor {
28932886

28942887
func (c *Compiler) rewriteWithModifiers() {
28952888
f := newEqualityFactory(c.localvargen)
2896-
for _, name := range c.sorted {
2897-
mod := c.Modules[name]
2889+
c.forEachModule(func(mod *Module) {
28982890
t := NewGenericTransformer(func(x any) (any, error) {
28992891
body, ok := x.(Body)
29002892
if !ok {
@@ -2908,7 +2900,7 @@ func (c *Compiler) rewriteWithModifiers() {
29082900
return body, nil
29092901
})
29102902
_, _ = Transform(t, mod) // ignore error
2911-
}
2903+
})
29122904
}
29132905

29142906
func (c *Compiler) setModuleTree() {
@@ -4333,27 +4325,27 @@ const LocalVarPrefix = "__local"
43334325
type localVarGenerator struct {
43344326
exclude VarSet
43354327
suffix string
4336-
next int
4328+
next *atomic.Int32
43374329
}
43384330

43394331
func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator {
43404332
vis := NewVarVisitor()
43414333
for _, key := range sorted {
43424334
vis.Walk(modules[key])
43434335
}
4344-
return &localVarGenerator{exclude: vis.vars, next: 0}
4336+
return &localVarGenerator{exclude: vis.vars, next: &atomic.Int32{}}
43454337
}
43464338

43474339
func newLocalVarGenerator(suffix string, node any) *localVarGenerator {
43484340
vis := NewVarVisitor()
43494341
vis.Walk(node)
4350-
return &localVarGenerator{exclude: vis.vars, suffix: suffix, next: 0}
4342+
return &localVarGenerator{exclude: vis.vars, suffix: suffix, next: &atomic.Int32{}}
43514343
}
43524344

43534345
func (l *localVarGenerator) Generate() Var {
43544346
for {
4355-
result := Var(LocalVarPrefix + l.suffix + strconv.Itoa(l.next) + "__")
4356-
l.next++
4347+
next := l.next.Add(1) - 1 // we want the old number
4348+
result := Var(LocalVarPrefix + l.suffix + strconv.Itoa(int(next)) + "__")
43574349
if !l.exclude.Contains(result) {
43584350
return result
43594351
}

0 commit comments

Comments
 (0)