diff --git a/decoder.go b/decoder.go index 485ac4e..672861f 100644 --- a/decoder.go +++ b/decoder.go @@ -25,7 +25,18 @@ func DecodeAST(out interface{}, obj *ast.ObjectNode) error { } func decode(name string, n ast.Node, result reflect.Value) error { - switch result.Kind() { + k := result + + // If we have an interface with a valid value, we use that + // for the check. + if result.Kind() == reflect.Interface { + elem := result.Elem() + if elem.IsValid() { + k = elem + } + } + + switch k.Kind() { case reflect.Int: return decodeInt(name, n, result) case reflect.Interface: @@ -185,15 +196,41 @@ func decodeMap(name string, raw ast.Node, result reflect.Value) error { } func decodeSlice(name string, raw ast.Node, result reflect.Value) error { - // Create the slice + n, ok := raw.(ast.ListNode) + if !ok { + return fmt.Errorf("%s: not a list type", name) + } + + // If we have an interface, then we can address the interface, + // but not the slice itself, so get the element but set the interface + set := result + if result.Kind() == reflect.Interface { + result = result.Elem() + } + + // Create the slice if it isn't nil resultType := result.Type() resultElemType := resultType.Elem() - resultSliceType := reflect.SliceOf(resultElemType) - resultSlice := reflect.MakeSlice( - resultSliceType, 0, 0) + if result.IsNil() { + resultSliceType := reflect.SliceOf(resultElemType) + result = reflect.MakeSlice( + resultSliceType, 0, 0) + } - result.Set(resultSlice) + for i, elem := range n.Elem { + fieldName := fmt.Sprintf("%s[%d]", name, i) + // Decode + val := reflect.Indirect(reflect.New(resultElemType)) + if err := decode(fieldName, elem, val); err != nil { + return err + } + + // Append it onto the slice + result = reflect.Append(result, val) + } + + set.Set(result) return nil } diff --git a/decoder_test.go b/decoder_test.go index e89e35f..a98f6b7 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -44,7 +44,7 @@ func TestDecode(t *testing.T) { t.Fatalf("err: %s", err) } - var out map[string]interface{} + var out interface{} err = Decode(&out, string(d)) if (err != nil) != tc.Err { t.Fatalf("Input: %s\n\nError: %s", tc.File, err) @@ -72,6 +72,10 @@ func TestDecode_equal(t *testing.T) { "structure.hcl", "structure_flat.json", }, + { + "structure_multi.hcl", + "structure_multi.json", + }, } for _, tc := range cases { diff --git a/test-fixtures/structure_multi.hcl b/test-fixtures/structure_multi.hcl new file mode 100644 index 0000000..6acfcde --- /dev/null +++ b/test-fixtures/structure_multi.hcl @@ -0,0 +1,7 @@ +foo "baz" { + key = 7 +} + +foo "bar" { + key = 12 +} diff --git a/test-fixtures/structure_multi.json b/test-fixtures/structure_multi.json new file mode 100644 index 0000000..9a170f3 --- /dev/null +++ b/test-fixtures/structure_multi.json @@ -0,0 +1,11 @@ +{ + "foo": [{ + "baz": [{ + "key": 7 + }] + }, { + "bar": [{ + "key": 12 + }] + }] +}