ast: change signature of Walk() to allow rewriting AST
With the previous Walk function it's not easy to rewrite the node as we don't have any kind of reference to the parent. If we want to rewrite a given AST, we have to manually traverse it as Walk is not usable. To allow us rewriting the AST we change the signature of the function passed to Walk. It'll allow us to rewrite the AST and return back. Internally Walk() overrides the returned AST. This idea was also talked here: https://groups.google.com/forum/#!topic/golang-nuts/cRZQV36IckM extensively.
This commit is contained in:
parent
8ec7833c13
commit
d45f5d133c
@ -2,6 +2,7 @@ package ast
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/hcl/hcl/token"
|
||||
@ -64,3 +65,136 @@ func TestObjectListFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalk(t *testing.T) {
|
||||
items := []*ObjectItem{
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
|
||||
},
|
||||
Val: &LiteralType{Token: token.Token{Type: token.STRING, Text: `"example"`}},
|
||||
},
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
node := &ObjectList{Items: items}
|
||||
|
||||
order := []string{
|
||||
"*ast.ObjectList",
|
||||
"*ast.ObjectItem",
|
||||
"*ast.ObjectKey",
|
||||
"*ast.ObjectKey",
|
||||
"*ast.LiteralType",
|
||||
"*ast.ObjectItem",
|
||||
"*ast.ObjectKey",
|
||||
}
|
||||
count := 0
|
||||
|
||||
Walk(node, func(n Node) (Node, bool) {
|
||||
if n == nil {
|
||||
return n, false
|
||||
}
|
||||
|
||||
typeName := reflect.TypeOf(n).String()
|
||||
if order[count] != typeName {
|
||||
t.Errorf("expected '%s' got: '%s'", order[count], typeName)
|
||||
}
|
||||
count++
|
||||
return n, true
|
||||
})
|
||||
}
|
||||
|
||||
func TestWalkEquality(t *testing.T) {
|
||||
items := []*ObjectItem{
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
|
||||
},
|
||||
},
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
node := &ObjectList{Items: items}
|
||||
|
||||
rewritten := Walk(node, func(n Node) (Node, bool) { return n, true })
|
||||
|
||||
newNode, ok := rewritten.(*ObjectList)
|
||||
if !ok {
|
||||
t.Fatalf("expected Objectlist, got %T", rewritten)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(node, newNode) {
|
||||
t.Fatal("rewritten node is not equal to the given node")
|
||||
}
|
||||
|
||||
if len(newNode.Items) != 2 {
|
||||
t.Error("expected newNode length 2, got: %d", len(newNode.Items))
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
`"foo"`,
|
||||
`"bar"`,
|
||||
}
|
||||
|
||||
for i, item := range newNode.Items {
|
||||
if len(item.Keys) != 1 {
|
||||
t.Error("expected keys newNode length 1, got: %d", len(item.Keys))
|
||||
}
|
||||
|
||||
if item.Keys[0].Token.Text != expected[i] {
|
||||
t.Errorf("expected key %s, got %s", expected[i], item.Keys[0].Token.Text)
|
||||
}
|
||||
|
||||
if item.Val != nil {
|
||||
t.Errorf("expected item value should be nil")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalkRewrite(t *testing.T) {
|
||||
items := []*ObjectItem{
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
|
||||
},
|
||||
},
|
||||
&ObjectItem{
|
||||
Keys: []*ObjectKey{
|
||||
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
node := &ObjectList{Items: items}
|
||||
|
||||
suffix := "_example"
|
||||
node = Walk(node, func(n Node) (Node, bool) {
|
||||
switch i := n.(type) {
|
||||
case *ObjectKey:
|
||||
i.Token.Text = i.Token.Text + suffix
|
||||
n = i
|
||||
}
|
||||
return n, true
|
||||
}).(*ObjectList)
|
||||
|
||||
Walk(node, func(n Node) (Node, bool) {
|
||||
switch i := n.(type) {
|
||||
case *ObjectKey:
|
||||
if !strings.HasSuffix(i.Token.Text, suffix) {
|
||||
t.Errorf("Token '%s' should have suffix: %s", i.Token.Text, suffix)
|
||||
}
|
||||
}
|
||||
return n, true
|
||||
})
|
||||
|
||||
}
|
||||
|
@ -2,39 +2,51 @@ package ast
|
||||
|
||||
import "fmt"
|
||||
|
||||
// WalkFunc describes a function to be called for each node during a Walk. The
|
||||
// returned node can be used to rewrite the AST. Walking stops the returned
|
||||
// bool is false.
|
||||
type WalkFunc func(Node) (Node, bool)
|
||||
|
||||
// Walk traverses an AST in depth-first order: It starts by calling fn(node);
|
||||
// node must not be nil. If f returns true, Walk invokes f recursively for
|
||||
// each of the non-nil children of node, followed by a call of f(nil).
|
||||
func Walk(node Node, fn func(Node) bool) {
|
||||
if !fn(node) {
|
||||
return
|
||||
// node must not be nil. If fn returns true, Walk invokes fn recursively for
|
||||
// each of the non-nil children of node, followed by a call of fn(nil). The
|
||||
// returned node of fn can be used to rewrite the passed node to fn.
|
||||
func Walk(node Node, fn WalkFunc) Node {
|
||||
rewritten, ok := fn(node)
|
||||
if !ok {
|
||||
return rewritten
|
||||
}
|
||||
|
||||
switch n := node.(type) {
|
||||
case *File:
|
||||
Walk(n.Node, fn)
|
||||
n.Node = Walk(n.Node, fn)
|
||||
case *ObjectList:
|
||||
for _, item := range n.Items {
|
||||
Walk(item, fn)
|
||||
for i, item := range n.Items {
|
||||
n.Items[i] = Walk(item, fn).(*ObjectItem)
|
||||
}
|
||||
case *ObjectKey:
|
||||
// nothing to do
|
||||
case *ObjectItem:
|
||||
for _, k := range n.Keys {
|
||||
Walk(k, fn)
|
||||
for i, k := range n.Keys {
|
||||
n.Keys[i] = Walk(k, fn).(*ObjectKey)
|
||||
}
|
||||
|
||||
if n.Val != nil {
|
||||
n.Val = Walk(n.Val, fn)
|
||||
}
|
||||
Walk(n.Val, fn)
|
||||
case *LiteralType:
|
||||
// nothing to do
|
||||
case *ListType:
|
||||
for _, l := range n.List {
|
||||
Walk(l, fn)
|
||||
for i, l := range n.List {
|
||||
n.List[i] = Walk(l, fn)
|
||||
}
|
||||
case *ObjectType:
|
||||
Walk(n.List, fn)
|
||||
n.List = Walk(n.List, fn).(*ObjectList)
|
||||
default:
|
||||
fmt.Printf(" unknown type: %T\n", n)
|
||||
// should we panic here?
|
||||
fmt.Printf("unknown type: %T\n", n)
|
||||
}
|
||||
|
||||
fn(nil)
|
||||
return rewritten
|
||||
}
|
||||
|
@ -42,13 +42,13 @@ func (b ByPosition) Less(i, j int) bool { return b[i].Pos().Before(b[j].Pos()) }
|
||||
func (p *printer) collectComments(node ast.Node) {
|
||||
// first collect all comments. This is already stored in
|
||||
// ast.File.(comments)
|
||||
ast.Walk(node, func(nn ast.Node) bool {
|
||||
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
|
||||
switch t := nn.(type) {
|
||||
case *ast.File:
|
||||
p.comments = t.Comments
|
||||
return false
|
||||
return nn, false
|
||||
}
|
||||
return true
|
||||
return nn, true
|
||||
})
|
||||
|
||||
standaloneComments := make(map[token.Pos]*ast.CommentGroup, 0)
|
||||
@ -59,7 +59,7 @@ func (p *printer) collectComments(node ast.Node) {
|
||||
// next remove all lead and line comments from the overall comment map.
|
||||
// This will give us comments which are standalone, comments which are not
|
||||
// assigned to any kind of node.
|
||||
ast.Walk(node, func(nn ast.Node) bool {
|
||||
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
|
||||
switch t := nn.(type) {
|
||||
case *ast.LiteralType:
|
||||
if t.LineComment != nil {
|
||||
@ -87,7 +87,7 @@ func (p *printer) collectComments(node ast.Node) {
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return nn, true
|
||||
})
|
||||
|
||||
for _, c := range standaloneComments {
|
||||
|
@ -6,11 +6,11 @@ import (
|
||||
|
||||
// flattenObjects takes an AST node, walks it, and flattens
|
||||
func flattenObjects(node ast.Node) {
|
||||
ast.Walk(node, func(n ast.Node) bool {
|
||||
ast.Walk(node, func(n ast.Node) (ast.Node, bool) {
|
||||
// We only care about lists, because this is what we modify
|
||||
list, ok := n.(*ast.ObjectList)
|
||||
if !ok {
|
||||
return true
|
||||
return n, true
|
||||
}
|
||||
|
||||
// Rebuild the item list
|
||||
@ -41,7 +41,7 @@ func flattenObjects(node ast.Node) {
|
||||
|
||||
// Done! Set the original items
|
||||
list.Items = items
|
||||
return true
|
||||
return n, true
|
||||
})
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user