From a095771be9cce0f27b98575cc0ee314ab4bb7ee4 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Sun, 3 Aug 2014 17:57:54 -0700 Subject: [PATCH] case insensitive Get --- ast/ast.go | 10 ++++++++-- ast/ast_test.go | 2 +- decoder.go | 35 ++++++++--------------------------- 3 files changed, 17 insertions(+), 30 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index df85e13..51a9b4a 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -1,5 +1,9 @@ package ast +import ( + "strings" +) + // ValueType is an enum represnting the type of a value in // a LiteralNode. type ValueType byte @@ -69,11 +73,13 @@ func (n ObjectNode) Accept(v Visitor) { // Get returns all the elements of this object with the given key. // This is a case-sensitive search. -func (n ObjectNode) Get(k string) []Node { +func (n ObjectNode) Get(k string, insensitive bool) []Node { result := make([]Node, 0, 1) for _, elem := range n.Elem { if elem.Key() != k { - continue + if !insensitive || !strings.EqualFold(elem.Key(), k) { + continue + } } switch n := elem.(type) { diff --git a/ast/ast_test.go b/ast/ast_test.go index fa45382..16d4881 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -86,7 +86,7 @@ func TestObjectNodeGet(t *testing.T) { LiteralNode{Value: "baz"}, } - actual := n.Get("foo") + actual := n.Get("foo", false) if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad: %#v", actual) diff --git a/decoder.go b/decoder.go index 48a0b9d..5676821 100644 --- a/decoder.go +++ b/decoder.go @@ -419,39 +419,20 @@ func (d *decoder) decodeStruct(name string, raw ast.Node, result reflect.Value) } // Find the element matching this name - var elem ast.Node - elemKey := fieldName - if elems := obj.Get(fieldName); len(elems) > 0 { - elem = elems[len(elems)-1] - } else { - // Do a slower search by iterating over each key and - // doing case-insensitive search. - for _, v := range obj.Elem { - if strings.EqualFold(v.Key(), fieldName) { - elem = v - elemKey = v.Key() - break - } - } - - if elem == nil { - // No key matching this field. - continue - } - } - - // Make sure we get the value of the element - if an, ok := elem.(ast.AssignmentNode); ok { - elem = an.Value + elems := obj.Get(fieldName, true) + if len(elems) == 0 { + continue } // Track the used key - usedKeys[elemKey] = struct{}{} + usedKeys[fieldName] = struct{}{} // Create the field name and decode fieldName = fmt.Sprintf("%s.%s", name, fieldName) - if err := d.decode(fieldName, elem, field); err != nil { - return err + for _, elem := range elems { + if err := d.decode(fieldName, elem, field); err != nil { + return err + } } decodedFields = append(decodedFields, fieldType.Name)