Skip to content

Commit

Permalink
Substitute comprehension terms requring eval
Browse files Browse the repository at this point in the history
  • Loading branch information
tsandall committed Sep 13, 2017
1 parent 8019c7e commit 86f1a9d
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 76 deletions.
167 changes: 111 additions & 56 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ func NewCompiler() *Compiler {
c.setRuleTree,
c.setFuncTree,
c.setGraph,
c.rewriteComprehensionTerms,
c.rewriteRefsInHead,
c.checkWithModifiers,
c.checkRuleConflicts,
Expand Down Expand Up @@ -819,6 +820,18 @@ func (c *Compiler) resolveAllRefs() {
}
}

func (c *Compiler) rewriteComprehensionTerms() {
for _, mod := range c.Modules {
f := newEqualityFactory(newLocalVarGenerator(mod))
rewriteComprehensionTerms(f, mod)
if vs, ok := c.generatedVars[mod]; !ok {
c.generatedVars[mod] = f.gen.Generated()
} else {
vs.Update(f.gen.Generated())
}
}
}

// rewriteTermsInHead will rewrite rules so that the head does not contain any
// terms that require evaluation (e.g., refs or comprehensions). If the key or
// value contains or more of these terms, the key or value will be moved into
Expand All @@ -834,61 +847,25 @@ func (c *Compiler) resolveAllRefs() {
// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
func (c *Compiler) rewriteRefsInHead() {
for _, mod := range c.Modules {
generator := newLocalVarGenerator(mod)
f := newEqualityFactory(newLocalVarGenerator(mod))
WalkRules(mod, func(rule *Rule) bool {
if rule.Head.Key != nil {
found := false
vis := NewGenericVisitor(func(x interface{}) bool {
if found {
return true
}
switch x.(type) {
case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension:
found = true
return true
}
return false
})
Walk(vis, rule.Head.Key)
if found {
// Replace rule key with generated var
key := rule.Head.Key
local := generator.Generate()
term := &Term{Value: local}
rule.Head.Key = term
expr := Equality.Expr(term, key)
expr.Location = rule.Loc()
rule.Body.Append(expr)
}
if requiresEval(rule.Head.Key) {
expr := f.Generate(rule.Head.Key)
rule.Head.Key = expr.Operand(0)
rule.Body.Append(expr)
}
if rule.Head.Value != nil {
found := false
vis := NewGenericVisitor(func(x interface{}) bool {
if found {
return true
}
switch x.(type) {
case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension:
found = true
return true
}
return false
})
Walk(vis, rule.Head.Value)
if found {
// Replace rule value with generated var
value := rule.Head.Value
local := generator.Generate()
term := &Term{Value: local}
rule.Head.Value = term
expr := Equality.Expr(term, value)
expr.Location = rule.Loc()
rule.Body.Append(expr)
}
if requiresEval(rule.Head.Value) {
expr := f.Generate(rule.Head.Value)
rule.Head.Value = expr.Operand(0)
rule.Body.Append(expr)
}
return false
})
c.generatedVars[mod] = generator.Generated()
if vs, ok := c.generatedVars[mod]; !ok {
c.generatedVars[mod] = f.gen.Generated()
} else {
vs.Update(f.gen.Generated())
}
}
}

Expand Down Expand Up @@ -931,6 +908,7 @@ func (qc *queryCompiler) Compile(query Body) (Body, error) {

stages := []func(*QueryContext, Body) (Body, error){
qc.resolveRefs,
qc.rewriteComprehensionTerms,
qc.checkWithModifiers,
qc.checkSafety,
qc.checkTypes,
Expand Down Expand Up @@ -980,6 +958,16 @@ func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error
return resolveRefsInBody(globals, body), nil
}

func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator(body)
f := newEqualityFactory(gen)
node, err := rewriteComprehensionTerms(f, body)
if err != nil {
return nil, err
}
return node.(Body), nil
}

func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {

safe := ReservedVars.Copy()
Expand Down Expand Up @@ -1631,26 +1619,37 @@ func reorderBodyForClosures(globals VarSet, body Body) (Body, unsafeVars) {
return reordered, unsafe
}

type equalityFactory struct {
gen *localVarGenerator
}

func newEqualityFactory(gen *localVarGenerator) *equalityFactory {
return &equalityFactory{gen}
}

func (f *equalityFactory) Generate(other *Term) *Expr {
term := NewTerm(f.gen.Generate()).SetLocation(other.Location)
expr := Equality.Expr(term, other)
expr.Location = other.Location
return expr
}

const localVarFmt = "__local%d__"

type localVarGenerator struct {
exclude VarSet
generated VarSet
}

func newLocalVarGenerator(module *Module) *localVarGenerator {
func newLocalVarGenerator(node interface{}) *localVarGenerator {
exclude := NewVarSet()
vis := &VarVisitor{
vars: exclude,
}
Walk(vis, module)
Walk(vis, node)
return &localVarGenerator{exclude, NewVarSet()}
}

func (l *localVarGenerator) Generated() VarSet {
return l.generated
}

func (l *localVarGenerator) Generate() Var {
name := Var("")
x := 0
Expand All @@ -1662,6 +1661,10 @@ func (l *localVarGenerator) Generate() Var {
return name
}

func (l *localVarGenerator) Generated() VarSet {
return l.generated
}

func getGlobals(pkg *Package, rules []Var, funcs []*Func, imports []*Import) map[Var]Ref {

globals := map[Var]Ref{}
Expand Down Expand Up @@ -1695,6 +1698,13 @@ func getGlobals(pkg *Package, rules []Var, funcs []*Func, imports []*Import) map
return globals
}

func requiresEval(x *Term) bool {
if x == nil {
return false
}
return ContainsRefs(x) || ContainsComprehensions(x)
}

func resolveRef(globals map[Var]Ref, ref Ref) Ref {

r := Ref{}
Expand Down Expand Up @@ -1849,3 +1859,48 @@ func resolveRefsInTerm(globals map[Var]Ref, term *Term) *Term {
return term
}
}

// rewriteComprehensionTerms will rewrite comprehensions so that the term part
// is bound to a variable in the body. This allows any type of term to be used
// in the term part (even if the term requires evaluation.)
//
// For instance, given the following comprehension:
//
// [x[0] | x = y[_]; y = [1,2,3]]
//
// The comprehension would be rewritten as:
//
// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]]
func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) {
return TransformComprehensions(node, func(x interface{}) (Value, error) {
switch x := x.(type) {
case *ArrayComprehension:
if requiresEval(x.Term) {
expr := f.Generate(x.Term)
x.Term = expr.Operand(0)
x.Body.Append(expr)
}
return x, nil
case *SetComprehension:
if requiresEval(x.Term) {
expr := f.Generate(x.Term)
x.Term = expr.Operand(0)
x.Body.Append(expr)
}
return x, nil
case *ObjectComprehension:
if requiresEval(x.Key) {
expr := f.Generate(x.Key)
x.Key = expr.Operand(0)
x.Body.Append(expr)
}
if requiresEval(x.Value) {
expr := f.Generate(x.Value)
x.Value = expr.Operand(0)
x.Body.Append(expr)
}
return x, nil
}
panic("illegal type")
})
}
29 changes: 28 additions & 1 deletion ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) {
{"array-compr-closure", `p { _ = [v | v = [x | x = data.a[_]]; x > 1] }`, `{x,}`},
{"array-compr-term", `p { _ = [u | true] }`, `{u,}`},
{"array-compr-term-nested", `p { _ = [v | v = [w | w != 0]] }`, `{w,}`},
{"array-compr-term-output", `p { _ = [x[i] | x = []] }`, `{i,}`},
{"array-compr-mixed", `p { _ = [x | y = [a | a = z[i]]] }`, `{a, x, z, i}`},
{"array-compr-builtin", `p { [true | eq != 2] }`, `{eq,}`},
{"closure-self", `p { x = [x | x = 1] }`, `{x,}`},
Expand Down Expand Up @@ -840,6 +839,33 @@ elsekw {
assertRulesEqual(t, rule5, expected5)
}

func TestRewriteComprehensionTerm(t *testing.T) {

c := NewCompiler()
c.Modules["head"] = MustParseModule(`package head
arr = [[1], [2], [3]]
arr2 = [["a"], ["b"], ["c"]]
arr_comp = [[x[i]] | arr[j] = x]
set_comp = {[x[i]] | arr[j] = x}
obj_comp = {x[i]: x[i] | arr2[j] = x}
`)

compileStages(c, c.rewriteComprehensionTerms)
assertNotFailed(t, c)

arrCompRule := c.Modules["head"].Rules[2]
exp1 := MustParseRule(`arr_comp = [__local0__ | data.head.arr[j] = x; __local0__ = [x[i]]] { true }`)
assertRulesEqual(t, arrCompRule, exp1)

setCompRule := c.Modules["head"].Rules[3]
exp2 := MustParseRule(`set_comp = {__local1__ | data.head.arr[j] = x; __local1__ = [x[i]]} { true }`)
assertRulesEqual(t, setCompRule, exp2)

objCompRule := c.Modules["head"].Rules[4]
exp3 := MustParseRule(`obj_comp = {__local2__: __local3__ | data.head.arr2[j] = x; __local2__ = x[i]; __local3__ = x[i]} { true }`)
assertRulesEqual(t, objCompRule, exp3)
}

func TestCompilerSetGraph(t *testing.T) {
c := NewCompiler()
c.Modules = getCompilerTestModules()
Expand Down Expand Up @@ -1501,6 +1527,7 @@ func TestQueryCompiler(t *testing.T) {
}{
{"exports resolved", "z", `package a.b.c`, nil, "", "data.a.b.c.z"},
{"imports resolved", "z", `package a.b.c.d`, []string{"import data.a.b.c.z"}, "", "data.a.b.c.z"},
{"rewrite comprehensions", "[x[i] | a = [[1], [2]]; x = a[j]]", "", nil, "", "[__local0__ | a = [[1], [2]]; x = a[j]; __local0__ = x[i]]"},
{"unsafe vars", "z", "", nil, "", fmt.Errorf("1 error occurred: 1:1: rego_unsafe_var_error: var z is unsafe")},
{"safe vars", `data; abc`, `package ex`, []string{"import input.xyz as abc"}, `{}`, `data; input.xyz`},
{"reorder", `x != 1; x = 0`, "", nil, "", `x = 0; x != 1`},
Expand Down
53 changes: 34 additions & 19 deletions ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,27 +334,42 @@ func (term *Term) Vars() VarSet {

// IsConstant returns true if the AST value is constant.
func IsConstant(v Value) bool {
switch v := v.(type) {
case Var, Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension:
return false
case Array:
for i := 0; i < len(v); i++ {
if !IsConstant(v[i].Value) {
return false
}
}
case Object:
for i := 0; i < len(v); i++ {
if !IsConstant(v[i][0].Value) || !IsConstant(v[i][1].Value) {
return false
found := false
Walk(&GenericVisitor{
func(x interface{}) bool {
switch x.(type) {
case Var, Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension:
found = true
return true
}
return false
},
}, v)
return !found
}

// ContainsRefs returns true if the Value v contains refs.
func ContainsRefs(v interface{}) bool {
found := false
WalkRefs(v, func(r Ref) bool {
found = true
return found
})
return found
}

// ContainsComprehensions returns true if the Value v contains comprehensions.
func ContainsComprehensions(v interface{}) bool {
found := false
WalkClosures(v, func(x interface{}) bool {
switch x.(type) {
case *ArrayComprehension, *ObjectComprehension, *SetComprehension:
found = true
return found
}
case *Set:
return !v.Iter(func(t *Term) bool {
return !IsConstant(t.Value)
})
}
return true
return found
})
return found
}

// IsScalar returns true if the AST value is a scalar.
Expand Down
16 changes: 16 additions & 0 deletions ast/transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,22 @@ func TransformRefs(x interface{}, f func(Ref) (Value, error)) (interface{}, erro
return Transform(t, x)
}

// TransformComprehensions calls the functio nf on all comprehensions under x.
func TransformComprehensions(x interface{}, f func(interface{}) (Value, error)) (interface{}, error) {
t := &GenericTransformer{func(x interface{}) (interface{}, error) {
switch x := x.(type) {
case *ArrayComprehension:
return f(x)
case *SetComprehension:
return f(x)
case *ObjectComprehension:
return f(x)
}
return x, nil
}}
return Transform(t, x)
}

// GenericTransformer implements the Transformer interface to provide a utility
// to transform AST nodes using a closure.
type GenericTransformer struct {
Expand Down

0 comments on commit 86f1a9d

Please sign in to comment.