From 010e045c4eb965ad42d8feccd193abb1b80ee42f Mon Sep 17 00:00:00 2001 From: Robert Findley Date: Fri, 1 Sep 2023 10:05:02 -0400 Subject: [PATCH] internal/persistent: use generics Now that we're on 1.18+, make internal/persistent.Map generic. Change-Id: I3403241fe22e28f969d7feb09a752b52f0d2ee4d Reviewed-on: https://go-review.googlesource.com/c/tools/+/524759 gopls-CI: kokoro LUCI-TryBot-Result: Go LUCI Reviewed-by: Alan Donovan --- gopls/internal/lsp/cache/check.go | 9 ++-- gopls/internal/lsp/cache/load.go | 3 +- gopls/internal/lsp/cache/maps.go | 31 +++-------- gopls/internal/lsp/cache/mod.go | 6 +-- gopls/internal/lsp/cache/mod_tidy.go | 2 +- gopls/internal/lsp/cache/mod_vuln.go | 2 +- gopls/internal/lsp/cache/session.go | 17 +++--- gopls/internal/lsp/cache/snapshot.go | 27 +++++----- gopls/internal/lsp/cache/symbols.go | 3 +- internal/constraints/constraint.go | 52 ++++++++++++++++++ internal/persistent/map.go | 81 ++++++++++++++-------------- internal/persistent/map_test.go | 36 ++++++------- 12 files changed, 148 insertions(+), 121 deletions(-) create mode 100644 internal/constraints/constraint.go diff --git a/gopls/internal/lsp/cache/check.go b/gopls/internal/lsp/cache/check.go index 74404af98c1..b7267e983ec 100644 --- a/gopls/internal/lsp/cache/check.go +++ b/gopls/internal/lsp/cache/check.go @@ -849,7 +849,7 @@ func (s *snapshot) getPackageHandles(ctx context.Context, ids []PackageID) (map[ unfinishedSuccs: int32(len(m.DepsByPkgPath)), } if entry, hit := b.s.packages.Get(m.ID); hit { - n.ph = entry.(*packageHandle) + n.ph = entry } if n.unfinishedSuccs == 0 { leaves = append(leaves, n) @@ -1118,12 +1118,11 @@ func (b *packageHandleBuilder) buildPackageHandle(ctx context.Context, n *handle } // Check the packages map again in case another goroutine got there first. - if alt, ok := b.s.packages.Get(n.m.ID); ok && alt.(*packageHandle).validated { - altPH := alt.(*packageHandle) - if altPH.m != n.m { + if alt, ok := b.s.packages.Get(n.m.ID); ok && alt.validated { + if alt.m != n.m { bug.Reportf("existing package handle does not match for %s", n.m.ID) } - n.ph = altPH + n.ph = alt } else { b.s.packages.Set(n.m.ID, n.ph, nil) } diff --git a/gopls/internal/lsp/cache/load.go b/gopls/internal/lsp/cache/load.go index 05d44329c20..03db2a35d0d 100644 --- a/gopls/internal/lsp/cache/load.go +++ b/gopls/internal/lsp/cache/load.go @@ -217,8 +217,7 @@ func (s *snapshot) load(ctx context.Context, allowNetwork bool, scopes ...loadSc s.mu.Lock() // Assert the invariant s.packages.Get(id).m == s.meta.metadata[id]. - s.packages.Range(func(k, v interface{}) { - id, ph := k.(PackageID), v.(*packageHandle) + s.packages.Range(func(id PackageID, ph *packageHandle) { if s.meta.metadata[id] != ph.m { panic("inconsistent metadata") } diff --git a/gopls/internal/lsp/cache/maps.go b/gopls/internal/lsp/cache/maps.go index de6187da255..3fa866cb840 100644 --- a/gopls/internal/lsp/cache/maps.go +++ b/gopls/internal/lsp/cache/maps.go @@ -10,21 +10,14 @@ import ( "golang.org/x/tools/internal/persistent" ) -// TODO(euroelessar): Use generics once support for go1.17 is dropped. - type filesMap struct { - impl *persistent.Map + impl *persistent.Map[span.URI, source.FileHandle] overlayMap map[span.URI]*Overlay // the subset that are overlays } -// uriLessInterface is the < relation for "any" values containing span.URIs. -func uriLessInterface(a, b interface{}) bool { - return a.(span.URI) < b.(span.URI) -} - func newFilesMap() filesMap { return filesMap{ - impl: persistent.NewMap(uriLessInterface), + impl: new(persistent.Map[span.URI, source.FileHandle]), overlayMap: make(map[span.URI]*Overlay), } } @@ -53,9 +46,7 @@ func (m filesMap) Get(key span.URI) (source.FileHandle, bool) { } func (m filesMap) Range(do func(key span.URI, value source.FileHandle)) { - m.impl.Range(func(key, value interface{}) { - do(key.(span.URI), value.(source.FileHandle)) - }) + m.impl.Range(do) } func (m filesMap) Set(key span.URI, value source.FileHandle) { @@ -86,19 +77,13 @@ func (m filesMap) overlays() []*Overlay { return overlays } -func packageIDLessInterface(x, y interface{}) bool { - return x.(PackageID) < y.(PackageID) -} - type knownDirsSet struct { - impl *persistent.Map + impl *persistent.Map[span.URI, struct{}] } func newKnownDirsSet() knownDirsSet { return knownDirsSet{ - impl: persistent.NewMap(func(a, b interface{}) bool { - return a.(span.URI) < b.(span.URI) - }), + impl: new(persistent.Map[span.URI, struct{}]), } } @@ -118,8 +103,8 @@ func (s knownDirsSet) Contains(key span.URI) bool { } func (s knownDirsSet) Range(do func(key span.URI)) { - s.impl.Range(func(key, value interface{}) { - do(key.(span.URI)) + s.impl.Range(func(key span.URI, value struct{}) { + do(key) }) } @@ -128,7 +113,7 @@ func (s knownDirsSet) SetAll(other knownDirsSet) { } func (s knownDirsSet) Insert(key span.URI) { - s.impl.Set(key, nil, nil) + s.impl.Set(key, struct{}{}, nil) } func (s knownDirsSet) Remove(key span.URI) { diff --git a/gopls/internal/lsp/cache/mod.go b/gopls/internal/lsp/cache/mod.go index db0ab0a64b8..8a452ab086d 100644 --- a/gopls/internal/lsp/cache/mod.go +++ b/gopls/internal/lsp/cache/mod.go @@ -52,7 +52,7 @@ func (s *snapshot) ParseMod(ctx context.Context, fh source.FileHandle) (*source. } // Await result. - v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) + v, err := s.awaitPromise(ctx, entry) if err != nil { return nil, err } @@ -130,7 +130,7 @@ func (s *snapshot) ParseWork(ctx context.Context, fh source.FileHandle) (*source } // Await result. - v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) + v, err := s.awaitPromise(ctx, entry) if err != nil { return nil, err } @@ -240,7 +240,7 @@ func (s *snapshot) ModWhy(ctx context.Context, fh source.FileHandle) (map[string } // Await result. - v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) + v, err := s.awaitPromise(ctx, entry) if err != nil { return nil, err } diff --git a/gopls/internal/lsp/cache/mod_tidy.go b/gopls/internal/lsp/cache/mod_tidy.go index 64e02d1c01e..b806edb7499 100644 --- a/gopls/internal/lsp/cache/mod_tidy.go +++ b/gopls/internal/lsp/cache/mod_tidy.go @@ -85,7 +85,7 @@ func (s *snapshot) ModTidy(ctx context.Context, pm *source.ParsedModule) (*sourc } // Await result. - v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) + v, err := s.awaitPromise(ctx, entry) if err != nil { return nil, err } diff --git a/gopls/internal/lsp/cache/mod_vuln.go b/gopls/internal/lsp/cache/mod_vuln.go index 942ca52525c..dcd58bfa94a 100644 --- a/gopls/internal/lsp/cache/mod_vuln.go +++ b/gopls/internal/lsp/cache/mod_vuln.go @@ -55,7 +55,7 @@ func (s *snapshot) ModVuln(ctx context.Context, modURI span.URI) (*govulncheck.R } // Await result. - v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) + v, err := s.awaitPromise(ctx, entry) if err != nil { return nil, err } diff --git a/gopls/internal/lsp/cache/session.go b/gopls/internal/lsp/cache/session.go index 6b75f10b36f..cd51e6d498a 100644 --- a/gopls/internal/lsp/cache/session.go +++ b/gopls/internal/lsp/cache/session.go @@ -20,6 +20,7 @@ import ( "golang.org/x/tools/internal/event" "golang.org/x/tools/internal/gocommand" "golang.org/x/tools/internal/imports" + "golang.org/x/tools/internal/memoize" "golang.org/x/tools/internal/persistent" "golang.org/x/tools/internal/xcontext" ) @@ -169,18 +170,18 @@ func (s *Session) createView(ctx context.Context, name string, folder span.URI, backgroundCtx: backgroundCtx, cancel: cancel, store: s.cache.store, - packages: persistent.NewMap(packageIDLessInterface), + packages: new(persistent.Map[PackageID, *packageHandle]), meta: new(metadataGraph), files: newFilesMap(), - activePackages: persistent.NewMap(packageIDLessInterface), - symbolizeHandles: persistent.NewMap(uriLessInterface), + activePackages: new(persistent.Map[PackageID, *Package]), + symbolizeHandles: new(persistent.Map[span.URI, *memoize.Promise]), workspacePackages: make(map[PackageID]PackagePath), unloadableFiles: make(map[span.URI]struct{}), - parseModHandles: persistent.NewMap(uriLessInterface), - parseWorkHandles: persistent.NewMap(uriLessInterface), - modTidyHandles: persistent.NewMap(uriLessInterface), - modVulnHandles: persistent.NewMap(uriLessInterface), - modWhyHandles: persistent.NewMap(uriLessInterface), + parseModHandles: new(persistent.Map[span.URI, *memoize.Promise]), + parseWorkHandles: new(persistent.Map[span.URI, *memoize.Promise]), + modTidyHandles: new(persistent.Map[span.URI, *memoize.Promise]), + modVulnHandles: new(persistent.Map[span.URI, *memoize.Promise]), + modWhyHandles: new(persistent.Map[span.URI, *memoize.Promise]), knownSubdirs: newKnownDirsSet(), workspaceModFiles: wsModFiles, workspaceModFilesErr: wsModFilesErr, diff --git a/gopls/internal/lsp/cache/snapshot.go b/gopls/internal/lsp/cache/snapshot.go index a1fe4753e2f..a914880a4e3 100644 --- a/gopls/internal/lsp/cache/snapshot.go +++ b/gopls/internal/lsp/cache/snapshot.go @@ -101,7 +101,7 @@ type snapshot struct { // symbolizeHandles maps each file URI to a handle for the future // result of computing the symbols declared in that file. - symbolizeHandles *persistent.Map // from span.URI to *memoize.Promise[symbolizeResult] + symbolizeHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[symbolizeResult] // packages maps a packageKey to a *packageHandle. // It may be invalidated when a file's content changes. @@ -110,13 +110,13 @@ type snapshot struct { // - packages.Get(id).meta == meta.metadata[id] for all ids // - if a package is in packages, then all of its dependencies should also // be in packages, unless there is a missing import - packages *persistent.Map // from packageID to *packageHandle + packages *persistent.Map[PackageID, *packageHandle] // activePackages maps a package ID to a memoized active package, or nil if // the package is known not to be open. // // IDs not contained in the map are not known to be open or not open. - activePackages *persistent.Map // from packageID to *Package + activePackages *persistent.Map[PackageID, *Package] // workspacePackages contains the workspace's packages, which are loaded // when the view is created. It contains no intermediate test variants. @@ -137,18 +137,18 @@ type snapshot struct { // parseModHandles keeps track of any parseModHandles for the snapshot. // The handles need not refer to only the view's go.mod file. - parseModHandles *persistent.Map // from span.URI to *memoize.Promise[parseModResult] + parseModHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[parseModResult] // parseWorkHandles keeps track of any parseWorkHandles for the snapshot. // The handles need not refer to only the view's go.work file. - parseWorkHandles *persistent.Map // from span.URI to *memoize.Promise[parseWorkResult] + parseWorkHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[parseWorkResult] // Preserve go.mod-related handles to avoid garbage-collecting the results // of various calls to the go command. The handles need not refer to only // the view's go.mod file. - modTidyHandles *persistent.Map // from span.URI to *memoize.Promise[modTidyResult] - modWhyHandles *persistent.Map // from span.URI to *memoize.Promise[modWhyResult] - modVulnHandles *persistent.Map // from span.URI to *memoize.Promise[modVulnResult] + modTidyHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[modTidyResult] + modWhyHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[modWhyResult] + modVulnHandles *persistent.Map[span.URI, *memoize.Promise] // *memoize.Promise[modVulnResult] // knownSubdirs is the set of subdirectory URIs in the workspace, // used to create glob patterns for file watching. @@ -871,7 +871,7 @@ func (s *snapshot) getActivePackage(id PackageID) *Package { defer s.mu.Unlock() if value, ok := s.activePackages.Get(id); ok { - return value.(*Package) // possibly nil, if we have already checked this id. + return value } return nil } @@ -895,7 +895,7 @@ func (s *snapshot) setActivePackage(id PackageID, pkg *Package) { func (s *snapshot) resetActivePackagesLocked() { s.activePackages.Destroy() - s.activePackages = persistent.NewMap(packageIDLessInterface) + s.activePackages = new(persistent.Map[PackageID, *Package]) } const fileExtensions = "go,mod,sum,work" @@ -2189,7 +2189,7 @@ func (s *snapshot) clone(ctx, bgCtx context.Context, changes map[span.URI]*fileC result.packages.Delete(id) } else { if entry, hit := result.packages.Get(id); hit { - ph := entry.(*packageHandle).clone(false) + ph := entry.clone(false) result.packages.Set(id, ph, nil) } } @@ -2291,12 +2291,11 @@ func (s *snapshot) clone(ctx, bgCtx context.Context, changes map[span.URI]*fileC // changed that happens not to be present in the map, but that's OK: the goal // of this function is to guarantee that IF the nearest mod file is present in // the map, it is invalidated. -func deleteMostRelevantModFile(m *persistent.Map, changed span.URI) { +func deleteMostRelevantModFile(m *persistent.Map[span.URI, *memoize.Promise], changed span.URI) { var mostRelevant span.URI changedFile := changed.Filename() - m.Range(func(key, value interface{}) { - modURI := key.(span.URI) + m.Range(func(modURI span.URI, _ *memoize.Promise) { if len(modURI) > len(mostRelevant) { if source.InDir(filepath.Dir(modURI.Filename()), changedFile) { mostRelevant = modURI diff --git a/gopls/internal/lsp/cache/symbols.go b/gopls/internal/lsp/cache/symbols.go index 466d9dc71a6..3ecd794303b 100644 --- a/gopls/internal/lsp/cache/symbols.go +++ b/gopls/internal/lsp/cache/symbols.go @@ -15,7 +15,6 @@ import ( "golang.org/x/tools/gopls/internal/lsp/protocol" "golang.org/x/tools/gopls/internal/lsp/source" "golang.org/x/tools/gopls/internal/span" - "golang.org/x/tools/internal/memoize" ) // symbolize returns the result of symbolizing the file identified by uri, using a cache. @@ -51,7 +50,7 @@ func (s *snapshot) symbolize(ctx context.Context, uri span.URI) ([]source.Symbol } // Await result. - v, err := s.awaitPromise(ctx, entry.(*memoize.Promise)) + v, err := s.awaitPromise(ctx, entry) if err != nil { return nil, err } diff --git a/internal/constraints/constraint.go b/internal/constraints/constraint.go new file mode 100644 index 00000000000..4e6ab61ea34 --- /dev/null +++ b/internal/constraints/constraint.go @@ -0,0 +1,52 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package constraints defines a set of useful constraints to be used +// with type parameters. +package constraints + +// Copied from x/exp/constraints. + +// Signed is a constraint that permits any signed integer type. +// If future releases of Go add new predeclared signed integer types, +// this constraint will be modified to include them. +type Signed interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +// Unsigned is a constraint that permits any unsigned integer type. +// If future releases of Go add new predeclared unsigned integer types, +// this constraint will be modified to include them. +type Unsigned interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +// Integer is a constraint that permits any integer type. +// If future releases of Go add new predeclared integer types, +// this constraint will be modified to include them. +type Integer interface { + Signed | Unsigned +} + +// Float is a constraint that permits any floating-point type. +// If future releases of Go add new predeclared floating-point types, +// this constraint will be modified to include them. +type Float interface { + ~float32 | ~float64 +} + +// Complex is a constraint that permits any complex numeric type. +// If future releases of Go add new predeclared complex numeric types, +// this constraint will be modified to include them. +type Complex interface { + ~complex64 | ~complex128 +} + +// Ordered is a constraint that permits any ordered type: any type +// that supports the operators < <= >= >. +// If future releases of Go add new ordered types, +// this constraint will be modified to include them. +type Ordered interface { + Integer | Float | ~string +} diff --git a/internal/persistent/map.go b/internal/persistent/map.go index a9d878f4146..02389f89dc5 100644 --- a/internal/persistent/map.go +++ b/internal/persistent/map.go @@ -12,6 +12,8 @@ import ( "math/rand" "strings" "sync/atomic" + + "golang.org/x/tools/internal/constraints" ) // Implementation details: @@ -25,9 +27,7 @@ import ( // Each argument is followed by a delta change to its reference counter. // In case if no change is expected, the delta will be `-0`. -// Map is an associative mapping from keys to values, both represented as -// interface{}. Key comparison and iteration order is defined by a -// client-provided function that implements a strict weak order. +// Map is an associative mapping from keys to values. // // Maps can be Cloned in constant time. // Get, Store, and Delete operations are done on average in logarithmic time. @@ -38,16 +38,23 @@ import ( // // Internally the implementation is based on a randomized persistent treap: // https://en.wikipedia.org/wiki/Treap. -type Map struct { - less func(a, b interface{}) bool +// +// The zero value is ready to use. +type Map[K constraints.Ordered, V any] struct { + // Map is a generic wrapper around a non-generic implementation to avoid a + // significant increase in the size of the executable. root *mapNode } -func (m *Map) String() string { +func (*Map[K, V]) less(l, r any) bool { + return l.(K) < r.(K) +} + +func (m *Map[K, V]) String() string { var buf strings.Builder buf.WriteByte('{') var sep string - m.Range(func(k, v interface{}) { + m.Range(func(k K, v V) { fmt.Fprintf(&buf, "%s%v: %v", sep, k, v) sep = ", " }) @@ -56,7 +63,7 @@ func (m *Map) String() string { } type mapNode struct { - key interface{} + key any value *refValue weight uint64 refCount int32 @@ -65,11 +72,11 @@ type mapNode struct { type refValue struct { refCount int32 - value interface{} - release func(key, value interface{}) + value any + release func(key, value any) } -func newNodeWithRef(key, value interface{}, release func(key, value interface{})) *mapNode { +func newNodeWithRef[K constraints.Ordered, V any](key K, value V, release func(key, value any)) *mapNode { return &mapNode{ key: key, value: &refValue{ @@ -116,20 +123,10 @@ func (node *mapNode) decref() { } } -// NewMap returns a new map whose keys are ordered by the given comparison -// function (a strict weak order). It is the responsibility of the caller to -// Destroy it at later time. -func NewMap(less func(a, b interface{}) bool) *Map { - return &Map{ - less: less, - } -} - // Clone returns a copy of the given map. It is a responsibility of the caller // to Destroy it at later time. -func (pm *Map) Clone() *Map { - return &Map{ - less: pm.less, +func (pm *Map[K, V]) Clone() *Map[K, V] { + return &Map[K, V]{ root: pm.root.incref(), } } @@ -137,24 +134,26 @@ func (pm *Map) Clone() *Map { // Destroy destroys the map. // // After Destroy, the Map should not be used again. -func (pm *Map) Destroy() { +func (pm *Map[K, V]) Destroy() { // The implementation of these two functions is the same, // but their intent is different. pm.Clear() } // Clear removes all entries from the map. -func (pm *Map) Clear() { +func (pm *Map[K, V]) Clear() { pm.root.decref() pm.root = nil } // Range calls f sequentially in ascending key order for all entries in the map. -func (pm *Map) Range(f func(key, value interface{})) { - pm.root.forEach(f) +func (pm *Map[K, V]) Range(f func(key K, value V)) { + pm.root.forEach(func(k, v any) { + f(k.(K), v.(V)) + }) } -func (node *mapNode) forEach(f func(key, value interface{})) { +func (node *mapNode) forEach(f func(key, value any)) { if node == nil { return } @@ -163,26 +162,26 @@ func (node *mapNode) forEach(f func(key, value interface{})) { node.right.forEach(f) } -// Get returns the map value associated with the specified key, or nil if no entry -// is present. The ok result indicates whether an entry was found in the map. -func (pm *Map) Get(key interface{}) (interface{}, bool) { +// Get returns the map value associated with the specified key. +// The ok result indicates whether an entry was found in the map. +func (pm *Map[K, V]) Get(key K) (V, bool) { node := pm.root for node != nil { - if pm.less(key, node.key) { + if key < node.key.(K) { node = node.left - } else if pm.less(node.key, key) { + } else if node.key.(K) < key { node = node.right } else { - return node.value.value, true + return node.value.value.(V), true } } - return nil, false + var zero V + return zero, false } // SetAll updates the map with key/value pairs from the other map, overwriting existing keys. // It is equivalent to calling Set for each entry in the other map but is more efficient. -// Both maps must have the same comparison function, otherwise behavior is undefined. -func (pm *Map) SetAll(other *Map) { +func (pm *Map[K, V]) SetAll(other *Map[K, V]) { root := pm.root pm.root = union(root, other.root, pm.less, true) root.decref() @@ -191,7 +190,7 @@ func (pm *Map) SetAll(other *Map) { // Set updates the value associated with the specified key. // If release is non-nil, it will be called with entry's key and value once the // key is no longer contained in the map or any clone. -func (pm *Map) Set(key, value interface{}, release func(key, value interface{})) { +func (pm *Map[K, V]) Set(key K, value V, release func(key, value any)) { first := pm.root second := newNodeWithRef(key, value, release) pm.root = union(first, second, pm.less, true) @@ -205,7 +204,7 @@ func (pm *Map) Set(key, value interface{}, release func(key, value interface{})) // union(first:-0, second:-0) (result:+1) // Union borrows both subtrees without affecting their refcount and returns a // new reference that the caller is expected to call decref. -func union(first, second *mapNode, less func(a, b interface{}) bool, overwrite bool) *mapNode { +func union(first, second *mapNode, less func(any, any) bool, overwrite bool) *mapNode { if first == nil { return second.incref() } @@ -243,7 +242,7 @@ func union(first, second *mapNode, less func(a, b interface{}) bool, overwrite b // split(n:-0) (left:+1, mid:+1, right:+1) // Split borrows n without affecting its refcount, and returns three // new references that the caller is expected to call decref. -func split(n *mapNode, key interface{}, less func(a, b interface{}) bool, requireMid bool) (left, mid, right *mapNode) { +func split(n *mapNode, key any, less func(any, any) bool, requireMid bool) (left, mid, right *mapNode) { if n == nil { return nil, nil, nil } @@ -272,7 +271,7 @@ func split(n *mapNode, key interface{}, less func(a, b interface{}) bool, requir } // Delete deletes the value for a key. -func (pm *Map) Delete(key interface{}) { +func (pm *Map[K, V]) Delete(key K) { root := pm.root left, mid, right := split(root, key, pm.less, true) if mid == nil { diff --git a/internal/persistent/map_test.go b/internal/persistent/map_test.go index 9f89a1d300c..c73e5662d90 100644 --- a/internal/persistent/map_test.go +++ b/internal/persistent/map_test.go @@ -18,7 +18,7 @@ type mapEntry struct { } type validatedMap struct { - impl *Map + impl *Map[int, int] expected map[int]int // current key-value mapping. deleted map[mapEntry]int // maps deleted entries to their clock time of last deletion seen map[mapEntry]int // maps seen entries to their clock time of last insertion @@ -30,9 +30,7 @@ func TestSimpleMap(t *testing.T) { seenEntries := make(map[mapEntry]int) m1 := &validatedMap{ - impl: NewMap(func(a, b interface{}) bool { - return a.(int) < b.(int) - }), + impl: new(Map[int, int]), expected: make(map[int]int), deleted: deletedEntries, seen: seenEntries, @@ -123,9 +121,7 @@ func TestRandomMap(t *testing.T) { seenEntries := make(map[mapEntry]int) m := &validatedMap{ - impl: NewMap(func(a, b interface{}) bool { - return a.(int) < b.(int) - }), + impl: new(Map[int, int]), expected: make(map[int]int), deleted: deletedEntries, seen: seenEntries, @@ -165,9 +161,7 @@ func TestUpdate(t *testing.T) { seenEntries := make(map[mapEntry]int) m1 := &validatedMap{ - impl: NewMap(func(a, b interface{}) bool { - return a.(int) < b.(int) - }), + impl: new(Map[int, int]), expected: make(map[int]int), deleted: deletedEntries, seen: seenEntries, @@ -233,7 +227,7 @@ func dumpMap(t *testing.T, prefix string, n *mapNode) { func (vm *validatedMap) validate(t *testing.T) { t.Helper() - validateNode(t, vm.impl.root, vm.impl.less) + validateNode(t, vm.impl.root) // Note: this validation may not make sense if maps were constructed using // SetAll operations. If this proves to be problematic, remove the clock, @@ -246,23 +240,23 @@ func (vm *validatedMap) validate(t *testing.T) { } actualMap := make(map[int]int, len(vm.expected)) - vm.impl.Range(func(key, value interface{}) { - if other, ok := actualMap[key.(int)]; ok { + vm.impl.Range(func(key, value int) { + if other, ok := actualMap[key]; ok { t.Fatalf("key is present twice, key: %d, first value: %d, second value: %d", key, value, other) } - actualMap[key.(int)] = value.(int) + actualMap[key] = value }) assertSameMap(t, actualMap, vm.expected) } -func validateNode(t *testing.T, node *mapNode, less func(a, b interface{}) bool) { +func validateNode(t *testing.T, node *mapNode) { if node == nil { return } if node.left != nil { - if less(node.key, node.left.key) { + if node.key.(int) < node.left.key.(int) { t.Fatalf("left child has larger key: %v vs %v", node.left.key, node.key) } if node.left.weight > node.weight { @@ -271,7 +265,7 @@ func validateNode(t *testing.T, node *mapNode, less func(a, b interface{}) bool) } if node.right != nil { - if less(node.right.key, node.key) { + if node.right.key.(int) < node.key.(int) { t.Fatalf("right child has smaller key: %v vs %v", node.right.key, node.key) } if node.right.weight > node.weight { @@ -279,8 +273,8 @@ func validateNode(t *testing.T, node *mapNode, less func(a, b interface{}) bool) } } - validateNode(t, node.left, less) - validateNode(t, node.right, less) + validateNode(t, node.left) + validateNode(t, node.right) } func (vm *validatedMap) setAll(t *testing.T, other *validatedMap) { @@ -300,7 +294,7 @@ func (vm *validatedMap) set(t *testing.T, key, value int) { vm.clock++ vm.seen[entry] = vm.clock - vm.impl.Set(key, value, func(deletedKey, deletedValue interface{}) { + vm.impl.Set(key, value, func(deletedKey, deletedValue any) { if deletedKey != key || deletedValue != value { t.Fatalf("unexpected passed in deleted entry: %v/%v, expected: %v/%v", deletedKey, deletedValue, key, value) } @@ -346,7 +340,7 @@ func (vm *validatedMap) destroy() { vm.impl.Destroy() } -func assertSameMap(t *testing.T, map1, map2 interface{}) { +func assertSameMap(t *testing.T, map1, map2 any) { t.Helper() if !reflect.DeepEqual(map1, map2) {