Skip to content

Commit 96ecf38

Browse files
authored
trace+tester: Adding local var values to trace and test report (#6815)
Fixing: #2546 Signed-off-by: Johan Fylling <[email protected]>
1 parent 31120ce commit 96ecf38

File tree

21 files changed

+3061
-72
lines changed

21 files changed

+3061
-72
lines changed

.github/workflows/pull-request.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ jobs:
321321
fail-fast: false
322322
matrix:
323323
os: [ubuntu-22.04, macos-14]
324-
version: ["1.20"]
324+
version: ["1.21"]
325325
steps:
326326
- uses: actions/checkout@v4
327327
- name: Download generated artifacts

ast/compile.go

Lines changed: 121 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -120,34 +120,35 @@ type Compiler struct {
120120
// Capabliities required by the modules that were compiled.
121121
Required *Capabilities
122122

123-
localvargen *localVarGenerator
124-
moduleLoader ModuleLoader
125-
ruleIndices *util.HashMap
126-
stages []stage
127-
maxErrs int
128-
sorted []string // list of sorted module names
129-
pathExists func([]string) (bool, error)
130-
after map[string][]CompilerStageDefinition
131-
metrics metrics.Metrics
132-
capabilities *Capabilities // user-supplied capabilities
133-
imports map[string][]*Import // saved imports from stripping
134-
builtins map[string]*Builtin // universe of built-in functions
135-
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
136-
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
137-
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
138-
enablePrintStatements bool // indicates if print statements should be elided (default)
139-
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
140-
initialized bool // indicates if init() has been called
141-
debug debug.Debug // emits debug information produced during compilation
142-
schemaSet *SchemaSet // user-supplied schemas for input and data documents
143-
inputType types.Type // global input type retrieved from schema set
144-
annotationSet *AnnotationSet // hierarchical set of annotations
145-
strict bool // enforce strict compilation checks
146-
keepModules bool // whether to keep the unprocessed, parse modules (below)
147-
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
148-
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
149-
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
150-
evalMode CompilerEvalMode
123+
localvargen *localVarGenerator
124+
moduleLoader ModuleLoader
125+
ruleIndices *util.HashMap
126+
stages []stage
127+
maxErrs int
128+
sorted []string // list of sorted module names
129+
pathExists func([]string) (bool, error)
130+
after map[string][]CompilerStageDefinition
131+
metrics metrics.Metrics
132+
capabilities *Capabilities // user-supplied capabilities
133+
imports map[string][]*Import // saved imports from stripping
134+
builtins map[string]*Builtin // universe of built-in functions
135+
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
136+
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
137+
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
138+
enablePrintStatements bool // indicates if print statements should be elided (default)
139+
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
140+
initialized bool // indicates if init() has been called
141+
debug debug.Debug // emits debug information produced during compilation
142+
schemaSet *SchemaSet // user-supplied schemas for input and data documents
143+
inputType types.Type // global input type retrieved from schema set
144+
annotationSet *AnnotationSet // hierarchical set of annotations
145+
strict bool // enforce strict compilation checks
146+
keepModules bool // whether to keep the unprocessed, parse modules (below)
147+
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
148+
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
149+
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
150+
evalMode CompilerEvalMode //
151+
rewriteTestRulesForTracing bool // rewrite test rules to capture dynamic values for tracing.
151152
}
152153

153154
// CompilerStage defines the interface for stages in the compiler.
@@ -346,6 +347,7 @@ func NewCompiler() *Compiler {
346347
{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
347348
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
348349
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
350+
{"RewriteTestRulesForTracing", "compile_stage_rewrite_test_rules_for_tracing", c.rewriteTestRuleEqualities}, // must run after RewriteDynamicTerms
349351
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
350352
{"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion
351353
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
@@ -469,6 +471,13 @@ func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler {
469471
return c
470472
}
471473

474+
// WithRewriteTestRules enables rewriting test rules to capture dynamic values in local variables,
475+
// so they can be accessed by tracing.
476+
func (c *Compiler) WithRewriteTestRules(rewrite bool) *Compiler {
477+
c.rewriteTestRulesForTracing = rewrite
478+
return c
479+
}
480+
472481
// ParsedModules returns the parsed, unprocessed modules from the compiler.
473482
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
474483
// The map includes all modules loaded via the ModuleLoader, if one was used.
@@ -2167,6 +2176,43 @@ func (c *Compiler) rewriteDynamicTerms() {
21672176
}
21682177
}
21692178

2179+
// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
2180+
// not have their values captured through tracing, such as refs and comprehensions not unified/assigned to a local var.
2181+
// For example, given the following module:
2182+
//
2183+
// package test
2184+
//
2185+
// p.q contains v if {
2186+
// some v in numbers.range(1, 3)
2187+
// }
2188+
//
2189+
// p.r := "foo"
2190+
//
2191+
// test_rule {
2192+
// p == {
2193+
// "q": {4, 5, 6}
2194+
// }
2195+
// }
2196+
//
2197+
// `p` in `test_rule` resolves to `data.test.p`, which won't be an entry in the virtual-cache and must therefore be calculated after-the-fact.
2198+
// If `p` isn't captured in a local var, there is no trivial way to retrieve its value for test reporting.
2199+
func (c *Compiler) rewriteTestRuleEqualities() {
2200+
if !c.rewriteTestRulesForTracing {
2201+
return
2202+
}
2203+
2204+
f := newEqualityFactory(c.localvargen)
2205+
for _, name := range c.sorted {
2206+
mod := c.Modules[name]
2207+
WalkRules(mod, func(rule *Rule) bool {
2208+
if strings.HasPrefix(string(rule.Head.Name), "test_") {
2209+
rule.Body = rewriteTestEqualities(f, rule.Body)
2210+
}
2211+
return false
2212+
})
2213+
}
2214+
}
2215+
21702216
func (c *Compiler) parseMetadataBlocks() {
21712217
// Only parse annotations if rego.metadata built-ins are called
21722218
regoMetadataCalled := false
@@ -4517,6 +4563,41 @@ func rewriteEquals(x interface{}) (modified bool) {
45174563
return modified
45184564
}
45194565

4566+
func rewriteTestEqualities(f *equalityFactory, body Body) Body {
4567+
result := make(Body, 0, len(body))
4568+
for _, expr := range body {
4569+
// We can't rewrite negated expressions; if the extracted term is undefined, evaluation would fail before
4570+
// reaching the negation check.
4571+
if !expr.Negated && !expr.Generated {
4572+
switch {
4573+
case expr.IsEquality():
4574+
terms := expr.Terms.([]*Term)
4575+
result, terms[1] = rewriteDynamicsShallow(expr, f, terms[1], result)
4576+
result, terms[2] = rewriteDynamicsShallow(expr, f, terms[2], result)
4577+
case expr.IsEvery():
4578+
// We rewrite equalities inside of every-bodies as a fail here will be the cause of the test-rule fail.
4579+
// Failures inside other expressions with closures, such as comprehensions, won't cause the test-rule to fail, so we skip those.
4580+
every := expr.Terms.(*Every)
4581+
every.Body = rewriteTestEqualities(f, every.Body)
4582+
}
4583+
}
4584+
result = appendExpr(result, expr)
4585+
}
4586+
return result
4587+
}
4588+
4589+
func rewriteDynamicsShallow(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
4590+
switch term.Value.(type) {
4591+
case Ref, *ArrayComprehension, *SetComprehension, *ObjectComprehension:
4592+
generated := f.Generate(term)
4593+
generated.With = original.With
4594+
result.Append(generated)
4595+
connectGeneratedExprs(original, generated)
4596+
return result, result[len(result)-1].Operand(0)
4597+
}
4598+
return result, term
4599+
}
4600+
45204601
// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
45214602
// comprehensions) are bound to vars earlier in the query. This translation
45224603
// results in eager evaluation.
@@ -4608,6 +4689,7 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
46084689
generated := f.Generate(term)
46094690
generated.With = original.With
46104691
result.Append(generated)
4692+
connectGeneratedExprs(original, generated)
46114693
return result, result[len(result)-1].Operand(0)
46124694
case *Array:
46134695
for i := 0; i < v.Len(); i++ {
@@ -4636,16 +4718,19 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
46364718
var extra *Expr
46374719
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
46384720
result.Append(extra)
4721+
connectGeneratedExprs(original, extra)
46394722
return result, result[len(result)-1].Operand(0)
46404723
case *SetComprehension:
46414724
var extra *Expr
46424725
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
46434726
result.Append(extra)
4727+
connectGeneratedExprs(original, extra)
46444728
return result, result[len(result)-1].Operand(0)
46454729
case *ObjectComprehension:
46464730
var extra *Expr
46474731
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
46484732
result.Append(extra)
4733+
connectGeneratedExprs(original, extra)
46494734
return result, result[len(result)-1].Operand(0)
46504735
}
46514736
return result, term
@@ -4713,6 +4798,7 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
47134798
for i := 1; i < len(terms); i++ {
47144799
var extras []*Expr
47154800
extras, terms[i] = expandExprTerm(gen, terms[i])
4801+
connectGeneratedExprs(expr, extras...)
47164802
if len(expr.With) > 0 {
47174803
for i := range extras {
47184804
extras[i].With = expr.With
@@ -4740,6 +4826,13 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
47404826
return
47414827
}
47424828

4829+
func connectGeneratedExprs(parent *Expr, children ...*Expr) {
4830+
for _, child := range children {
4831+
child.generatedFrom = parent
4832+
parent.generates = append(parent.generates, child)
4833+
}
4834+
}
4835+
47434836
func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
47444837
output = term
47454838
switch v := term.Value.(type) {

ast/compile_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10679,3 +10679,165 @@ deny {
1067910679
t.Fatal(c.Errors)
1068010680
}
1068110681
}
10682+
10683+
func TestCompilerRewriteTestRulesForTracing(t *testing.T) {
10684+
tests := []struct {
10685+
note string
10686+
rewrite bool
10687+
module string
10688+
exp string
10689+
}{
10690+
{
10691+
note: "ref comparison, no rewrite",
10692+
module: `package test
10693+
import rego.v1
10694+
10695+
a := 1
10696+
b := 2
10697+
10698+
test_something if {
10699+
a == b
10700+
}`,
10701+
exp: `package test
10702+
10703+
a := 1 { true }
10704+
b := 2 { true }
10705+
10706+
test_something = true {
10707+
data.test.a = data.test.b
10708+
}`,
10709+
},
10710+
{
10711+
note: "ref comparison, rewrite",
10712+
rewrite: true,
10713+
module: `package test
10714+
import rego.v1
10715+
10716+
a := 1
10717+
b := 2
10718+
10719+
test_something if {
10720+
a == b
10721+
}`,
10722+
// When the test fails on '__local0__ = __local1__', the values for 'a' and 'b' are captured in local bindings,
10723+
// accessible by the tracer.
10724+
exp: `package test
10725+
10726+
a := 1 { true }
10727+
b := 2 { true }
10728+
10729+
test_something = true {
10730+
__local0__ = data.test.a
10731+
__local1__ = data.test.b
10732+
__local0__ = __local1__
10733+
}`,
10734+
},
10735+
{
10736+
note: "ref comparison, not-stmt, rewrite",
10737+
rewrite: true,
10738+
module: `package test
10739+
import rego.v1
10740+
10741+
a := 1
10742+
b := 2
10743+
10744+
test_something if {
10745+
not a == b
10746+
}`,
10747+
// We don't break out local vars from a not-stmt, as that would change the semantics of the rule.
10748+
exp: `package test
10749+
10750+
a := 1 { true }
10751+
b := 2 { true }
10752+
10753+
test_something = true {
10754+
not data.test.a = data.test.b
10755+
}`,
10756+
},
10757+
{
10758+
note: "ref comparison, inside every-stmt, no rewrite",
10759+
module: `package test
10760+
import rego.v1
10761+
10762+
a := 1
10763+
b := 2
10764+
l := [1, 2, 3]
10765+
10766+
test_something if {
10767+
every x in l {
10768+
a < b + x
10769+
}
10770+
}`,
10771+
exp: `package test
10772+
import future.keywords
10773+
10774+
a := 1 { true }
10775+
b := 2 { true }
10776+
l := [1, 2, 3] { true }
10777+
10778+
test_something = true {
10779+
__local2__ = data.test.l
10780+
every __local0__, __local1__ in __local2__ {
10781+
__local4__ = data.test.b
10782+
plus(__local4__, __local1__, __local3__)
10783+
__local5__ = data.test.a
10784+
lt(__local5__, __local3__)
10785+
}
10786+
}`,
10787+
},
10788+
{
10789+
note: "ref comparison, inside every-stmt, rewrite",
10790+
rewrite: true,
10791+
module: `package test
10792+
import rego.v1
10793+
10794+
a := 1
10795+
b := 2
10796+
l := [1, 2, 3]
10797+
10798+
test_something if {
10799+
every x in l {
10800+
a < b + x
10801+
}
10802+
}`,
10803+
// When tests contain an 'every' statement, we're interested in the circumstances that made the every fail,
10804+
// so it's body is rewritten.
10805+
exp: `package test
10806+
import future.keywords
10807+
10808+
a := 1 { true }
10809+
b := 2 { true }
10810+
l := [1, 2, 3] { true }
10811+
10812+
test_something = true {
10813+
__local2__ = data.test.l;
10814+
every __local0__, __local1__ in __local2__ {
10815+
__local4__ = data.test.b
10816+
plus(__local4__, __local1__, __local3__)
10817+
__local5__ = data.test.a
10818+
lt(__local5__, __local3__)
10819+
}
10820+
}`,
10821+
},
10822+
}
10823+
10824+
for _, tc := range tests {
10825+
t.Run(tc.note, func(t *testing.T) {
10826+
ms := map[string]string{
10827+
"test.rego": tc.module,
10828+
}
10829+
c := getCompilerWithParsedModules(ms).
10830+
WithRewriteTestRules(tc.rewrite)
10831+
10832+
compileStages(c, c.rewriteTestRuleEqualities)
10833+
assertNotFailed(t, c)
10834+
10835+
result := c.Modules["test.rego"]
10836+
exp := MustParseModule(tc.exp)
10837+
exp.Imports = nil // We strip the imports since the compiler will too
10838+
if result.Compare(exp) != 0 {
10839+
t.Fatalf("\nExpected:\n\n%v\n\nGot:\n\n%v", exp, result)
10840+
}
10841+
})
10842+
}
10843+
}

0 commit comments

Comments
 (0)