ast: change signature of Walk() to allow rewriting AST

With the previous Walk function it's not easy to rewrite the node as we
don't have any kind of reference to the parent. If we want to rewrite a
given AST, we have to manually traverse it as Walk is not usable. To
allow us rewriting the AST we change the signature of the function
passed to Walk. It'll allow us to rewrite the AST and return back.
Internally Walk() overrides the returned AST.

This idea was also talked here:
https://groups.google.com/forum/#!topic/golang-nuts/cRZQV36IckM
extensively.
This commit is contained in:
Fatih Arslan 2015-11-13 14:14:22 +02:00
parent 8ec7833c13
commit d45f5d133c
4 changed files with 169 additions and 23 deletions

View File

@ -2,6 +2,7 @@ package ast
import ( import (
"reflect" "reflect"
"strings"
"testing" "testing"
"github.com/hashicorp/hcl/hcl/token" "github.com/hashicorp/hcl/hcl/token"
@ -64,3 +65,136 @@ func TestObjectListFilter(t *testing.T) {
} }
} }
} }
func TestWalk(t *testing.T) {
items := []*ObjectItem{
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
},
Val: &LiteralType{Token: token.Token{Type: token.STRING, Text: `"example"`}},
},
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
},
},
}
node := &ObjectList{Items: items}
order := []string{
"*ast.ObjectList",
"*ast.ObjectItem",
"*ast.ObjectKey",
"*ast.ObjectKey",
"*ast.LiteralType",
"*ast.ObjectItem",
"*ast.ObjectKey",
}
count := 0
Walk(node, func(n Node) (Node, bool) {
if n == nil {
return n, false
}
typeName := reflect.TypeOf(n).String()
if order[count] != typeName {
t.Errorf("expected '%s' got: '%s'", order[count], typeName)
}
count++
return n, true
})
}
func TestWalkEquality(t *testing.T) {
items := []*ObjectItem{
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
},
},
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
},
},
}
node := &ObjectList{Items: items}
rewritten := Walk(node, func(n Node) (Node, bool) { return n, true })
newNode, ok := rewritten.(*ObjectList)
if !ok {
t.Fatalf("expected Objectlist, got %T", rewritten)
}
if !reflect.DeepEqual(node, newNode) {
t.Fatal("rewritten node is not equal to the given node")
}
if len(newNode.Items) != 2 {
t.Error("expected newNode length 2, got: %d", len(newNode.Items))
}
expected := []string{
`"foo"`,
`"bar"`,
}
for i, item := range newNode.Items {
if len(item.Keys) != 1 {
t.Error("expected keys newNode length 1, got: %d", len(item.Keys))
}
if item.Keys[0].Token.Text != expected[i] {
t.Errorf("expected key %s, got %s", expected[i], item.Keys[0].Token.Text)
}
if item.Val != nil {
t.Errorf("expected item value should be nil")
}
}
}
func TestWalkRewrite(t *testing.T) {
items := []*ObjectItem{
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
},
},
&ObjectItem{
Keys: []*ObjectKey{
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
},
},
}
node := &ObjectList{Items: items}
suffix := "_example"
node = Walk(node, func(n Node) (Node, bool) {
switch i := n.(type) {
case *ObjectKey:
i.Token.Text = i.Token.Text + suffix
n = i
}
return n, true
}).(*ObjectList)
Walk(node, func(n Node) (Node, bool) {
switch i := n.(type) {
case *ObjectKey:
if !strings.HasSuffix(i.Token.Text, suffix) {
t.Errorf("Token '%s' should have suffix: %s", i.Token.Text, suffix)
}
}
return n, true
})
}

View File

@ -2,39 +2,51 @@ package ast
import "fmt" import "fmt"
// WalkFunc describes a function to be called for each node during a Walk. The
// returned node can be used to rewrite the AST. Walking stops the returned
// bool is false.
type WalkFunc func(Node) (Node, bool)
// Walk traverses an AST in depth-first order: It starts by calling fn(node); // Walk traverses an AST in depth-first order: It starts by calling fn(node);
// node must not be nil. If f returns true, Walk invokes f recursively for // node must not be nil. If fn returns true, Walk invokes fn recursively for
// each of the non-nil children of node, followed by a call of f(nil). // each of the non-nil children of node, followed by a call of fn(nil). The
func Walk(node Node, fn func(Node) bool) { // returned node of fn can be used to rewrite the passed node to fn.
if !fn(node) { func Walk(node Node, fn WalkFunc) Node {
return rewritten, ok := fn(node)
if !ok {
return rewritten
} }
switch n := node.(type) { switch n := node.(type) {
case *File: case *File:
Walk(n.Node, fn) n.Node = Walk(n.Node, fn)
case *ObjectList: case *ObjectList:
for _, item := range n.Items { for i, item := range n.Items {
Walk(item, fn) n.Items[i] = Walk(item, fn).(*ObjectItem)
} }
case *ObjectKey: case *ObjectKey:
// nothing to do // nothing to do
case *ObjectItem: case *ObjectItem:
for _, k := range n.Keys { for i, k := range n.Keys {
Walk(k, fn) n.Keys[i] = Walk(k, fn).(*ObjectKey)
}
if n.Val != nil {
n.Val = Walk(n.Val, fn)
} }
Walk(n.Val, fn)
case *LiteralType: case *LiteralType:
// nothing to do // nothing to do
case *ListType: case *ListType:
for _, l := range n.List { for i, l := range n.List {
Walk(l, fn) n.List[i] = Walk(l, fn)
} }
case *ObjectType: case *ObjectType:
Walk(n.List, fn) n.List = Walk(n.List, fn).(*ObjectList)
default: default:
fmt.Printf(" unknown type: %T\n", n) // should we panic here?
fmt.Printf("unknown type: %T\n", n)
} }
fn(nil) fn(nil)
return rewritten
} }

View File

@ -42,13 +42,13 @@ func (b ByPosition) Less(i, j int) bool { return b[i].Pos().Before(b[j].Pos()) }
func (p *printer) collectComments(node ast.Node) { func (p *printer) collectComments(node ast.Node) {
// first collect all comments. This is already stored in // first collect all comments. This is already stored in
// ast.File.(comments) // ast.File.(comments)
ast.Walk(node, func(nn ast.Node) bool { ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
switch t := nn.(type) { switch t := nn.(type) {
case *ast.File: case *ast.File:
p.comments = t.Comments p.comments = t.Comments
return false return nn, false
} }
return true return nn, true
}) })
standaloneComments := make(map[token.Pos]*ast.CommentGroup, 0) standaloneComments := make(map[token.Pos]*ast.CommentGroup, 0)
@ -59,7 +59,7 @@ func (p *printer) collectComments(node ast.Node) {
// next remove all lead and line comments from the overall comment map. // next remove all lead and line comments from the overall comment map.
// This will give us comments which are standalone, comments which are not // This will give us comments which are standalone, comments which are not
// assigned to any kind of node. // assigned to any kind of node.
ast.Walk(node, func(nn ast.Node) bool { ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
switch t := nn.(type) { switch t := nn.(type) {
case *ast.LiteralType: case *ast.LiteralType:
if t.LineComment != nil { if t.LineComment != nil {
@ -87,7 +87,7 @@ func (p *printer) collectComments(node ast.Node) {
} }
} }
return true return nn, true
}) })
for _, c := range standaloneComments { for _, c := range standaloneComments {

View File

@ -6,11 +6,11 @@ import (
// flattenObjects takes an AST node, walks it, and flattens // flattenObjects takes an AST node, walks it, and flattens
func flattenObjects(node ast.Node) { func flattenObjects(node ast.Node) {
ast.Walk(node, func(n ast.Node) bool { ast.Walk(node, func(n ast.Node) (ast.Node, bool) {
// We only care about lists, because this is what we modify // We only care about lists, because this is what we modify
list, ok := n.(*ast.ObjectList) list, ok := n.(*ast.ObjectList)
if !ok { if !ok {
return true return n, true
} }
// Rebuild the item list // Rebuild the item list
@ -41,7 +41,7 @@ func flattenObjects(node ast.Node) {
// Done! Set the original items // Done! Set the original items
list.Items = items list.Items = items
return true return n, true
}) })
} }