case insensitive Get

This commit is contained in:
Mitchell Hashimoto 2014-08-03 17:57:54 -07:00
parent 8c0a6c555f
commit a095771be9
3 changed files with 17 additions and 30 deletions

View File

@ -1,5 +1,9 @@
package ast package ast
import (
"strings"
)
// ValueType is an enum represnting the type of a value in // ValueType is an enum represnting the type of a value in
// a LiteralNode. // a LiteralNode.
type ValueType byte type ValueType byte
@ -69,11 +73,13 @@ func (n ObjectNode) Accept(v Visitor) {
// Get returns all the elements of this object with the given key. // Get returns all the elements of this object with the given key.
// This is a case-sensitive search. // This is a case-sensitive search.
func (n ObjectNode) Get(k string) []Node { func (n ObjectNode) Get(k string, insensitive bool) []Node {
result := make([]Node, 0, 1) result := make([]Node, 0, 1)
for _, elem := range n.Elem { for _, elem := range n.Elem {
if elem.Key() != k { if elem.Key() != k {
continue if !insensitive || !strings.EqualFold(elem.Key(), k) {
continue
}
} }
switch n := elem.(type) { switch n := elem.(type) {

View File

@ -86,7 +86,7 @@ func TestObjectNodeGet(t *testing.T) {
LiteralNode{Value: "baz"}, LiteralNode{Value: "baz"},
} }
actual := n.Get("foo") actual := n.Get("foo", false)
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad: %#v", actual) t.Fatalf("bad: %#v", actual)

View File

@ -419,39 +419,20 @@ func (d *decoder) decodeStruct(name string, raw ast.Node, result reflect.Value)
} }
// Find the element matching this name // Find the element matching this name
var elem ast.Node elems := obj.Get(fieldName, true)
elemKey := fieldName if len(elems) == 0 {
if elems := obj.Get(fieldName); len(elems) > 0 { continue
elem = elems[len(elems)-1]
} else {
// Do a slower search by iterating over each key and
// doing case-insensitive search.
for _, v := range obj.Elem {
if strings.EqualFold(v.Key(), fieldName) {
elem = v
elemKey = v.Key()
break
}
}
if elem == nil {
// No key matching this field.
continue
}
}
// Make sure we get the value of the element
if an, ok := elem.(ast.AssignmentNode); ok {
elem = an.Value
} }
// Track the used key // Track the used key
usedKeys[elemKey] = struct{}{} usedKeys[fieldName] = struct{}{}
// Create the field name and decode // Create the field name and decode
fieldName = fmt.Sprintf("%s.%s", name, fieldName) fieldName = fmt.Sprintf("%s.%s", name, fieldName)
if err := d.decode(fieldName, elem, field); err != nil { for _, elem := range elems {
return err if err := d.decode(fieldName, elem, field); err != nil {
return err
}
} }
decodedFields = append(decodedFields, fieldType.Name) decodedFields = append(decodedFields, fieldType.Name)