Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-22.04, macos-14]
version: ["1.20"]
version: ["1.21"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.20 isn't a supported Go version anyways. I think we simply forgot to bump this when we bumped the build version to 0.22.

steps:
- uses: actions/checkout@v4
- name: Download generated artifacts
Expand Down
149 changes: 121 additions & 28 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,34 +120,35 @@ type Compiler struct {
// Capabliities required by the modules that were compiled.
Required *Capabilities

localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
imports map[string][]*Import // saved imports from stripping
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
evalMode CompilerEvalMode
localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
imports map[string][]*Import // saved imports from stripping
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
evalMode CompilerEvalMode //
rewriteTestRulesForTracing bool // rewrite test rules to capture dynamic values for tracing.
}

// CompilerStage defines the interface for stages in the compiler.
Expand Down Expand Up @@ -346,6 +347,7 @@ func NewCompiler() *Compiler {
{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
{"RewriteTestRulesForTracing", "compile_stage_rewrite_test_rules_for_tracing", c.rewriteTestRuleEqualities}, // must run after RewriteDynamicTerms
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
{"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
Expand Down Expand Up @@ -469,6 +471,13 @@ func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler {
return c
}

// WithRewriteTestRules enables rewriting test rules to capture dynamic values in local variables,
// so they can be accessed by tracing.
func (c *Compiler) WithRewriteTestRules(rewrite bool) *Compiler {
c.rewriteTestRulesForTracing = rewrite
return c
}

// ParsedModules returns the parsed, unprocessed modules from the compiler.
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
// The map includes all modules loaded via the ModuleLoader, if one was used.
Expand Down Expand Up @@ -2167,6 +2176,43 @@ func (c *Compiler) rewriteDynamicTerms() {
}
}

// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
// not have their values captured through tracing, such as refs and comprehensions not unified/assigned to a local var.
// For example, given the following module:
//
// package test
//
// p.q contains v if {
// some v in numbers.range(1, 3)
// }
//
// p.r := "foo"
//
// test_rule {
// p == {
// "q": {4, 5, 6}
// }
// }
//
// `p` in `test_rule` resolves to `data.test.p`, which won't be an entry in the virtual-cache and must therefore be calculated after-the-fact.
// If `p` isn't captured in a local var, there is no trivial way to retrieve its value for test reporting.
func (c *Compiler) rewriteTestRuleEqualities() {
if !c.rewriteTestRulesForTracing {
return
}

f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
if strings.HasPrefix(string(rule.Head.Name), "test_") {
rule.Body = rewriteTestEqualities(f, rule.Body)
}
return false
})
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be possible to add a few tests cases showing of the rewriting that's happening here? I'm afraid I don't understand it just from the code 😅 Maybe I just haven't found them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll add some of those 👍 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some tests for this compiler stage. Hope they make things more clear.


func (c *Compiler) parseMetadataBlocks() {
// Only parse annotations if rego.metadata built-ins are called
regoMetadataCalled := false
Expand Down Expand Up @@ -4517,6 +4563,41 @@ func rewriteEquals(x interface{}) (modified bool) {
return modified
}

func rewriteTestEqualities(f *equalityFactory, body Body) Body {
result := make(Body, 0, len(body))
for _, expr := range body {
// We can't rewrite negated expressions; if the extracted term is undefined, evaluation would fail before
// reaching the negation check.
if !expr.Negated && !expr.Generated {
switch {
case expr.IsEquality():
terms := expr.Terms.([]*Term)
result, terms[1] = rewriteDynamicsShallow(expr, f, terms[1], result)
result, terms[2] = rewriteDynamicsShallow(expr, f, terms[2], result)
case expr.IsEvery():
// We rewrite equalities inside of every-bodies as a fail here will be the cause of the test-rule fail.
// Failures inside other expressions with closures, such as comprehensions, won't cause the test-rule to fail, so we skip those.
every := expr.Terms.(*Every)
every.Body = rewriteTestEqualities(f, every.Body)
}
}
result = appendExpr(result, expr)
}
return result
}

func rewriteDynamicsShallow(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
switch term.Value.(type) {
case Ref, *ArrayComprehension, *SetComprehension, *ObjectComprehension:
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
connectGeneratedExprs(original, generated)
return result, result[len(result)-1].Operand(0)
}
return result, term
}

// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
// comprehensions) are bound to vars earlier in the query. This translation
// results in eager evaluation.
Expand Down Expand Up @@ -4608,6 +4689,7 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
connectGeneratedExprs(original, generated)
return result, result[len(result)-1].Operand(0)
case *Array:
for i := 0; i < v.Len(); i++ {
Expand Down Expand Up @@ -4636,16 +4718,19 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
case *SetComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
case *ObjectComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
}
return result, term
Expand Down Expand Up @@ -4713,6 +4798,7 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
for i := 1; i < len(terms); i++ {
var extras []*Expr
extras, terms[i] = expandExprTerm(gen, terms[i])
connectGeneratedExprs(expr, extras...)
if len(expr.With) > 0 {
for i := range extras {
extras[i].With = expr.With
Expand Down Expand Up @@ -4740,6 +4826,13 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
return
}

func connectGeneratedExprs(parent *Expr, children ...*Expr) {
for _, child := range children {
child.generatedFrom = parent
parent.generates = append(parent.generates, child)
}
}

func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
output = term
switch v := term.Value.(type) {
Expand Down
162 changes: 162 additions & 0 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10679,3 +10679,165 @@ deny {
t.Fatal(c.Errors)
}
}

func TestCompilerRewriteTestRulesForTracing(t *testing.T) {
tests := []struct {
note string
rewrite bool
module string
exp string
}{
{
note: "ref comparison, no rewrite",
module: `package test
import rego.v1

a := 1
b := 2

test_something if {
a == b
}`,
exp: `package test

a := 1 { true }
b := 2 { true }

test_something = true {
data.test.a = data.test.b
}`,
},
{
note: "ref comparison, rewrite",
rewrite: true,
module: `package test
import rego.v1

a := 1
b := 2

test_something if {
a == b
}`,
// When the test fails on '__local0__ = __local1__', the values for 'a' and 'b' are captured in local bindings,
// accessible by the tracer.
exp: `package test

a := 1 { true }
b := 2 { true }

test_something = true {
__local0__ = data.test.a
__local1__ = data.test.b
__local0__ = __local1__
}`,
},
{
note: "ref comparison, not-stmt, rewrite",
rewrite: true,
module: `package test
import rego.v1

a := 1
b := 2

test_something if {
not a == b
}`,
// We don't break out local vars from a not-stmt, as that would change the semantics of the rule.
exp: `package test

a := 1 { true }
b := 2 { true }

test_something = true {
not data.test.a = data.test.b
}`,
},
{
note: "ref comparison, inside every-stmt, no rewrite",
module: `package test
import rego.v1

a := 1
b := 2
l := [1, 2, 3]

test_something if {
every x in l {
a < b + x
}
}`,
exp: `package test
import future.keywords

a := 1 { true }
b := 2 { true }
l := [1, 2, 3] { true }

test_something = true {
__local2__ = data.test.l
every __local0__, __local1__ in __local2__ {
__local4__ = data.test.b
plus(__local4__, __local1__, __local3__)
__local5__ = data.test.a
lt(__local5__, __local3__)
}
}`,
},
{
note: "ref comparison, inside every-stmt, rewrite",
rewrite: true,
module: `package test
import rego.v1

a := 1
b := 2
l := [1, 2, 3]

test_something if {
every x in l {
a < b + x
}
}`,
// When tests contain an 'every' statement, we're interested in the circumstances that made the every fail,
// so it's body is rewritten.
exp: `package test
import future.keywords

a := 1 { true }
b := 2 { true }
l := [1, 2, 3] { true }

test_something = true {
__local2__ = data.test.l;
every __local0__, __local1__ in __local2__ {
__local4__ = data.test.b
plus(__local4__, __local1__, __local3__)
__local5__ = data.test.a
lt(__local5__, __local3__)
}
}`,
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
ms := map[string]string{
"test.rego": tc.module,
}
c := getCompilerWithParsedModules(ms).
WithRewriteTestRules(tc.rewrite)

compileStages(c, c.rewriteTestRuleEqualities)
assertNotFailed(t, c)

result := c.Modules["test.rego"]
exp := MustParseModule(tc.exp)
exp.Imports = nil // We strip the imports since the compiler will too
if result.Compare(exp) != 0 {
t.Fatalf("\nExpected:\n\n%v\n\nGot:\n\n%v", exp, result)
}
})
}
}
Loading