diff --git a/internal/template/groupby.go b/internal/template/groupby.go index f8eae67d..b7063b93 100644 --- a/internal/template/groupby.go +++ b/internal/template/groupby.go @@ -2,7 +2,6 @@ package template import ( "fmt" - "reflect" "strings" "github.com/nginx-proxy/docker-gen/internal/context" @@ -18,7 +17,7 @@ func generalizedGroupBy(funcName string, entries interface{}, getValue func(inte groups := make(map[string][]interface{}) for i := 0; i < entriesVal.Len(); i++ { - v := reflect.Indirect(entriesVal.Index(i)).Interface() + v := entriesVal.Index(i).Interface() value, err := getValue(v) if err != nil { return nil, err @@ -73,13 +72,13 @@ func groupByKeys(entries interface{}, key string) ([]string, error) { // groupByLabel is the same as groupBy but over a given label func groupByLabel(entries interface{}, label string) (map[string][]interface{}, error) { getLabel := func(v interface{}) (interface{}, error) { - if container, ok := v.(context.RuntimeContainer); ok { + if container, ok := v.(*context.RuntimeContainer); ok { if value, ok := container.Labels[label]; ok { return value, nil } return nil, nil } - return nil, fmt.Errorf("must pass an array or slice of RuntimeContainer to 'groupByLabel'; received %v", v) + return nil, fmt.Errorf("must pass an array or slice of *RuntimeContainer to 'groupByLabel'; received %v", v) } return generalizedGroupBy("groupByLabel", entries, getLabel, func(groups map[string][]interface{}, value interface{}, v interface{}) { groups[value.(string)] = append(groups[value.(string)], v) diff --git a/internal/template/groupby_test.go b/internal/template/groupby_test.go index 6a2acf6a..628bb7cd 100644 --- a/internal/template/groupby_test.go +++ b/internal/template/groupby_test.go @@ -35,7 +35,7 @@ func TestGroupByExistingKey(t *testing.T) { assert.Len(t, groups, 2) assert.Len(t, groups["demo1.localhost"], 2) assert.Len(t, groups["demo2.localhost"], 1) - assert.Equal(t, "3", groups["demo2.localhost"][0].(context.RuntimeContainer).ID) + assert.Equal(t, "3", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID) } func TestGroupByAfterWhere(t *testing.T) { @@ -69,7 +69,7 @@ func TestGroupByAfterWhere(t *testing.T) { assert.Len(t, groups, 2) assert.Len(t, groups["demo1.localhost"], 1) assert.Len(t, groups["demo2.localhost"], 1) - assert.Equal(t, "3", groups["demo2.localhost"][0].(context.RuntimeContainer).ID) + assert.Equal(t, "3", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID) } func TestGroupByKeys(t *testing.T) { @@ -149,7 +149,7 @@ func TestGroupByLabel(t *testing.T) { assert.Len(t, groups["one"], 2) assert.Len(t, groups[""], 1) assert.Len(t, groups["two"], 1) - assert.Equal(t, "2", groups["two"][0].(context.RuntimeContainer).ID) + assert.Equal(t, "2", groups["two"][0].(*context.RuntimeContainer).ID) } func TestGroupByLabelError(t *testing.T) { @@ -193,13 +193,13 @@ func TestGroupByMulti(t *testing.T) { if len(groups["demo2.localhost"]) != 1 { t.Fatalf("expected 1 got %d", len(groups["demo2.localhost"])) } - if groups["demo2.localhost"][0].(context.RuntimeContainer).ID != "3" { - t.Fatalf("expected 2 got %s", groups["demo2.localhost"][0].(context.RuntimeContainer).ID) + if groups["demo2.localhost"][0].(*context.RuntimeContainer).ID != "3" { + t.Fatalf("expected 2 got %s", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID) } if len(groups["demo3.localhost"]) != 1 { t.Fatalf("expect 1 got %d", len(groups["demo3.localhost"])) } - if groups["demo3.localhost"][0].(context.RuntimeContainer).ID != "2" { - t.Fatalf("expected 2 got %s", groups["demo3.localhost"][0].(context.RuntimeContainer).ID) + if groups["demo3.localhost"][0].(*context.RuntimeContainer).ID != "2" { + t.Fatalf("expected 2 got %s", groups["demo3.localhost"][0].(*context.RuntimeContainer).ID) } } diff --git a/internal/template/reflect.go b/internal/template/reflect.go index dfb465bc..e2e16b43 100644 --- a/internal/template/reflect.go +++ b/internal/template/reflect.go @@ -2,48 +2,57 @@ package template import ( "log" + "math" "reflect" + "strconv" "strings" ) -func stripPrefix(s, prefix string) string { - path := s - for { - if strings.HasPrefix(path, ".") { - path = path[1:] - continue - } - break +func deepGetImpl(v reflect.Value, path []string) interface{} { + if !v.IsValid() { + log.Printf("invalid value\n") + return nil } - return path -} - -func deepGet(item interface{}, path string) interface{} { - if path == "" { - return item + if len(path) == 0 { + return v.Interface() } - - path = stripPrefix(path, ".") - parts := strings.Split(path, ".") - itemValue := reflect.ValueOf(item) - - if len(parts) > 0 { - switch itemValue.Kind() { - case reflect.Struct: - fieldValue := itemValue.FieldByName(parts[0]) - if fieldValue.IsValid() { - return deepGet(fieldValue.Interface(), strings.Join(parts[1:], ".")) - } - case reflect.Map: - mapValue := itemValue.MapIndex(reflect.ValueOf(parts[0])) - if mapValue.IsValid() { - return deepGet(mapValue.Interface(), strings.Join(parts[1:], ".")) - } - default: - log.Printf("Can't group by %s (value %v, kind %s)\n", path, itemValue, itemValue.Kind()) + if v.Kind() == reflect.Pointer { + v = v.Elem() + } + if v.Kind() == reflect.Pointer { + log.Printf("unable to descend into pointer of a pointer\n") + return nil + } + switch v.Kind() { + case reflect.Struct: + return deepGetImpl(v.FieldByName(path[0]), path[1:]) + case reflect.Map: + return deepGetImpl(v.MapIndex(reflect.ValueOf(path[0])), path[1:]) + case reflect.Slice, reflect.Array: + iu64, err := strconv.ParseUint(path[0], 10, 64) + if err != nil { + log.Printf("non-negative decimal number required for array/slice index, got %#v\n", path[0]) + return nil + } + if iu64 > math.MaxInt { + iu64 = math.MaxInt + } + i := int(iu64) + if i >= v.Len() { + log.Printf("index %v out of bounds", i) + return nil } + return deepGetImpl(v.Index(i), path[1:]) + default: + log.Printf("unable to index by %s (value %v, kind %s)\n", path[0], v, v.Kind()) return nil } +} - return itemValue.Interface() +func deepGet(item interface{}, path string) interface{} { + var parts []string + if path != "" { + parts = strings.Split(strings.TrimPrefix(path, "."), ".") + } + return deepGetImpl(reflect.ValueOf(item), parts) } diff --git a/internal/template/reflect_test.go b/internal/template/reflect_test.go index 0211c5e9..d8367455 100644 --- a/internal/template/reflect_test.go +++ b/internal/template/reflect_test.go @@ -34,7 +34,7 @@ func TestDeepGetSimpleDotPrefix(t *testing.T) { item := context.RuntimeContainer{ ID: "expected", } - value := deepGet(item, "...ID") + value := deepGet(item, ".ID") assert.IsType(t, "", value) assert.Equal(t, "expected", value) @@ -51,3 +51,45 @@ func TestDeepGetMap(t *testing.T) { assert.Equal(t, "value", value) } + +func TestDeepGet(t *testing.T) { + s := struct{ X string }{"foo"} + sp := &s + + for _, tc := range []struct { + desc string + item interface{} + path string + want interface{} + }{ + { + "map key empty string", + map[string]map[string]map[string]string{ + "": map[string]map[string]string{ + "": map[string]string{ + "": "foo", + }, + }, + }, + "...", + "foo", + }, + {"struct", s, "X", "foo"}, + {"pointer to struct", sp, "X", "foo"}, + {"double pointer to struct", &sp, ".X", nil}, + {"slice index", []string{"foo", "bar"}, "1", "bar"}, + {"slice index out of bounds", []string{}, "0", nil}, + {"slice index negative", []string{}, "-1", nil}, + {"slice index nonnumber", []string{}, "foo", nil}, + {"array index", [2]string{"foo", "bar"}, "1", "bar"}, + {"array index out of bounds", [1]string{"foo"}, "1", nil}, + {"array index negative", [1]string{"foo"}, "-1", nil}, + {"array index nonnumber", [1]string{"foo"}, "foo", nil}, + } { + t.Run(tc.desc, func(t *testing.T) { + got := deepGet(tc.item, tc.path) + assert.IsType(t, tc.want, got) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/internal/template/sort.go b/internal/template/sort.go index e27f5658..76689fc4 100644 --- a/internal/template/sort.go +++ b/internal/template/sort.go @@ -47,7 +47,7 @@ func (s *sortableByKey) set(funcName string, entries interface{}) (err error) { } s.data = make([]interface{}, entriesVal.Len()) for i := 0; i < entriesVal.Len(); i++ { - s.data[i] = reflect.Indirect(entriesVal.Index(i)).Interface() + s.data[i] = entriesVal.Index(i).Interface() } return } diff --git a/internal/template/sort_test.go b/internal/template/sort_test.go index 1a003713..5c7405a1 100644 --- a/internal/template/sort_test.go +++ b/internal/template/sort_test.go @@ -49,15 +49,15 @@ func TestSortObjectsByKeysAsc(t *testing.T) { assert.NoError(t, err) assert.Len(t, sorted, 4) - assert.Equal(t, "foo.localhost", sorted[0].(context.RuntimeContainer).Env["VIRTUAL_HOST"]) - assert.Equal(t, "9", sorted[3].(context.RuntimeContainer).ID) + assert.Equal(t, "foo.localhost", sorted[0].(*context.RuntimeContainer).Env["VIRTUAL_HOST"]) + assert.Equal(t, "9", sorted[3].(*context.RuntimeContainer).ID) sorted, err = sortObjectsByKeysAsc(sorted, "Env.VIRTUAL_HOST") assert.NoError(t, err) assert.Len(t, sorted, 4) - assert.Equal(t, "foo.localhost", sorted[3].(context.RuntimeContainer).Env["VIRTUAL_HOST"]) - assert.Equal(t, "8", sorted[0].(context.RuntimeContainer).ID) + assert.Equal(t, "foo.localhost", sorted[3].(*context.RuntimeContainer).Env["VIRTUAL_HOST"]) + assert.Equal(t, "8", sorted[0].(*context.RuntimeContainer).ID) } func TestSortObjectsByKeysDesc(t *testing.T) { @@ -90,13 +90,13 @@ func TestSortObjectsByKeysDesc(t *testing.T) { assert.NoError(t, err) assert.Len(t, sorted, 4) - assert.Equal(t, "bar.localhost", sorted[0].(context.RuntimeContainer).Env["VIRTUAL_HOST"]) - assert.Equal(t, "1", sorted[3].(context.RuntimeContainer).ID) + assert.Equal(t, "bar.localhost", sorted[0].(*context.RuntimeContainer).Env["VIRTUAL_HOST"]) + assert.Equal(t, "1", sorted[3].(*context.RuntimeContainer).ID) sorted, err = sortObjectsByKeysDesc(sorted, "Env.VIRTUAL_HOST") assert.NoError(t, err) assert.Len(t, sorted, 4) - assert.Equal(t, "", sorted[3].(context.RuntimeContainer).Env["VIRTUAL_HOST"]) - assert.Equal(t, "1", sorted[0].(context.RuntimeContainer).ID) + assert.Equal(t, "", sorted[3].(*context.RuntimeContainer).Env["VIRTUAL_HOST"]) + assert.Equal(t, "1", sorted[0].(*context.RuntimeContainer).ID) } diff --git a/internal/template/template.go b/internal/template/template.go index 3b8eaf80..b9c71268 100644 --- a/internal/template/template.go +++ b/internal/template/template.go @@ -30,7 +30,7 @@ func getArrayValues(funcName string, entries interface{}) (*reflect.Value, error kind := entriesVal.Kind() if kind == reflect.Ptr { - entriesVal = reflect.Indirect(entriesVal) + entriesVal = entriesVal.Elem() kind = entriesVal.Kind() } diff --git a/internal/template/where.go b/internal/template/where.go index 2ca6bdd3..ac95fc0f 100644 --- a/internal/template/where.go +++ b/internal/template/where.go @@ -18,7 +18,7 @@ func generalizedWhere(funcName string, entries interface{}, key string, test fun selection := make([]interface{}, 0) for i := 0; i < entriesVal.Len(); i++ { - v := reflect.Indirect(entriesVal.Index(i)).Interface() + v := entriesVal.Index(i).Interface() value := deepGet(v, key) if test(value) {