decoder can decode into structures

This commit is contained in:
Mitchell Hashimoto 2014-08-03 14:06:18 -07:00
parent d3d3f86887
commit daed1fd01d
4 changed files with 212 additions and 2 deletions

View File

@ -67,10 +67,30 @@ func (n ObjectNode) Accept(v Visitor) {
}
}
// Get returns all the elements of this object with the given key.
// This is a case-sensitive search.
func (n ObjectNode) Get(k string) []KeyedNode {
result := make([]KeyedNode, 0, 1)
for _, elem := range n.Elem {
if elem.Key() == k {
result = append(result, elem)
}
}
return result
}
// Key returns the key of this object. If this is "", then it is
// the root object.
func (n ObjectNode) Key() string {
return n.K
}
// Len returns the number of elements of this object.
func (n ObjectNode) Len() int {
return len(n.Elem)
}
func (n AssignmentNode) Accept(v Visitor) {
v.Visit(n)
n.Value.Accept(v)

View File

@ -4,10 +4,14 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/hashicorp/hcl/ast"
)
// This is the tag to use with structures to have settings for HCL
const tagName = "hcl"
// Decode reads the given input and decodes it into the structure
// given by `out`.
func Decode(out interface{}, in string) error {
@ -58,8 +62,11 @@ func (d *decoder) decode(name string, n ast.Node, result reflect.Value) error {
return d.decodeSlice(name, n, result)
case reflect.String:
return d.decodeString(name, n, result)
case reflect.Struct:
return d.decodeStruct(name, n, result)
default:
return fmt.Errorf("%s: unknown kind: %s", name, result.Kind())
return fmt.Errorf(
"%s: unknown kind to decode into: %s", name, result.Kind())
}
return nil
@ -73,7 +80,7 @@ func (d *decoder) decodeInt(name string, raw ast.Node, result reflect.Value) err
switch n.Type {
case ast.ValueTypeInt:
result.Set(reflect.ValueOf(int64(n.Value.(int))))
result.Set(reflect.ValueOf(n.Value.(int)))
default:
return fmt.Errorf("%s: unknown type %s", name, n.Type)
}
@ -288,3 +295,161 @@ func (d *decoder) decodeString(name string, raw ast.Node, result reflect.Value)
return nil
}
func (d *decoder) decodeStruct(name string, raw ast.Node, result reflect.Value) error {
obj, ok := raw.(ast.ObjectNode)
if !ok {
return fmt.Errorf("%s: not an object type", name)
}
// This slice will keep track of all the structs we'll be decoding.
// There can be more than one struct if there are embedded structs
// that are squashed.
structs := make([]reflect.Value, 1, 5)
structs[0] = result
// Compile the list of all the fields that we're going to be decoding
// from all the structs.
fields := make(map[*reflect.StructField]reflect.Value)
for len(structs) > 0 {
structVal := structs[0]
structs = structs[1:]
structType := structVal.Type()
for i := 0; i < structType.NumField(); i++ {
fieldType := structType.Field(i)
if fieldType.Anonymous {
fieldKind := fieldType.Type.Kind()
if fieldKind != reflect.Struct {
return fmt.Errorf(
"%s: unsupported type to struct: %s",
fieldType.Name, fieldKind)
}
// We have an embedded field. We "squash" the fields down
// if specified in the tag.
squash := false
tagParts := strings.Split(fieldType.Tag.Get(tagName), ",")
for _, tag := range tagParts[1:] {
if tag == "squash" {
squash = true
break
}
}
if squash {
structs = append(
structs, result.FieldByName(fieldType.Name))
continue
}
}
// Normal struct field, store it away
fields[&fieldType] = structVal.Field(i)
}
}
usedKeys := make(map[string]struct{})
decodedFields := make([]string, 0, len(fields))
decodedFieldsVal := make([]reflect.Value, 0)
unusedKeysVal := make([]reflect.Value, 0)
for fieldType, field := range fields {
if !field.IsValid() {
// This should never happen
panic("field is not valid")
}
// If we can't set the field, then it is unexported or something,
// and we just continue onwards.
if !field.CanSet() {
continue
}
fieldName := fieldType.Name
tagValue := fieldType.Tag.Get(tagName)
tagParts := strings.SplitN(tagValue, ",", 2)
if len(tagParts) >= 2 {
switch tagParts[1] {
case "decodedFields":
decodedFieldsVal = append(decodedFieldsVal, field)
continue
case "key":
field.SetString(obj.Key())
continue
case "unusedKeys":
unusedKeysVal = append(unusedKeysVal, field)
continue
}
}
if tagParts[0] != "" {
fieldName = tagParts[0]
}
// Find the element matching this name
var elem ast.Node
elemKey := fieldName
if elems := obj.Get(fieldName); len(elems) > 0 {
elem = elems[len(elems)-1]
} else {
// Do a slower search by iterating over each key and
// doing case-insensitive search.
for _, v := range obj.Elem {
if strings.EqualFold(v.Key(), fieldName) {
elem = v
elemKey = v.Key()
break
}
}
if elem == nil {
// No key matching this field.
continue
}
}
// Make sure we get the value of the element
if an, ok := elem.(ast.AssignmentNode); ok {
elem = an.Value
}
// Track the used key
usedKeys[elemKey] = struct{}{}
// Create the field name and decode
fieldName = fmt.Sprintf("%s.%s", name, fieldName)
if err := d.decode(fieldName, elem, field); err != nil {
return err
}
decodedFields = append(decodedFields, fieldType.Name)
}
for _, v := range decodedFieldsVal {
v.Set(reflect.ValueOf(decodedFields))
}
// If we want to know what keys are unused, compile that
if len(unusedKeysVal) > 0 {
unusedKeys := make([]string, 0, int(obj.Len())-len(usedKeys))
for _, elem := range obj.Elem {
k := elem.Key()
if _, ok := usedKeys[k]; !ok {
unusedKeys = append(unusedKeys, k)
}
}
if len(unusedKeys) == 0 {
unusedKeys = nil
}
for _, v := range unusedKeysVal {
v.Set(reflect.ValueOf(unusedKeys))
}
}
return nil
}

View File

@ -137,3 +137,26 @@ func TestDecode_flatMap(t *testing.T) {
t.Fatalf("Actual: %#v\n\nExpected: %#v", val, expected)
}
}
func TestDecode_structure(t *testing.T) {
type V struct {
Key int
Foo string
}
var actual V
err := Decode(&actual, testReadFile(t, "flat.hcl"))
if err != nil {
t.Fatalf("err: %s", err)
}
expected := V{
Key: 7,
Foo: "bar",
}
if !reflect.DeepEqual(actual, expected) {
t.Fatalf("Actual: %#v\n\nExpected: %#v", actual, expected)
}
}

2
test-fixtures/flat.hcl Normal file
View File

@ -0,0 +1,2 @@
foo = "bar"
Key = 7