From 86f1a9dc151dc6942bfd463eb83a51e95305b6ad Mon Sep 17 00:00:00 2001 From: Torin Sandall Date: Wed, 13 Sep 2017 08:32:48 -0700 Subject: [PATCH] Substitute comprehension terms requring eval Fixes #453 --- ast/compile.go | 167 +++++++++++++++++++++++++++++--------------- ast/compile_test.go | 29 +++++++- ast/term.go | 53 +++++++++----- ast/transform.go | 16 +++++ 4 files changed, 189 insertions(+), 76 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 2302e37cb6..1c938378ac 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -218,6 +218,7 @@ func NewCompiler() *Compiler { c.setRuleTree, c.setFuncTree, c.setGraph, + c.rewriteComprehensionTerms, c.rewriteRefsInHead, c.checkWithModifiers, c.checkRuleConflicts, @@ -819,6 +820,18 @@ func (c *Compiler) resolveAllRefs() { } } +func (c *Compiler) rewriteComprehensionTerms() { + for _, mod := range c.Modules { + f := newEqualityFactory(newLocalVarGenerator(mod)) + rewriteComprehensionTerms(f, mod) + if vs, ok := c.generatedVars[mod]; !ok { + c.generatedVars[mod] = f.gen.Generated() + } else { + vs.Update(f.gen.Generated()) + } + } +} + // rewriteTermsInHead will rewrite rules so that the head does not contain any // terms that require evaluation (e.g., refs or comprehensions). If the key or // value contains or more of these terms, the key or value will be moved into @@ -834,61 +847,25 @@ func (c *Compiler) resolveAllRefs() { // p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} } func (c *Compiler) rewriteRefsInHead() { for _, mod := range c.Modules { - generator := newLocalVarGenerator(mod) + f := newEqualityFactory(newLocalVarGenerator(mod)) WalkRules(mod, func(rule *Rule) bool { - if rule.Head.Key != nil { - found := false - vis := NewGenericVisitor(func(x interface{}) bool { - if found { - return true - } - switch x.(type) { - case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension: - found = true - return true - } - return false - }) - Walk(vis, rule.Head.Key) - if found { - // Replace rule key with generated var - key := rule.Head.Key - local := generator.Generate() - term := &Term{Value: local} - rule.Head.Key = term - expr := Equality.Expr(term, key) - expr.Location = rule.Loc() - rule.Body.Append(expr) - } + if requiresEval(rule.Head.Key) { + expr := f.Generate(rule.Head.Key) + rule.Head.Key = expr.Operand(0) + rule.Body.Append(expr) } - if rule.Head.Value != nil { - found := false - vis := NewGenericVisitor(func(x interface{}) bool { - if found { - return true - } - switch x.(type) { - case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension: - found = true - return true - } - return false - }) - Walk(vis, rule.Head.Value) - if found { - // Replace rule value with generated var - value := rule.Head.Value - local := generator.Generate() - term := &Term{Value: local} - rule.Head.Value = term - expr := Equality.Expr(term, value) - expr.Location = rule.Loc() - rule.Body.Append(expr) - } + if requiresEval(rule.Head.Value) { + expr := f.Generate(rule.Head.Value) + rule.Head.Value = expr.Operand(0) + rule.Body.Append(expr) } return false }) - c.generatedVars[mod] = generator.Generated() + if vs, ok := c.generatedVars[mod]; !ok { + c.generatedVars[mod] = f.gen.Generated() + } else { + vs.Update(f.gen.Generated()) + } } } @@ -931,6 +908,7 @@ func (qc *queryCompiler) Compile(query Body) (Body, error) { stages := []func(*QueryContext, Body) (Body, error){ qc.resolveRefs, + qc.rewriteComprehensionTerms, qc.checkWithModifiers, qc.checkSafety, qc.checkTypes, @@ -980,6 +958,16 @@ func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error return resolveRefsInBody(globals, body), nil } +func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) { + gen := newLocalVarGenerator(body) + f := newEqualityFactory(gen) + node, err := rewriteComprehensionTerms(f, body) + if err != nil { + return nil, err + } + return node.(Body), nil +} + func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) { safe := ReservedVars.Copy() @@ -1631,6 +1619,21 @@ func reorderBodyForClosures(globals VarSet, body Body) (Body, unsafeVars) { return reordered, unsafe } +type equalityFactory struct { + gen *localVarGenerator +} + +func newEqualityFactory(gen *localVarGenerator) *equalityFactory { + return &equalityFactory{gen} +} + +func (f *equalityFactory) Generate(other *Term) *Expr { + term := NewTerm(f.gen.Generate()).SetLocation(other.Location) + expr := Equality.Expr(term, other) + expr.Location = other.Location + return expr +} + const localVarFmt = "__local%d__" type localVarGenerator struct { @@ -1638,19 +1641,15 @@ type localVarGenerator struct { generated VarSet } -func newLocalVarGenerator(module *Module) *localVarGenerator { +func newLocalVarGenerator(node interface{}) *localVarGenerator { exclude := NewVarSet() vis := &VarVisitor{ vars: exclude, } - Walk(vis, module) + Walk(vis, node) return &localVarGenerator{exclude, NewVarSet()} } -func (l *localVarGenerator) Generated() VarSet { - return l.generated -} - func (l *localVarGenerator) Generate() Var { name := Var("") x := 0 @@ -1662,6 +1661,10 @@ func (l *localVarGenerator) Generate() Var { return name } +func (l *localVarGenerator) Generated() VarSet { + return l.generated +} + func getGlobals(pkg *Package, rules []Var, funcs []*Func, imports []*Import) map[Var]Ref { globals := map[Var]Ref{} @@ -1695,6 +1698,13 @@ func getGlobals(pkg *Package, rules []Var, funcs []*Func, imports []*Import) map return globals } +func requiresEval(x *Term) bool { + if x == nil { + return false + } + return ContainsRefs(x) || ContainsComprehensions(x) +} + func resolveRef(globals map[Var]Ref, ref Ref) Ref { r := Ref{} @@ -1849,3 +1859,48 @@ func resolveRefsInTerm(globals map[Var]Ref, term *Term) *Term { return term } } + +// rewriteComprehensionTerms will rewrite comprehensions so that the term part +// is bound to a variable in the body. This allows any type of term to be used +// in the term part (even if the term requires evaluation.) +// +// For instance, given the following comprehension: +// +// [x[0] | x = y[_]; y = [1,2,3]] +// +// The comprehension would be rewritten as: +// +// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]] +func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) { + return TransformComprehensions(node, func(x interface{}) (Value, error) { + switch x := x.(type) { + case *ArrayComprehension: + if requiresEval(x.Term) { + expr := f.Generate(x.Term) + x.Term = expr.Operand(0) + x.Body.Append(expr) + } + return x, nil + case *SetComprehension: + if requiresEval(x.Term) { + expr := f.Generate(x.Term) + x.Term = expr.Operand(0) + x.Body.Append(expr) + } + return x, nil + case *ObjectComprehension: + if requiresEval(x.Key) { + expr := f.Generate(x.Key) + x.Key = expr.Operand(0) + x.Body.Append(expr) + } + if requiresEval(x.Value) { + expr := f.Generate(x.Value) + x.Value = expr.Operand(0) + x.Body.Append(expr) + } + return x, nil + } + panic("illegal type") + }) +} diff --git a/ast/compile_test.go b/ast/compile_test.go index f762cc1ef9..3ec1f1ab92 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -524,7 +524,6 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) { {"array-compr-closure", `p { _ = [v | v = [x | x = data.a[_]]; x > 1] }`, `{x,}`}, {"array-compr-term", `p { _ = [u | true] }`, `{u,}`}, {"array-compr-term-nested", `p { _ = [v | v = [w | w != 0]] }`, `{w,}`}, - {"array-compr-term-output", `p { _ = [x[i] | x = []] }`, `{i,}`}, {"array-compr-mixed", `p { _ = [x | y = [a | a = z[i]]] }`, `{a, x, z, i}`}, {"array-compr-builtin", `p { [true | eq != 2] }`, `{eq,}`}, {"closure-self", `p { x = [x | x = 1] }`, `{x,}`}, @@ -840,6 +839,33 @@ elsekw { assertRulesEqual(t, rule5, expected5) } +func TestRewriteComprehensionTerm(t *testing.T) { + + c := NewCompiler() + c.Modules["head"] = MustParseModule(`package head + arr = [[1], [2], [3]] + arr2 = [["a"], ["b"], ["c"]] + arr_comp = [[x[i]] | arr[j] = x] + set_comp = {[x[i]] | arr[j] = x} + obj_comp = {x[i]: x[i] | arr2[j] = x} + `) + + compileStages(c, c.rewriteComprehensionTerms) + assertNotFailed(t, c) + + arrCompRule := c.Modules["head"].Rules[2] + exp1 := MustParseRule(`arr_comp = [__local0__ | data.head.arr[j] = x; __local0__ = [x[i]]] { true }`) + assertRulesEqual(t, arrCompRule, exp1) + + setCompRule := c.Modules["head"].Rules[3] + exp2 := MustParseRule(`set_comp = {__local1__ | data.head.arr[j] = x; __local1__ = [x[i]]} { true }`) + assertRulesEqual(t, setCompRule, exp2) + + objCompRule := c.Modules["head"].Rules[4] + exp3 := MustParseRule(`obj_comp = {__local2__: __local3__ | data.head.arr2[j] = x; __local2__ = x[i]; __local3__ = x[i]} { true }`) + assertRulesEqual(t, objCompRule, exp3) +} + func TestCompilerSetGraph(t *testing.T) { c := NewCompiler() c.Modules = getCompilerTestModules() @@ -1501,6 +1527,7 @@ func TestQueryCompiler(t *testing.T) { }{ {"exports resolved", "z", `package a.b.c`, nil, "", "data.a.b.c.z"}, {"imports resolved", "z", `package a.b.c.d`, []string{"import data.a.b.c.z"}, "", "data.a.b.c.z"}, + {"rewrite comprehensions", "[x[i] | a = [[1], [2]]; x = a[j]]", "", nil, "", "[__local0__ | a = [[1], [2]]; x = a[j]; __local0__ = x[i]]"}, {"unsafe vars", "z", "", nil, "", fmt.Errorf("1 error occurred: 1:1: rego_unsafe_var_error: var z is unsafe")}, {"safe vars", `data; abc`, `package ex`, []string{"import input.xyz as abc"}, `{}`, `data; input.xyz`}, {"reorder", `x != 1; x = 0`, "", nil, "", `x = 0; x != 1`}, diff --git a/ast/term.go b/ast/term.go index 5641e754b7..9445501518 100644 --- a/ast/term.go +++ b/ast/term.go @@ -334,27 +334,42 @@ func (term *Term) Vars() VarSet { // IsConstant returns true if the AST value is constant. func IsConstant(v Value) bool { - switch v := v.(type) { - case Var, Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension: - return false - case Array: - for i := 0; i < len(v); i++ { - if !IsConstant(v[i].Value) { - return false - } - } - case Object: - for i := 0; i < len(v); i++ { - if !IsConstant(v[i][0].Value) || !IsConstant(v[i][1].Value) { - return false + found := false + Walk(&GenericVisitor{ + func(x interface{}) bool { + switch x.(type) { + case Var, Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension: + found = true + return true } + return false + }, + }, v) + return !found +} + +// ContainsRefs returns true if the Value v contains refs. +func ContainsRefs(v interface{}) bool { + found := false + WalkRefs(v, func(r Ref) bool { + found = true + return found + }) + return found +} + +// ContainsComprehensions returns true if the Value v contains comprehensions. +func ContainsComprehensions(v interface{}) bool { + found := false + WalkClosures(v, func(x interface{}) bool { + switch x.(type) { + case *ArrayComprehension, *ObjectComprehension, *SetComprehension: + found = true + return found } - case *Set: - return !v.Iter(func(t *Term) bool { - return !IsConstant(t.Value) - }) - } - return true + return found + }) + return found } // IsScalar returns true if the AST value is a scalar. diff --git a/ast/transform.go b/ast/transform.go index 9233154f42..2779ca263b 100644 --- a/ast/transform.go +++ b/ast/transform.go @@ -267,6 +267,22 @@ func TransformRefs(x interface{}, f func(Ref) (Value, error)) (interface{}, erro return Transform(t, x) } +// TransformComprehensions calls the functio nf on all comprehensions under x. +func TransformComprehensions(x interface{}, f func(interface{}) (Value, error)) (interface{}, error) { + t := &GenericTransformer{func(x interface{}) (interface{}, error) { + switch x := x.(type) { + case *ArrayComprehension: + return f(x) + case *SetComprehension: + return f(x) + case *ObjectComprehension: + return f(x) + } + return x, nil + }} + return Transform(t, x) +} + // GenericTransformer implements the Transformer interface to provide a utility // to transform AST nodes using a closure. type GenericTransformer struct {