@@ -13,6 +13,8 @@ import (
1313 "sort"
1414 "strconv"
1515 "strings"
16+ "sync"
17+ "sync/atomic"
1618
1719 "github.com/open-policy-agent/opa/internal/debug"
1820 "github.com/open-policy-agent/opa/internal/gojsonschema"
@@ -1031,10 +1033,8 @@ func (c *Compiler) buildRequiredCapabilities() {
10311033 c .Required .FutureKeywords = util .KeysSorted (keywords )
10321034
10331035 // extract required features from modules
1034-
10351036 for _ , name := range c .sorted {
10361037 mod := c .Modules [name ]
1037-
10381038 if c .moduleIsRegoV1 (mod ) {
10391039 features [FeatureRegoV1 ] = struct {}{}
10401040 } else {
@@ -1225,12 +1225,11 @@ func (c *Compiler) checkRuleConflicts() {
12251225}
12261226
12271227func (c * Compiler ) checkUndefinedFuncs () {
1228- for _ , name := range c .sorted {
1229- m := c .Modules [name ]
1230- for _ , err := range checkUndefinedFuncs (c .TypeEnv , m , c .GetArity , c .RewrittenVars ) {
1228+ c .forEachModule (func (mod * Module ) {
1229+ for _ , err := range checkUndefinedFuncs (c .TypeEnv , mod , c .GetArity , c .RewrittenVars ) {
12311230 c .err (err )
12321231 }
1233- }
1232+ })
12341233}
12351234
12361235func checkUndefinedFuncs (env * TypeEnv , x any , arity func (Ref ) int , rwVars map [Var ]Var ) Errors {
@@ -1285,17 +1284,16 @@ func arityMismatchError(env *TypeEnv, f Ref, expr *Expr, exp, act int) *Error {
12851284// positions of built-in expressions will be bound when evaluating the rule from left
12861285// to right, re-ordering as necessary.
12871286func (c * Compiler ) checkSafetyRuleBodies () {
1288- for _ , name := range c .sorted {
1289- m := c .Modules [name ]
1290- WalkRules (m , func (r * Rule ) bool {
1287+ c .forEachModule (func (mod * Module ) {
1288+ WalkRules (mod , func (r * Rule ) bool {
12911289 safe := ReservedVars .Copy ()
12921290 if len (r .Head .Args ) > 0 {
12931291 safe .Update (r .Head .Args .Vars ())
12941292 }
12951293 r .Body = c .checkBodySafety (safe , r .Body )
12961294 return false
12971295 })
1298- }
1296+ })
12991297}
13001298
13011299func (c * Compiler ) checkBodySafety (safe VarSet , b Body ) Body {
@@ -1320,8 +1318,8 @@ var SafetyCheckVisitorParams = VarVisitorParams{
13201318// checkSafetyRuleHeads ensures that variables appearing in the head of a
13211319// rule also appear in the body.
13221320func (c * Compiler ) checkSafetyRuleHeads () {
1323- for _ , name := range c . sorted {
1324- WalkRules (c . Modules [ name ] , func (r * Rule ) bool {
1321+ c . forEachModule ( func ( mod * Module ) {
1322+ WalkRules (mod , func (r * Rule ) bool {
13251323 safe := r .Body .Vars (SafetyCheckVisitorParams )
13261324 if len (r .Head .Args ) > 0 {
13271325 safe .Update (r .Head .Args .Vars ())
@@ -1342,7 +1340,7 @@ func (c *Compiler) checkSafetyRuleHeads() {
13421340 }
13431341 return false
13441342 })
1345- }
1343+ })
13461344}
13471345
13481346func compileSchema (goSchema any , allowNet []string ) (* gojsonschema.Schema , error ) {
@@ -1609,13 +1607,12 @@ func (c *Compiler) checkUnsafeBuiltins() {
16091607 if len (c .unsafeBuiltinsMap ) == 0 {
16101608 return
16111609 }
1612-
1613- for _ , name := range c .sorted {
1614- errs := checkUnsafeBuiltins (c .unsafeBuiltinsMap , c .Modules [name ])
1610+ c .forEachModule (func (mod * Module ) {
1611+ errs := checkUnsafeBuiltins (c .unsafeBuiltinsMap , mod )
16151612 for _ , err := range errs {
16161613 c .err (err )
16171614 }
1618- }
1615+ })
16191616}
16201617
16211618func (c * Compiler ) checkDeprecatedBuiltins () {
@@ -1630,15 +1627,14 @@ func (c *Compiler) checkDeprecatedBuiltins() {
16301627 return
16311628 }
16321629
1633- for _ , name := range c .sorted {
1634- mod := c .Modules [name ]
1630+ c .forEachModule (func (mod * Module ) {
16351631 if c .strict || mod .regoV1Compatible () {
16361632 errs := checkDeprecatedBuiltins (c .deprecatedBuiltinsMap , mod )
16371633 for _ , err := range errs {
16381634 c .err (err )
16391635 }
16401636 }
1641- }
1637+ })
16421638}
16431639
16441640func (c * Compiler ) runStage (metricName string , f func ()) {
@@ -1756,6 +1752,15 @@ func (c *Compiler) init() {
17561752 c .initialized = true
17571753}
17581754
1755+ func (c * Compiler ) forEachModule (f func (mod * Module )) {
1756+ wg := & sync.WaitGroup {}
1757+ for _ , name := range c .sorted {
1758+ wg .Go (func () { f (c .Modules [name ]) })
1759+ }
1760+ wg .Wait ()
1761+ }
1762+
1763+ // TODO(sr): Fix this, it's not concurrency-safe. And the panic is bad form, too.
17591764func (c * Compiler ) err (err * Error ) {
17601765 if c .maxErrs > 0 && len (c .Errors ) >= c .maxErrs {
17611766 c .Errors = append (c .Errors , errLimitReached )
@@ -1957,7 +1962,7 @@ func (c *Compiler) resolveAllRefs() {
19571962
19581963func (c * Compiler ) removeImports () {
19591964 c .imports = make (map [string ][]* Import , len (c .Modules ))
1960- for name := range c .Modules {
1965+ for name := range c .Modules { // Trivial. No fan-out for this.
19611966 c .imports [name ] = c .Modules [name ].Imports
19621967 c .Modules [name ].Imports = nil
19631968 }
@@ -1969,27 +1974,25 @@ func (c *Compiler) initLocalVarGen() {
19691974
19701975func (c * Compiler ) rewriteComprehensionTerms () {
19711976 f := newEqualityFactory (c .localvargen )
1972- for _ , name := range c .sorted {
1973- mod := c .Modules [name ]
1977+ c .forEachModule (func (mod * Module ) {
19741978 _ , _ = rewriteComprehensionTerms (f , mod ) // ignore error
1975- }
1979+ })
19761980}
19771981
19781982func (c * Compiler ) rewriteExprTerms () {
1979- for _ , name := range c .sorted {
1980- mod := c .Modules [name ]
1983+ c .forEachModule (func (mod * Module ) {
19811984 WalkRules (mod , func (rule * Rule ) bool {
19821985 rewriteExprTermsInHead (c .localvargen , rule )
19831986 rule .Body = rewriteExprTermsInBody (c .localvargen , rule .Body )
19841987 return false
19851988 })
1986- }
1989+ })
19871990}
19881991
19891992func (c * Compiler ) rewriteRuleHeadRefs () {
19901993 f := newEqualityFactory (c .localvargen )
1991- for _ , name := range c . sorted {
1992- WalkRules (c . Modules [ name ] , func (rule * Rule ) bool {
1994+ c . forEachModule ( func ( mod * Module ) {
1995+ WalkRules (mod , func (rule * Rule ) bool {
19931996
19941997 ref := rule .Head .Ref ()
19951998 // NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but
@@ -2042,16 +2045,15 @@ func (c *Compiler) rewriteRuleHeadRefs() {
20422045
20432046 return true
20442047 })
2045- }
2048+ })
20462049}
20472050
20482051func (c * Compiler ) checkVoidCalls () {
2049- for _ , name := range c .sorted {
2050- mod := c .Modules [name ]
2052+ c .forEachModule (func (mod * Module ) {
20512053 for _ , err := range checkVoidCalls (c .TypeEnv , mod ) {
20522054 c .err (err )
20532055 }
2054- }
2056+ })
20552057}
20562058
20572059func (c * Compiler ) rewritePrintCalls () {
@@ -2063,8 +2065,7 @@ func (c *Compiler) rewritePrintCalls() {
20632065 }
20642066 }
20652067 } else {
2066- for _ , name := range c .sorted {
2067- mod := c .Modules [name ]
2068+ c .forEachModule (func (mod * Module ) {
20682069 WalkRules (mod , func (r * Rule ) bool {
20692070 safe := r .Head .Args .Vars ()
20702071 safe .Update (ReservedVars )
@@ -2082,7 +2083,7 @@ func (c *Compiler) rewritePrintCalls() {
20822083 WalkBodies (r .Body , vis )
20832084 return false
20842085 })
2085- }
2086+ })
20862087 }
20872088 if modified {
20882089 c .Required .addBuiltinSorted (Print )
@@ -2295,8 +2296,7 @@ func isPrintCall(x *Expr) bool {
22952296// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
22962297func (c * Compiler ) rewriteRefsInHead () {
22972298 f := newEqualityFactory (c .localvargen )
2298- for _ , name := range c .sorted {
2299- mod := c .Modules [name ]
2299+ c .forEachModule (func (mod * Module ) {
23002300 WalkRules (mod , func (rule * Rule ) bool {
23012301 if requiresEval (rule .Head .Key ) {
23022302 expr := f .Generate (rule .Head .Key )
@@ -2317,27 +2317,27 @@ func (c *Compiler) rewriteRefsInHead() {
23172317 }
23182318 return false
23192319 })
2320- }
2320+ })
23212321}
23222322
23232323func (c * Compiler ) rewriteEquals () {
23242324 modified := false
2325- for _ , name := range c . sorted {
2326- modified = rewriteEquals (c . Modules [ name ] ) || modified
2327- }
2325+ c . forEachModule ( func ( mod * Module ) {
2326+ modified = rewriteEquals (mod ) || modified
2327+ })
23282328 if modified {
23292329 c .Required .addBuiltinSorted (Equal )
23302330 }
23312331}
23322332
23332333func (c * Compiler ) rewriteDynamicTerms () {
23342334 f := newEqualityFactory (c .localvargen )
2335- for _ , name := range c . sorted {
2336- WalkRules (c . Modules [ name ] , func (rule * Rule ) bool {
2335+ c . forEachModule ( func ( mod * Module ) {
2336+ WalkRules (mod , func (rule * Rule ) bool {
23372337 rule .Body = rewriteDynamics (f , rule .Body )
23382338 return false
23392339 })
2340- }
2340+ })
23412341}
23422342
23432343// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
@@ -2366,39 +2366,34 @@ func (c *Compiler) rewriteTestRuleEqualities() {
23662366 }
23672367
23682368 f := newEqualityFactory (c .localvargen )
2369- for _ , name := range c .sorted {
2370- mod := c .Modules [name ]
2369+ c .forEachModule (func (mod * Module ) {
23712370 WalkRules (mod , func (rule * Rule ) bool {
23722371 if strings .HasPrefix (string (rule .Head .Name ), "test_" ) {
23732372 rule .Body = rewriteTestEqualities (f , rule .Body )
23742373 }
23752374 return false
23762375 })
2377- }
2376+ })
23782377}
23792378
23802379func (c * Compiler ) parseMetadataBlocks () {
23812380 // Only parse annotations if rego.metadata built-ins are called
23822381 regoMetadataCalled := false
2383- for _ , name := range c .sorted {
2384- mod := c .Modules [name ]
2382+ c .forEachModule (func (mod * Module ) {
2383+ if regoMetadataCalled {
2384+ return
2385+ }
23852386 WalkExprs (mod , func (expr * Expr ) bool {
23862387 if isRegoMetadataChainCall (expr ) || isRegoMetadataRuleCall (expr ) {
23872388 regoMetadataCalled = true
23882389 }
23892390 return regoMetadataCalled
23902391 })
2391-
2392- if regoMetadataCalled {
2393- break
2394- }
2395- }
2392+ })
23962393
23972394 if regoMetadataCalled {
23982395 // NOTE: Possible optimization: only parse annotations for modules on the path of rego.metadata-calling module
2399- for _ , name := range c .sorted {
2400- mod := c .Modules [name ]
2401-
2396+ c .forEachModule (func (mod * Module ) {
24022397 if len (mod .Annotations ) == 0 {
24032398 var errs Errors
24042399 mod .Annotations , errs = parseAnnotations (mod .Comments )
@@ -2409,7 +2404,7 @@ func (c *Compiler) parseMetadataBlocks() {
24092404
24102405 attachRuleAnnotations (mod )
24112406 }
2412- }
2407+ })
24132408 }
24142409}
24152410
@@ -2419,9 +2414,7 @@ func (c *Compiler) rewriteRegoMetadataCalls() {
24192414 _ , chainFuncAllowed := c .builtins [RegoMetadataChain .Name ]
24202415 _ , ruleFuncAllowed := c .builtins [RegoMetadataRule .Name ]
24212416
2422- for _ , name := range c .sorted {
2423- mod := c .Modules [name ]
2424-
2417+ c .forEachModule (func (mod * Module ) {
24252418 WalkRules (mod , func (rule * Rule ) bool {
24262419 var firstChainCall * Expr
24272420 var firstRuleCall * Expr
@@ -2499,7 +2492,7 @@ func (c *Compiler) rewriteRegoMetadataCalls() {
24992492
25002493 return false
25012494 })
2502- }
2495+ })
25032496}
25042497
25052498func getPrimaryRuleAnnotations (as * AnnotationSet , rule * Rule ) * Annotations {
@@ -2893,8 +2886,7 @@ func (vis *ruleArgLocalRewriter) Visit(x any) Visitor {
28932886
28942887func (c * Compiler ) rewriteWithModifiers () {
28952888 f := newEqualityFactory (c .localvargen )
2896- for _ , name := range c .sorted {
2897- mod := c .Modules [name ]
2889+ c .forEachModule (func (mod * Module ) {
28982890 t := NewGenericTransformer (func (x any ) (any , error ) {
28992891 body , ok := x .(Body )
29002892 if ! ok {
@@ -2908,7 +2900,7 @@ func (c *Compiler) rewriteWithModifiers() {
29082900 return body , nil
29092901 })
29102902 _ , _ = Transform (t , mod ) // ignore error
2911- }
2903+ })
29122904}
29132905
29142906func (c * Compiler ) setModuleTree () {
@@ -4333,27 +4325,27 @@ const LocalVarPrefix = "__local"
43334325type localVarGenerator struct {
43344326 exclude VarSet
43354327 suffix string
4336- next int
4328+ next * atomic. Int32
43374329}
43384330
43394331func newLocalVarGeneratorForModuleSet (sorted []string , modules map [string ]* Module ) * localVarGenerator {
43404332 vis := NewVarVisitor ()
43414333 for _ , key := range sorted {
43424334 vis .Walk (modules [key ])
43434335 }
4344- return & localVarGenerator {exclude : vis .vars , next : 0 }
4336+ return & localVarGenerator {exclude : vis .vars , next : & atomic. Int32 {} }
43454337}
43464338
43474339func newLocalVarGenerator (suffix string , node any ) * localVarGenerator {
43484340 vis := NewVarVisitor ()
43494341 vis .Walk (node )
4350- return & localVarGenerator {exclude : vis .vars , suffix : suffix , next : 0 }
4342+ return & localVarGenerator {exclude : vis .vars , suffix : suffix , next : & atomic. Int32 {} }
43514343}
43524344
43534345func (l * localVarGenerator ) Generate () Var {
43544346 for {
4355- result := Var ( LocalVarPrefix + l . suffix + strconv . Itoa ( l . next ) + "__" )
4356- l . next ++
4347+ next := l . next . Add ( 1 ) - 1 // we want the old number
4348+ result := Var ( LocalVarPrefix + l . suffix + strconv . Itoa ( int ( next )) + "__" )
43574349 if ! l .exclude .Contains (result ) {
43584350 return result
43594351 }
0 commit comments