diff --git a/mapstr/mapstr.go b/mapstr/mapstr.go index e507ce5..9d92b71 100644 --- a/mapstr/mapstr.go +++ b/mapstr/mapstr.go @@ -181,55 +181,145 @@ func (m M) HasKey(key string) (bool, error) { // key of the map that matched the given key and the value stored under this key. // Returns `ErrKeyCollision` if multiple keys match the same request. // Returns `ErrNotMapType` when one of the values on the path is not a map and cannot be traversed. -func (m M) FindFold(key string) (matchedKey string, value interface{}, err error) { - path := strings.Split(key, ".") +// Returns `ErrKeyNotFound` when the path does not exist +func (m M) FindFold(path string) (matchedKey string, value interface{}, err error) { + segmentCount := strings.Count(path, ".") + 1 + err = m.Traverse(path, CaseInsensitiveMode, func(level M, key string) error { + segmentCount-- + matchedKey += key + if segmentCount != 0 { + matchedKey += "." + return nil + } + + value = level[key] + return nil + }) + if err != nil { + return "", nil, err + } + return matchedKey, value, nil +} + +type AlterFunc func(string) (string, error) + +// AlterPath walks the given `path` and replaces matching keys using the value returned by `alterFunc`. +// `mode` sets the behavior how the given path is matched throughout the levels. +// Returns `ErrKeyCollision` if multiple keys match the same request (when `mode` is `CaseInsensitiveMode`). +// Returns `ErrNotMapType` when one of the values on the path is not a map and cannot be traversed. +// Returns `ErrKeyNotFound` when the path does not exist +func (m M) AlterPath(path string, mode TraversalMode, alterFunc AlterFunc) (err error) { + return m.Traverse(path, mode, func(level M, key string) error { + val := level[key] + newKey, err := alterFunc(key) + if err != nil { + return fmt.Errorf("failed to apply a change to %q: %w", key, err) + } + if newKey == "" { + return fmt.Errorf("replacement key for %q cannot be empty", key) + } + _, exists := level[newKey] + if exists { + return fmt.Errorf("replacement key %q already exists: %w", newKey, ErrKeyCollision) + } + delete(level, key) + level[newKey] = val + + return nil + }) +} + +// TraversalMode used for traversing the map through multiple levels. +type TraversalMode int + +const ( + // The key match is strictly case-sensitive + CaseSensitiveMode = iota + // The key match is performed with `strings.EqualFold` + CaseInsensitiveMode = iota +) + +type TraversalVisitor func(M, string) error + +// Traverse walks the given nested `path` in the map and invokes the `visitor` function on each level passing +// the current-level map and the current key. +// `mode` sets the behavior how the given path is matched throughout the levels. +// The `visitor` function is allowed to make changes in the level or collect data. +// Returns `ErrKeyCollision` if multiple keys match the same request (when `mode` is `CaseInsensitiveMode`). +// Returns `ErrNotMapType` when one of the values on the path is not a map and cannot be traversed. +// Returns `ErrKeyNotFound` when the path does not exist +func (m M) Traverse(path string, mode TraversalMode, visitor TraversalVisitor) (err error) { + segments := strings.Split(path, ".") + var match func(string, string) bool + + switch mode { + case CaseInsensitiveMode: + match = strings.EqualFold + case CaseSensitiveMode: + match = func(a, b string) bool { return a == b } + } + // the initial value must be `true` for the first iteration to work found := true // start with the root current := m // allocate only once - var mapType bool + var ( + mapType bool + next interface{} + ) - for i, segment := range path { + for i, segment := range segments { if !found { - return "", nil, ErrKeyNotFound + return ErrKeyNotFound } found = false // we have to go through the list of all key on each level to detect case-insensitive collisions for k := range current { - if !strings.EqualFold(segment, k) { + if !match(segment, k) { continue } + // if already found on this level, it's a collision if found { - return "", nil, fmt.Errorf("key collision on the same path %q, previous match - %q, another subkey - %q: %w", key, matchedKey, k, ErrKeyCollision) + return fmt.Errorf("multiple keys match %q on the same level of the path %q: %w", k, path, ErrKeyCollision) } // mark for collision detection found = true - // build the result with the currently matched segment - matchedKey += k - value = current[k] + // we need to save this in case the visitor makes changes in keys + next = current[k] + err = visitor(current, k) + if err != nil { + return fmt.Errorf("error visiting key %q of the path %q: %w", k, path, err) + } - // if it's the last segment, we don't need to go deeper - if i == len(path)-1 { + // if it's the last segment, we don't need to go deeper, skipping... + if i == len(segments)-1 { continue } - // if it's not the last segment we put the separator dot - matchedKey += "." - - // go one level deeper - current, mapType = tryToMapStr(current[k]) + // try to go one level deeper + current, mapType = tryToMapStr(next) if !mapType { - return "", nil, fmt.Errorf("cannot continue path %q (full: %q), next element is not a map: %w", matchedKey, key, ErrNotMapType) + return fmt.Errorf("cannot continue path %q, next value %q is not a map: %w", path, k, ErrNotMapType) + } + + // if it's a case-sensitive key match, we don't have to care about collision detection + // and we can simply stop iterating here. + if mode == CaseSensitiveMode { + break } } } - return matchedKey, value, nil + if !found { + return ErrKeyNotFound + } + + return nil } // GetValue gets a value from the map. If the key does not exist then an error diff --git a/mapstr/mapstr_test.go b/mapstr/mapstr_test.go index 85be4f2..2d3a791 100644 --- a/mapstr/mapstr_test.go +++ b/mapstr/mapstr_test.go @@ -21,6 +21,7 @@ package mapstr import ( "encoding/json" + "errors" "fmt" "strings" "testing" @@ -1185,7 +1186,7 @@ func TestFindFold(t *testing.T) { { name: "returns non-map error", key: "level1_field1.non_map.some_key", - expErr: "next element is not a map", + expErr: "is not a map", }, { name: "returns non-found error", @@ -1210,3 +1211,221 @@ func TestFindFold(t *testing.T) { }) } } + +func TestAlterPath(t *testing.T) { + var ( + lower AlterFunc = func(s string) (string, error) { + return strings.ToLower(s), nil + } + + exclamation AlterFunc = func(s string) (string, error) { + return s + "!", nil + } + + empty AlterFunc = func(string) (string, error) { + return "", nil + } + + errorFunc AlterFunc = func(string) (string, error) { + return "", errors.New("oops") + } + ) + + cases := []struct { + name string + from string + mode TraversalMode + alterFunc AlterFunc + m M + exp M + expErr string + }{ + { + name: "alters keys on root level with case-insensitive matching", + from: "level1_field1", + mode: CaseInsensitiveMode, + alterFunc: lower, + m: M{ + "level1_Field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "Value2", + "level3_Field1": "Value3", + }, + }, + }, + exp: M{ + "level1_field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "Value2", + "level3_Field1": "Value3", + }, + }, + }, + }, + { + name: "alters keys on all nested levels with case-insensitive matching", + from: "level1_field1.level2_field1.level3_field1", + mode: CaseInsensitiveMode, + alterFunc: lower, + m: M{ + "level1_Field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "Value2", + "level3_Field1": "Value3", + }, + }, + }, + exp: M{ + "level1_field1": M{ + "Key": "value1", + "level2_field1": M{ + "Key": "Value2", + "level3_field1": "Value3", + }, + }, + }, + }, + { + name: "alters keys on all nested levels with case-sensitive matchig", + from: "level1_Field1.level2_Field1.level3_Field1", + mode: CaseSensitiveMode, + alterFunc: exclamation, + m: M{ + "level1_Field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "Value2", + "level3_Field1": "Value3", + }, + }, + }, + exp: M{ + "level1_Field1!": M{ + "Key": "value1", + "level2_Field1!": M{ + "Key": "Value2", + "level3_Field1!": "Value3", + }, + }, + }, + }, + { + name: "errors if the source does not exist", + from: "level1_Field1.NOT_EXIST.level3_Field1", + mode: CaseInsensitiveMode, + alterFunc: lower, + m: M{ + "level1_Field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "value2", + "level3_Field1": "value3", + }, + }, + }, + expErr: "key not found", + }, + { + name: "errors if the casing does not match", + from: "level1_Field1.level2_field1.level3_Field1", + mode: CaseSensitiveMode, + alterFunc: lower, + m: M{ + "level1_Field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "value2", + "level3_Field1": "value3", + }, + }, + }, + expErr: "key not found", + }, + { + name: "errors if the last segment does not match", + from: "level1_Field1.level2_Field1.level3_field1", + mode: CaseSensitiveMode, + alterFunc: lower, + m: M{ + "level1_Field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "value2", + "level3_Field1": "value3", + }, + }, + }, + expErr: "key not found", + }, + { + name: "errors if the new name already exists", + from: "level1_Field1.level2_Field1.level3_Field1", + mode: CaseInsensitiveMode, + alterFunc: lower, + m: M{ + "level1_Field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "value2", + "level3_Field1": "value3", + }, + "Level2_field1": M{ + "Key": "value4", + "level3_Field2": "value5", + }, + }, + }, + expErr: "key collision", + }, + { + name: "errors if alter function returns empty string", + from: "level1_Field1.level2_Field1.level3_Field1", + mode: CaseInsensitiveMode, + alterFunc: empty, + m: M{ + "level1_Field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "value2", + "level3_Field1": "value3", + }, + }, + }, + expErr: "cannot be empty", + }, + { + name: "errors if alter function returns error", + from: "level1_Field1.level2_Field1.level3_Field1", + mode: CaseInsensitiveMode, + alterFunc: errorFunc, + m: M{ + "level1_Field1": M{ + "Key": "value1", + "level2_Field1": M{ + "Key": "value2", + "level3_Field1": "value3", + }, + }, + }, + expErr: "oops", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cloned := tc.m.Clone() // we need to preserve the initial state + + err := cloned.AlterPath(tc.from, tc.mode, tc.alterFunc) + if tc.expErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expErr) + return + } + require.NoError(t, err) + assert.Equal(t, tc.exp.StringToPrint(), cloned.StringToPrint()) + }) + } +}