diff --git a/README.md b/README.md index ca8d603..c811bc6 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ import ( ``` Create a new ranger implemented using Path-Compressed prefix trie. ```go -ipt := iptrie.NewTrie() +ipt := iptrie.NewTrie[any]() ``` Inserts CIDR blocks. @@ -71,8 +71,8 @@ ipt.ContainingNetworks(netip.MustParseAddr("10.1.0.0")) For insertion of a large number (millions) of addresses, it will likely be much faster to use TrieLoader. ```go -ipt := iptrie.NewTrie() -loader := iptrie.NewTrieLoader(ipt) +ipt := iptrie.NewTrie[any]() +loader := iptrie.NewTrieLoader[any](ipt) for network in []string{ "10.0.0.0/8", diff --git a/benchmark/benchmark_table_test.go b/benchmark/benchmark_table_test.go index d2e29e5..163b8a9 100644 --- a/benchmark/benchmark_table_test.go +++ b/benchmark/benchmark_table_test.go @@ -90,7 +90,7 @@ func parseToTable(buf *bytes.Buffer) string { pkgNames := []string{} for _, times := range testTimes { - for k, _ := range times { + for k := range times { pkgNames = append(pkgNames, k) } break @@ -105,7 +105,7 @@ func parseToTable(buf *bytes.Buffer) string { tbl.SetAutoFormatHeaders(false) tbl.SetHeader(append([]string{"*(OPs/Sec)*"}, pkgNames...)) var testNames []string - for testName, _ := range testTimes { + for testName := range testTimes { testNames = append(testNames, testName) } sort.Strings(testNames) diff --git a/trie.go b/trie.go index f61fcef..8030a14 100644 --- a/trie.go +++ b/trie.go @@ -6,71 +6,85 @@ import ( "math/bits" "net/netip" "strings" - "unsafe" ) -// Trie is a compressed IP radix trie implementation, similar to what is described at +// Trie represents operations of a trie. +type Trie[T any] interface { + fmt.Stringer + + Insert(network netip.Prefix, value T) Trie[T] + Remove(network netip.Prefix) (T, bool) + Find(ip netip.Addr) (T, bool) + FindLargest(ip netip.Addr) (T, bool) + Contains(ip netip.Addr) bool + ContainingNetworks(ip netip.Addr) []netip.Prefix + CoveredNetworks(network netip.Prefix) []netip.Prefix + GetParent() Trie[T] + GetNetwork() netip.Prefix +} + +// trie is the default implementation of Trie. It is a compressed +// IP radix trie implementation, similar to what is described at // https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux // -// Path compression merges nodes with only one child into their parent, decreasing the amount of traversals needed when -// looking up a value. -type Trie struct { - parent *Trie - children [2]*Trie +// Path compression merges nodes with only one child into their parent, +// decreasing the amount of traversals needed when looking up a value. +type trie[T any] struct { + parent *trie[T] + children [2]*trie[T] network netip.Prefix - value any + value *T } // NewTrie creates a new Trie. -func NewTrie() *Trie { - return &Trie{ - network: netip.PrefixFrom(netip.IPv6Unspecified(), 0), - } +func NewTrie[T any]() Trie[T] { + return newSubTree[T](netip.PrefixFrom(netip.IPv6Unspecified(), 0), nil) } -func newSubTree(network netip.Prefix, value any) *Trie { - return &Trie{ +func newSubTree[T any](network netip.Prefix, value *T) *trie[T] { + return &trie[T]{ network: network, value: value, } } // Insert inserts an entry into the trie. -func (pt *Trie) Insert(network netip.Prefix, value any) { +func (pt *trie[T]) Insert(network netip.Prefix, value T) Trie[T] { network = normalizePrefix(network) - pt.insert(network, emptyize(value)) + return pt.insert(network, &value) } // Remove removes the entry identified by given network from trie. -func (pt *Trie) Remove(network netip.Prefix) any { +func (pt *trie[T]) Remove(network netip.Prefix) (T, bool) { network = normalizePrefix(network) return pt.remove(network) } // Find returns the value from the most specific network (largest prefix) containing the given address. -func (pt *Trie) Find(ip netip.Addr) any { +func (pt *trie[T]) Find(ip netip.Addr) (T, bool) { ip = normalizeAddr(ip) - return unempty(pt.find(ip)) + return pt.find(ip) } // FindLargest returns the value from the largest network (smallest prefix) containing the given address. -func (pt *Trie) FindLargest(ip netip.Addr) any { +func (pt *trie[T]) FindLargest(ip netip.Addr) (T, bool) { ip = normalizeAddr(ip) - return unempty(pt.findLargest(ip)) + return pt.findLargest(ip) } // Contains indicates whether the trie contains the given ip. -func (pt *Trie) Contains(ip netip.Addr) bool { +func (pt *trie[T]) Contains(ip netip.Addr) bool { ip = normalizeAddr(ip) - return pt.findLargest(ip) != nil + _, contained := pt.findLargest(ip) + return contained } // ContainingNetworks returns the list of networks containing the given ip in ascending prefix order (largest network to // smallest). // // Note: Inserted addresses are normalized to IPv6, so the returned list will be IPv6 only. -func (pt *Trie) ContainingNetworks(ip netip.Addr) []netip.Prefix { +func (pt *trie[T]) ContainingNetworks(ip netip.Addr) []netip.Prefix { ip = normalizeAddr(ip) return pt.containingNetworks(ip) } @@ -78,18 +92,28 @@ func (pt *Trie) ContainingNetworks(ip netip.Addr) []netip.Prefix { // CoveredNetworks returns the list of networks contained within the given network. // // Note: Inserted addresses are normalized to IPv6, so the returned list will be IPv6 only. -func (pt *Trie) CoveredNetworks(network netip.Prefix) []netip.Prefix { +func (pt *trie[T]) CoveredNetworks(network netip.Prefix) []netip.Prefix { network = normalizePrefix(network) return pt.coveredNetworks(network) } +// GetParent returns the trie's parent (if any). +func (pt *trie[T]) GetParent() Trie[T] { + return pt.parent +} + +// GetNetwork returns the trie's network (if any). +func (pt *trie[T]) GetNetwork() netip.Prefix { + return pt.network +} + // String returns string representation of trie. // // The result will contain implicit nodes which exist as parents for multiple entries, but can be distinguished by the // lack of a value. // // Note: Addresses are normalized to IPv6. -func (pt *Trie) String() string { +func (pt *trie[T]) String() string { children := []string{} padding := strings.Repeat("├ ", pt.level()+1) for _, child := range pt.children { @@ -102,7 +126,7 @@ func (pt *Trie) String() string { var value string if pt.value != nil { - value = fmt.Sprintf("%v", unempty(pt.value)) + value = fmt.Sprintf("%v", *pt.value) if len(value) > 32 { value = value[0:31] + "…" } @@ -113,37 +137,40 @@ func (pt *Trie) String() string { value, strings.Join(children, "")) } -func (pt *Trie) find(ip netip.Addr) any { +func (pt *trie[T]) find(ip netip.Addr) (T, bool) { if !netContains(pt.network, ip) { - return nil + return zeroValueOfT[T](), false } if pt.network.Bits() == 128 { - return pt.value + return *pt.value, true } bit := pt.discriminatorBitFromIP(ip) child := pt.children[bit] if child != nil { - if v := child.find(ip); v != nil { - return v + if v, found := child.find(ip); found { + return v, found } } - return unempty(pt.value) + if pt.value != nil { + return *pt.value, true + } + return zeroValueOfT[T](), false } -func (pt *Trie) findLargest(ip netip.Addr) any { +func (pt *trie[T]) findLargest(ip netip.Addr) (T, bool) { if !netContains(pt.network, ip) { - return nil + return zeroValueOfT[T](), false } if pt.value != nil { - return pt.value + return *pt.value, true } if pt.network.Bits() == 128 { - return nil + return zeroValueOfT[T](), false } bit := pt.discriminatorBitFromIP(ip) @@ -152,10 +179,10 @@ func (pt *Trie) findLargest(ip netip.Addr) any { return child.findLargest(ip) } - return nil + return zeroValueOfT[T](), false } -func (pt *Trie) containingNetworks(ip netip.Addr) []netip.Prefix { +func (pt *trie[T]) containingNetworks(ip netip.Addr) []netip.Prefix { var results []netip.Prefix if !pt.network.Contains(ip) { return results @@ -181,7 +208,7 @@ func (pt *Trie) containingNetworks(ip netip.Addr) []netip.Prefix { return results } -func (pt *Trie) coveredNetworks(network netip.Prefix) []netip.Prefix { +func (pt *trie[T]) coveredNetworks(network netip.Prefix) []netip.Prefix { var results []netip.Prefix if network.Bits() <= pt.network.Bits() && network.Contains(pt.network.Addr()) { for entry := range pt.walkDepth() { @@ -228,7 +255,7 @@ func netDivergence(net1 netip.Prefix, net2 netip.Prefix) netip.Prefix { return pfx } -func (pt *Trie) insert(network netip.Prefix, value any) *Trie { +func (pt *trie[T]) insert(network netip.Prefix, value *T) *trie[T] { if pt.network == network { pt.value = value return pt @@ -248,7 +275,7 @@ func (pt *Trie) insert(network netip.Prefix, value any) *Trie { // in the case that inserted network diverges on its path to existing child. netdiv := netDivergence(existingChild.network, network) if netdiv != existingChild.network { - pathPrefix := newSubTree(netdiv, nil) + pathPrefix := newSubTree[T](netdiv, nil) pt.insertPrefix(bit, pathPrefix, existingChild) // Update new child existingChild = pathPrefix @@ -256,12 +283,12 @@ func (pt *Trie) insert(network netip.Prefix, value any) *Trie { return existingChild.insert(network, value) } -func (pt *Trie) appendTrie(bit uint8, prefix *Trie) { +func (pt *trie[T]) appendTrie(bit uint8, prefix *trie[T]) { pt.children[bit] = prefix prefix.parent = pt } -func (pt *Trie) insertPrefix(bit uint8, pathPrefix, child *Trie) { +func (pt *trie[T]) insertPrefix(bit uint8, pathPrefix, child *trie[T]) { // Set parent/child relationship between current trie and inserted pathPrefix pt.children[bit] = pathPrefix pathPrefix.parent = pt @@ -272,26 +299,26 @@ func (pt *Trie) insertPrefix(bit uint8, pathPrefix, child *Trie) { child.parent = pathPrefix } -func (pt *Trie) remove(network netip.Prefix) any { +func (pt *trie[T]) remove(network netip.Prefix) (T, bool) { if pt.value != nil && pt.network == network { - entry := pt.value + entry := *pt.value pt.value = nil pt.compressPathIfPossible() - return entry + return entry, true } if pt.network.Bits() == 128 { - return nil + return zeroValueOfT[T](), false } bit := pt.discriminatorBitFromIP(network.Addr()) child := pt.children[bit] if child != nil { return child.remove(network) } - return nil + return zeroValueOfT[T](), false } -func (pt *Trie) qualifiesForPathCompression() bool { +func (pt *trie[T]) qualifiesForPathCompression() bool { // Current prefix trie can be path compressed if it meets all following. // 1. records no CIDR entry // 2. has single or no child @@ -299,14 +326,14 @@ func (pt *Trie) qualifiesForPathCompression() bool { return pt.value == nil && pt.childrenCount() <= 1 && pt.parent != nil } -func (pt *Trie) compressPathIfPossible() { +func (pt *trie[T]) compressPathIfPossible() { if !pt.qualifiesForPathCompression() { // Does not qualify to be compressed return } // Find lone child. - var loneChild *Trie + var loneChild *trie[T] for _, child := range pt.children { if child != nil { loneChild = child @@ -326,7 +353,7 @@ func (pt *Trie) compressPathIfPossible() { parent.compressPathIfPossible() } -func (pt *Trie) childrenCount() int { +func (pt *trie[T]) childrenCount() int { count := 0 for _, child := range pt.children { if child != nil { @@ -336,7 +363,7 @@ func (pt *Trie) childrenCount() int { return count } -func (pt *Trie) discriminatorBitFromIP(addr netip.Addr) uint8 { +func (pt *trie[T]) discriminatorBitFromIP(addr netip.Addr) uint8 { // This is a safe uint boxing of int since we should never attempt to get // target bit at a negative position. pos := pt.network.Bits() @@ -347,7 +374,7 @@ func (pt *Trie) discriminatorBitFromIP(addr netip.Addr) uint8 { return uint8(a128.lo >> (63 - (pos - 64)) & 1) } -func (pt *Trie) level() int { +func (pt *trie[T]) level() int { if pt.parent == nil { return 0 } @@ -355,7 +382,7 @@ func (pt *Trie) level() int { } // walkDepth walks the trie in depth order -func (pt *Trie) walkDepth() <-chan netip.Prefix { +func (pt *trie[T]) walkDepth() <-chan netip.Prefix { entries := make(chan netip.Prefix) go func() { if pt.value != nil { @@ -377,88 +404,3 @@ func (pt *Trie) walkDepth() <-chan netip.Prefix { }() return entries } - -// TrieLoader can be used to improve the performance of bulk inserts to a Trie. It caches the node of the -// last insert in the tree, using it as the starting point to start searching for the location of the next insert. This -// is highly beneficial when the addresses are pre-sorted. -type TrieLoader struct { - trie *Trie - lastInsert *Trie -} - -func NewTrieLoader(trie *Trie) *TrieLoader { - return &TrieLoader{ - trie: trie, - lastInsert: trie, - } -} - -func (ptl *TrieLoader) Insert(pfx netip.Prefix, v any) { - pfx = normalizePrefix(pfx) - - diff := addr128(ptl.lastInsert.network.Addr()).xor(addr128(pfx.Addr())) - var pos int - if diff.hi != 0 { - pos = bits.LeadingZeros64(diff.hi) - } else { - pos = bits.LeadingZeros64(diff.lo) + 64 - } - if pos > pfx.Bits() { - pos = pfx.Bits() - } - if pos > ptl.lastInsert.network.Bits() { - pos = ptl.lastInsert.network.Bits() - } - - parent := ptl.lastInsert - for parent.network.Bits() > pos { - parent = parent.parent - } - ptl.lastInsert = parent.insert(pfx, v) -} - -func normalizeAddr(addr netip.Addr) netip.Addr { - if addr.Is4() { - return netip.AddrFrom16(addr.As16()) - } - return addr -} -func normalizePrefix(pfx netip.Prefix) netip.Prefix { - if pfx.Addr().Is4() { - pfx = netip.PrefixFrom(netip.AddrFrom16(pfx.Addr().As16()), pfx.Bits()+96) - } - return pfx.Masked() -} - -// A lot of the code uses nil value tests to determine whether a node is explicit or implicitly created. Therefore -// inserted values cannot be nil, and so `empty` is a placeholder to represent nil. -type emptyStruct struct{} - -var empty = emptyStruct{} - -func emptyize(v any) any { - if v == nil { - return empty - } - return v -} -func unempty(v any) any { - if v == empty { - return nil - } - return v -} - -func addr128(addr netip.Addr) uint128 { - return *(*uint128)(unsafe.Pointer(&addr)) -} -func init() { - // Accessing the underlying data of a `netip.Addr` relies upon the data being - // in a known format, which is not guaranteed to be stable. So this init() - // function is to detect if it ever changes. - ip := netip.AddrFrom16([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}) - i128 := addr128(ip) - if i128.hi != 0x0001020304050607 || i128.lo != 0x08090a0b0c0d0e0f { - panic("netip.Addr format mismatch") - } -} diff --git a/trie_test.go b/trie_test.go index e7b0d87..36cbb36 100644 --- a/trie_test.go +++ b/trie_test.go @@ -12,18 +12,23 @@ import ( ) func ExampleTrie() { - ipt := NewTrie() + ipt := NewTrie[any]() ipt.Insert(netip.MustParsePrefix("10.0.0.0/8"), "foo") ipt.Insert(netip.MustParsePrefix("10.1.0.0/24"), "bar") - fmt.Printf("10.2.0.1: %+v\n", ipt.Find(netip.MustParseAddr("10.2.0.1"))) - fmt.Printf("10.1.0.1: %+v\n", ipt.Find(netip.MustParseAddr("10.1.0.1"))) - fmt.Printf("11.0.0.1: %+v\n", ipt.Find(netip.MustParseAddr("11.0.0.1"))) + value, found := ipt.Find(netip.MustParseAddr("10.2.0.1")) + fmt.Printf("10.2.0.1: %+v (%t)\n", value, found) + + value, found = ipt.Find(netip.MustParseAddr("10.1.0.1")) + fmt.Printf("10.1.0.1: %+v (%t)\n", value, found) + + value, found = ipt.Find(netip.MustParseAddr("11.0.0.1")) + fmt.Printf("11.0.0.1: %+v (%t)\n", value, found) // Output: - // 10.2.0.1: foo - // 10.1.0.1: bar - // 11.0.0.1: + // 10.2.0.1: foo (true) + // 10.1.0.1: bar (true) + // 11.0.0.1: (false) } func TestTrieInsert(t *testing.T) { @@ -76,13 +81,14 @@ func TestTrieInsert(t *testing.T) { v := any(1) for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - trie := NewTrie() + tr := NewTrie[any]() for _, insert := range tc.inserts { network := netip.MustParsePrefix(insert) - trie.Insert(network, v) + tr.Insert(network, v) } - walk := trie.walkDepth() + assert.IsType(t, &trie[any]{}, tr) + walk := tr.(*trie[any]).walkDepth() for _, network := range tc.expectedNetworksInDepthOrder { expected := normalizePrefix(netip.MustParsePrefix(network)) actual := <-walk @@ -99,7 +105,7 @@ func TestTrieInsert(t *testing.T) { func ExampleTrie_String() { inserts := []string{"192.168.0.0/24", "192.168.1.0/24", "192.168.1.0/30"} - trie := NewTrie() + trie := NewTrie[string]() for _, insert := range inserts { network := netip.MustParsePrefix(insert) trie.Insert(network, "net="+insert) @@ -202,22 +208,25 @@ func TestTrieRemove(t *testing.T) { for tci, tc := range cases { t.Run(tc.name, func(t *testing.T) { - trie := NewTrie() + tr := NewTrie[string]() for _, insert := range tc.inserts { network := netip.MustParsePrefix(insert) - trie.Insert(network, insert) + tr.Insert(network, insert) } for i, remove := range tc.removes { network := netip.MustParsePrefix(remove) - removed := trie.Remove(network) + removed, wasRemoved := tr.Remove(network) if str := tc.expectedRemoves[i]; str != "" { assert.Equal(t, str, removed, "tc=%d", tci) + assert.True(t, wasRemoved) } else { - assert.Nil(t, removed, "tc=%d", tci) + assert.Equal(t, zeroValueOfT[string](), removed, "tc=%d", tci) + assert.False(t, wasRemoved) } } - walk := trie.walkDepth() + assert.IsType(t, &trie[string]{}, tr) + walk := tr.(*trie[string]).walkDepth() for _, network := range tc.expectedNetworksInDepthOrder { expected := normalizePrefix(netip.MustParsePrefix(network)) actual := <-walk @@ -229,13 +238,13 @@ func TestTrieRemove(t *testing.T) { assert.Nil(t, network) } - assert.Equal(t, tc.expectedTrieString, trie.String(), "tc=%d", tci) + assert.Equal(t, tc.expectedTrieString, tr.String(), "tc=%d", tci) }) } } func TestTrieContains(t *testing.T) { - pt := NewTrie() + pt := NewTrie[any]() assert.False(t, pt.Contains(netip.MustParseAddr("10.0.0.1"))) @@ -245,7 +254,7 @@ func TestTrieContains(t *testing.T) { } func TestTrieNilValue(t *testing.T) { - pt := NewTrie() + pt := NewTrie[any]() pt.Insert(netip.MustParsePrefix("10.0.0.0/8"), nil) pt.Insert(netip.MustParsePrefix("10.1.0.0/16"), nil) pt.Remove(netip.MustParsePrefix("10.1.0.0/16")) @@ -277,19 +286,20 @@ func TestFindFull128(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - trie := NewTrie() + tr := NewTrie[any]() for _, insert := range tc.inserts { network := netip.MustParsePrefix(insert) - trie.Insert(network, insert) + tr.Insert(network, insert) } expectedEntries := []netip.Prefix{} for _, network := range tc.networks { expected := normalizePrefix(netip.MustParsePrefix(network)) expectedEntries = append(expectedEntries, expected) } - contains := trie.Find(tc.ip) + contains, found := tr.Find(tc.ip) assert.NotNil(t, contains) - networks := trie.ContainingNetworks(tc.ip) + assert.True(t, found) + networks := tr.ContainingNetworks(tc.ip) assert.Equal(t, expectedEntries, networks) }) } @@ -325,32 +335,36 @@ func TestTrieFind(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - trie := NewTrie() + tr := NewTrie[any]() v := any(1) for _, insert := range tc.inserts { network := netip.MustParsePrefix(insert) - trie.Insert(network, v) + tr.Insert(network, v) } for _, expectedIPRange := range tc.expectedIPs { var contains any + var found bool start := expectedIPRange.start for ; expectedIPRange.end != start; start = start.Next() { - contains = trie.Find(start) + contains, found = tr.Find(start) assert.NotNil(t, contains) + assert.True(t, found) } // Check out of bounds ips on both ends - contains = trie.Find(expectedIPRange.start.Prev()) - assert.Nil(t, contains) - contains = trie.Find(expectedIPRange.end.Next()) - assert.Nil(t, contains) + contains, found = tr.Find(expectedIPRange.start.Prev()) + assert.Equal(t, zeroValueOfT[any](), contains) + assert.False(t, found) + contains, found = tr.Find(expectedIPRange.end.Next()) + assert.Equal(t, zeroValueOfT[any](), contains) + assert.False(t, found) } }) } } func TestTrieFindOverlap(t *testing.T) { - trie := NewTrie() + trie := NewTrie[any]() v1 := any(1) trie.Insert(netip.MustParsePrefix("192.168.0.0/24"), v1) @@ -358,8 +372,9 @@ func TestTrieFindOverlap(t *testing.T) { v2 := any(2) trie.Insert(netip.MustParsePrefix("192.168.0.0/25"), v2) - v := trie.Find(netip.MustParseAddr("192.168.0.1")) + v, found := trie.Find(netip.MustParseAddr("192.168.0.1")) assert.Equal(t, v2, v) + assert.True(t, found) } func TestTrieContainingNetworks(t *testing.T) { @@ -384,7 +399,7 @@ func TestTrieContainingNetworks(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - trie := NewTrie() + trie := NewTrie[any]() v := any(1) for _, insert := range tc.inserts { network := netip.MustParsePrefix(insert) @@ -467,7 +482,7 @@ var coveredNetworkTests = []coveredNetworkTest{ func TestTrieCoveredNetworks(t *testing.T) { for _, tc := range coveredNetworkTests { t.Run(tc.name, func(t *testing.T) { - trie := NewTrie() + trie := NewTrie[any]() v := any(1) for _, insert := range tc.inserts { network := netip.MustParsePrefix(insert) @@ -496,7 +511,7 @@ func TestTrieMemUsage(t *testing.T) { // by threshold, picking 1% as sane number for detecting memory leak. thresh := 1.01 - trie := NewTrie() + trie := NewTrie[any]() var baseLineHeap, totalHeapAllocOverRuns uint64 for i := 0; i < runs; i++ { @@ -556,7 +571,7 @@ func GetHeapAllocation() uint64 { } func ExampleTrieLoader() { - pt := NewTrie() + pt := NewTrie[any]() ptl := NewTrieLoader(pt) networks := []string{ diff --git a/trieloader.go b/trieloader.go new file mode 100644 index 0000000..5228458 --- /dev/null +++ b/trieloader.go @@ -0,0 +1,55 @@ +package iptrie + +import ( + "math/bits" + "net/netip" +) + +// TrieLoader inserts items into a Trie. +type TrieLoader[T any] interface { + Insert(pfx netip.Prefix, v T) +} + +// trieLoader is the default implementation of TrieLoader. +// It can be used to improve the performance of bulk inserts to a Trie. +// It caches the node of the last insert in the tree, using it as the +// starting point to start searching for the location of the next insert. +// This is highly beneficial when the addresses are pre-sorted. +type trieLoader[T any] struct { + trie Trie[T] + lastInsert Trie[T] +} + +// NewTrieLoader creates a new TrieLoader. +func NewTrieLoader[T any](trie Trie[T]) TrieLoader[T] { + return &trieLoader[T]{ + trie: trie, + lastInsert: trie, + } +} + +// Insert inserts the value v onto the trie at the given network. +func (ptl *trieLoader[T]) Insert(pfx netip.Prefix, v T) { + pfx = normalizePrefix(pfx) + + lastInsertNetwork := ptl.lastInsert.GetNetwork() + diff := addr128(lastInsertNetwork.Addr()).xor(addr128(pfx.Addr())) + var pos int + if diff.hi != 0 { + pos = bits.LeadingZeros64(diff.hi) + } else { + pos = bits.LeadingZeros64(diff.lo) + 64 + } + if pos > pfx.Bits() { + pos = pfx.Bits() + } + if pos > lastInsertNetwork.Bits() { + pos = lastInsertNetwork.Bits() + } + + parent := ptl.lastInsert + for parent.GetNetwork().Bits() > pos { + parent = parent.GetParent() + } + ptl.lastInsert = parent.Insert(pfx, v) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..fe27009 --- /dev/null +++ b/util.go @@ -0,0 +1,40 @@ +package iptrie + +import ( + "net/netip" + "unsafe" +) + +func zeroValueOfT[T any]() T { + var t T + return t +} + +func normalizeAddr(addr netip.Addr) netip.Addr { + if addr.Is4() { + return netip.AddrFrom16(addr.As16()) + } + return addr +} + +func normalizePrefix(pfx netip.Prefix) netip.Prefix { + if pfx.Addr().Is4() { + pfx = netip.PrefixFrom(netip.AddrFrom16(pfx.Addr().As16()), pfx.Bits()+96) + } + return pfx.Masked() +} + +func addr128(addr netip.Addr) uint128 { + return *(*uint128)(unsafe.Pointer(&addr)) +} + +func init() { + // Accessing the underlying data of a `netip.Addr` relies upon the data being + // in a known format, which is not guaranteed to be stable. So this init() + // function is to detect if it ever changes. + ip := netip.AddrFrom16([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}) + i128 := addr128(ip) + if i128.hi != 0x0001020304050607 || i128.lo != 0x08090a0b0c0d0e0f { + panic("netip.Addr format mismatch") + } +}