hcl/ast: Get should return a list

This commit is contained in:
Mitchell Hashimoto 2015-11-07 09:52:50 -08:00
parent 8754ac7343
commit 32411ba6af
2 changed files with 23 additions and 19 deletions

View File

@ -533,20 +533,11 @@ func (d *decoder) decodeStruct(name string, node ast.Node, result reflect.Value)
// Determine the element we'll use to decode. If it is a single // Determine the element we'll use to decode. If it is a single
// match (only object with the field), then we decode it exactly. // match (only object with the field), then we decode it exactly.
// If it is a prefix match, then we decode the matches. // If it is a prefix match, then we decode the matches.
var decodeNode ast.Node prefixMatches := list.Prefix(fieldName)
matches := list.Prefix(fieldName) matches := list.Get(fieldName)
if len(matches.Items) == 0 { if len(matches.Items) == 0 && len(prefixMatches.Items) == 0 {
single := list.Get(fieldName)
if single == nil {
continue continue
} }
decodeNode = single.Val
if ot, ok := decodeNode.(*ast.ObjectType); ok {
decodeNode = &ast.ObjectList{Items: ot.List.Items}
}
} else {
decodeNode = matches
}
// Track the used key // Track the used key
usedKeys[fieldName] = struct{}{} usedKeys[fieldName] = struct{}{}
@ -554,9 +545,21 @@ func (d *decoder) decodeStruct(name string, node ast.Node, result reflect.Value)
// Create the field name and decode. We range over the elements // Create the field name and decode. We range over the elements
// because we actually want the value. // because we actually want the value.
fieldName = fmt.Sprintf("%s.%s", name, fieldName) fieldName = fmt.Sprintf("%s.%s", name, fieldName)
if len(prefixMatches.Items) > 0 {
if err := d.decode(fieldName, prefixMatches, field); err != nil {
return err
}
}
for _, match := range matches.Items {
var decodeNode ast.Node = match.Val
if ot, ok := decodeNode.(*ast.ObjectType); ok {
decodeNode = &ast.ObjectList{Items: ot.List.Items}
}
if err := d.decode(fieldName, decodeNode, field); err != nil { if err := d.decode(fieldName, decodeNode, field); err != nil {
return err return err
} }
}
decodedFields = append(decodedFields, fieldType.Name) decodedFields = append(decodedFields, fieldType.Name)
} }

View File

@ -45,19 +45,20 @@ func (o *ObjectList) Add(item *ObjectItem) {
o.Items = append(o.Items, item) o.Items = append(o.Items, item)
} }
func (o *ObjectList) Get(key string) *ObjectItem { func (o *ObjectList) Get(key string) *ObjectList {
var result ObjectList
for _, item := range o.Items { for _, item := range o.Items {
if len(item.Keys) == 0 { if len(item.Keys) != 1 {
continue continue
} }
text := item.Keys[0].Token.Text text := item.Keys[0].Token.Text
if text == key || strings.EqualFold(text, key) { if text == key || strings.EqualFold(text, key) {
return item result.Add(item)
} }
} }
return nil return &result
} }
func (o *ObjectList) Prefix(keys ...string) *ObjectList { func (o *ObjectList) Prefix(keys ...string) *ObjectList {