@@ -120,34 +120,35 @@ type Compiler struct {
120120 // Capabliities required by the modules that were compiled.
121121 Required * Capabilities
122122
123- localvargen * localVarGenerator
124- moduleLoader ModuleLoader
125- ruleIndices * util.HashMap
126- stages []stage
127- maxErrs int
128- sorted []string // list of sorted module names
129- pathExists func ([]string ) (bool , error )
130- after map [string ][]CompilerStageDefinition
131- metrics metrics.Metrics
132- capabilities * Capabilities // user-supplied capabilities
133- imports map [string ][]* Import // saved imports from stripping
134- builtins map [string ]* Builtin // universe of built-in functions
135- customBuiltins map [string ]* Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
136- unsafeBuiltinsMap map [string ]struct {} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
137- deprecatedBuiltinsMap map [string ]struct {} // set of deprecated, but not removed, built-in functions
138- enablePrintStatements bool // indicates if print statements should be elided (default)
139- comprehensionIndices map [* Term ]* ComprehensionIndex // comprehension key index
140- initialized bool // indicates if init() has been called
141- debug debug.Debug // emits debug information produced during compilation
142- schemaSet * SchemaSet // user-supplied schemas for input and data documents
143- inputType types.Type // global input type retrieved from schema set
144- annotationSet * AnnotationSet // hierarchical set of annotations
145- strict bool // enforce strict compilation checks
146- keepModules bool // whether to keep the unprocessed, parse modules (below)
147- parsedModules map [string ]* Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
148- useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
149- allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
150- evalMode CompilerEvalMode
123+ localvargen * localVarGenerator
124+ moduleLoader ModuleLoader
125+ ruleIndices * util.HashMap
126+ stages []stage
127+ maxErrs int
128+ sorted []string // list of sorted module names
129+ pathExists func ([]string ) (bool , error )
130+ after map [string ][]CompilerStageDefinition
131+ metrics metrics.Metrics
132+ capabilities * Capabilities // user-supplied capabilities
133+ imports map [string ][]* Import // saved imports from stripping
134+ builtins map [string ]* Builtin // universe of built-in functions
135+ customBuiltins map [string ]* Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
136+ unsafeBuiltinsMap map [string ]struct {} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
137+ deprecatedBuiltinsMap map [string ]struct {} // set of deprecated, but not removed, built-in functions
138+ enablePrintStatements bool // indicates if print statements should be elided (default)
139+ comprehensionIndices map [* Term ]* ComprehensionIndex // comprehension key index
140+ initialized bool // indicates if init() has been called
141+ debug debug.Debug // emits debug information produced during compilation
142+ schemaSet * SchemaSet // user-supplied schemas for input and data documents
143+ inputType types.Type // global input type retrieved from schema set
144+ annotationSet * AnnotationSet // hierarchical set of annotations
145+ strict bool // enforce strict compilation checks
146+ keepModules bool // whether to keep the unprocessed, parse modules (below)
147+ parsedModules map [string ]* Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
148+ useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
149+ allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
150+ evalMode CompilerEvalMode //
151+ rewriteTestRulesForTracing bool // rewrite test rules to capture dynamic values for tracing.
151152}
152153
153154// CompilerStage defines the interface for stages in the compiler.
@@ -346,6 +347,7 @@ func NewCompiler() *Compiler {
346347 {"CheckSafetyRuleBodies" , "compile_stage_check_safety_rule_bodies" , c .checkSafetyRuleBodies },
347348 {"RewriteEquals" , "compile_stage_rewrite_equals" , c .rewriteEquals },
348349 {"RewriteDynamicTerms" , "compile_stage_rewrite_dynamic_terms" , c .rewriteDynamicTerms },
350+ {"RewriteTestRulesForTracing" , "compile_stage_rewrite_test_rules_for_tracing" , c .rewriteTestRuleEqualities }, // must run after RewriteDynamicTerms
349351 {"CheckRecursion" , "compile_stage_check_recursion" , c .checkRecursion },
350352 {"CheckTypes" , "compile_stage_check_types" , c .checkTypes }, // must be run after CheckRecursion
351353 {"CheckUnsafeBuiltins" , "compile_state_check_unsafe_builtins" , c .checkUnsafeBuiltins },
@@ -469,6 +471,13 @@ func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler {
469471 return c
470472}
471473
474+ // WithRewriteTestRules enables rewriting test rules to capture dynamic values in local variables,
475+ // so they can be accessed by tracing.
476+ func (c * Compiler ) WithRewriteTestRules (rewrite bool ) * Compiler {
477+ c .rewriteTestRulesForTracing = rewrite
478+ return c
479+ }
480+
472481// ParsedModules returns the parsed, unprocessed modules from the compiler.
473482// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
474483// The map includes all modules loaded via the ModuleLoader, if one was used.
@@ -2167,6 +2176,43 @@ func (c *Compiler) rewriteDynamicTerms() {
21672176 }
21682177}
21692178
2179+ // rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
2180+ // not have their values captured through tracing, such as refs and comprehensions not unified/assigned to a local var.
2181+ // For example, given the following module:
2182+ //
2183+ // package test
2184+ //
2185+ // p.q contains v if {
2186+ // some v in numbers.range(1, 3)
2187+ // }
2188+ //
2189+ // p.r := "foo"
2190+ //
2191+ // test_rule {
2192+ // p == {
2193+ // "q": {4, 5, 6}
2194+ // }
2195+ // }
2196+ //
2197+ // `p` in `test_rule` resolves to `data.test.p`, which won't be an entry in the virtual-cache and must therefore be calculated after-the-fact.
2198+ // If `p` isn't captured in a local var, there is no trivial way to retrieve its value for test reporting.
2199+ func (c * Compiler ) rewriteTestRuleEqualities () {
2200+ if ! c .rewriteTestRulesForTracing {
2201+ return
2202+ }
2203+
2204+ f := newEqualityFactory (c .localvargen )
2205+ for _ , name := range c .sorted {
2206+ mod := c .Modules [name ]
2207+ WalkRules (mod , func (rule * Rule ) bool {
2208+ if strings .HasPrefix (string (rule .Head .Name ), "test_" ) {
2209+ rule .Body = rewriteTestEqualities (f , rule .Body )
2210+ }
2211+ return false
2212+ })
2213+ }
2214+ }
2215+
21702216func (c * Compiler ) parseMetadataBlocks () {
21712217 // Only parse annotations if rego.metadata built-ins are called
21722218 regoMetadataCalled := false
@@ -4517,6 +4563,41 @@ func rewriteEquals(x interface{}) (modified bool) {
45174563 return modified
45184564}
45194565
4566+ func rewriteTestEqualities (f * equalityFactory , body Body ) Body {
4567+ result := make (Body , 0 , len (body ))
4568+ for _ , expr := range body {
4569+ // We can't rewrite negated expressions; if the extracted term is undefined, evaluation would fail before
4570+ // reaching the negation check.
4571+ if ! expr .Negated && ! expr .Generated {
4572+ switch {
4573+ case expr .IsEquality ():
4574+ terms := expr .Terms .([]* Term )
4575+ result , terms [1 ] = rewriteDynamicsShallow (expr , f , terms [1 ], result )
4576+ result , terms [2 ] = rewriteDynamicsShallow (expr , f , terms [2 ], result )
4577+ case expr .IsEvery ():
4578+ // We rewrite equalities inside of every-bodies as a fail here will be the cause of the test-rule fail.
4579+ // Failures inside other expressions with closures, such as comprehensions, won't cause the test-rule to fail, so we skip those.
4580+ every := expr .Terms .(* Every )
4581+ every .Body = rewriteTestEqualities (f , every .Body )
4582+ }
4583+ }
4584+ result = appendExpr (result , expr )
4585+ }
4586+ return result
4587+ }
4588+
4589+ func rewriteDynamicsShallow (original * Expr , f * equalityFactory , term * Term , result Body ) (Body , * Term ) {
4590+ switch term .Value .(type ) {
4591+ case Ref , * ArrayComprehension , * SetComprehension , * ObjectComprehension :
4592+ generated := f .Generate (term )
4593+ generated .With = original .With
4594+ result .Append (generated )
4595+ connectGeneratedExprs (original , generated )
4596+ return result , result [len (result )- 1 ].Operand (0 )
4597+ }
4598+ return result , term
4599+ }
4600+
45204601// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
45214602// comprehensions) are bound to vars earlier in the query. This translation
45224603// results in eager evaluation.
@@ -4608,6 +4689,7 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
46084689 generated := f .Generate (term )
46094690 generated .With = original .With
46104691 result .Append (generated )
4692+ connectGeneratedExprs (original , generated )
46114693 return result , result [len (result )- 1 ].Operand (0 )
46124694 case * Array :
46134695 for i := 0 ; i < v .Len (); i ++ {
@@ -4636,16 +4718,19 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
46364718 var extra * Expr
46374719 v .Body , extra = rewriteDynamicsComprehensionBody (original , f , v .Body , term )
46384720 result .Append (extra )
4721+ connectGeneratedExprs (original , extra )
46394722 return result , result [len (result )- 1 ].Operand (0 )
46404723 case * SetComprehension :
46414724 var extra * Expr
46424725 v .Body , extra = rewriteDynamicsComprehensionBody (original , f , v .Body , term )
46434726 result .Append (extra )
4727+ connectGeneratedExprs (original , extra )
46444728 return result , result [len (result )- 1 ].Operand (0 )
46454729 case * ObjectComprehension :
46464730 var extra * Expr
46474731 v .Body , extra = rewriteDynamicsComprehensionBody (original , f , v .Body , term )
46484732 result .Append (extra )
4733+ connectGeneratedExprs (original , extra )
46494734 return result , result [len (result )- 1 ].Operand (0 )
46504735 }
46514736 return result , term
@@ -4713,6 +4798,7 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
47134798 for i := 1 ; i < len (terms ); i ++ {
47144799 var extras []* Expr
47154800 extras , terms [i ] = expandExprTerm (gen , terms [i ])
4801+ connectGeneratedExprs (expr , extras ... )
47164802 if len (expr .With ) > 0 {
47174803 for i := range extras {
47184804 extras [i ].With = expr .With
@@ -4740,6 +4826,13 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
47404826 return
47414827}
47424828
4829+ func connectGeneratedExprs (parent * Expr , children ... * Expr ) {
4830+ for _ , child := range children {
4831+ child .generatedFrom = parent
4832+ parent .generates = append (parent .generates , child )
4833+ }
4834+ }
4835+
47434836func expandExprTerm (gen * localVarGenerator , term * Term ) (support []* Expr , output * Term ) {
47444837 output = term
47454838 switch v := term .Value .(type ) {
0 commit comments