case insensitive Get
This commit is contained in:
parent
8c0a6c555f
commit
a095771be9
10
ast/ast.go
10
ast/ast.go
@ -1,5 +1,9 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValueType is an enum represnting the type of a value in
|
||||
// a LiteralNode.
|
||||
type ValueType byte
|
||||
@ -69,11 +73,13 @@ func (n ObjectNode) Accept(v Visitor) {
|
||||
|
||||
// Get returns all the elements of this object with the given key.
|
||||
// This is a case-sensitive search.
|
||||
func (n ObjectNode) Get(k string) []Node {
|
||||
func (n ObjectNode) Get(k string, insensitive bool) []Node {
|
||||
result := make([]Node, 0, 1)
|
||||
for _, elem := range n.Elem {
|
||||
if elem.Key() != k {
|
||||
continue
|
||||
if !insensitive || !strings.EqualFold(elem.Key(), k) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
switch n := elem.(type) {
|
||||
|
@ -86,7 +86,7 @@ func TestObjectNodeGet(t *testing.T) {
|
||||
LiteralNode{Value: "baz"},
|
||||
}
|
||||
|
||||
actual := n.Get("foo")
|
||||
actual := n.Get("foo", false)
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("bad: %#v", actual)
|
||||
|
35
decoder.go
35
decoder.go
@ -419,39 +419,20 @@ func (d *decoder) decodeStruct(name string, raw ast.Node, result reflect.Value)
|
||||
}
|
||||
|
||||
// Find the element matching this name
|
||||
var elem ast.Node
|
||||
elemKey := fieldName
|
||||
if elems := obj.Get(fieldName); len(elems) > 0 {
|
||||
elem = elems[len(elems)-1]
|
||||
} else {
|
||||
// Do a slower search by iterating over each key and
|
||||
// doing case-insensitive search.
|
||||
for _, v := range obj.Elem {
|
||||
if strings.EqualFold(v.Key(), fieldName) {
|
||||
elem = v
|
||||
elemKey = v.Key()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if elem == nil {
|
||||
// No key matching this field.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we get the value of the element
|
||||
if an, ok := elem.(ast.AssignmentNode); ok {
|
||||
elem = an.Value
|
||||
elems := obj.Get(fieldName, true)
|
||||
if len(elems) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Track the used key
|
||||
usedKeys[elemKey] = struct{}{}
|
||||
usedKeys[fieldName] = struct{}{}
|
||||
|
||||
// Create the field name and decode
|
||||
fieldName = fmt.Sprintf("%s.%s", name, fieldName)
|
||||
if err := d.decode(fieldName, elem, field); err != nil {
|
||||
return err
|
||||
for _, elem := range elems {
|
||||
if err := d.decode(fieldName, elem, field); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
decodedFields = append(decodedFields, fieldType.Name)
|
||||
|
Loading…
Reference in New Issue
Block a user