Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Handle adjacent prefixes in ipinfo tool aggregate #193

7 changes: 1 addition & 6 deletions ipinfo/cmd_tool_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ func printHelpToolAggregate() {
`Usage: %s tool aggregate [<opts>] <cidr | ip | ip-range | filepath>

Description:
Accepts IPs, IP ranges, and CIDRs, aggregating them efficiently.
Input can be IPs, IP ranges, CIDRs, and/or filepath to a file
containing any of these. Works for both IPv4 and IPv6.
Accepts IPv4 IPs and CIDRs, aggregating them efficiently.

If input contains single IPs, it tries to merge them into the input CIDRs,
otherwise they are printed to the output as they are.
Expand All @@ -37,9 +35,6 @@ Examples:
# Aggregate two CIDRs.
$ %[1]s tool aggregate 1.1.1.0/30 1.1.1.0/28

# Aggregate IP range and CIDR.
$ %[1]s tool aggregate 1.1.1.0-1.1.1.244 1.1.1.0/28

# Aggregate enteries from 2 files.
$ %[1]s tool aggregate /path/to/file1.txt /path/to/file2.txt

Expand Down
68 changes: 68 additions & 0 deletions lib/cidr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package lib

import (
"bytes"
"encoding/binary"
"math"
"net"
"sort"
)

// CIDR represens a Classless Inter-Domain Routing structure.
type CIDR struct {
IP net.IP
Network *net.IPNet
}

// newCidr creates a newCidr CIDR structure.
func newCidr(s string) *CIDR {
ip, ipnet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
return &CIDR{
IP: ip,
Network: ipnet,
}
}

func (c *CIDR) String() string {
return c.Network.String()
}

// MaskLen returns a network mask length.
func (c *CIDR) MaskLen() uint32 {
i, _ := c.Network.Mask.Size()
return uint32(i)
}

// PrefixUint32 returns a prefix.
func (c *CIDR) PrefixUint32() uint32 {
return binary.BigEndian.Uint32(c.IP.To4())
}

// Size returns a size of a CIDR range.
func (c *CIDR) Size() int {
ones, bits := c.Network.Mask.Size()
return int(math.Pow(2, float64(bits-ones)))
}

// list returns a slice of sorted CIDR structures.
func list(s []string) []*CIDR {
out := make([]*CIDR, 0)
for _, c := range s {
out = append(out, newCidr(c))
}
sort.Sort(cidrSort(out))
return out
}

type cidrSort []*CIDR

func (s cidrSort) Len() int { return len(s) }
func (s cidrSort) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

func (s cidrSort) Less(i, j int) bool {
cmp := bytes.Compare(s[i].IP, s[j].IP)
return cmp < 0 || (cmp == 0 && s[i].MaskLen() < s[j].MaskLen())
}
198 changes: 79 additions & 119 deletions lib/cmd_tool_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package lib

import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"sort"
"strings"

"github.com/spf13/pflag"
Expand Down Expand Up @@ -55,72 +53,21 @@ func CmdToolAggregate(
return nil
}

// Parses a list of CIDRs.
parseCIDRs := func(cidrs []string) []net.IPNet {
parsedCIDRs := make([]net.IPNet, 0)
for _, cidrStr := range cidrs {
_, ipNet, err := net.ParseCIDR(cidrStr)
if err != nil {
if !f.Quiet {
fmt.Printf("Invalid CIDR: %s\n", cidrStr)
}
continue
}
parsedCIDRs = append(parsedCIDRs, *ipNet)
}

return parsedCIDRs
}

// Input parser.
parseInput := func(rows []string) ([]net.IPNet, []net.IP) {
parsedCIDRs := make([]net.IPNet, 0)
parseInput := func(rows []string) ([]string, []net.IP) {
parsedCIDRs := make([]string, 0)
parsedIPs := make([]net.IP, 0)
var separator string
for _, rowStr := range rows {
if strings.ContainsAny(rowStr, ",-") {
if delim := strings.ContainsRune(rowStr, ','); delim {
separator = ","
} else {
separator = "-"
}

ipRange := strings.Split(rowStr, separator)
if len(ipRange) != 2 {
if !f.Quiet {
fmt.Printf("Invalid IP range: %s\n", rowStr)
}
continue
}

if strings.ContainsRune(rowStr, ':') {
cidrs, err := CIDRsFromIP6RangeStrRaw(rowStr)
if err == nil {
parsedCIDRs = append(parsedCIDRs, parseCIDRs(cidrs)...)
continue
} else {
if !f.Quiet {
fmt.Printf("Invalid IP range %s. Err: %v\n", rowStr, err)
}
continue
}
} else {
cidrs, err := CIDRsFromIPRangeStrRaw(rowStr)
if err == nil {
parsedCIDRs = append(parsedCIDRs, parseCIDRs(cidrs)...)
continue
} else {
if !f.Quiet {
fmt.Printf("Invalid IP range %s. Err: %v\n", rowStr, err)
}
continue
}
}
continue
} else if strings.ContainsRune(rowStr, '/') {
parsedCIDRs = append(parsedCIDRs, parseCIDRs([]string{rowStr})...)
_, ipnet, err := net.ParseCIDR(rowStr)
if err == nil && IsCIDRIPv4(ipnet) {
parsedCIDRs = append(parsedCIDRs, []string{rowStr}...)
}
continue
} else {
if ip := net.ParseIP(rowStr); ip != nil {
if ip := net.ParseIP(rowStr); IsIPv4(ip) {
parsedIPs = append(parsedIPs, ip)
} else {
if !f.Quiet {
Expand Down Expand Up @@ -165,7 +112,7 @@ func CmdToolAggregate(
}

// Vars to contain CIDRs/IPs from all input sources.
parsedCIDRs := make([]net.IPNet, 0)
parsedCIDRs := make([]string, 0)
parsedIPs := make([]net.IP, 0)

// Collect CIDRs/IPs from stdin.
Expand All @@ -187,93 +134,106 @@ func CmdToolAggregate(
rows := scanrdr(file)
file.Close()
cidrs, ips := parseInput(rows)

parsedCIDRs = append(parsedCIDRs, cidrs...)
parsedIPs = append(parsedIPs, ips...)
}

// Sort and merge collected CIDRs and IPs.
aggregatedCIDRs := aggregateCIDRs(parsedCIDRs)
adjacentCombined := combineAdjacent(stripOverlapping(list(parsedCIDRs)))

outlierIPs := make([]net.IP, 0)
length := len(aggregatedCIDRs)
for _, ip := range parsedIPs {
for i, cidr := range aggregatedCIDRs {
if cidr.Contains(ip) {
break
} else if i == length-1 {
outlierIPs = append(outlierIPs, ip)
length := len(adjacentCombined)
if length != 0 {
for _, ip := range parsedIPs {
for i, cidr := range adjacentCombined {
if cidr.Network.Contains(ip) {
break
} else if i == length-1 {
outlierIPs = append(outlierIPs, ip)
}
}
}
} else {
outlierIPs = append(outlierIPs, parsedIPs...)
}

// Print the aggregated CIDRs.
for _, r := range aggregatedCIDRs {
for _, r := range adjacentCombined {
fmt.Println(r.String())
}

// Print outliers.
// Print the outlierIPs.
for _, r := range outlierIPs {
fmt.Println(r.String())
}

return nil
}

// Helper function to aggregate IP ranges.
func aggregateCIDRs(cidrs []net.IPNet) []net.IPNet {
aggregatedCIDRs := make([]net.IPNet, 0)

// Sort CIDRs by starting IP.
sortCIDRs(cidrs)

for _, r := range cidrs {
if len(aggregatedCIDRs) == 0 {
aggregatedCIDRs = append(aggregatedCIDRs, r)
// stripOverlapping returns a slice of CIDR structures with overlapping ranges
// stripped.
func stripOverlapping(s []*CIDR) []*CIDR {
l := len(s)
for i := 0; i < l-1; i++ {
if s[i] == nil {
continue
}

last := len(aggregatedCIDRs) - 1
prev := aggregatedCIDRs[last]

if canAggregate(prev, r) {
// Merge overlapping CIDRs.
aggregatedCIDRs[last] = aggregateCIDR(prev, r)
} else {
aggregatedCIDRs = append(aggregatedCIDRs, r)
for j := i + 1; j < l; j++ {
if overlaps(s[j], s[i]) {
s[j] = nil
}
}
}

return aggregatedCIDRs
}

// Helper function to sort IP ranges by starting IP.
func sortCIDRs(ipRanges []net.IPNet) {
sort.SliceStable(ipRanges, func(i, j int) bool {
return bytes.Compare(ipRanges[i].IP, ipRanges[j].IP) < 0
})
return filter(s)
}

// Helper function to check if two CIDRs can be aggregated.
func canAggregate(r1, r2 net.IPNet) bool {
return r1.Contains(r2.IP) || r2.Contains(r1.IP)
func overlaps(a, b *CIDR) bool {
return (a.PrefixUint32() / (1 << (32 - b.MaskLen()))) ==
(b.PrefixUint32() / (1 << (32 - b.MaskLen())))
}

// Helper function to aggregate two CIDRs.
func aggregateCIDR(r1, r2 net.IPNet) net.IPNet {
mask1, _ := r1.Mask.Size()
mask2, _ := r2.Mask.Size()

ipLen := net.IPv6len * 8
if r1.IP.To4() != nil {
ipLen = net.IPv4len * 8
}
// combineAdjacent returns a slice of CIDR structures with adjacent ranges
// combined.
func combineAdjacent(s []*CIDR) []*CIDR {
for {
found := false
l := len(s)
for i := 0; i < l-1; i++ {
if s[i] == nil {
continue
}
for j := i + 1; j < l; j++ {
if s[j] == nil {
continue
}
if adjacent(s[i], s[j]) {
c := fmt.Sprintf("%s/%d", s[i].IP.String(), s[i].MaskLen()-1)
s[i] = newCidr(c)
s[j] = nil
found = true
}
}
}

// Find the common prefix length
commonPrefixLen := mask1
if mask2 < commonPrefixLen {
commonPrefixLen = mask2
if !found {
break
}
}
return filter(s)
}

commonPrefix := r1.IP.Mask(net.CIDRMask(commonPrefixLen, ipLen))
func adjacent(a, b *CIDR) bool {
return (a.MaskLen() == b.MaskLen()) &&
(a.PrefixUint32()%(2<<(32-b.MaskLen())) == 0) &&
(b.PrefixUint32()-a.PrefixUint32() == (1 << (32 - a.MaskLen())))
}

return net.IPNet{IP: commonPrefix, Mask: net.CIDRMask(commonPrefixLen, ipLen)}
func filter(s []*CIDR) []*CIDR {
out := s[:0]
for _, x := range s {
if x != nil {
out = append(out, x)
}
}
return out
}