ast: change signature of Walk() to allow rewriting AST

With the previous Walk function it's not easy to rewrite the node as we
don't have any kind of reference to the parent. If we want to rewrite a
given AST, we have to manually traverse it as Walk is not usable. To
allow us rewriting the AST we change the signature of the function
passed to Walk. It'll allow us to rewrite the AST and return back.
Internally Walk() overrides the returned AST.

This idea was also talked here:
https://groups.google.com/forum/#!topic/golang-nuts/cRZQV36IckM
extensively.
This commit is contained in:
Fatih Arslan 2015-11-13 14:14:22 +02:00
parent 8ec7833c13
commit d45f5d133c
4 changed files with 169 additions and 23 deletions

View File

@ -2,6 +2,7 @@ package ast
import (
"reflect"
"strings"
"testing"
"github.com/hashicorp/hcl/hcl/token"
@ -64,3 +65,136 @@ func TestObjectListFilter(t *testing.T) {
}
}
}
func TestWalk(t *testing.T) {
items := []*ObjectItem{
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
},
Val: &LiteralType{Token: token.Token{Type: token.STRING, Text: `"example"`}},
},
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
},
},
}
node := &ObjectList{Items: items}
order := []string{
"*ast.ObjectList",
"*ast.ObjectItem",
"*ast.ObjectKey",
"*ast.ObjectKey",
"*ast.LiteralType",
"*ast.ObjectItem",
"*ast.ObjectKey",
}
count := 0
Walk(node, func(n Node) (Node, bool) {
if n == nil {
return n, false
}
typeName := reflect.TypeOf(n).String()
if order[count] != typeName {
t.Errorf("expected '%s' got: '%s'", order[count], typeName)
}
count++
return n, true
})
}
func TestWalkEquality(t *testing.T) {
items := []*ObjectItem{
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
},
},
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
},
},
}
node := &ObjectList{Items: items}
rewritten := Walk(node, func(n Node) (Node, bool) { return n, true })
newNode, ok := rewritten.(*ObjectList)
if !ok {
t.Fatalf("expected Objectlist, got %T", rewritten)
}
if !reflect.DeepEqual(node, newNode) {
t.Fatal("rewritten node is not equal to the given node")
}
if len(newNode.Items) != 2 {
t.Error("expected newNode length 2, got: %d", len(newNode.Items))
}
expected := []string{
`"foo"`,
`"bar"`,
}
for i, item := range newNode.Items {
if len(item.Keys) != 1 {
t.Error("expected keys newNode length 1, got: %d", len(item.Keys))
}
if item.Keys[0].Token.Text != expected[i] {
t.Errorf("expected key %s, got %s", expected[i], item.Keys[0].Token.Text)
}
if item.Val != nil {
t.Errorf("expected item value should be nil")
}
}
}
func TestWalkRewrite(t *testing.T) {
items := []*ObjectItem{
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
},
},
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
},
},
}
node := &ObjectList{Items: items}
suffix := "_example"
node = Walk(node, func(n Node) (Node, bool) {
switch i := n.(type) {
case *ObjectKey:
i.Token.Text = i.Token.Text + suffix
n = i
}
return n, true
}).(*ObjectList)
Walk(node, func(n Node) (Node, bool) {
switch i := n.(type) {
case *ObjectKey:
if !strings.HasSuffix(i.Token.Text, suffix) {
t.Errorf("Token '%s' should have suffix: %s", i.Token.Text, suffix)
}
}
return n, true
})
}

View File

@ -2,39 +2,51 @@ package ast
import "fmt"
// WalkFunc describes a function to be called for each node during a Walk. The
// returned node can be used to rewrite the AST. Walking stops the returned
// bool is false.
type WalkFunc func(Node) (Node, bool)
// Walk traverses an AST in depth-first order: It starts by calling fn(node);
// node must not be nil. If f returns true, Walk invokes f recursively for
// each of the non-nil children of node, followed by a call of f(nil).
func Walk(node Node, fn func(Node) bool) {
if !fn(node) {
return
// node must not be nil. If fn returns true, Walk invokes fn recursively for
// each of the non-nil children of node, followed by a call of fn(nil). The
// returned node of fn can be used to rewrite the passed node to fn.
func Walk(node Node, fn WalkFunc) Node {
rewritten, ok := fn(node)
if !ok {
return rewritten
}
switch n := node.(type) {
case *File:
Walk(n.Node, fn)
n.Node = Walk(n.Node, fn)
case *ObjectList:
for _, item := range n.Items {
Walk(item, fn)
for i, item := range n.Items {
n.Items[i] = Walk(item, fn).(*ObjectItem)
}
case *ObjectKey:
// nothing to do
case *ObjectItem:
for _, k := range n.Keys {
Walk(k, fn)
for i, k := range n.Keys {
n.Keys[i] = Walk(k, fn).(*ObjectKey)
}
if n.Val != nil {
n.Val = Walk(n.Val, fn)
}
Walk(n.Val, fn)
case *LiteralType:
// nothing to do
case *ListType:
for _, l := range n.List {
Walk(l, fn)
for i, l := range n.List {
n.List[i] = Walk(l, fn)
}
case *ObjectType:
Walk(n.List, fn)
n.List = Walk(n.List, fn).(*ObjectList)
default:
fmt.Printf(" unknown type: %T\n", n)
// should we panic here?
fmt.Printf("unknown type: %T\n", n)
}
fn(nil)
return rewritten
}

View File

@ -42,13 +42,13 @@ func (b ByPosition) Less(i, j int) bool { return b[i].Pos().Before(b[j].Pos()) }
func (p *printer) collectComments(node ast.Node) {
// first collect all comments. This is already stored in
// ast.File.(comments)
ast.Walk(node, func(nn ast.Node) bool {
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
switch t := nn.(type) {
case *ast.File:
p.comments = t.Comments
return false
return nn, false
}
return true
return nn, true
})
standaloneComments := make(map[token.Pos]*ast.CommentGroup, 0)
@ -59,7 +59,7 @@ func (p *printer) collectComments(node ast.Node) {
// next remove all lead and line comments from the overall comment map.
// This will give us comments which are standalone, comments which are not
// assigned to any kind of node.
ast.Walk(node, func(nn ast.Node) bool {
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
switch t := nn.(type) {
case *ast.LiteralType:
if t.LineComment != nil {
@ -87,7 +87,7 @@ func (p *printer) collectComments(node ast.Node) {
}
}
return true
return nn, true
})
for _, c := range standaloneComments {

View File

@ -6,11 +6,11 @@ import (
// flattenObjects takes an AST node, walks it, and flattens
func flattenObjects(node ast.Node) {
ast.Walk(node, func(n ast.Node) bool {
ast.Walk(node, func(n ast.Node) (ast.Node, bool) {
// We only care about lists, because this is what we modify
list, ok := n.(*ast.ObjectList)
if !ok {
return true
return n, true
}
// Rebuild the item list
@ -41,7 +41,7 @@ func flattenObjects(node ast.Node) {
// Done! Set the original items
list.Items = items
return true
return n, true
})
}