Merge pull request #53 from hashicorp/add-rewrite
ast: add Rewrite() to rewrite AST
This commit is contained in:
commit
fbd0456768
@ -2,6 +2,7 @@ package ast
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/hcl/hcl/token"
|
||||
@ -64,3 +65,136 @@ func TestObjectListFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalk(t *testing.T) {
|
||||
items := []*ObjectItem{
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
|
||||
},
|
||||
Val: &LiteralType{Token: token.Token{Type: token.STRING, Text: `"example"`}},
|
||||
},
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
node := &ObjectList{Items: items}
|
||||
|
||||
order := []string{
|
||||
"*ast.ObjectList",
|
||||
"*ast.ObjectItem",
|
||||
"*ast.ObjectKey",
|
||||
"*ast.ObjectKey",
|
||||
"*ast.LiteralType",
|
||||
"*ast.ObjectItem",
|
||||
"*ast.ObjectKey",
|
||||
}
|
||||
count := 0
|
||||
|
||||
Walk(node, func(n Node) (Node, bool) {
|
||||
if n == nil {
|
||||
return n, false
|
||||
}
|
||||
|
||||
typeName := reflect.TypeOf(n).String()
|
||||
if order[count] != typeName {
|
||||
t.Errorf("expected '%s' got: '%s'", order[count], typeName)
|
||||
}
|
||||
count++
|
||||
return n, true
|
||||
})
|
||||
}
|
||||
|
||||
func TestWalkEquality(t *testing.T) {
|
||||
items := []*ObjectItem{
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
|
||||
},
|
||||
},
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
node := &ObjectList{Items: items}
|
||||
|
||||
rewritten := Walk(node, func(n Node) (Node, bool) { return n, true })
|
||||
|
||||
newNode, ok := rewritten.(*ObjectList)
|
||||
if !ok {
|
||||
t.Fatalf("expected Objectlist, got %T", rewritten)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(node, newNode) {
|
||||
t.Fatal("rewritten node is not equal to the given node")
|
||||
}
|
||||
|
||||
if len(newNode.Items) != 2 {
|
||||
t.Error("expected newNode length 2, got: %d", len(newNode.Items))
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
`"foo"`,
|
||||
`"bar"`,
|
||||
}
|
||||
|
||||
for i, item := range newNode.Items {
|
||||
if len(item.Keys) != 1 {
|
||||
t.Error("expected keys newNode length 1, got: %d", len(item.Keys))
|
||||
}
|
||||
|
||||
if item.Keys[0].Token.Text != expected[i] {
|
||||
t.Errorf("expected key %s, got %s", expected[i], item.Keys[0].Token.Text)
|
||||
}
|
||||
|
||||
if item.Val != nil {
|
||||
t.Errorf("expected item value should be nil")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalkRewrite(t *testing.T) {
|
||||
items := []*ObjectItem{
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
|
||||
},
|
||||
},
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
node := &ObjectList{Items: items}
|
||||
|
||||
suffix := "_example"
|
||||
node = Walk(node, func(n Node) (Node, bool) {
|
||||
switch i := n.(type) {
|
||||
case *ObjectKey:
|
||||
i.Token.Text = i.Token.Text + suffix
|
||||
n = i
|
||||
}
|
||||
return n, true
|
||||
}).(*ObjectList)
|
||||
|
||||
Walk(node, func(n Node) (Node, bool) {
|
||||
switch i := n.(type) {
|
||||
case *ObjectKey:
|
||||
if !strings.HasSuffix(i.Token.Text, suffix) {
|
||||
t.Errorf("Token '%s' should have suffix: %s", i.Token.Text, suffix)
|
||||
}
|
||||
}
|
||||
return n, true
|
||||
})
|
||||
|
||||
}
|
||||
|
@ -2,39 +2,51 @@ package ast
|
||||
|
||||
import "fmt"
|
||||
|
||||
// WalkFunc describes a function to be called for each node during a Walk. The
|
||||
// returned node can be used to rewrite the AST. Walking stops the returned
|
||||
// bool is false.
|
||||
type WalkFunc func(Node) (Node, bool)
|
||||
|
||||
// Walk traverses an AST in depth-first order: It starts by calling fn(node);
|
||||
// node must not be nil. If f returns true, Walk invokes f recursively for
|
||||
// each of the non-nil children of node, followed by a call of f(nil).
|
||||
func Walk(node Node, fn func(Node) bool) {
|
||||
if !fn(node) {
|
||||
return
|
||||
// node must not be nil. If fn returns true, Walk invokes fn recursively for
|
||||
// each of the non-nil children of node, followed by a call of fn(nil). The
|
||||
// returned node of fn can be used to rewrite the passed node to fn.
|
||||
func Walk(node Node, fn WalkFunc) Node {
|
||||
rewritten, ok := fn(node)
|
||||
if !ok {
|
||||
return rewritten
|
||||
}
|
||||
|
||||
switch n := node.(type) {
|
||||
case *File:
|
||||
Walk(n.Node, fn)
|
||||
n.Node = Walk(n.Node, fn)
|
||||
case *ObjectList:
|
||||
for _, item := range n.Items {
|
||||
Walk(item, fn)
|
||||
for i, item := range n.Items {
|
||||
n.Items[i] = Walk(item, fn).(*ObjectItem)
|
||||
}
|
||||
case *ObjectKey:
|
||||
// nothing to do
|
||||
case *ObjectItem:
|
||||
for _, k := range n.Keys {
|
||||
Walk(k, fn)
|
||||
for i, k := range n.Keys {
|
||||
n.Keys[i] = Walk(k, fn).(*ObjectKey)
|
||||
}
|
||||
|
||||
if n.Val != nil {
|
||||
n.Val = Walk(n.Val, fn)
|
||||
}
|
||||
Walk(n.Val, fn)
|
||||
case *LiteralType:
|
||||
// nothing to do
|
||||
case *ListType:
|
||||
for _, l := range n.List {
|
||||
Walk(l, fn)
|
||||
for i, l := range n.List {
|
||||
n.List[i] = Walk(l, fn)
|
||||
}
|
||||
case *ObjectType:
|
||||
Walk(n.List, fn)
|
||||
n.List = Walk(n.List, fn).(*ObjectList)
|
||||
default:
|
||||
fmt.Printf(" unknown type: %T\n", n)
|
||||
// should we panic here?
|
||||
fmt.Printf("unknown type: %T\n", n)
|
||||
}
|
||||
|
||||
fn(nil)
|
||||
return rewritten
|
||||
}
|
||||
|
@ -42,13 +42,13 @@ func (b ByPosition) Less(i, j int) bool { return b[i].Pos().Before(b[j].Pos()) }
|
||||
func (p *printer) collectComments(node ast.Node) {
|
||||
// first collect all comments. This is already stored in
|
||||
// ast.File.(comments)
|
||||
ast.Walk(node, func(nn ast.Node) bool {
|
||||
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
|
||||
switch t := nn.(type) {
|
||||
case *ast.File:
|
||||
p.comments = t.Comments
|
||||
return false
|
||||
return nn, false
|
||||
}
|
||||
return true
|
||||
return nn, true
|
||||
})
|
||||
|
||||
standaloneComments := make(map[token.Pos]*ast.CommentGroup, 0)
|
||||
@ -59,7 +59,7 @@ func (p *printer) collectComments(node ast.Node) {
|
||||
// next remove all lead and line comments from the overall comment map.
|
||||
// This will give us comments which are standalone, comments which are not
|
||||
// assigned to any kind of node.
|
||||
ast.Walk(node, func(nn ast.Node) bool {
|
||||
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
|
||||
switch t := nn.(type) {
|
||||
case *ast.LiteralType:
|
||||
if t.LineComment != nil {
|
||||
@ -87,7 +87,7 @@ func (p *printer) collectComments(node ast.Node) {
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return nn, true
|
||||
})
|
||||
|
||||
for _, c := range standaloneComments {
|
||||
|
@ -6,11 +6,11 @@ import (
|
||||
|
||||
// flattenObjects takes an AST node, walks it, and flattens
|
||||
func flattenObjects(node ast.Node) {
|
||||
ast.Walk(node, func(n ast.Node) bool {
|
||||
ast.Walk(node, func(n ast.Node) (ast.Node, bool) {
|
||||
// We only care about lists, because this is what we modify
|
||||
list, ok := n.(*ast.ObjectList)
|
||||
if !ok {
|
||||
return true
|
||||
return n, true
|
||||
}
|
||||
|
||||
// Rebuild the item list
|
||||
@ -41,7 +41,7 @@ func flattenObjects(node ast.Node) {
|
||||
|
||||
// Done! Set the original items
|
||||
list.Items = items
|
||||
return true
|
||||
return n, true
|
||||
})
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user