Merge pull request #53 from hashicorp/add-rewrite
ast: add Rewrite() to rewrite AST
This commit is contained in:
commit
fbd0456768
@ -2,6 +2,7 @@ package ast
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/hashicorp/hcl/hcl/token"
|
"github.com/hashicorp/hcl/hcl/token"
|
||||||
@ -64,3 +65,136 @@ func TestObjectListFilter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWalk(t *testing.T) {
|
||||||
|
items := []*ObjectItem{
|
||||||
|
&ObjectItem{
|
||||||
|
Keys: []*ObjectKey{
|
||||||
|
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
|
||||||
|
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
|
||||||
|
},
|
||||||
|
Val: &LiteralType{Token: token.Token{Type: token.STRING, Text: `"example"`}},
|
||||||
|
},
|
||||||
|
&ObjectItem{
|
||||||
|
Keys: []*ObjectKey{
|
||||||
|
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
node := &ObjectList{Items: items}
|
||||||
|
|
||||||
|
order := []string{
|
||||||
|
"*ast.ObjectList",
|
||||||
|
"*ast.ObjectItem",
|
||||||
|
"*ast.ObjectKey",
|
||||||
|
"*ast.ObjectKey",
|
||||||
|
"*ast.LiteralType",
|
||||||
|
"*ast.ObjectItem",
|
||||||
|
"*ast.ObjectKey",
|
||||||
|
}
|
||||||
|
count := 0
|
||||||
|
|
||||||
|
Walk(node, func(n Node) (Node, bool) {
|
||||||
|
if n == nil {
|
||||||
|
return n, false
|
||||||
|
}
|
||||||
|
|
||||||
|
typeName := reflect.TypeOf(n).String()
|
||||||
|
if order[count] != typeName {
|
||||||
|
t.Errorf("expected '%s' got: '%s'", order[count], typeName)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
return n, true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalkEquality(t *testing.T) {
|
||||||
|
items := []*ObjectItem{
|
||||||
|
&ObjectItem{
|
||||||
|
Keys: []*ObjectKey{
|
||||||
|
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&ObjectItem{
|
||||||
|
Keys: []*ObjectKey{
|
||||||
|
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
node := &ObjectList{Items: items}
|
||||||
|
|
||||||
|
rewritten := Walk(node, func(n Node) (Node, bool) { return n, true })
|
||||||
|
|
||||||
|
newNode, ok := rewritten.(*ObjectList)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected Objectlist, got %T", rewritten)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(node, newNode) {
|
||||||
|
t.Fatal("rewritten node is not equal to the given node")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(newNode.Items) != 2 {
|
||||||
|
t.Error("expected newNode length 2, got: %d", len(newNode.Items))
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{
|
||||||
|
`"foo"`,
|
||||||
|
`"bar"`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, item := range newNode.Items {
|
||||||
|
if len(item.Keys) != 1 {
|
||||||
|
t.Error("expected keys newNode length 1, got: %d", len(item.Keys))
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.Keys[0].Token.Text != expected[i] {
|
||||||
|
t.Errorf("expected key %s, got %s", expected[i], item.Keys[0].Token.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.Val != nil {
|
||||||
|
t.Errorf("expected item value should be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalkRewrite(t *testing.T) {
|
||||||
|
items := []*ObjectItem{
|
||||||
|
&ObjectItem{
|
||||||
|
Keys: []*ObjectKey{
|
||||||
|
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"foo"`}},
|
||||||
|
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"bar"`}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&ObjectItem{
|
||||||
|
Keys: []*ObjectKey{
|
||||||
|
&ObjectKey{Token: token.Token{Type: token.STRING, Text: `"baz"`}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
node := &ObjectList{Items: items}
|
||||||
|
|
||||||
|
suffix := "_example"
|
||||||
|
node = Walk(node, func(n Node) (Node, bool) {
|
||||||
|
switch i := n.(type) {
|
||||||
|
case *ObjectKey:
|
||||||
|
i.Token.Text = i.Token.Text + suffix
|
||||||
|
n = i
|
||||||
|
}
|
||||||
|
return n, true
|
||||||
|
}).(*ObjectList)
|
||||||
|
|
||||||
|
Walk(node, func(n Node) (Node, bool) {
|
||||||
|
switch i := n.(type) {
|
||||||
|
case *ObjectKey:
|
||||||
|
if !strings.HasSuffix(i.Token.Text, suffix) {
|
||||||
|
t.Errorf("Token '%s' should have suffix: %s", i.Token.Text, suffix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, true
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -2,39 +2,51 @@ package ast
|
|||||||
|
|
||||||
import "fmt"
|
import "fmt"
|
||||||
|
|
||||||
|
// WalkFunc describes a function to be called for each node during a Walk. The
|
||||||
|
// returned node can be used to rewrite the AST. Walking stops the returned
|
||||||
|
// bool is false.
|
||||||
|
type WalkFunc func(Node) (Node, bool)
|
||||||
|
|
||||||
// Walk traverses an AST in depth-first order: It starts by calling fn(node);
|
// Walk traverses an AST in depth-first order: It starts by calling fn(node);
|
||||||
// node must not be nil. If f returns true, Walk invokes f recursively for
|
// node must not be nil. If fn returns true, Walk invokes fn recursively for
|
||||||
// each of the non-nil children of node, followed by a call of f(nil).
|
// each of the non-nil children of node, followed by a call of fn(nil). The
|
||||||
func Walk(node Node, fn func(Node) bool) {
|
// returned node of fn can be used to rewrite the passed node to fn.
|
||||||
if !fn(node) {
|
func Walk(node Node, fn WalkFunc) Node {
|
||||||
return
|
rewritten, ok := fn(node)
|
||||||
|
if !ok {
|
||||||
|
return rewritten
|
||||||
}
|
}
|
||||||
|
|
||||||
switch n := node.(type) {
|
switch n := node.(type) {
|
||||||
case *File:
|
case *File:
|
||||||
Walk(n.Node, fn)
|
n.Node = Walk(n.Node, fn)
|
||||||
case *ObjectList:
|
case *ObjectList:
|
||||||
for _, item := range n.Items {
|
for i, item := range n.Items {
|
||||||
Walk(item, fn)
|
n.Items[i] = Walk(item, fn).(*ObjectItem)
|
||||||
}
|
}
|
||||||
case *ObjectKey:
|
case *ObjectKey:
|
||||||
// nothing to do
|
// nothing to do
|
||||||
case *ObjectItem:
|
case *ObjectItem:
|
||||||
for _, k := range n.Keys {
|
for i, k := range n.Keys {
|
||||||
Walk(k, fn)
|
n.Keys[i] = Walk(k, fn).(*ObjectKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.Val != nil {
|
||||||
|
n.Val = Walk(n.Val, fn)
|
||||||
}
|
}
|
||||||
Walk(n.Val, fn)
|
|
||||||
case *LiteralType:
|
case *LiteralType:
|
||||||
// nothing to do
|
// nothing to do
|
||||||
case *ListType:
|
case *ListType:
|
||||||
for _, l := range n.List {
|
for i, l := range n.List {
|
||||||
Walk(l, fn)
|
n.List[i] = Walk(l, fn)
|
||||||
}
|
}
|
||||||
case *ObjectType:
|
case *ObjectType:
|
||||||
Walk(n.List, fn)
|
n.List = Walk(n.List, fn).(*ObjectList)
|
||||||
default:
|
default:
|
||||||
fmt.Printf(" unknown type: %T\n", n)
|
// should we panic here?
|
||||||
|
fmt.Printf("unknown type: %T\n", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(nil)
|
fn(nil)
|
||||||
|
return rewritten
|
||||||
}
|
}
|
||||||
|
@ -42,13 +42,13 @@ func (b ByPosition) Less(i, j int) bool { return b[i].Pos().Before(b[j].Pos()) }
|
|||||||
func (p *printer) collectComments(node ast.Node) {
|
func (p *printer) collectComments(node ast.Node) {
|
||||||
// first collect all comments. This is already stored in
|
// first collect all comments. This is already stored in
|
||||||
// ast.File.(comments)
|
// ast.File.(comments)
|
||||||
ast.Walk(node, func(nn ast.Node) bool {
|
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
|
||||||
switch t := nn.(type) {
|
switch t := nn.(type) {
|
||||||
case *ast.File:
|
case *ast.File:
|
||||||
p.comments = t.Comments
|
p.comments = t.Comments
|
||||||
return false
|
return nn, false
|
||||||
}
|
}
|
||||||
return true
|
return nn, true
|
||||||
})
|
})
|
||||||
|
|
||||||
standaloneComments := make(map[token.Pos]*ast.CommentGroup, 0)
|
standaloneComments := make(map[token.Pos]*ast.CommentGroup, 0)
|
||||||
@ -59,7 +59,7 @@ func (p *printer) collectComments(node ast.Node) {
|
|||||||
// next remove all lead and line comments from the overall comment map.
|
// next remove all lead and line comments from the overall comment map.
|
||||||
// This will give us comments which are standalone, comments which are not
|
// This will give us comments which are standalone, comments which are not
|
||||||
// assigned to any kind of node.
|
// assigned to any kind of node.
|
||||||
ast.Walk(node, func(nn ast.Node) bool {
|
ast.Walk(node, func(nn ast.Node) (ast.Node, bool) {
|
||||||
switch t := nn.(type) {
|
switch t := nn.(type) {
|
||||||
case *ast.LiteralType:
|
case *ast.LiteralType:
|
||||||
if t.LineComment != nil {
|
if t.LineComment != nil {
|
||||||
@ -87,7 +87,7 @@ func (p *printer) collectComments(node ast.Node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return nn, true
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, c := range standaloneComments {
|
for _, c := range standaloneComments {
|
||||||
|
@ -6,11 +6,11 @@ import (
|
|||||||
|
|
||||||
// flattenObjects takes an AST node, walks it, and flattens
|
// flattenObjects takes an AST node, walks it, and flattens
|
||||||
func flattenObjects(node ast.Node) {
|
func flattenObjects(node ast.Node) {
|
||||||
ast.Walk(node, func(n ast.Node) bool {
|
ast.Walk(node, func(n ast.Node) (ast.Node, bool) {
|
||||||
// We only care about lists, because this is what we modify
|
// We only care about lists, because this is what we modify
|
||||||
list, ok := n.(*ast.ObjectList)
|
list, ok := n.(*ast.ObjectList)
|
||||||
if !ok {
|
if !ok {
|
||||||
return true
|
return n, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rebuild the item list
|
// Rebuild the item list
|
||||||
@ -41,7 +41,7 @@ func flattenObjects(node ast.Node) {
|
|||||||
|
|
||||||
// Done! Set the original items
|
// Done! Set the original items
|
||||||
list.Items = items
|
list.Items = items
|
||||||
return true
|
return n, true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user