Merge pull request #53 from hashicorp/add-rewrite

ast: add Rewrite() to rewrite AST
This commit is contained in:
Mitchell Hashimoto 2015-11-14 17:38:20 -08:00
commit fbd0456768
4 changed files with 169 additions and 23 deletions

View File

@ -2,6 +2,7 @@ package ast
import (
"reflect"
"strings"
"testing"
"github.com/hashicorp/hcl/hcl/token"
@ -64,3 +65,136 @@ func TestObjectListFilter(t *testing.T) {
}
}
}
func TestWalk(t *testing.T) {
items := []*ObjectItem{
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
},
Val: &LiteralType{Token: token.Token{Type: token.STRING, Text: `"example"`}},
},
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
},
},
}
node := &ObjectList{Items: items}
order := []string{
"*ast.ObjectList",
"*ast.ObjectItem",
"*ast.ObjectKey",
"*ast.ObjectKey",
"*ast.LiteralType",
"*ast.ObjectItem",
"*ast.ObjectKey",
}
count := 0
Walk(node, func(n Node) (Node, bool) {
if n == nil {
return n, false
}
typeName := reflect.TypeOf(n).String()
if order[count] != typeName {
t.Errorf("expected '%s' got: '%s'", order[count], typeName)
}
count++
return n, true
})
}
func TestWalkEquality(t *testing.T) {
items := []*ObjectItem{
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
},
},
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
},
},
}
node := &ObjectList{Items: items}
rewritten := Walk(node, func(n Node) (Node, bool) { return n, true })
newNode, ok := rewritten.(*ObjectList)
if !ok {
t.Fatalf("expected Objectlist, got %T", rewritten)
}
if !reflect.DeepEqual(node, newNode) {
t.Fatal("rewritten node is not equal to the given node")
}
if len(newNode.Items) != 2 {
t.Error("expected newNode length 2, got: %d", len(newNode.Items))
}
expected := []string{
`"foo"`,
`"bar"`,
}
for i, item := range newNode.Items {
if len(item.Keys) != 1 {
t.Error("expected keys newNode length 1, got: %d", len(item.Keys))
}
if item.Keys[0].Token.Text != expected[i] {
t.Errorf("expected key %s, got %s", expected[i], item.Keys[0].Token.Text)
}
if item.Val != nil {
t.Errorf("expected item value should be nil")
}
}
}
func TestWalkRewrite(t *testing.T) {
items := []*ObjectItem{
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
},
},
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
},
},
}
node := &ObjectList{Items: items}
suffix := "_example"
node = Walk(node, func(n Node) (Node, bool) {
switch i := n.(type) {
case *ObjectKey:
i.Token.Text = i.Token.Text + suffix
n = i
}
return n, true
}).(*ObjectList)
Walk(node, func(n Node) (Node, bool) {
switch i := n.(type) {
case *ObjectKey:
if !strings.HasSuffix(i.Token.Text, suffix) {
t.Errorf("Token '%s' should have suffix: %s", i.Token.Text, suffix)
}
}
return n, true
})
}

View File

@ -2,39 +2,51 @@ package ast
import "fmt"
// WalkFunc describes a function to be called for each node during a Walk. The
// returned node can be used to rewrite the AST. Walking stops the returned
// bool is false.
type WalkFunc func(Node) (Node, bool)
// Walk traverses an AST in depth-first order: It starts by calling fn(node);
// node must not be nil. If f returns true, Walk invokes f recursively for
// each of the non-nil children of node, followed by a call of f(nil).
func Walk(node Node, fn func(Node) bool) {
if !fn(node) {
return
// node must not be nil. If fn returns true, Walk invokes fn recursively for
// each of the non-nil children of node, followed by a call of fn(nil). The
// returned node of fn can be used to rewrite the passed node to fn.
func Walk(node Node, fn WalkFunc) Node {
rewritten, ok := fn(node)
if !ok {
return rewritten
}
switch n := node.(type) {
case *File:
Walk(n.Node, fn)
n.Node = Walk(n.Node, fn)
case *ObjectList:
for _, item := range n.Items {
Walk(item, fn)
for i, item := range n.Items {
n.Items[i] = Walk(item, fn).(*ObjectItem)
}
case *ObjectKey:
// nothing to do
case *ObjectItem:
for _, k := range n.Keys {
Walk(k, fn)
for i, k := range n.Keys {
n.Keys[i] = Walk(k, fn).(*ObjectKey)
}
if n.Val != nil {
n.Val = Walk(n.Val, fn)
}
Walk(n.Val, fn)
case *LiteralType:
// nothing to do
case *ListType:
for _, l := range n.List {
Walk(l, fn)
for i, l := range n.List {
n.List[i] = Walk(l, fn)
}
case *ObjectType:
Walk(n.List, fn)
n.List = Walk(n.List, fn).(*ObjectList)
default:
fmt.Printf(" unknown type: %T\n", n)
// should we panic here?
fmt.Printf("unknown type: %T\n", n)
}
fn(nil)
return rewritten
}

View File

@ -42,13 +42,13 @@ func (b ByPosition) Less(i, j int) bool { return b[i].Pos().Before(b[j].Pos()) }
func (p *printer) collectComments(node ast.Node) {
// first collect all comments. This is already stored in
// ast.File.(comments)
ast.Walk(node, func(nn ast.Node) bool {
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
switch t := nn.(type) {
case *ast.File:
p.comments = t.Comments
return false
return nn, false
}
return true
return nn, true
})
standaloneComments := make(map[token.Pos]*ast.CommentGroup, 0)
@ -59,7 +59,7 @@ func (p *printer) collectComments(node ast.Node) {
// next remove all lead and line comments from the overall comment map.
// This will give us comments which are standalone, comments which are not
// assigned to any kind of node.
ast.Walk(node, func(nn ast.Node) bool {
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
switch t := nn.(type) {
case *ast.LiteralType:
if t.LineComment != nil {
@ -87,7 +87,7 @@ func (p *printer) collectComments(node ast.Node) {
}
}
return true
return nn, true
})
for _, c := range standaloneComments {

View File

@ -6,11 +6,11 @@ import (
// flattenObjects takes an AST node, walks it, and flattens
func flattenObjects(node ast.Node) {
ast.Walk(node, func(n ast.Node) bool {
ast.Walk(node, func(n ast.Node) (ast.Node, bool) {
// We only care about lists, because this is what we modify
list, ok := n.(*ast.ObjectList)
if !ok {
return true
return n, true
}
// Rebuild the item list
@ -41,7 +41,7 @@ func flattenObjects(node ast.Node) {
// Done! Set the original items
list.Items = items
return true
return n, true
})
}