decoder can decode into structures
This commit is contained in:
parent
d3d3f86887
commit
daed1fd01d
20
ast/ast.go
20
ast/ast.go
@ -67,10 +67,30 @@ func (n ObjectNode) Accept(v Visitor) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get returns all the elements of this object with the given key.
|
||||||
|
// This is a case-sensitive search.
|
||||||
|
func (n ObjectNode) Get(k string) []KeyedNode {
|
||||||
|
result := make([]KeyedNode, 0, 1)
|
||||||
|
for _, elem := range n.Elem {
|
||||||
|
if elem.Key() == k {
|
||||||
|
result = append(result, elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Key returns the key of this object. If this is "", then it is
|
||||||
|
// the root object.
|
||||||
func (n ObjectNode) Key() string {
|
func (n ObjectNode) Key() string {
|
||||||
return n.K
|
return n.K
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Len returns the number of elements of this object.
|
||||||
|
func (n ObjectNode) Len() int {
|
||||||
|
return len(n.Elem)
|
||||||
|
}
|
||||||
|
|
||||||
func (n AssignmentNode) Accept(v Visitor) {
|
func (n AssignmentNode) Accept(v Visitor) {
|
||||||
v.Visit(n)
|
v.Visit(n)
|
||||||
n.Value.Accept(v)
|
n.Value.Accept(v)
|
||||||
|
169
decoder.go
169
decoder.go
@ -4,10 +4,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/hashicorp/hcl/ast"
|
"github.com/hashicorp/hcl/ast"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// This is the tag to use with structures to have settings for HCL
|
||||||
|
const tagName = "hcl"
|
||||||
|
|
||||||
// Decode reads the given input and decodes it into the structure
|
// Decode reads the given input and decodes it into the structure
|
||||||
// given by `out`.
|
// given by `out`.
|
||||||
func Decode(out interface{}, in string) error {
|
func Decode(out interface{}, in string) error {
|
||||||
@ -58,8 +62,11 @@ func (d *decoder) decode(name string, n ast.Node, result reflect.Value) error {
|
|||||||
return d.decodeSlice(name, n, result)
|
return d.decodeSlice(name, n, result)
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
return d.decodeString(name, n, result)
|
return d.decodeString(name, n, result)
|
||||||
|
case reflect.Struct:
|
||||||
|
return d.decodeStruct(name, n, result)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%s: unknown kind: %s", name, result.Kind())
|
return fmt.Errorf(
|
||||||
|
"%s: unknown kind to decode into: %s", name, result.Kind())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -73,7 +80,7 @@ func (d *decoder) decodeInt(name string, raw ast.Node, result reflect.Value) err
|
|||||||
|
|
||||||
switch n.Type {
|
switch n.Type {
|
||||||
case ast.ValueTypeInt:
|
case ast.ValueTypeInt:
|
||||||
result.Set(reflect.ValueOf(int64(n.Value.(int))))
|
result.Set(reflect.ValueOf(n.Value.(int)))
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("%s: unknown type %s", name, n.Type)
|
return fmt.Errorf("%s: unknown type %s", name, n.Type)
|
||||||
}
|
}
|
||||||
@ -288,3 +295,161 @@ func (d *decoder) decodeString(name string, raw ast.Node, result reflect.Value)
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *decoder) decodeStruct(name string, raw ast.Node, result reflect.Value) error {
|
||||||
|
obj, ok := raw.(ast.ObjectNode)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("%s: not an object type", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This slice will keep track of all the structs we'll be decoding.
|
||||||
|
// There can be more than one struct if there are embedded structs
|
||||||
|
// that are squashed.
|
||||||
|
structs := make([]reflect.Value, 1, 5)
|
||||||
|
structs[0] = result
|
||||||
|
|
||||||
|
// Compile the list of all the fields that we're going to be decoding
|
||||||
|
// from all the structs.
|
||||||
|
fields := make(map[*reflect.StructField]reflect.Value)
|
||||||
|
for len(structs) > 0 {
|
||||||
|
structVal := structs[0]
|
||||||
|
structs = structs[1:]
|
||||||
|
|
||||||
|
structType := structVal.Type()
|
||||||
|
for i := 0; i < structType.NumField(); i++ {
|
||||||
|
fieldType := structType.Field(i)
|
||||||
|
|
||||||
|
if fieldType.Anonymous {
|
||||||
|
fieldKind := fieldType.Type.Kind()
|
||||||
|
if fieldKind != reflect.Struct {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"%s: unsupported type to struct: %s",
|
||||||
|
fieldType.Name, fieldKind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have an embedded field. We "squash" the fields down
|
||||||
|
// if specified in the tag.
|
||||||
|
squash := false
|
||||||
|
tagParts := strings.Split(fieldType.Tag.Get(tagName), ",")
|
||||||
|
for _, tag := range tagParts[1:] {
|
||||||
|
if tag == "squash" {
|
||||||
|
squash = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if squash {
|
||||||
|
structs = append(
|
||||||
|
structs, result.FieldByName(fieldType.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal struct field, store it away
|
||||||
|
fields[&fieldType] = structVal.Field(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usedKeys := make(map[string]struct{})
|
||||||
|
decodedFields := make([]string, 0, len(fields))
|
||||||
|
decodedFieldsVal := make([]reflect.Value, 0)
|
||||||
|
unusedKeysVal := make([]reflect.Value, 0)
|
||||||
|
for fieldType, field := range fields {
|
||||||
|
if !field.IsValid() {
|
||||||
|
// This should never happen
|
||||||
|
panic("field is not valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we can't set the field, then it is unexported or something,
|
||||||
|
// and we just continue onwards.
|
||||||
|
if !field.CanSet() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldName := fieldType.Name
|
||||||
|
|
||||||
|
tagValue := fieldType.Tag.Get(tagName)
|
||||||
|
tagParts := strings.SplitN(tagValue, ",", 2)
|
||||||
|
if len(tagParts) >= 2 {
|
||||||
|
switch tagParts[1] {
|
||||||
|
case "decodedFields":
|
||||||
|
decodedFieldsVal = append(decodedFieldsVal, field)
|
||||||
|
continue
|
||||||
|
case "key":
|
||||||
|
field.SetString(obj.Key())
|
||||||
|
continue
|
||||||
|
case "unusedKeys":
|
||||||
|
unusedKeysVal = append(unusedKeysVal, field)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tagParts[0] != "" {
|
||||||
|
fieldName = tagParts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the element matching this name
|
||||||
|
var elem ast.Node
|
||||||
|
elemKey := fieldName
|
||||||
|
if elems := obj.Get(fieldName); len(elems) > 0 {
|
||||||
|
elem = elems[len(elems)-1]
|
||||||
|
} else {
|
||||||
|
// Do a slower search by iterating over each key and
|
||||||
|
// doing case-insensitive search.
|
||||||
|
for _, v := range obj.Elem {
|
||||||
|
if strings.EqualFold(v.Key(), fieldName) {
|
||||||
|
elem = v
|
||||||
|
elemKey = v.Key()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if elem == nil {
|
||||||
|
// No key matching this field.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we get the value of the element
|
||||||
|
if an, ok := elem.(ast.AssignmentNode); ok {
|
||||||
|
elem = an.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track the used key
|
||||||
|
usedKeys[elemKey] = struct{}{}
|
||||||
|
|
||||||
|
// Create the field name and decode
|
||||||
|
fieldName = fmt.Sprintf("%s.%s", name, fieldName)
|
||||||
|
if err := d.decode(fieldName, elem, field); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
decodedFields = append(decodedFields, fieldType.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range decodedFieldsVal {
|
||||||
|
v.Set(reflect.ValueOf(decodedFields))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we want to know what keys are unused, compile that
|
||||||
|
if len(unusedKeysVal) > 0 {
|
||||||
|
unusedKeys := make([]string, 0, int(obj.Len())-len(usedKeys))
|
||||||
|
|
||||||
|
for _, elem := range obj.Elem {
|
||||||
|
k := elem.Key()
|
||||||
|
if _, ok := usedKeys[k]; !ok {
|
||||||
|
unusedKeys = append(unusedKeys, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(unusedKeys) == 0 {
|
||||||
|
unusedKeys = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unusedKeysVal {
|
||||||
|
v.Set(reflect.ValueOf(unusedKeys))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -137,3 +137,26 @@ func TestDecode_flatMap(t *testing.T) {
|
|||||||
t.Fatalf("Actual: %#v\n\nExpected: %#v", val, expected)
|
t.Fatalf("Actual: %#v\n\nExpected: %#v", val, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecode_structure(t *testing.T) {
|
||||||
|
type V struct {
|
||||||
|
Key int
|
||||||
|
Foo string
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual V
|
||||||
|
|
||||||
|
err := Decode(&actual, testReadFile(t, "flat.hcl"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := V{
|
||||||
|
Key: 7,
|
||||||
|
Foo: "bar",
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Fatalf("Actual: %#v\n\nExpected: %#v", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
2
test-fixtures/flat.hcl
Normal file
2
test-fixtures/flat.hcl
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
foo = "bar"
|
||||||
|
Key = 7
|
Loading…
x
Reference in New Issue
Block a user