From e0d71f8165e76dafced6fcd1836bf82ee048ea9e Mon Sep 17 00:00:00 2001 From: Richard Hansen Date: Sat, 21 Jan 2023 03:23:42 -0500 Subject: [PATCH 1/6] chore: Refactor `deepGet` This will make it easier to fix bugs and add new features. --- internal/template/reflect.go | 45 +++++++++++++++++------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/internal/template/reflect.go b/internal/template/reflect.go index dfb465bc..32e4c4ec 100644 --- a/internal/template/reflect.go +++ b/internal/template/reflect.go @@ -18,32 +18,29 @@ func stripPrefix(s, prefix string) string { return path } -func deepGet(item interface{}, path string) interface{} { - if path == "" { - return item +func deepGetImpl(v reflect.Value, path []string) interface{} { + if !v.IsValid() { + log.Printf("invalid value\n") + return nil } - - path = stripPrefix(path, ".") - parts := strings.Split(path, ".") - itemValue := reflect.ValueOf(item) - - if len(parts) > 0 { - switch itemValue.Kind() { - case reflect.Struct: - fieldValue := itemValue.FieldByName(parts[0]) - if fieldValue.IsValid() { - return deepGet(fieldValue.Interface(), strings.Join(parts[1:], ".")) - } - case reflect.Map: - mapValue := itemValue.MapIndex(reflect.ValueOf(parts[0])) - if mapValue.IsValid() { - return deepGet(mapValue.Interface(), strings.Join(parts[1:], ".")) - } - default: - log.Printf("Can't group by %s (value %v, kind %s)\n", path, itemValue, itemValue.Kind()) - } + if len(path) == 0 { + return v.Interface() + } + switch v.Kind() { + case reflect.Struct: + return deepGetImpl(v.FieldByName(path[0]), path[1:]) + case reflect.Map: + return deepGetImpl(v.MapIndex(reflect.ValueOf(path[0])), path[1:]) + default: + log.Printf("unable to index by %s (value %v, kind %s)\n", path[0], v, v.Kind()) return nil } +} - return itemValue.Interface() +func deepGet(item interface{}, path string) interface{} { + var parts []string + if path != "" { + parts = strings.Split(stripPrefix(path, "."), ".") + } + return deepGetImpl(reflect.ValueOf(item), parts) } From a0318cd598067fdd0b32f6c8d4c6b546d9ce8d7a Mon Sep 17 00:00:00 2001 From: Richard Hansen Date: Sat, 21 Jan 2023 03:16:10 -0500 Subject: [PATCH 2/6] fix: Only strip the first "." from the path passed to `deepGet` This makes it possible to use the empty string as a map key. --- internal/template/reflect.go | 14 +------------- internal/template/reflect_test.go | 30 +++++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/internal/template/reflect.go b/internal/template/reflect.go index 32e4c4ec..92a70046 100644 --- a/internal/template/reflect.go +++ b/internal/template/reflect.go @@ -6,18 +6,6 @@ import ( "strings" ) -func stripPrefix(s, prefix string) string { - path := s - for { - if strings.HasPrefix(path, ".") { - path = path[1:] - continue - } - break - } - return path -} - func deepGetImpl(v reflect.Value, path []string) interface{} { if !v.IsValid() { log.Printf("invalid value\n") @@ -40,7 +28,7 @@ func deepGetImpl(v reflect.Value, path []string) interface{} { func deepGet(item interface{}, path string) interface{} { var parts []string if path != "" { - parts = strings.Split(stripPrefix(path, "."), ".") + parts = strings.Split(strings.TrimPrefix(path, "."), ".") } return deepGetImpl(reflect.ValueOf(item), parts) } diff --git a/internal/template/reflect_test.go b/internal/template/reflect_test.go index 0211c5e9..9587f3a2 100644 --- a/internal/template/reflect_test.go +++ b/internal/template/reflect_test.go @@ -34,7 +34,7 @@ func TestDeepGetSimpleDotPrefix(t *testing.T) { item := context.RuntimeContainer{ ID: "expected", } - value := deepGet(item, "...ID") + value := deepGet(item, ".ID") assert.IsType(t, "", value) assert.Equal(t, "expected", value) @@ -51,3 +51,31 @@ func TestDeepGetMap(t *testing.T) { assert.Equal(t, "value", value) } + +func TestDeepGet(t *testing.T) { + for _, tc := range []struct { + desc string + item interface{} + path string + want interface{} + }{ + { + "map key empty string", + map[string]map[string]map[string]string{ + "": map[string]map[string]string{ + "": map[string]string{ + "": "foo", + }, + }, + }, + "...", + "foo", + }, + } { + t.Run(tc.desc, func(t *testing.T) { + got := deepGet(tc.item, tc.path) + assert.IsType(t, tc.want, got) + assert.Equal(t, tc.want, got) + }) + } +} From 8648033193270d6887d084b064f9d8b33dd5adba Mon Sep 17 00:00:00 2001 From: Richard Hansen Date: Sat, 21 Jan 2023 03:25:53 -0500 Subject: [PATCH 3/6] feat: Automatically dereference pointer types in `deepGet` This matches the behavior of Go, and makes it possible to use `groupBy` and friends on a slice of `*RuntimeContainers`. --- internal/template/reflect.go | 7 +++++++ internal/template/reflect_test.go | 6 ++++++ 2 files changed, 13 insertions(+) diff --git a/internal/template/reflect.go b/internal/template/reflect.go index 92a70046..4cac6f64 100644 --- a/internal/template/reflect.go +++ b/internal/template/reflect.go @@ -14,6 +14,13 @@ func deepGetImpl(v reflect.Value, path []string) interface{} { if len(path) == 0 { return v.Interface() } + if v.Kind() == reflect.Pointer { + v = v.Elem() + } + if v.Kind() == reflect.Pointer { + log.Printf("unable to descend into pointer of a pointer\n") + return nil + } switch v.Kind() { case reflect.Struct: return deepGetImpl(v.FieldByName(path[0]), path[1:]) diff --git a/internal/template/reflect_test.go b/internal/template/reflect_test.go index 9587f3a2..5f5e437c 100644 --- a/internal/template/reflect_test.go +++ b/internal/template/reflect_test.go @@ -53,6 +53,9 @@ func TestDeepGetMap(t *testing.T) { } func TestDeepGet(t *testing.T) { + s := struct{ X string }{"foo"} + sp := &s + for _, tc := range []struct { desc string item interface{} @@ -71,6 +74,9 @@ func TestDeepGet(t *testing.T) { "...", "foo", }, + {"struct", s, "X", "foo"}, + {"pointer to struct", sp, "X", "foo"}, + {"double pointer to struct", &sp, ".X", nil}, } { t.Run(tc.desc, func(t *testing.T) { got := deepGet(tc.item, tc.path) From ba97c3705dc09f83cbee8e682217ef4af5c1243d Mon Sep 17 00:00:00 2001 From: Richard Hansen Date: Sat, 21 Jan 2023 03:27:31 -0500 Subject: [PATCH 4/6] feat: Support indexing slices and arrays in `deepGet` --- internal/template/reflect.go | 17 +++++++++++++++++ internal/template/reflect_test.go | 8 ++++++++ 2 files changed, 25 insertions(+) diff --git a/internal/template/reflect.go b/internal/template/reflect.go index 4cac6f64..e2e16b43 100644 --- a/internal/template/reflect.go +++ b/internal/template/reflect.go @@ -2,7 +2,9 @@ package template import ( "log" + "math" "reflect" + "strconv" "strings" ) @@ -26,6 +28,21 @@ func deepGetImpl(v reflect.Value, path []string) interface{} { return deepGetImpl(v.FieldByName(path[0]), path[1:]) case reflect.Map: return deepGetImpl(v.MapIndex(reflect.ValueOf(path[0])), path[1:]) + case reflect.Slice, reflect.Array: + iu64, err := strconv.ParseUint(path[0], 10, 64) + if err != nil { + log.Printf("non-negative decimal number required for array/slice index, got %#v\n", path[0]) + return nil + } + if iu64 > math.MaxInt { + iu64 = math.MaxInt + } + i := int(iu64) + if i >= v.Len() { + log.Printf("index %v out of bounds", i) + return nil + } + return deepGetImpl(v.Index(i), path[1:]) default: log.Printf("unable to index by %s (value %v, kind %s)\n", path[0], v, v.Kind()) return nil diff --git a/internal/template/reflect_test.go b/internal/template/reflect_test.go index 5f5e437c..d8367455 100644 --- a/internal/template/reflect_test.go +++ b/internal/template/reflect_test.go @@ -77,6 +77,14 @@ func TestDeepGet(t *testing.T) { {"struct", s, "X", "foo"}, {"pointer to struct", sp, "X", "foo"}, {"double pointer to struct", &sp, ".X", nil}, + {"slice index", []string{"foo", "bar"}, "1", "bar"}, + {"slice index out of bounds", []string{}, "0", nil}, + {"slice index negative", []string{}, "-1", nil}, + {"slice index nonnumber", []string{}, "foo", nil}, + {"array index", [2]string{"foo", "bar"}, "1", "bar"}, + {"array index out of bounds", [1]string{"foo"}, "1", nil}, + {"array index negative", [1]string{"foo"}, "-1", nil}, + {"array index nonnumber", [1]string{"foo"}, "foo", nil}, } { t.Run(tc.desc, func(t *testing.T) { got := deepGet(tc.item, tc.path) From 218a63df94a2bbe16f00d8e02a0276c3515daa87 Mon Sep 17 00:00:00 2001 From: Richard Hansen Date: Sat, 21 Jan 2023 04:48:59 -0500 Subject: [PATCH 5/6] fix: Don't dereference pointers in `where`, `groupBy`, `sort`, etc. --- internal/template/groupby.go | 7 +++---- internal/template/groupby_test.go | 14 +++++++------- internal/template/sort.go | 2 +- internal/template/sort_test.go | 16 ++++++++-------- internal/template/where.go | 2 +- 5 files changed, 20 insertions(+), 21 deletions(-) diff --git a/internal/template/groupby.go b/internal/template/groupby.go index f8eae67d..b7063b93 100644 --- a/internal/template/groupby.go +++ b/internal/template/groupby.go @@ -2,7 +2,6 @@ package template import ( "fmt" - "reflect" "strings" "github.com/nginx-proxy/docker-gen/internal/context" @@ -18,7 +17,7 @@ func generalizedGroupBy(funcName string, entries interface{}, getValue func(inte groups := make(map[string][]interface{}) for i := 0; i < entriesVal.Len(); i++ { - v := reflect.Indirect(entriesVal.Index(i)).Interface() + v := entriesVal.Index(i).Interface() value, err := getValue(v) if err != nil { return nil, err @@ -73,13 +72,13 @@ func groupByKeys(entries interface{}, key string) ([]string, error) { // groupByLabel is the same as groupBy but over a given label func groupByLabel(entries interface{}, label string) (map[string][]interface{}, error) { getLabel := func(v interface{}) (interface{}, error) { - if container, ok := v.(context.RuntimeContainer); ok { + if container, ok := v.(*context.RuntimeContainer); ok { if value, ok := container.Labels[label]; ok { return value, nil } return nil, nil } - return nil, fmt.Errorf("must pass an array or slice of RuntimeContainer to 'groupByLabel'; received %v", v) + return nil, fmt.Errorf("must pass an array or slice of *RuntimeContainer to 'groupByLabel'; received %v", v) } return generalizedGroupBy("groupByLabel", entries, getLabel, func(groups map[string][]interface{}, value interface{}, v interface{}) { groups[value.(string)] = append(groups[value.(string)], v) diff --git a/internal/template/groupby_test.go b/internal/template/groupby_test.go index 6a2acf6a..628bb7cd 100644 --- a/internal/template/groupby_test.go +++ b/internal/template/groupby_test.go @@ -35,7 +35,7 @@ func TestGroupByExistingKey(t *testing.T) { assert.Len(t, groups, 2) assert.Len(t, groups["demo1.localhost"], 2) assert.Len(t, groups["demo2.localhost"], 1) - assert.Equal(t, "3", groups["demo2.localhost"][0].(context.RuntimeContainer).ID) + assert.Equal(t, "3", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID) } func TestGroupByAfterWhere(t *testing.T) { @@ -69,7 +69,7 @@ func TestGroupByAfterWhere(t *testing.T) { assert.Len(t, groups, 2) assert.Len(t, groups["demo1.localhost"], 1) assert.Len(t, groups["demo2.localhost"], 1) - assert.Equal(t, "3", groups["demo2.localhost"][0].(context.RuntimeContainer).ID) + assert.Equal(t, "3", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID) } func TestGroupByKeys(t *testing.T) { @@ -149,7 +149,7 @@ func TestGroupByLabel(t *testing.T) { assert.Len(t, groups["one"], 2) assert.Len(t, groups[""], 1) assert.Len(t, groups["two"], 1) - assert.Equal(t, "2", groups["two"][0].(context.RuntimeContainer).ID) + assert.Equal(t, "2", groups["two"][0].(*context.RuntimeContainer).ID) } func TestGroupByLabelError(t *testing.T) { @@ -193,13 +193,13 @@ func TestGroupByMulti(t *testing.T) { if len(groups["demo2.localhost"]) != 1 { t.Fatalf("expected 1 got %d", len(groups["demo2.localhost"])) } - if groups["demo2.localhost"][0].(context.RuntimeContainer).ID != "3" { - t.Fatalf("expected 2 got %s", groups["demo2.localhost"][0].(context.RuntimeContainer).ID) + if groups["demo2.localhost"][0].(*context.RuntimeContainer).ID != "3" { + t.Fatalf("expected 2 got %s", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID) } if len(groups["demo3.localhost"]) != 1 { t.Fatalf("expect 1 got %d", len(groups["demo3.localhost"])) } - if groups["demo3.localhost"][0].(context.RuntimeContainer).ID != "2" { - t.Fatalf("expected 2 got %s", groups["demo3.localhost"][0].(context.RuntimeContainer).ID) + if groups["demo3.localhost"][0].(*context.RuntimeContainer).ID != "2" { + t.Fatalf("expected 2 got %s", groups["demo3.localhost"][0].(*context.RuntimeContainer).ID) } } diff --git a/internal/template/sort.go b/internal/template/sort.go index e27f5658..76689fc4 100644 --- a/internal/template/sort.go +++ b/internal/template/sort.go @@ -47,7 +47,7 @@ func (s *sortableByKey) set(funcName string, entries interface{}) (err error) { } s.data = make([]interface{}, entriesVal.Len()) for i := 0; i < entriesVal.Len(); i++ { - s.data[i] = reflect.Indirect(entriesVal.Index(i)).Interface() + s.data[i] = entriesVal.Index(i).Interface() } return } diff --git a/internal/template/sort_test.go b/internal/template/sort_test.go index 1a003713..5c7405a1 100644 --- a/internal/template/sort_test.go +++ b/internal/template/sort_test.go @@ -49,15 +49,15 @@ func TestSortObjectsByKeysAsc(t *testing.T) { assert.NoError(t, err) assert.Len(t, sorted, 4) - assert.Equal(t, "foo.localhost", sorted[0].(context.RuntimeContainer).Env["VIRTUAL_HOST"]) - assert.Equal(t, "9", sorted[3].(context.RuntimeContainer).ID) + assert.Equal(t, "foo.localhost", sorted[0].(*context.RuntimeContainer).Env["VIRTUAL_HOST"]) + assert.Equal(t, "9", sorted[3].(*context.RuntimeContainer).ID) sorted, err = sortObjectsByKeysAsc(sorted, "Env.VIRTUAL_HOST") assert.NoError(t, err) assert.Len(t, sorted, 4) - assert.Equal(t, "foo.localhost", sorted[3].(context.RuntimeContainer).Env["VIRTUAL_HOST"]) - assert.Equal(t, "8", sorted[0].(context.RuntimeContainer).ID) + assert.Equal(t, "foo.localhost", sorted[3].(*context.RuntimeContainer).Env["VIRTUAL_HOST"]) + assert.Equal(t, "8", sorted[0].(*context.RuntimeContainer).ID) } func TestSortObjectsByKeysDesc(t *testing.T) { @@ -90,13 +90,13 @@ func TestSortObjectsByKeysDesc(t *testing.T) { assert.NoError(t, err) assert.Len(t, sorted, 4) - assert.Equal(t, "bar.localhost", sorted[0].(context.RuntimeContainer).Env["VIRTUAL_HOST"]) - assert.Equal(t, "1", sorted[3].(context.RuntimeContainer).ID) + assert.Equal(t, "bar.localhost", sorted[0].(*context.RuntimeContainer).Env["VIRTUAL_HOST"]) + assert.Equal(t, "1", sorted[3].(*context.RuntimeContainer).ID) sorted, err = sortObjectsByKeysDesc(sorted, "Env.VIRTUAL_HOST") assert.NoError(t, err) assert.Len(t, sorted, 4) - assert.Equal(t, "", sorted[3].(context.RuntimeContainer).Env["VIRTUAL_HOST"]) - assert.Equal(t, "1", sorted[0].(context.RuntimeContainer).ID) + assert.Equal(t, "", sorted[3].(*context.RuntimeContainer).Env["VIRTUAL_HOST"]) + assert.Equal(t, "1", sorted[0].(*context.RuntimeContainer).ID) } diff --git a/internal/template/where.go b/internal/template/where.go index 2ca6bdd3..ac95fc0f 100644 --- a/internal/template/where.go +++ b/internal/template/where.go @@ -18,7 +18,7 @@ func generalizedWhere(funcName string, entries interface{}, key string, test fun selection := make([]interface{}, 0) for i := 0; i < entriesVal.Len(); i++ { - v := reflect.Indirect(entriesVal.Index(i)).Interface() + v := entriesVal.Index(i).Interface() value := deepGet(v, key) if test(value) { From 298e6507b812f7b1ce4cbc4204b03ee7c7eab933 Mon Sep 17 00:00:00 2001 From: Richard Hansen Date: Sat, 21 Jan 2023 04:50:13 -0500 Subject: [PATCH 6/6] chore: Use `Value.Elem` instead of `reflect.Indirect` for readability --- internal/template/template.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/template/template.go b/internal/template/template.go index 3b8eaf80..b9c71268 100644 --- a/internal/template/template.go +++ b/internal/template/template.go @@ -30,7 +30,7 @@ func getArrayValues(funcName string, entries interface{}) (*reflect.Value, error kind := entriesVal.Kind() if kind == reflect.Ptr { - entriesVal = reflect.Indirect(entriesVal) + entriesVal = entriesVal.Elem() kind = entriesVal.Kind() }