decoder can decode into structures
This commit is contained in:
parent
d3d3f86887
commit
daed1fd01d
20
ast/ast.go
20
ast/ast.go
@ -67,10 +67,30 @@ func (n ObjectNode) Accept(v Visitor) {
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns all the elements of this object with the given key.
|
||||
// This is a case-sensitive search.
|
||||
func (n ObjectNode) Get(k string) []KeyedNode {
|
||||
result := make([]KeyedNode, 0, 1)
|
||||
for _, elem := range n.Elem {
|
||||
if elem.Key() == k {
|
||||
result = append(result, elem)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Key returns the key of this object. If this is "", then it is
|
||||
// the root object.
|
||||
func (n ObjectNode) Key() string {
|
||||
return n.K
|
||||
}
|
||||
|
||||
// Len returns the number of elements of this object.
|
||||
func (n ObjectNode) Len() int {
|
||||
return len(n.Elem)
|
||||
}
|
||||
|
||||
func (n AssignmentNode) Accept(v Visitor) {
|
||||
v.Visit(n)
|
||||
n.Value.Accept(v)
|
||||
|
169
decoder.go
169
decoder.go
@ -4,10 +4,14 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/hcl/ast"
|
||||
)
|
||||
|
||||
// This is the tag to use with structures to have settings for HCL
|
||||
const tagName = "hcl"
|
||||
|
||||
// Decode reads the given input and decodes it into the structure
|
||||
// given by `out`.
|
||||
func Decode(out interface{}, in string) error {
|
||||
@ -58,8 +62,11 @@ func (d *decoder) decode(name string, n ast.Node, result reflect.Value) error {
|
||||
return d.decodeSlice(name, n, result)
|
||||
case reflect.String:
|
||||
return d.decodeString(name, n, result)
|
||||
case reflect.Struct:
|
||||
return d.decodeStruct(name, n, result)
|
||||
default:
|
||||
return fmt.Errorf("%s: unknown kind: %s", name, result.Kind())
|
||||
return fmt.Errorf(
|
||||
"%s: unknown kind to decode into: %s", name, result.Kind())
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -73,7 +80,7 @@ func (d *decoder) decodeInt(name string, raw ast.Node, result reflect.Value) err
|
||||
|
||||
switch n.Type {
|
||||
case ast.ValueTypeInt:
|
||||
result.Set(reflect.ValueOf(int64(n.Value.(int))))
|
||||
result.Set(reflect.ValueOf(n.Value.(int)))
|
||||
default:
|
||||
return fmt.Errorf("%s: unknown type %s", name, n.Type)
|
||||
}
|
||||
@ -288,3 +295,161 @@ func (d *decoder) decodeString(name string, raw ast.Node, result reflect.Value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *decoder) decodeStruct(name string, raw ast.Node, result reflect.Value) error {
|
||||
obj, ok := raw.(ast.ObjectNode)
|
||||
if !ok {
|
||||
return fmt.Errorf("%s: not an object type", name)
|
||||
}
|
||||
|
||||
// This slice will keep track of all the structs we'll be decoding.
|
||||
// There can be more than one struct if there are embedded structs
|
||||
// that are squashed.
|
||||
structs := make([]reflect.Value, 1, 5)
|
||||
structs[0] = result
|
||||
|
||||
// Compile the list of all the fields that we're going to be decoding
|
||||
// from all the structs.
|
||||
fields := make(map[*reflect.StructField]reflect.Value)
|
||||
for len(structs) > 0 {
|
||||
structVal := structs[0]
|
||||
structs = structs[1:]
|
||||
|
||||
structType := structVal.Type()
|
||||
for i := 0; i < structType.NumField(); i++ {
|
||||
fieldType := structType.Field(i)
|
||||
|
||||
if fieldType.Anonymous {
|
||||
fieldKind := fieldType.Type.Kind()
|
||||
if fieldKind != reflect.Struct {
|
||||
return fmt.Errorf(
|
||||
"%s: unsupported type to struct: %s",
|
||||
fieldType.Name, fieldKind)
|
||||
}
|
||||
|
||||
// We have an embedded field. We "squash" the fields down
|
||||
// if specified in the tag.
|
||||
squash := false
|
||||
tagParts := strings.Split(fieldType.Tag.Get(tagName), ",")
|
||||
for _, tag := range tagParts[1:] {
|
||||
if tag == "squash" {
|
||||
squash = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if squash {
|
||||
structs = append(
|
||||
structs, result.FieldByName(fieldType.Name))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Normal struct field, store it away
|
||||
fields[&fieldType] = structVal.Field(i)
|
||||
}
|
||||
}
|
||||
|
||||
usedKeys := make(map[string]struct{})
|
||||
decodedFields := make([]string, 0, len(fields))
|
||||
decodedFieldsVal := make([]reflect.Value, 0)
|
||||
unusedKeysVal := make([]reflect.Value, 0)
|
||||
for fieldType, field := range fields {
|
||||
if !field.IsValid() {
|
||||
// This should never happen
|
||||
panic("field is not valid")
|
||||
}
|
||||
|
||||
// If we can't set the field, then it is unexported or something,
|
||||
// and we just continue onwards.
|
||||
if !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldName := fieldType.Name
|
||||
|
||||
tagValue := fieldType.Tag.Get(tagName)
|
||||
tagParts := strings.SplitN(tagValue, ",", 2)
|
||||
if len(tagParts) >= 2 {
|
||||
switch tagParts[1] {
|
||||
case "decodedFields":
|
||||
decodedFieldsVal = append(decodedFieldsVal, field)
|
||||
continue
|
||||
case "key":
|
||||
field.SetString(obj.Key())
|
||||
continue
|
||||
case "unusedKeys":
|
||||
unusedKeysVal = append(unusedKeysVal, field)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if tagParts[0] != "" {
|
||||
fieldName = tagParts[0]
|
||||
}
|
||||
|
||||
// Find the element matching this name
|
||||
var elem ast.Node
|
||||
elemKey := fieldName
|
||||
if elems := obj.Get(fieldName); len(elems) > 0 {
|
||||
elem = elems[len(elems)-1]
|
||||
} else {
|
||||
// Do a slower search by iterating over each key and
|
||||
// doing case-insensitive search.
|
||||
for _, v := range obj.Elem {
|
||||
if strings.EqualFold(v.Key(), fieldName) {
|
||||
elem = v
|
||||
elemKey = v.Key()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if elem == nil {
|
||||
// No key matching this field.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we get the value of the element
|
||||
if an, ok := elem.(ast.AssignmentNode); ok {
|
||||
elem = an.Value
|
||||
}
|
||||
|
||||
// Track the used key
|
||||
usedKeys[elemKey] = struct{}{}
|
||||
|
||||
// Create the field name and decode
|
||||
fieldName = fmt.Sprintf("%s.%s", name, fieldName)
|
||||
if err := d.decode(fieldName, elem, field); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
decodedFields = append(decodedFields, fieldType.Name)
|
||||
}
|
||||
|
||||
for _, v := range decodedFieldsVal {
|
||||
v.Set(reflect.ValueOf(decodedFields))
|
||||
}
|
||||
|
||||
// If we want to know what keys are unused, compile that
|
||||
if len(unusedKeysVal) > 0 {
|
||||
unusedKeys := make([]string, 0, int(obj.Len())-len(usedKeys))
|
||||
|
||||
for _, elem := range obj.Elem {
|
||||
k := elem.Key()
|
||||
if _, ok := usedKeys[k]; !ok {
|
||||
unusedKeys = append(unusedKeys, k)
|
||||
}
|
||||
}
|
||||
|
||||
if len(unusedKeys) == 0 {
|
||||
unusedKeys = nil
|
||||
}
|
||||
|
||||
for _, v := range unusedKeysVal {
|
||||
v.Set(reflect.ValueOf(unusedKeys))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -137,3 +137,26 @@ func TestDecode_flatMap(t *testing.T) {
|
||||
t.Fatalf("Actual: %#v\n\nExpected: %#v", val, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecode_structure(t *testing.T) {
|
||||
type V struct {
|
||||
Key int
|
||||
Foo string
|
||||
}
|
||||
|
||||
var actual V
|
||||
|
||||
err := Decode(&actual, testReadFile(t, "flat.hcl"))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expected := V{
|
||||
Key: 7,
|
||||
Foo: "bar",
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(actual, expected) {
|
||||
t.Fatalf("Actual: %#v\n\nExpected: %#v", actual, expected)
|
||||
}
|
||||
}
|
||||
|
2
test-fixtures/flat.hcl
Normal file
2
test-fixtures/flat.hcl
Normal file
@ -0,0 +1,2 @@
|
||||
foo = "bar"
|
||||
Key = 7
|
Loading…
Reference in New Issue
Block a user