Skip to content
149 changes: 121 additions & 28 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,34 +120,35 @@ type Compiler struct {
// Capabliities required by the modules that were compiled.
Required *Capabilities

localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
imports map[string][]*Import // saved imports from stripping
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
evalMode CompilerEvalMode
localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
imports map[string][]*Import // saved imports from stripping
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
evalMode CompilerEvalMode
rewriteTestRulesToCaptureUnboundDynamics bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Dynamics" -- it's not a term that's common in ast or topdown, is it? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True 😄, I think that trickled up from the rewriteDynamics*() naming convention we have in the compiler. Naming suggestions are welcome :). It's a bit too long for my liking too ..

}

// CompilerStage defines the interface for stages in the compiler.
Expand Down Expand Up @@ -346,6 +347,7 @@ func NewCompiler() *Compiler {
{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
{"RewriteTestRuleEqualities", "compile_stage_rewrite_test_rule_equalities", c.rewriteTestRuleEqualities}, // must run after RewriteDynamicTerms
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
{"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
Expand Down Expand Up @@ -469,6 +471,13 @@ func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler {
return c
}

// WithRewriteTestRulesToCaptureUnboundDynamics enables rewriting test rules to capture dynamic values in local variables,
// so they can be accessed by tracing.
func (c *Compiler) WithRewriteTestRulesToCaptureUnboundDynamics(rewrite bool) *Compiler {
c.rewriteTestRulesToCaptureUnboundDynamics = rewrite
return c
}

// ParsedModules returns the parsed, unprocessed modules from the compiler.
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
// The map includes all modules loaded via the ModuleLoader, if one was used.
Expand Down Expand Up @@ -2167,6 +2176,43 @@ func (c *Compiler) rewriteDynamicTerms() {
}
}

// rewriteDynamics rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
// not have their values captured through tracing, such as refs and comprehensions not unified/assigned to a local var.
// For example, given the following module:
//
// package test
//
// p.q contains v if {
// some v in numbers.range(1, 3)
// }
//
// p.r := "foo"
//
// test_rule {
// p == {
// "q": {4, 5, 6}
// }
// }
//
// `p` in `test_rule` resolves to `data.test.p`, which won't be an entry in the virtual-cache and must therefore be calculated after-the-fact.
// If `p` isn't captured in a local var, there is no trivial way to retrieve its value for test reporting.
func (c *Compiler) rewriteTestRuleEqualities() {
if !c.rewriteTestRulesToCaptureUnboundDynamics {
return
}

f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
if strings.HasPrefix(string(rule.Head.Name), "test_") {
rule.Body = rewriteTestEqualities(f, rule.Body)
}
return false
})
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be possible to add a few tests cases showing of the rewriting that's happening here? I'm afraid I don't understand it just from the code 😅 Maybe I just haven't found them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll add some of those 👍 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some tests for this compiler stage. Hope they make things more clear.


func (c *Compiler) parseMetadataBlocks() {
// Only parse annotations if rego.metadata built-ins are called
regoMetadataCalled := false
Expand Down Expand Up @@ -4517,6 +4563,41 @@ func rewriteEquals(x interface{}) (modified bool) {
return modified
}

func rewriteTestEqualities(f *equalityFactory, body Body) Body {
result := make(Body, 0, len(body))
for _, expr := range body {
// We can't rewrite negated expressions; if the extracted term is undefined, evaluation would fail before
// reaching the negation check.
if !expr.Negated && !expr.Generated {
switch {
case expr.IsEquality():
terms := expr.Terms.([]*Term)
result, terms[1] = rewriteDynamicsShallow(expr, f, terms[1], result)
result, terms[2] = rewriteDynamicsShallow(expr, f, terms[2], result)
case expr.IsEvery():
// We rewrite equalities inside of every-bodies as a fail here will be the cause of the test-rule fail.
// Failures inside other expressions with closures, such as comprehensions, won't cause the test-rule to fail, so we skip those.
every := expr.Terms.(*Every)
every.Body = rewriteTestEqualities(f, every.Body)
}
}
result = appendExpr(result, expr)
}
return result
}

func rewriteDynamicsShallow(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
switch term.Value.(type) {
case Ref, *ArrayComprehension, *SetComprehension, *ObjectComprehension:
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
connectGeneratedExprs(original, generated)
return result, result[len(result)-1].Operand(0)
}
return result, term
}

// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
// comprehensions) are bound to vars earlier in the query. This translation
// results in eager evaluation.
Expand Down Expand Up @@ -4608,6 +4689,7 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
connectGeneratedExprs(original, generated)
return result, result[len(result)-1].Operand(0)
case *Array:
for i := 0; i < v.Len(); i++ {
Expand Down Expand Up @@ -4636,16 +4718,19 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
case *SetComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
case *ObjectComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
}
return result, term
Expand Down Expand Up @@ -4713,6 +4798,7 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
for i := 1; i < len(terms); i++ {
var extras []*Expr
extras, terms[i] = expandExprTerm(gen, terms[i])
connectGeneratedExprs(expr, extras...)
if len(expr.With) > 0 {
for i := range extras {
extras[i].With = expr.With
Expand Down Expand Up @@ -4740,6 +4826,13 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
return
}

func connectGeneratedExprs(parent *Expr, children ...*Expr) {
for _, child := range children {
child.generatedFrom = parent
parent.generates = append(parent.generates, child)
}
}

func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
output = term
switch v := term.Value.(type) {
Expand Down
17 changes: 12 additions & 5 deletions ast/internal/scanner/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Scanner struct {
width int
errors []Error
keywords map[string]tokens.Token
tabs []int
regoV1Compatible bool
}

Expand All @@ -37,10 +38,11 @@ type Error struct {

// Position represents a point in the scanned source code.
type Position struct {
Offset int // start offset in bytes
End int // end offset in bytes
Row int // line number computed in bytes
Col int // column number computed in bytes
Offset int // start offset in bytes
End int // end offset in bytes
Row int // line number computed in bytes
Col int // column number computed in bytes
Tabs []int // positions of any tabs preceding Col
}

// New returns an initialized scanner that will scan
Expand All @@ -60,6 +62,7 @@ func New(r io.Reader) (*Scanner, error) {
curr: -1,
width: 0,
keywords: tokens.Keywords(),
tabs: []int{},
}

s.next()
Expand Down Expand Up @@ -156,7 +159,7 @@ func (s *Scanner) WithoutKeywords(kws map[string]tokens.Token) (*Scanner, map[st
// for any errors before using the other values.
func (s *Scanner) Scan() (tokens.Token, Position, string, []Error) {

pos := Position{Offset: s.offset - s.width, Row: s.row, Col: s.col}
pos := Position{Offset: s.offset - s.width, Row: s.row, Col: s.col, Tabs: s.tabs}
var tok tokens.Token
var lit string

Expand Down Expand Up @@ -410,8 +413,12 @@ func (s *Scanner) next() {
if s.curr == '\n' {
s.row++
s.col = 0
s.tabs = []int{}
} else {
s.col++
if s.curr == '\t' {
s.tabs = append(s.tabs, s.col)
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions ast/location/location.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type Location struct {

// JSONOptions specifies options for marshaling and unmarshalling of locations
JSONOptions astJSON.Options

Tabs []int `json:"-"` // The column offsets of tabs in the source.
}

// NewLocation returns a new Location object.
Expand Down
1 change: 1 addition & 0 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2130,6 +2130,7 @@ func (p *Parser) doScan(skipws bool) {
p.s.loc.Col = pos.Col
p.s.loc.Offset = pos.Offset
p.s.loc.Text = p.s.Text(pos.Offset, pos.End)
p.s.loc.Tabs = pos.Tabs

for _, err := range errs {
p.error(p.s.Loc(), err.Message)
Expand Down
44 changes: 43 additions & 1 deletion ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ type (
Negated bool `json:"negated,omitempty"`
Location *Location `json:"location,omitempty"`

jsonOptions astJSON.Options
jsonOptions astJSON.Options
generatedFrom *Expr
generates []*Expr
}

// SomeDecl represents a variable declaration statement. The symbols are variables.
Expand Down Expand Up @@ -1593,6 +1595,46 @@ func NewBuiltinExpr(terms ...*Term) *Expr {
return &Expr{Terms: terms}
}

func (expr *Expr) CogeneratedExprs() []*Expr {
visited := map[*Expr]struct{}{}
visitCogeneratedExprs(expr, func(e *Expr) bool {
if expr.Equal(e) {
return true
}
if _, ok := visited[e]; ok {
return true
}
visited[e] = struct{}{}
return false
})

result := make([]*Expr, 0, len(visited))
for e := range visited {
result = append(result, e)
}
return result
}

func (expr *Expr) BaseCogeneratedExpr() *Expr {
if expr.generatedFrom == nil {
return expr
}
return expr.generatedFrom.BaseCogeneratedExpr()
}

func visitCogeneratedExprs(expr *Expr, f func(*Expr) bool) {
if parent := expr.generatedFrom; parent != nil {
if stop := f(parent); !stop {
visitCogeneratedExprs(parent, f)
}
}
for _, child := range expr.generates {
if stop := f(child); !stop {
visitCogeneratedExprs(child, f)
}
}
}

func (d *SomeDecl) String() string {
if call, ok := d.Symbols[0].Value.(Call); ok {
if len(call) == 4 {
Expand Down
10 changes: 8 additions & 2 deletions cmd/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type evalCommandParams struct {
entrypoints repeatedStringFlag
strict bool
v1Compatible bool
traceVarValues bool
}

func newEvalCommandParams() evalCommandParams {
Expand Down Expand Up @@ -307,9 +308,9 @@ access.
evalCommand.Flags().VarP(&params.prettyLimit, "pretty-limit", "", "set limit after which pretty output gets truncated")
evalCommand.Flags().BoolVarP(&params.failDefined, "fail-defined", "", false, "exits with non-zero exit code on defined/non-empty result and errors")
evalCommand.Flags().DurationVar(&params.timeout, "timeout", 0, "set eval timeout (default unlimited)")

evalCommand.Flags().IntVarP(&params.optimizationLevel, "optimize", "O", 0, "set optimization level")
evalCommand.Flags().VarP(&params.entrypoints, "entrypoint", "e", "set slash separated entrypoint path")
evalCommand.Flags().BoolVar(&params.traceVarValues, "var-values", false, "show local variable values in pretty trace output")

// Shared flags
addCapabilitiesFlag(evalCommand.Flags(), params.capabilities)
Expand Down Expand Up @@ -398,7 +399,12 @@ func eval(args []string, params evalCommandParams, w io.Writer) (bool, error) {
case evalValuesOutput:
err = pr.Values(w, result)
case evalPrettyOutput:
err = pr.Pretty(w, result)
err = pr.PrettyWithOptions(w, result, pr.PrettyOptions{
TraceOpts: topdown.PrettyTraceOptions{
Locations: true,
ExprVariables: ectx.params.traceVarValues,
},
})
case evalSourceOutput:
err = pr.Source(w, result)
case evalRawOutput:
Expand Down
Loading