diff --git a/mapping.go b/mapping.go index d72a15e..16057a3 100644 --- a/mapping.go +++ b/mapping.go @@ -4,12 +4,24 @@ package dig import ( + "encoding/json" + "fmt" "reflect" ) // Mapping is a nested key-value map where the keys are strings and values are any. In Ruby it is called a Hash (with string keys), in YAML it's called a "mapping". type Mapping map[string]any +// UnmarshalText for supporting json.Unmarshal +func (m *Mapping) UnmarshalJSON(text []byte) error { + var result map[string]any + if err := json.Unmarshal(text, &result); err != nil { + return err + } + *m = cleanUpInterfaceMap(result) + return nil +} + // UnmarshalYAML for supporting yaml.Unmarshal func (m *Mapping) UnmarshalYAML(unmarshal func(any) error) error { var result map[string]any @@ -131,13 +143,21 @@ func cleanUpInterfaceArray(in []any) []any { // Cleans up the map keys to be strings func cleanUpInterfaceMap(in map[string]any) Mapping { - result := make(Mapping) + result := make(Mapping, len(in)) for k, v := range in { result[k] = cleanUpValue(v) } return result } +func stringifyKeys(in map[any]any) map[string]any { + result := make(map[string]any) + for k, v := range in { + result[fmt.Sprintf("%v", k)] = v + } + return result +} + // Cleans up the value in the map, recurses in case of arrays and maps func cleanUpValue(v any) any { switch v := v.(type) { @@ -145,6 +165,8 @@ func cleanUpValue(v any) any { return cleanUpInterfaceArray(v) case map[string]any: return cleanUpInterfaceMap(v) + case map[any]any: + return cleanUpInterfaceMap(stringifyKeys(v)) default: return v } diff --git a/mapping_test.go b/mapping_test.go index 0f0507e..b956827 100644 --- a/mapping_test.go +++ b/mapping_test.go @@ -45,10 +45,12 @@ func TestDig(t *testing.T) { m.DigMapping("foo")["int"] = 1 mustEqual(t, 1, m.Dig("foo", "int")) }) + t.Run("float value", func(t *testing.T) { m.DigMapping("foo")["float"] = 0.5 mustEqual(t, 0.5, m.Dig("foo", "float")) }) + t.Run("bool value", func(t *testing.T) { m.DigMapping("foo")["bool"] = true mustEqual(t, true, m.Dig("foo", "bool")) @@ -155,6 +157,17 @@ func TestUnmarshalJSONWithFloat(t *testing.T) { mustEqual(t, 0.5, val) } +func TestUnmarshalJSONWithSliceOfMaps(t *testing.T) { + data := []byte(`{"foo": [{"bar": "baz"}]}`) + var m dig.Mapping + mustBeNil(t, json.Unmarshal(data, &m)) + val, ok := m.Dig("foo").([]any) + mustEqual(t, true, ok) + obj, ok := val[0].(dig.Mapping) + mustEqual(t, true, ok) + mustEqual(t, "baz", obj["bar"]) +} + func ExampleMapping_Dig() { h := dig.Mapping{ "greeting": dig.Mapping{