decoder work

This commit is contained in:
Mitchell Hashimoto 2015-11-06 23:12:15 -08:00
parent 0c18c66fff
commit 9501fc5ad0
2 changed files with 293 additions and 247 deletions

View File

@ -4,11 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strconv" "strconv"
"strings"
"github.com/hashicorp/hcl/hcl" "github.com/hashicorp/hcl/hcl/ast"
"github.com/hashicorp/hcl/hcl/token"
) )
// This is the tag to use with structures to have settings for HCL // This is the tag to use with structures to have settings for HCL
@ -27,21 +26,21 @@ func Decode(out interface{}, in string) error {
// DecodeObject is a lower-level version of Decode. It decodes a // DecodeObject is a lower-level version of Decode. It decodes a
// raw Object into the given output. // raw Object into the given output.
func DecodeObject(out interface{}, n *hcl.Object) error { func DecodeObject(out interface{}, n *ast.File) error {
val := reflect.ValueOf(out) val := reflect.ValueOf(out)
if val.Kind() != reflect.Ptr { if val.Kind() != reflect.Ptr {
return errors.New("result must be a pointer") return errors.New("result must be a pointer")
} }
var d decoder var d decoder
return d.decode("root", n, val.Elem()) return d.decode("root", n.Node, val.Elem())
} }
type decoder struct { type decoder struct {
stack []reflect.Kind stack []reflect.Kind
} }
func (d *decoder) decode(name string, o *hcl.Object, result reflect.Value) error { func (d *decoder) decode(name string, node ast.Node, result reflect.Value) error {
k := result k := result
// If we have an interface with a valid value, we use that // If we have an interface with a valid value, we use that
@ -65,24 +64,24 @@ func (d *decoder) decode(name string, o *hcl.Object, result reflect.Value) error
switch k.Kind() { switch k.Kind() {
case reflect.Bool: case reflect.Bool:
return d.decodeBool(name, o, result) return d.decodeBool(name, node, result)
case reflect.Float64: case reflect.Float64:
return d.decodeFloat(name, o, result) return d.decodeFloat(name, node, result)
case reflect.Int: case reflect.Int:
return d.decodeInt(name, o, result) return d.decodeInt(name, node, result)
case reflect.Interface: case reflect.Interface:
// When we see an interface, we make our own thing // When we see an interface, we make our own thing
return d.decodeInterface(name, o, result) return d.decodeInterface(name, node, result)
case reflect.Map: case reflect.Map:
return d.decodeMap(name, o, result) return d.decodeMap(name, node, result)
case reflect.Ptr: case reflect.Ptr:
return d.decodePtr(name, o, result) return d.decodePtr(name, node, result)
case reflect.Slice: case reflect.Slice:
return d.decodeSlice(name, o, result) return d.decodeSlice(name, node, result)
case reflect.String: case reflect.String:
return d.decodeString(name, o, result) return d.decodeString(name, node, result)
case reflect.Struct: case reflect.Struct:
return d.decodeStruct(name, o, result) return d.decodeStruct(name, node, result)
default: default:
return fmt.Errorf( return fmt.Errorf(
"%s: unknown kind to decode into: %s", name, k.Kind()) "%s: unknown kind to decode into: %s", name, k.Kind())
@ -91,52 +90,66 @@ func (d *decoder) decode(name string, o *hcl.Object, result reflect.Value) error
return nil return nil
} }
func (d *decoder) decodeBool(name string, o *hcl.Object, result reflect.Value) error { func (d *decoder) decodeBool(name string, node ast.Node, result reflect.Value) error {
switch o.Type { switch n := node.(type) {
case hcl.ValueTypeBool: case *ast.LiteralType:
result.Set(reflect.ValueOf(o.Value.(bool))) if n.Token.Type == token.BOOL {
default: v, err := strconv.ParseBool(n.Token.Text)
return fmt.Errorf("%s: unknown type %v", name, o.Type) if err != nil {
} return err
}
return nil result.SetBool(v)
} return nil
func (d *decoder) decodeFloat(name string, o *hcl.Object, result reflect.Value) error {
switch o.Type {
case hcl.ValueTypeFloat:
result.Set(reflect.ValueOf(o.Value.(float64)))
default:
return fmt.Errorf("%s: unknown type %v", name, o.Type)
}
return nil
}
func (d *decoder) decodeInt(name string, o *hcl.Object, result reflect.Value) error {
switch o.Type {
case hcl.ValueTypeInt:
result.Set(reflect.ValueOf(o.Value.(int)))
case hcl.ValueTypeString:
v, err := strconv.ParseInt(o.Value.(string), 0, 0)
if err != nil {
return err
} }
result.SetInt(int64(v))
default:
return fmt.Errorf("%s: unknown type %v", name, o.Type)
} }
return nil return fmt.Errorf("%s: unknown type %t", name, node)
} }
func (d *decoder) decodeInterface(name string, o *hcl.Object, result reflect.Value) error { func (d *decoder) decodeFloat(name string, node ast.Node, result reflect.Value) error {
switch n := node.(type) {
case *ast.LiteralType:
if n.Token.Type == token.FLOAT {
v, err := strconv.ParseFloat(n.Token.Text, 64)
if err != nil {
return err
}
result.Set(reflect.ValueOf(v))
return nil
}
}
return fmt.Errorf("%s: unknown type %t", name, node)
}
func (d *decoder) decodeInt(name string, node ast.Node, result reflect.Value) error {
switch n := node.(type) {
case *ast.LiteralType:
switch n.Token.Type {
case token.NUMBER:
fallthrough
case token.STRING:
v, err := strconv.ParseInt(n.Token.Text, 0, 64)
if err != nil {
return err
}
result.SetInt(int64(v))
return nil
}
}
return fmt.Errorf("%s: unknown type %t", name, node)
}
func (d *decoder) decodeInterface(name string, node ast.Node, result reflect.Value) error {
var set reflect.Value var set reflect.Value
redecode := true redecode := true
switch o.Type { switch n := node.(type) {
case hcl.ValueTypeObject: case *ast.ObjectList:
// If we're at the root or we're directly within a slice, then we // If we're at the root or we're directly within a slice, then we
// decode objects into map[string]interface{}, otherwise we decode // decode objects into map[string]interface{}, otherwise we decode
// them into lists. // them into lists.
@ -153,30 +166,43 @@ func (d *decoder) decodeInterface(name string, o *hcl.Object, result reflect.Val
var temp []map[string]interface{} var temp []map[string]interface{}
tempVal := reflect.ValueOf(temp) tempVal := reflect.ValueOf(temp)
result := reflect.MakeSlice( result := reflect.MakeSlice(
reflect.SliceOf(tempVal.Type().Elem()), 0, int(o.Len())) reflect.SliceOf(tempVal.Type().Elem()), 0, len(n.Items))
set = result set = result
} }
case hcl.ValueTypeList: case *ast.ObjectType:
var temp []map[string]interface{}
tempVal := reflect.ValueOf(temp)
result := reflect.MakeSlice(
reflect.SliceOf(tempVal.Type().Elem()), 0, 1)
set = result
case *ast.ListType:
var temp []interface{} var temp []interface{}
tempVal := reflect.ValueOf(temp) tempVal := reflect.ValueOf(temp)
result := reflect.MakeSlice( result := reflect.MakeSlice(
reflect.SliceOf(tempVal.Type().Elem()), 0, 0) reflect.SliceOf(tempVal.Type().Elem()), 0, 0)
set = result set = result
case hcl.ValueTypeBool: case *ast.LiteralType:
var result bool switch n.Token.Type {
set = reflect.Indirect(reflect.New(reflect.TypeOf(result))) case token.BOOL:
case hcl.ValueTypeFloat: var result bool
var result float64 set = reflect.Indirect(reflect.New(reflect.TypeOf(result)))
set = reflect.Indirect(reflect.New(reflect.TypeOf(result))) case token.FLOAT:
case hcl.ValueTypeInt: var result float64
var result int set = reflect.Indirect(reflect.New(reflect.TypeOf(result)))
set = reflect.Indirect(reflect.New(reflect.TypeOf(result))) case token.NUMBER:
case hcl.ValueTypeString: var result int
set = reflect.Indirect(reflect.New(reflect.TypeOf(""))) set = reflect.Indirect(reflect.New(reflect.TypeOf(result)))
case token.STRING:
set = reflect.Indirect(reflect.New(reflect.TypeOf("")))
default:
return fmt.Errorf(
"%s: cannot decode into interface: %T",
name, node)
}
default: default:
return fmt.Errorf( return fmt.Errorf(
"%s: cannot decode into interface: %T", "%s: cannot decode into interface: %T",
name, o) name, node)
} }
// Set the result to what its supposed to be, then reset // Set the result to what its supposed to be, then reset
@ -186,7 +212,7 @@ func (d *decoder) decodeInterface(name string, o *hcl.Object, result reflect.Val
if redecode { if redecode {
// Revisit the node so that we can use the newly instantiated // Revisit the node so that we can use the newly instantiated
// thing and populate it. // thing and populate it.
if err := d.decode(name, o, result); err != nil { if err := d.decode(name, node, result); err != nil {
return err return err
} }
} }
@ -194,9 +220,18 @@ func (d *decoder) decodeInterface(name string, o *hcl.Object, result reflect.Val
return nil return nil
} }
func (d *decoder) decodeMap(name string, o *hcl.Object, result reflect.Value) error { func (d *decoder) decodeMap(name string, node ast.Node, result reflect.Value) error {
if o.Type != hcl.ValueTypeObject { if item, ok := node.(*ast.ObjectItem); ok {
return fmt.Errorf("%s: not an object type for map (%v)", name, o.Type) node = &ast.ObjectList{Items: []*ast.ObjectItem{item}}
}
if ot, ok := node.(*ast.ObjectType); ok {
node = ot.List
}
n, ok := node.(*ast.ObjectList)
if !ok {
return fmt.Errorf("%s: not an object type for map (%T)", name, node)
} }
// If we have an interface, then we can address the interface, // If we have an interface, then we can address the interface,
@ -222,33 +257,48 @@ func (d *decoder) decodeMap(name string, o *hcl.Object, result reflect.Value) er
} }
// Go through each element and decode it. // Go through each element and decode it.
for _, o := range o.Elem(false) { done := make(map[string]struct{})
if o.Value == nil { for _, item := range n.Items {
if item.Val == nil {
continue continue
} }
for _, o := range o.Elem(true) { // Get the key we're dealing with, which is the first item
// Make the field name keyStr := item.Keys[0].Token.Value().(string)
fieldName := fmt.Sprintf("%s.%s", name, o.Key)
// Get the key/value as reflection values // If we've already processed this key, then ignore it
key := reflect.ValueOf(o.Key) if _, ok := done[keyStr]; ok {
val := reflect.Indirect(reflect.New(resultElemType)) continue
// If we have a pre-existing value in the map, use that
oldVal := resultMap.MapIndex(key)
if oldVal.IsValid() {
val.Set(oldVal)
}
// Decode!
if err := d.decode(fieldName, o, val); err != nil {
return err
}
// Set the value on the map
resultMap.SetMapIndex(key, val)
} }
// Determine the value. If we have more than one key, then we
// get the objectlist of only these keys.
itemVal := item.Val
if len(item.Keys) > 1 {
itemVal = n.Prefix(keyStr)
done[keyStr] = struct{}{}
}
// Make the field name
fieldName := fmt.Sprintf("%s.%s", name, keyStr)
// Get the key/value as reflection values
key := reflect.ValueOf(keyStr)
val := reflect.Indirect(reflect.New(resultElemType))
// If we have a pre-existing value in the map, use that
oldVal := resultMap.MapIndex(key)
if oldVal.IsValid() {
val.Set(oldVal)
}
// Decode!
if err := d.decode(fieldName, itemVal, val); err != nil {
return err
}
// Set the value on the map
resultMap.SetMapIndex(key, val)
} }
// Set the final map if we can // Set the final map if we can
@ -256,13 +306,13 @@ func (d *decoder) decodeMap(name string, o *hcl.Object, result reflect.Value) er
return nil return nil
} }
func (d *decoder) decodePtr(name string, o *hcl.Object, result reflect.Value) error { func (d *decoder) decodePtr(name string, node ast.Node, result reflect.Value) error {
// Create an element of the concrete (non pointer) type and decode // Create an element of the concrete (non pointer) type and decode
// into that. Then set the value of the pointer to this type. // into that. Then set the value of the pointer to this type.
resultType := result.Type() resultType := result.Type()
resultElemType := resultType.Elem() resultElemType := resultType.Elem()
val := reflect.New(resultElemType) val := reflect.New(resultElemType)
if err := d.decode(name, o, reflect.Indirect(val)); err != nil { if err := d.decode(name, node, reflect.Indirect(val)); err != nil {
return err return err
} }
@ -270,7 +320,7 @@ func (d *decoder) decodePtr(name string, o *hcl.Object, result reflect.Value) er
return nil return nil
} }
func (d *decoder) decodeSlice(name string, o *hcl.Object, result reflect.Value) error { func (d *decoder) decodeSlice(name string, node ast.Node, result reflect.Value) error {
// If we have an interface, then we can address the interface, // If we have an interface, then we can address the interface,
// but not the slice itself, so get the element but set the interface // but not the slice itself, so get the element but set the interface
set := result set := result
@ -287,197 +337,192 @@ func (d *decoder) decodeSlice(name string, o *hcl.Object, result reflect.Value)
resultSliceType, 0, 0) resultSliceType, 0, 0)
} }
// Determine how we're doing this // Figure out the items we'll be copying into the slice
expand := true var items []ast.Node
switch o.Type { switch n := node.(type) {
case hcl.ValueTypeObject: case *ast.ObjectList:
expand = false items = make([]ast.Node, len(n.Items))
default: for i, item := range n.Items {
// Array or anything else: we expand values and take it all items[i] = item
}
case *ast.ObjectType:
items = []ast.Node{n}
} }
i := 0 if items == nil {
for _, o := range o.Elem(expand) { return fmt.Errorf("unknown slice type: %T", node)
}
for i, item := range items {
fieldName := fmt.Sprintf("%s[%d]", name, i) fieldName := fmt.Sprintf("%s[%d]", name, i)
// Decode // Decode
val := reflect.Indirect(reflect.New(resultElemType)) val := reflect.Indirect(reflect.New(resultElemType))
if err := d.decode(fieldName, o, val); err != nil { if err := d.decode(fieldName, item, val); err != nil {
return err return err
} }
// Append it onto the slice // Append it onto the slice
result = reflect.Append(result, val) result = reflect.Append(result, val)
i += 1
} }
set.Set(result) set.Set(result)
return nil return nil
} }
func (d *decoder) decodeString(name string, o *hcl.Object, result reflect.Value) error { func (d *decoder) decodeString(name string, node ast.Node, result reflect.Value) error {
switch o.Type { switch n := node.(type) {
case hcl.ValueTypeInt: case *ast.LiteralType:
result.Set(reflect.ValueOf( switch n.Token.Type {
strconv.FormatInt(int64(o.Value.(int)), 10)).Convert(result.Type())) case token.NUMBER:
case hcl.ValueTypeString: fallthrough
result.Set(reflect.ValueOf(o.Value.(string)).Convert(result.Type())) case token.STRING:
default: result.Set(reflect.ValueOf(n.Token.Value()))
return fmt.Errorf("%s: unknown type to string: %v", name, o.Type) return nil
}
} }
return nil return fmt.Errorf("%s: unknown type %t", name, node)
} }
func (d *decoder) decodeStruct(name string, o *hcl.Object, result reflect.Value) error { func (d *decoder) decodeStruct(name string, node ast.Node, result reflect.Value) error {
if o.Type != hcl.ValueTypeObject { return nil
return fmt.Errorf("%s: not an object type for struct (%v)", name, o.Type) /*
} item, ok := node.(*ast.ObjectItem)
if !ok {
return fmt.Errorf("%s: not an object type for map (%t)", name, node)
}
// This slice will keep track of all the structs we'll be decoding. val, ok := node.(*ast.ObjectList)
// There can be more than one struct if there are embedded structs if !ok {
// that are squashed. return fmt.Errorf("%s: not an object type for map (%t)", name, node)
structs := make([]reflect.Value, 1, 5) }
structs[0] = result
// Compile the list of all the fields that we're going to be decoding // This slice will keep track of all the structs we'll be decoding.
// from all the structs. // There can be more than one struct if there are embedded structs
fields := make(map[*reflect.StructField]reflect.Value) // that are squashed.
for len(structs) > 0 { structs := make([]reflect.Value, 1, 5)
structVal := structs[0] structs[0] = result
structs = structs[1:]
structType := structVal.Type() // Compile the list of all the fields that we're going to be decoding
for i := 0; i < structType.NumField(); i++ { // from all the structs.
fieldType := structType.Field(i) fields := make(map[*reflect.StructField]reflect.Value)
for len(structs) > 0 {
structVal := structs[0]
structs = structs[1:]
if fieldType.Anonymous { structType := structVal.Type()
fieldKind := fieldType.Type.Kind() for i := 0; i < structType.NumField(); i++ {
if fieldKind != reflect.Struct { fieldType := structType.Field(i)
return fmt.Errorf(
"%s: unsupported type to struct: %s",
fieldType.Name, fieldKind)
}
// We have an embedded field. We "squash" the fields down if fieldType.Anonymous {
// if specified in the tag. fieldKind := fieldType.Type.Kind()
squash := false if fieldKind != reflect.Struct {
tagParts := strings.Split(fieldType.Tag.Get(tagName), ",") return fmt.Errorf(
for _, tag := range tagParts[1:] { "%s: unsupported type to struct: %s",
if tag == "squash" { fieldType.Name, fieldKind)
squash = true }
break
// We have an embedded field. We "squash" the fields down
// if specified in the tag.
squash := false
tagParts := strings.Split(fieldType.Tag.Get(tagName), ",")
for _, tag := range tagParts[1:] {
if tag == "squash" {
squash = true
break
}
}
if squash {
structs = append(
structs, result.FieldByName(fieldType.Name))
continue
} }
} }
if squash { // Normal struct field, store it away
structs = append( fields[&fieldType] = structVal.Field(i)
structs, result.FieldByName(fieldType.Name)) }
}
usedKeys := make(map[string]struct{})
decodedFields := make([]string, 0, len(fields))
decodedFieldsVal := make([]reflect.Value, 0)
unusedKeysVal := make([]reflect.Value, 0)
for fieldType, field := range fields {
if !field.IsValid() {
// This should never happen
panic("field is not valid")
}
// If we can't set the field, then it is unexported or something,
// and we just continue onwards.
if !field.CanSet() {
continue
}
fieldName := fieldType.Name
// This is whether or not we expand the object into its children
// later.
expand := false
tagValue := fieldType.Tag.Get(tagName)
tagParts := strings.SplitN(tagValue, ",", 2)
if len(tagParts) >= 2 {
switch tagParts[1] {
case "expand":
expand = true
case "decodedFields":
decodedFieldsVal = append(decodedFieldsVal, field)
continue
case "key":
field.SetString(item.Keys[0].Token.Text)
continue
case "unusedKeys":
unusedKeysVal = append(unusedKeysVal, field)
continue continue
} }
} }
// Normal struct field, store it away if tagParts[0] != "" {
fields[&fieldType] = structVal.Field(i) fieldName = tagParts[0]
}
}
usedKeys := make(map[string]struct{})
decodedFields := make([]string, 0, len(fields))
decodedFieldsVal := make([]reflect.Value, 0)
unusedKeysVal := make([]reflect.Value, 0)
for fieldType, field := range fields {
if !field.IsValid() {
// This should never happen
panic("field is not valid")
}
// If we can't set the field, then it is unexported or something,
// and we just continue onwards.
if !field.CanSet() {
continue
}
fieldName := fieldType.Name
// This is whether or not we expand the object into its children
// later.
expand := false
tagValue := fieldType.Tag.Get(tagName)
tagParts := strings.SplitN(tagValue, ",", 2)
if len(tagParts) >= 2 {
switch tagParts[1] {
case "expand":
expand = true
case "decodedFields":
decodedFieldsVal = append(decodedFieldsVal, field)
continue
case "key":
field.SetString(o.Key)
continue
case "unusedKeys":
unusedKeysVal = append(unusedKeysVal, field)
continue
} }
}
if tagParts[0] != "" { // Find the element matching this name
fieldName = tagParts[0]
}
// Find the element matching this name
obj := o.Get(fieldName, true)
if obj == nil {
continue continue
} /*
obj := o.Get(fieldName, true)
if obj == nil {
continue
}
// Track the used key // Track the used key
usedKeys[fieldName] = struct{}{} usedKeys[fieldName] = struct{}{}
// Create the field name and decode. We range over the elements // Create the field name and decode. We range over the elements
// because we actually want the value. // because we actually want the value.
fieldName = fmt.Sprintf("%s.%s", name, fieldName) fieldName = fmt.Sprintf("%s.%s", name, fieldName)
for _, obj := range obj.Elem(expand) { for _, obj := range obj.Elem(expand) {
if err := d.decode(fieldName, obj, field); err != nil { if err := d.decode(fieldName, obj, field); err != nil {
return err return err
} }
}
decodedFields = append(decodedFields, fieldType.Name)
}
if len(decodedFieldsVal) > 0 {
// Sort it so that it is deterministic
sort.Strings(decodedFields)
for _, v := range decodedFieldsVal {
v.Set(reflect.ValueOf(decodedFields))
}
}
// If we want to know what keys are unused, compile that
if len(unusedKeysVal) > 0 {
/*
unusedKeys := make([]string, 0, int(obj.Len())-len(usedKeys))
for _, elem := range obj.Elem {
k := elem.Key()
if _, ok := usedKeys[k]; !ok {
unusedKeys = append(unusedKeys, k)
} }
}
if len(unusedKeys) == 0 { decodedFields = append(decodedFields, fieldType.Name)
unusedKeys = nil }
}
for _, v := range unusedKeysVal { if len(decodedFieldsVal) > 0 {
v.Set(reflect.ValueOf(unusedKeys)) // Sort it so that it is deterministic
sort.Strings(decodedFields)
for _, v := range decodedFieldsVal {
v.Set(reflect.ValueOf(decodedFields))
} }
*/ }
}
*/
return nil return nil
} }

View File

@ -3,17 +3,18 @@ package hcl
import ( import (
"fmt" "fmt"
"github.com/hashicorp/hcl/hcl" "github.com/hashicorp/hcl/hcl/ast"
hclParser "github.com/hashicorp/hcl/hcl/parser"
"github.com/hashicorp/hcl/json" "github.com/hashicorp/hcl/json"
) )
// Parse parses the given input and returns the root object. // Parse parses the given input and returns the root object.
// //
// The input format can be either HCL or JSON. // The input format can be either HCL or JSON.
func Parse(input string) (*hcl.Object, error) { func Parse(input string) (*ast.File, error) {
switch lexMode(input) { switch lexMode(input) {
case lexModeHcl: case lexModeHcl:
return hcl.Parse(input) return hclParser.Parse([]byte(input))
case lexModeJson: case lexModeJson:
return json.Parse(input) return json.Parse(input)
} }