diff --git a/hclwrite/ast_body.go b/hclwrite/ast_body.go index a843dc1..c16d13e 100644 --- a/hclwrite/ast_body.go +++ b/hclwrite/ast_body.go @@ -84,6 +84,22 @@ func (b *Body) GetAttribute(name string) *Attribute { return nil } +// getAttributeNode is like GetAttribute but it returns the node containing +// the selected attribute (if one is found) rather than the attribute itself. +func (b *Body) getAttributeNode(name string) *node { + for n := range b.items { + if attr, isAttr := n.content.(*Attribute); isAttr { + nameObj := attr.name.content.(*identifier) + if nameObj.hasName(name) { + // We've found it! + return n + } + } + } + + return nil +} + // FirstMatchingBlock returns a first matching block from the body that has the // given name and labels or returns nil if there is currently no matching // block. @@ -91,8 +107,10 @@ func (b *Body) FirstMatchingBlock(typeName string, labels []string) *Block { for _, block := range b.Blocks() { if typeName == block.Type() { labelNames := block.Labels() + if len(labels) == 0 && len(labelNames) == 0 { + return block + } if reflect.DeepEqual(labels, labelNames) { - // We've found it! return block } } @@ -101,6 +119,21 @@ func (b *Body) FirstMatchingBlock(typeName string, labels []string) *Block { return nil } +// RemoveBlock removes the given block from the body, if it's in that body. +// If it isn't present, this is a no-op. +// +// Returns true if it removed something, or false otherwise. +func (b *Body) RemoveBlock(block *Block) bool { + for n := range b.items { + if n.content == block { + n.Detach() + b.items.Remove(n) + return true + } + } + return false +} + // SetAttributeValue either replaces the expression of an existing attribute // of the given name or adds a new attribute definition to the end of the block. // @@ -143,6 +176,20 @@ func (b *Body) SetAttributeTraversal(name string, traversal hcl.Traversal) *Attr return attr } +// RemoveAttribute removes the attribute with the given name from the body. +// +// The return value is the attribute that was removed, or nil if there was +// no such attribute (in which case the call was a no-op). +func (b *Body) RemoveAttribute(name string) *Attribute { + node := b.getAttributeNode(name) + if node == nil { + return nil + } + node.Detach() + b.items.Remove(node) + return node.content.(*Attribute) +} + // AppendBlock appends an existing block (which must not be already attached // to a body) to the end of the receiving body. func (b *Body) AppendBlock(block *Block) *Block { diff --git a/hclwrite/ast_body_test.go b/hclwrite/ast_body_test.go index e3319b5..d6ff789 100644 --- a/hclwrite/ast_body_test.go +++ b/hclwrite/ast_body_test.go @@ -869,6 +869,119 @@ func TestBodySetAttributeValueInNestedBlock(t *testing.T) { } } +func TestBodyRemoveAttribute(t *testing.T) { + tests := []struct { + src string + name string + want Tokens + }{ + { + "", + "a", + Tokens{ + { + Type: hclsyntax.TokenEOF, + Bytes: []byte{}, + SpacesBefore: 0, + }, + }, + }, + { + "b = false\n", + "a", + Tokens{ + { + Type: hclsyntax.TokenIdent, + Bytes: []byte{'b'}, + SpacesBefore: 0, + }, + { + Type: hclsyntax.TokenEqual, + Bytes: []byte{'='}, + SpacesBefore: 1, + }, + { + Type: hclsyntax.TokenIdent, + Bytes: []byte("false"), + SpacesBefore: 1, + }, + { + Type: hclsyntax.TokenNewline, + Bytes: []byte{'\n'}, + SpacesBefore: 0, + }, + { + Type: hclsyntax.TokenEOF, + Bytes: []byte{}, + SpacesBefore: 0, + }, + }, + }, + { + "a = false\n", + "a", + Tokens{ + { + Type: hclsyntax.TokenEOF, + Bytes: []byte{}, + SpacesBefore: 0, + }, + }, + }, + { + "a = 1\nb = false\n", + "a", + Tokens{ + { + Type: hclsyntax.TokenIdent, + Bytes: []byte{'b'}, + SpacesBefore: 0, + }, + { + Type: hclsyntax.TokenEqual, + Bytes: []byte{'='}, + SpacesBefore: 1, + }, + { + Type: hclsyntax.TokenIdent, + Bytes: []byte("false"), + SpacesBefore: 1, + }, + { + Type: hclsyntax.TokenNewline, + Bytes: []byte{'\n'}, + SpacesBefore: 0, + }, + { + Type: hclsyntax.TokenEOF, + Bytes: []byte{}, + SpacesBefore: 0, + }, + }, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s in %s", test.name, test.src), func(t *testing.T) { + f, diags := ParseConfig([]byte(test.src), "", hcl.Pos{Line: 1, Column: 1}) + if len(diags) != 0 { + for _, diag := range diags { + t.Logf("- %s", diag.Error()) + } + t.Fatalf("unexpected diagnostics") + } + + f.Body().RemoveAttribute(test.name) + got := f.BuildTokens(nil) + format(got) + if !reflect.DeepEqual(got, test.want) { + diff := cmp.Diff(test.want, got) + t.Errorf("wrong result\ngot: %s\nwant: %s\ndiff:\n%s", spew.Sdump(got), spew.Sdump(test.want), diff) + } + }) + } +} + func TestBodyAppendBlock(t *testing.T) { tests := []struct { src string @@ -1111,3 +1224,202 @@ func TestBodyAppendBlock(t *testing.T) { }) } } + +func TestBodyRemoveBlock(t *testing.T) { + src := strings.TrimSpace(` +a = 1 + +# Foo +foo { + b = 1 +} +foo { + b = 2 +} +bar {} +`) + f, diags := ParseConfig([]byte(src), "", hcl.Pos{Line: 1, Column: 1}) + if len(diags) != 0 { + for _, diag := range diags { + t.Logf("- %s", diag.Error()) + } + t.Fatalf("unexpected diagnostics") + } + + t.Logf("Removing the first block") + t.Logf("initial content:\n%s", f.Bytes()) + body := f.Body() + block := body.FirstMatchingBlock("foo", nil) + if block == nil { + t.Fatalf("didn't find a 'foo' block") + } + removed := body.RemoveBlock(block) + if !removed { + t.Fatalf("didn't remove first block") + } + t.Logf("updated content:\n%s", f.Bytes()) + got := f.BuildTokens(nil) + want := Tokens{ + 0: { + Type: hclsyntax.TokenIdent, + Bytes: []byte(`a`), + SpacesBefore: 0, + }, + 1: { + Type: hclsyntax.TokenEqual, + Bytes: []byte(`=`), + SpacesBefore: 1, + }, + 2: { + Type: hclsyntax.TokenNumberLit, + Bytes: []byte(`1`), + SpacesBefore: 1, + }, + 3: { + Type: hclsyntax.TokenNewline, + Bytes: []byte("\n"), + SpacesBefore: 0, + }, + 4: { + Type: hclsyntax.TokenNewline, + Bytes: []byte("\n"), + SpacesBefore: 0, + }, + 5: { + Type: hclsyntax.TokenIdent, + Bytes: []byte(`foo`), + SpacesBefore: 0, + }, + 6: { + Type: hclsyntax.TokenOBrace, + Bytes: []byte(`{`), + SpacesBefore: 1, + }, + 7: { + Type: hclsyntax.TokenNewline, + Bytes: []byte("\n"), + SpacesBefore: 0, + }, + 8: { + Type: hclsyntax.TokenIdent, + Bytes: []byte(`b`), + SpacesBefore: 2, + }, + 9: { + Type: hclsyntax.TokenEqual, + Bytes: []byte(`=`), + SpacesBefore: 1, + }, + 10: { + Type: hclsyntax.TokenNumberLit, + Bytes: []byte(`2`), + SpacesBefore: 1, + }, + 11: { + Type: hclsyntax.TokenNewline, + Bytes: []byte("\n"), + SpacesBefore: 0, + }, + 12: { + Type: hclsyntax.TokenCBrace, + Bytes: []byte(`}`), + SpacesBefore: 0, + }, + 13: { + Type: hclsyntax.TokenNewline, + Bytes: []byte("\n"), + SpacesBefore: 0, + }, + 14: { + Type: hclsyntax.TokenIdent, + Bytes: []byte(`bar`), + SpacesBefore: 0, + }, + 15: { + Type: hclsyntax.TokenOBrace, + Bytes: []byte(`{`), + SpacesBefore: 1, + }, + 16: { + Type: hclsyntax.TokenCBrace, + Bytes: []byte(`}`), + SpacesBefore: 0, + }, + 17: { + Type: hclsyntax.TokenEOF, + Bytes: []byte(""), + SpacesBefore: 0, + }, + } + format(got) + if !reflect.DeepEqual(got, want) { + diff := cmp.Diff(want, got) + t.Errorf("wrong result\ngot: %s\nwant: %s\ndiff:\n%s", spew.Sdump(got), spew.Sdump(want), diff) + } + + t.Logf("removing the second block") + t.Logf("initial content:\n%s", f.Bytes()) + block = body.FirstMatchingBlock("foo", nil) + if block == nil { + t.Fatalf("didn't find a 'foo' block") + } + removed = body.RemoveBlock(block) + if !removed { + t.Fatalf("didn't remove second block") + } + t.Logf("updated content:\n%s", f.Bytes()) + got = f.BuildTokens(nil) + want = Tokens{ + 0: { + Type: hclsyntax.TokenIdent, + Bytes: []byte(`a`), + SpacesBefore: 0, + }, + 1: { + Type: hclsyntax.TokenEqual, + Bytes: []byte(`=`), + SpacesBefore: 1, + }, + 2: { + Type: hclsyntax.TokenNumberLit, + Bytes: []byte(`1`), + SpacesBefore: 1, + }, + 3: { + Type: hclsyntax.TokenNewline, + Bytes: []byte("\n"), + SpacesBefore: 0, + }, + 4: { + Type: hclsyntax.TokenNewline, + Bytes: []byte("\n"), + SpacesBefore: 0, + }, + 5: { + Type: hclsyntax.TokenIdent, + Bytes: []byte(`bar`), + SpacesBefore: 0, + }, + 6: { + Type: hclsyntax.TokenOBrace, + Bytes: []byte(`{`), + SpacesBefore: 1, + }, + 7: { + Type: hclsyntax.TokenCBrace, + Bytes: []byte(`}`), + SpacesBefore: 0, + }, + 8: { + Type: hclsyntax.TokenEOF, + Bytes: []byte(""), + SpacesBefore: 0, + }, + } + format(got) + if !reflect.DeepEqual(got, want) { + diff := cmp.Diff(want, got) + t.Errorf("wrong result\ngot: %s\nwant: %s\ndiff:\n%s", spew.Sdump(got), spew.Sdump(want), diff) + } + +} diff --git a/hclwrite/node.go b/hclwrite/node.go index 71fd00f..45669f7 100644 --- a/hclwrite/node.go +++ b/hclwrite/node.go @@ -140,6 +140,18 @@ func (ns *nodes) AppendUnstructuredTokens(tokens Tokens) *node { return n } +// FindNodeWithContent searches the nodes for a node whose content equals +// the given content. If it finds one then it returns it. Otherwise it returns +// nil. +func (ns *nodes) FindNodeWithContent(content nodeContent) *node { + for n := ns.first; n != nil; n = n.after { + if n.content == content { + return n + } + } + return nil +} + // nodeSet is an unordered set of nodes. It is used to describe a set of nodes // that all belong to the same list that have some role or characteristic // in common. @@ -192,6 +204,18 @@ func (ns nodeSet) List() []*node { return ret } +// FindNodeWithContent searches the nodes for a node whose content equals +// the given content. If it finds one then it returns it. Otherwise it returns +// nil. +func (ns nodeSet) FindNodeWithContent(content nodeContent) *node { + for n := range ns { + if n.content == content { + return n + } + } + return nil +} + type internalWalkFunc func(*node) // inTree can be embedded into a content struct that has child nodes to get