diff --git a/util/netutils/ipv6.go b/util/netutils/ipv6.go index d48c5dd..9d4873a 100644 --- a/util/netutils/ipv6.go +++ b/util/netutils/ipv6.go @@ -19,10 +19,12 @@ import ( "fmt" "math/rand" "net" + "sort" "strconv" "strings" "yunion.io/x/pkg/errors" + "yunion.io/x/pkg/util/sortutils" ) type IPV6Addr [8]uint16 @@ -461,7 +463,7 @@ func (ar IPV6AddrRange) EndIp() IPV6Addr { return ar.end } -func (ar IPV6AddrRange) Merge(ar2 IPV6AddrRange) (*IPV6AddrRange, bool) { +func (ar IPV6AddrRange) Merge(ar2 IPV6AddrRange) (IPV6AddrRange, bool) { if ar.IsOverlap(ar2) || ar.end.StepUp().Equals(ar2.start) || ar2.end.StepUp().Equals(ar.start) { if ar2.start.Lt(ar.start) { ar.start = ar2.start @@ -469,13 +471,13 @@ func (ar IPV6AddrRange) Merge(ar2 IPV6AddrRange) (*IPV6AddrRange, bool) { if ar2.end.Gt(ar.end) { ar.end = ar2.end } - return &ar, true + return ar, true } - return nil, false + return ar, false } func (ar IPV6AddrRange) IsOverlap(ar2 IPV6AddrRange) bool { - if ar.start.Ge(ar2.end) || ar.end.Le(ar2.start) { + if ar.start.Gt(ar2.end) || ar.end.Lt(ar2.start) { return false } else { return true @@ -544,6 +546,15 @@ func NewIPV6Prefix(prefix string) (IPV6Prefix, error) { return pref, nil } +func NewIPV6PrefixFromAddr(addr IPV6Addr, masklen uint8) IPV6Prefix { + pref := IPV6Prefix{ + Address: addr.NetAddr(masklen), + MaskLen: masklen, + } + pref.ipRange = pref.ToIPRange() + return pref +} + func (prefix IPV6Prefix) ToIPRange() IPV6AddrRange { start := prefix.Address.NetAddr(prefix.MaskLen) end := prefix.Address.BroadcastAddr(prefix.MaskLen) @@ -581,3 +592,188 @@ func DeriveIPv6AddrFromIPv4AddrMac(ipAddr string, macAddr string, startIp6, endI } return "" } + +func (ar IPV6AddrRange) Substract(ar2 IPV6AddrRange) ([]IPV6AddrRange, *IPV6AddrRange) { + lefts, overlap, sub := ar.Substract2(ar2) + var subp *IPV6AddrRange + if overlap { + subp = &sub + } + return lefts, subp +} + +func (ar IPV6AddrRange) Substract2(ar2 IPV6AddrRange) ([]IPV6AddrRange, bool, IPV6AddrRange) { + lefts := []IPV6AddrRange{} + // no intersection, no substract + if ar.end.Lt(ar2.start) || ar.start.Gt(ar2.end) { + lefts = append(lefts, ar) + return lefts, false, IPV6AddrRange{} + } + + // ar contains ar2 + if ar.ContainsRange(ar2) { + if ar.start.Equals(ar2.start) && ar.end.Equals(ar2.end) { + // lefts empty + } else if ar.start.Lt(ar2.start) && ar.end.Equals(ar2.end) { + lefts = append(lefts, + NewIPV6AddrRange(ar.start, ar2.start.StepDown()), + ) + } else if ar.start.Equals(ar2.start) && ar.end.Gt(ar2.end) { + lefts = append(lefts, + NewIPV6AddrRange(ar2.end.StepUp(), ar.end), + ) + } else { + lefts = append(lefts, + NewIPV6AddrRange(ar.start, ar2.start.StepDown()), + NewIPV6AddrRange(ar2.end.StepUp(), ar.end), + ) + } + return lefts, true, ar2 + } + + // ar contained by ar2 + if ar2.ContainsRange(ar) { + return lefts, true, ar + } + + // intersect, ar on the left + if ar.start.Lt(ar2.start) && ar.end.Ge(ar2.start) { + lefts = append(lefts, NewIPV6AddrRange(ar.start, ar2.start.StepDown())) + sub_ := NewIPV6AddrRange(ar2.start, ar.end) + return lefts, true, sub_ + } + + // intersect, ar on the right + if ar.start.Le(ar2.end) && ar.end.Gt(ar2.end) { + lefts = append(lefts, NewIPV6AddrRange(ar2.end.StepUp(), ar.end)) + sub_ := NewIPV6AddrRange(ar.start, ar2.end) + return lefts, true, sub_ + } + + // no intersection + return lefts, false, IPV6AddrRange{} +} + +func (addr IPV6Addr) ToBytes() []byte { + ret := make([]byte, 16) + for i := 0; i < len(addr); i++ { + binary.BigEndian.PutUint16(ret[2*i:], addr[i]) + } + return ret +} + +func (addr IPV6Addr) ToIP() net.IP { + return net.IP(addr.ToBytes()) +} + +func (pref IPV6Prefix) ToIPNet() *net.IPNet { + return &net.IPNet{ + IP: pref.Address.ToIP(), + Mask: net.CIDRMask(int(pref.MaskLen), 128), + } +} + +func (ar IPV6AddrRange) ToIPNets() []*net.IPNet { + r := []*net.IPNet{} + mms := ar.ToPrefixes() + for _, mm := range mms { + r = append(r, mm.ToIPNet()) + } + return r +} + +func (ar IPV6AddrRange) ToPrefixes() []IPV6Prefix { + prefixes := make([]IPV6Prefix, 0) + sp := ar.StartIp() + ep := ar.EndIp() + for sp.Le(ep) { + masklen := uint8(128) + for sp.NetAddr(masklen-1).Equals(sp) && sp.BroadcastAddr(masklen-1).Le(ep) && masklen > 0 { + masklen-- + } + if masklen == 0 { + prefixes = append(prefixes, NewIPV6PrefixFromAddr(sp, 0)) + break + } + prefixes = append(prefixes, NewIPV6PrefixFromAddr(sp, masklen)) + sp = sp.BroadcastAddr(masklen).StepUp() + } + return prefixes +} + +type IPV6AddrRangeList []IPV6AddrRange + +func (rl IPV6AddrRangeList) Len() int { + return len(rl) +} + +func (rl IPV6AddrRangeList) Swap(i, j int) { + rl[i], rl[j] = rl[j], rl[i] +} + +func (rl IPV6AddrRangeList) Less(i, j int) bool { + return rl[i].Compare(rl[j]) == sortutils.Less +} + +func (v6range IPV6AddrRange) Compare(r2 IPV6AddrRange) sortutils.CompareResult { + if v6range.start.Lt(r2.start) { + return sortutils.Less + } else if v6range.start.Gt(r2.start) { + return sortutils.More + } else { + // start equals, compare ends + if v6range.end.Gt(r2.end) { + return sortutils.Less + } else if v6range.end.Lt(r2.end) { + return sortutils.More + } else { + return sortutils.Equal + } + } +} + +func (rl IPV6AddrRangeList) Merge() []IPV6AddrRange { + sort.Sort(rl) + ret := make([]IPV6AddrRange, 0, len(rl)) + for i := range rl { + if i == 0 { + ret = append(ret, rl[i]) + } else { + result, isMerged := ret[len(ret)-1].Merge(rl[i]) + if isMerged { + ret[len(ret)-1] = result + } else { + ret = append(ret, rl[i]) + } + } + } + return ret +} + +func (rl IPV6AddrRangeList) String() string { + ret := make([]string, len(rl)) + for i := range rl { + ret[i] = rl[i].String() + } + return strings.Join(ret, ",") +} + +var IPV6Zero = IPV6Addr([8]uint16{0, 0, 0, 0, 0, 0, 0, 0}) +var IPV6Ones = IPV6Addr([8]uint16{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff}) +var AllIPV6AddrRange = IPV6AddrRange{ + start: IPV6Zero, + end: IPV6Ones, +} + +func (r IPV6AddrRange) IsAll() bool { + return r.start.Equals(IPV6Zero) && r.end.Equals(IPV6Ones) +} + +func (rl IPV6AddrRangeList) Substract(addrRange IPV6AddrRange) []IPV6AddrRange { + ret := make([]IPV6AddrRange, 0) + for i := range rl { + lefts, _ := rl[i].Substract(addrRange) + ret = append(ret, lefts...) + } + return ret +} diff --git a/util/netutils/ipv6_test.go b/util/netutils/ipv6_test.go index 4bc82e9..f2987d7 100644 --- a/util/netutils/ipv6_test.go +++ b/util/netutils/ipv6_test.go @@ -14,7 +14,12 @@ package netutils -import "testing" +import ( + "strings" + "testing" + + "yunion.io/x/jsonutils" +) func TestNewIPV6Addr(t *testing.T) { cases := []struct { @@ -440,3 +445,247 @@ func TestDeriveIPv6Addr(t *testing.T) { } } } + +func TestPrefixV62Range(t *testing.T) { + cases := []struct { + prefix string + rangeStr string + }{ + { + prefix: "::/0", + rangeStr: "::-ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + }, + { + prefix: "fd:3ffe:3200:2::/64", + rangeStr: "fd:3ffe:3200:2::-fd:3ffe:3200:2:ffff:ffff:ffff:ffff", + }, + { + prefix: "fd:3ffe:3200:2::/120", + rangeStr: "fd:3ffe:3200:2::-fd:3ffe:3200:2::ff", + }, + } + for _, c := range cases { + pref, err := NewIPV6Prefix(c.prefix) + if err != nil { + t.Errorf("prefix %s fail %s", c.prefix, err) + } else { + ipRange := pref.ToIPRange() + if ipRange.String() != c.rangeStr { + t.Errorf("prefix %s to range got %s want %s", pref.String(), ipRange.String(), c.rangeStr) + } + } + } +} + +func TestV6RangeToPrefix(t *testing.T) { + cases := []struct { + start string + end string + prefixes []string + }{ + { + start: "fd:3ffe:3200:2::0", + end: "fd:3ffe:3200:2::ff", + prefixes: []string{ + "fd:3ffe:3200:2::/120", + }, + }, + { + start: "fd:3ffe:3200:2::0", + end: "fd:3ffe:3200:2::80", + prefixes: []string{ + "fd:3ffe:3200:2::/121", + "fd:3ffe:3200:2::80/128", + }, + }, + { + start: "fd:3ffe:3200:2::80", + end: "fd:3ffe:3200:2::ff", + prefixes: []string{ + "fd:3ffe:3200:2::80/121", + }, + }, + { + start: "fd:3ffe:3200:2::7f", + end: "fd:3ffe:3200:2::ff", + prefixes: []string{ + "fd:3ffe:3200:2::7f/128", + "fd:3ffe:3200:2::80/121", + }, + }, + { + start: "fd:3ffe:3200:2::7e", + end: "fd:3ffe:3200:2::ff", + prefixes: []string{ + "fd:3ffe:3200:2::7e/127", + "fd:3ffe:3200:2::80/121", + }, + }, + { + start: "fd:3ffe:3200:2::7d", + end: "fd:3ffe:3200:2::ff", + prefixes: []string{ + "fd:3ffe:3200:2::7d/128", + "fd:3ffe:3200:2::7e/127", + "fd:3ffe:3200:2::80/121", + }, + }, + { + start: "::", + end: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + prefixes: []string{ + "::/0", + }, + }, + } + for _, c := range cases { + startIp, err := NewIPV6Addr(c.start) + if err != nil { + t.Errorf("NewIPV6Addr %s fail %s", c.start, err) + } else { + endIp, err := NewIPV6Addr(c.end) + if err != nil { + t.Errorf("NewIPV6Addr %s fail %s", c.end, err) + } else { + v6Range := NewIPV6AddrRange(startIp, endIp) + prefixes := v6Range.ToIPNets() + prefStr := make([]string, 0, len(prefixes)) + for i := range prefixes { + prefStr = append(prefStr, prefixes[i].String()) + } + if jsonutils.Marshal(prefStr).String() != jsonutils.Marshal(c.prefixes).String() { + t.Errorf("expect %s got %s", jsonutils.Marshal(c.prefixes).String(), jsonutils.Marshal(prefStr).String()) + } + } + } + } +} + +func TestIPV6Substract(t *testing.T) { + cases := []struct { + start1 string + end1 string + start2 string + end2 string + lefts []string + sub string + }{ + { + start1: "fd:3ffe:3200:2::1", + end1: "fd:3ffe:3200:2::ff", + start2: "fd:3ffe:3200:2::80", + end2: "fd:3ffe:3200:2::100", + lefts: []string{ + "fd:3ffe:3200:2::1-fd:3ffe:3200:2::7f", + }, + sub: "fd:3ffe:3200:2::80-fd:3ffe:3200:2::ff", + }, + { + start1: "fd:3ffe:3200:2::1", + end1: "fd:3ffe:3200:2::ff", + start2: "fd:3ffe:3200:2::80", + end2: "fd:3ffe:3200:2::8f", + lefts: []string{ + "fd:3ffe:3200:2::1-fd:3ffe:3200:2::7f", + "fd:3ffe:3200:2::90-fd:3ffe:3200:2::ff", + }, + sub: "fd:3ffe:3200:2::80-fd:3ffe:3200:2::8f", + }, + { + start1: "fd:3ffe:3200:2::80", + end1: "fd:3ffe:3200:2::ff", + start2: "fd:3ffe:3200:2::1", + end2: "fd:3ffe:3200:2::8f", + lefts: []string{ + "fd:3ffe:3200:2::90-fd:3ffe:3200:2::ff", + }, + sub: "fd:3ffe:3200:2::80-fd:3ffe:3200:2::8f", + }, + { + start1: "fd:3ffe:3200:2::1", + end1: "fd:3ffe:3200:2::7f", + start2: "fd:3ffe:3200:2::80", + end2: "fd:3ffe:3200:2::ff", + lefts: []string{ + "fd:3ffe:3200:2::1-fd:3ffe:3200:2::7f", + }, + sub: "", + }, + } + for _, c := range cases { + start1, _ := NewIPV6Addr(c.start1) + end1, _ := NewIPV6Addr(c.end1) + start2, _ := NewIPV6Addr(c.start2) + end2, _ := NewIPV6Addr(c.end2) + range1 := NewIPV6AddrRange(start1, end1) + range2 := NewIPV6AddrRange(start2, end2) + lefts, sub := range1.Substract(range2) + leftStrs := make([]string, len(lefts)) + for i := range lefts { + leftStrs[i] = lefts[i].String() + } + if jsonutils.Marshal(leftStrs).String() != jsonutils.Marshal(c.lefts).String() { + t.Errorf("%s substrct %s expect lefts %s got %s", range1.String(), range2.String(), jsonutils.Marshal(c.lefts).String(), jsonutils.Marshal(leftStrs).String()) + } else { + if sub == nil { + if c.sub != "" { + t.Errorf("%s substrct %s expect empty sub got %s", range1.String(), range2.String(), c.sub) + } + } else if sub.String() != c.sub { + t.Errorf("%s substrct %s expect sub %s got %s", range1.String(), range2.String(), sub.String(), c.sub) + } + } + } +} + +func TestV6RangeListMerge(t *testing.T) { + cases := []struct { + ranges []string + wants []string + }{ + { + ranges: []string{ + "fd:3ffe:3200:2::1-fd:3ffe:3200:2::ff", + "fd:3ffe:3200:2::ff-fd:3ffe:3200:2::ffff", + }, + wants: []string{ + "fd:3ffe:3200:2::1-fd:3ffe:3200:2::ffff", + }, + }, + { + ranges: []string{ + "fd:3ffe:3200:2::1-fd:3ffe:3200:2::ff", + "fd:3ffe:3200:2::80-fd:3ffe:3200:2::ffff", + }, + wants: []string{ + "fd:3ffe:3200:2::1-fd:3ffe:3200:2::ffff", + }, + }, + { + ranges: []string{ + "fd:3ffe:3200:2::80-fd:3ffe:3200:2::ffff", + "fd:3ffe:3200:2::1-fd:3ffe:3200:2::ff", + }, + wants: []string{ + "fd:3ffe:3200:2::1-fd:3ffe:3200:2::ffff", + }, + }, + } + for _, c := range cases { + ranges := make([]IPV6AddrRange, 0) + for _, r := range c.ranges { + parts := strings.Split(r, "-") + start, _ := NewIPV6Addr(parts[0]) + end, _ := NewIPV6Addr(parts[1]) + ranges = append(ranges, NewIPV6AddrRange(start, end)) + } + merged := IPV6AddrRangeList(ranges).Merge() + mergeStrs := make([]string, 0, len(merged)) + for _, m := range merged { + mergeStrs = append(mergeStrs, m.String()) + } + if jsonutils.Marshal(mergeStrs).String() != jsonutils.Marshal(c.wants).String() { + t.Errorf("merge %s expect %s got %s", jsonutils.Marshal(c.ranges).String(), jsonutils.Marshal(c.wants).String(), jsonutils.Marshal(mergeStrs).String()) + } + } +} diff --git a/util/netutils/netutils.go b/util/netutils/netutils.go index f95082d..ce2cbd1 100644 --- a/util/netutils/netutils.go +++ b/util/netutils/netutils.go @@ -15,15 +15,18 @@ package netutils import ( + "encoding/binary" "fmt" "math/bits" "math/rand" "net" + "sort" "strconv" "strings" "yunion.io/x/pkg/errors" "yunion.io/x/pkg/util/regutils" + "yunion.io/x/pkg/util/sortutils" ) const macChars = "0123456789abcdef" @@ -53,7 +56,7 @@ func FormatMacAddr(macAddr string) string { func IP2Number(ipstr string) (uint32, error) { parts := strings.Split(ipstr, ".") if len(parts) == 4 { - var num uint32 + bytes := make([]byte, 4) for i := 0; i < 4; i += 1 { n, e := strconv.Atoi(strings.TrimSpace(parts[i])) if e != nil { @@ -62,37 +65,17 @@ func IP2Number(ipstr string) (uint32, error) { if n < 0 || n > 255 { return 0, ErrOutOfRange } - num = num | (uint32(n) << uint32(24-i*8)) + bytes[i] = byte(n) } - return num, nil + return binary.BigEndian.Uint32(bytes), nil } return 0, ErrInvalidIPAddr // fmt.Errorf("invalid ip address %s", ipstr) } -/*func IP2Bytes(ipstr string) ([]byte, error) { - parts := strings.Split(ipstr, ".") - if len(parts) == 4 { - bytes := make([]byte, 4) - for i := 0; i < 4; i += 1 { - n, e := strconv.Atoi(parts[i]) - if e != nil { - return nil, fmt.Errorf("invalid number %s", parts[i]) - } - bytes[i] = byte(n) - } - return bytes, nil - } - return nil, fmt.Errorf("invalid ip address %s", ipstr) -}*/ - func Number2Bytes(num uint32) []byte { - a := num >> 24 - num -= a << 24 - b := num >> 16 - num -= b << 16 - c := num >> 8 - num -= c << 8 - return []byte{byte(a), byte(b), byte(c), byte(num)} + ret := make([]byte, 4) + binary.BigEndian.PutUint32(ret, num) + return ret } func Number2IP(num uint32) string { @@ -144,11 +127,11 @@ func (addr IPV4Addr) String() string { } func (addr IPV4Addr) ToBytes() []byte { - a := byte((addr & 0xff000000) >> 24) - b := byte((addr & 0x00ff0000) >> 16) - c := byte((addr & 0x0000ff00) >> 8) - d := byte(addr & 0x000000ff) - return []byte{a, b, c, d} + return Number2Bytes(uint32(addr)) +} + +func (addr IPV4Addr) ToIP() net.IP { + return net.IP(addr.ToBytes()) } func (addr IPV4Addr) ToMac(prefix string) string { @@ -214,7 +197,7 @@ func (ar IPV4AddrRange) EndIp() IPV4Addr { return ar.end } -func (ar IPV4AddrRange) Merge(ar2 IPV4AddrRange) (*IPV4AddrRange, bool) { +func (ar IPV4AddrRange) Merge(ar2 IPV4AddrRange) (IPV4AddrRange, bool) { if ar.IsOverlap(ar2) || ar.end+1 == ar2.start || ar2.end+1 == ar.start { if ar2.start < ar.start { ar.start = ar2.start @@ -222,9 +205,9 @@ func (ar IPV4AddrRange) Merge(ar2 IPV4AddrRange) (*IPV4AddrRange, bool) { if ar2.end > ar.end { ar.end = ar2.end } - return &ar, true + return ar, true } - return nil, false + return ar, false } func (ar IPV4AddrRange) IsOverlap(ar2 IPV4AddrRange) bool { @@ -235,7 +218,23 @@ func (ar IPV4AddrRange) IsOverlap(ar2 IPV4AddrRange) bool { } } +func (pref IPV4Prefix) ToIPNet() *net.IPNet { + return &net.IPNet{ + IP: pref.Address.ToIP(), + Mask: net.CIDRMask(int(pref.MaskLen), 32), + } +} + func (ar IPV4AddrRange) ToIPNets() []*net.IPNet { + r := []*net.IPNet{} + mms := ar.ToPrefixes() + for _, mm := range mms { + r = append(r, mm.ToIPNet()) + } + return r +} + +/*func (ar IPV4AddrRange) ToIPNets() []*net.IPNet { r := []*net.IPNet{} mms := ar.ToMaskMatches() for _, mm := range mms { @@ -271,57 +270,86 @@ func (ar IPV4AddrRange) ToMaskMatches() [][2]uint32 { sp = sp + b } return r +}*/ + +func (ar IPV4AddrRange) ToPrefixes() []IPV4Prefix { + prefixes := make([]IPV4Prefix, 0) + sp := ar.StartIp() + ep := ar.EndIp() + for sp <= ep { + masklen := int8(32) + for masklen > 0 && sp.NetAddr(masklen-1) == sp && sp.BroadcastAddr(masklen-1) <= ep { + masklen-- + } + if masklen == 0 { + prefixes = append(prefixes, NewIPV4PrefixFromAddr(sp, 0)) + break + } + prefixes = append(prefixes, NewIPV4PrefixFromAddr(sp, masklen)) + sp = sp.BroadcastAddr(masklen).StepUp() + } + return prefixes +} + +func (ar IPV4AddrRange) Substract(ar2 IPV4AddrRange) ([]IPV4AddrRange, *IPV4AddrRange) { + lefts, overlap, sub := ar.Substract2(ar2) + var subp *IPV4AddrRange + if overlap { + subp = &sub + } + return lefts, subp } -func (ar IPV4AddrRange) Substract(ar2 IPV4AddrRange) (lefts []IPV4AddrRange, sub *IPV4AddrRange) { - lefts = []IPV4AddrRange{} +func (ar IPV4AddrRange) Substract2(ar2 IPV4AddrRange) ([]IPV4AddrRange, bool, IPV4AddrRange) { + lefts := []IPV4AddrRange{} // no intersection, no substract if ar.end < ar2.start || ar.start > ar2.end { lefts = append(lefts, ar) - return + return lefts, false, IPV4AddrRange{} } // ar contains ar2 if ar.ContainsRange(ar2) { - nns := [][2]int64{ - [2]int64{int64(ar.start), int64(ar2.start) - 1}, - [2]int64{int64(ar2.end) + 1, int64(ar.end)}, - } - for _, nn := range nns { - if nn[0] <= nn[1] { - lefts = append(lefts, NewIPV4AddrRange(IPV4Addr(nn[0]), IPV4Addr(nn[1]))) - } + if ar.start == ar2.start && ar.end == ar2.end { + // lefts empty + } else if ar.start < ar2.start && ar.end == ar2.end { + lefts = append(lefts, + NewIPV4AddrRange(ar.start, ar2.start.StepDown()), + ) + } else if ar.start == ar2.start && ar.end > ar2.end { + lefts = append(lefts, + NewIPV4AddrRange(ar2.end.StepUp(), ar.end), + ) + } else { + lefts = append(lefts, + NewIPV4AddrRange(ar.start, ar2.start.StepDown()), + NewIPV4AddrRange(ar2.end.StepUp(), ar.end), + ) } - ar2_ := ar2 - sub = &ar2_ - return + return lefts, true, ar2 } // ar contained by ar2 if ar2.ContainsRange(ar) { - ar_ := ar - sub = &ar_ - return + return lefts, true, ar } // intersect, ar on the left if ar.start < ar2.start && ar.end >= ar2.start { - lefts = append(lefts, NewIPV4AddrRange(ar.start, ar2.start-1)) + lefts = append(lefts, NewIPV4AddrRange(ar.start, ar2.start.StepDown())) sub_ := NewIPV4AddrRange(ar2.start, ar.end) - sub = &sub_ - return + return lefts, true, sub_ } // intersect, ar on the right if ar.start <= ar2.end && ar.end > ar2.end { - lefts = append(lefts, NewIPV4AddrRange(ar2.end+1, ar.end)) + lefts = append(lefts, NewIPV4AddrRange(ar2.end.StepUp(), ar.end)) sub_ := NewIPV4AddrRange(ar.start, ar2.end) - sub = &sub_ - return + return lefts, true, sub_ } // no intersection - return + return lefts, false, IPV4AddrRange{} } func (ar IPV4AddrRange) equals(ar2 IPV4AddrRange) bool { @@ -406,6 +434,15 @@ func NewIPV4Prefix(prefix string) (IPV4Prefix, error) { return pref, nil } +func NewIPV4PrefixFromAddr(addr IPV4Addr, masklen int8) IPV4Prefix { + pref := IPV4Prefix{ + Address: addr.NetAddr(masklen), + MaskLen: masklen, + } + pref.ipRange = pref.ToIPRange() + return pref +} + func (prefix IPV4Prefix) ToIPRange() IPV4AddrRange { start := prefix.Address.NetAddr(prefix.MaskLen) end := prefix.Address.BroadcastAddr(prefix.MaskLen) @@ -553,3 +590,80 @@ func Netlen2Mask(netmasklen int) string { } return mask } + +type IPV4AddrRangeList []IPV4AddrRange + +func (rl IPV4AddrRangeList) Len() int { + return len(rl) +} + +func (rl IPV4AddrRangeList) Swap(i, j int) { + rl[i], rl[j] = rl[j], rl[i] +} + +func (rl IPV4AddrRangeList) Less(i, j int) bool { + return rl[i].Compare(rl[j]) == sortutils.Less +} + +func (v4range IPV4AddrRange) Compare(r2 IPV4AddrRange) sortutils.CompareResult { + if v4range.start < r2.start { + return sortutils.Less + } else if v4range.start > r2.start { + return sortutils.More + } else { + // start equals, compare ends + if v4range.end > r2.end { + return sortutils.Less + } else if v4range.end < r2.end { + return sortutils.More + } else { + return sortutils.Equal + } + } +} + +func (rl IPV4AddrRangeList) Merge() []IPV4AddrRange { + sort.Sort(rl) + ret := make([]IPV4AddrRange, 0, len(rl)) + for i := range rl { + if i == 0 { + ret = append(ret, rl[i]) + } else { + result, isMerged := ret[len(ret)-1].Merge(rl[i]) + if isMerged { + ret[len(ret)-1] = result + } else { + ret = append(ret, rl[i]) + } + } + } + return ret +} + +func (rl IPV4AddrRangeList) String() string { + strs := make([]string, len(rl)) + for i := range rl { + strs[i] = rl[i].String() + } + return strings.Join(strs, ",") +} + +var IPV4Zero = IPV4Addr(0) +var IPV4Ones = IPV4Addr(0xffffffff) +var AllIPV4AddrRange = IPV4AddrRange{ + start: IPV4Zero, + end: IPV4Ones, +} + +func (r IPV4AddrRange) IsAll() bool { + return r.start == IPV4Zero && r.end == IPV4Ones +} + +func (rl IPV4AddrRangeList) Substract(addrRange IPV4AddrRange) []IPV4AddrRange { + ret := make([]IPV4AddrRange, 0) + for i := range rl { + lefts, _ := rl[i].Substract(addrRange) + ret = append(ret, lefts...) + } + return ret +} diff --git a/util/netutils/netutils_test.go b/util/netutils/netutils_test.go index 2850f43..3fad774 100644 --- a/util/netutils/netutils_test.go +++ b/util/netutils/netutils_test.go @@ -16,7 +16,10 @@ package netutils import ( "fmt" + "strings" "testing" + + "yunion.io/x/jsonutils" ) func TestFormatMacAddr(t *testing.T) { @@ -163,8 +166,8 @@ func TestIPV4AddrRange_Substract(t *testing.T) { } return i } - ar := nir(ni("192.168.2.0"), ni("192.168.2.255")) t.Run("disjoint (left)", func(t *testing.T) { + ar := nir(ni("192.168.2.0"), ni("192.168.2.255")) ar2 := nir(ni("192.168.1.2"), ni("192.168.1.255")) lefts, sub := ar.Substract(ar2) if len(lefts) != 1 || !lefts[0].equals(ar) { @@ -175,6 +178,7 @@ func TestIPV4AddrRange_Substract(t *testing.T) { } }) t.Run("overlap (cut right)", func(t *testing.T) { + ar := nir(ni("192.168.2.0"), ni("192.168.2.255")) ar2 := nir(ni("192.168.2.128"), ni("192.168.3.255")) lefts, sub := ar.Substract(ar2) if len(lefts) != 1 || !lefts[0].equals(nir(ni("192.168.2.0"), ni("192.168.2.127"))) { @@ -185,6 +189,7 @@ func TestIPV4AddrRange_Substract(t *testing.T) { } }) t.Run("contains (true subset)", func(t *testing.T) { + ar := nir(ni("192.168.2.0"), ni("192.168.2.255")) ar2 := nir(ni("192.168.2.33"), ni("192.168.2.44")) lefts, sub := ar.Substract(ar2) if len(lefts) != 2 || !lefts[0].equals(nir(ni("192.168.2.0"), ni("192.168.2.32"))) || !lefts[1].equals(nir(ni("192.168.2.45"), ni("192.168.2.255"))) { @@ -195,26 +200,29 @@ func TestIPV4AddrRange_Substract(t *testing.T) { } }) t.Run("contains (align left)", func(t *testing.T) { + ar := nir(ni("192.168.2.0"), ni("192.168.2.255")) ar2 := nir(ni("192.168.2.0"), ni("192.168.2.33")) lefts, sub := ar.Substract(ar2) if len(lefts) != 1 || !lefts[0].equals(nir(ni("192.168.2.34"), ni("192.168.2.255"))) { - t.Fatalf("bad `lefts`") + t.Fatalf("bad ar %s substract ar2 %s `lefts` %s", ar.String(), ar2.String(), IPV4AddrRangeList(lefts).String()) } if !sub.equals(nir(ni("192.168.2.0"), ni("192.168.2.33"))) { t.Fatalf("bad `sub`") } }) t.Run("contains (align right)", func(t *testing.T) { + ar := nir(ni("192.168.2.0"), ni("192.168.2.255")) ar2 := nir(ni("192.168.2.44"), ni("192.168.2.255")) lefts, sub := ar.Substract(ar2) if len(lefts) != 1 || !lefts[0].equals(nir(ni("192.168.2.0"), ni("192.168.2.43"))) { - t.Fatalf("bad `lefts`") + t.Fatalf("bad `lefts` %s", IPV4AddrRangeList(lefts).String()) } if !sub.equals(nir(ni("192.168.2.44"), ni("192.168.2.255"))) { t.Fatalf("bad `sub`") } }) t.Run("contained by", func(t *testing.T) { + ar := nir(ni("192.168.2.0"), ni("192.168.2.255")) ar2 := nir(ni("192.168.1.255"), ni("192.168.3.0")) lefts, sub := ar.Substract(ar2) if len(lefts) != 0 { @@ -224,6 +232,18 @@ func TestIPV4AddrRange_Substract(t *testing.T) { t.Fatalf("bad `sub`") } }) + + t.Run("192.168.2.0/25 - 192.168.2.0/24", func(t *testing.T) { + ar := nir(ni("192.168.2.0"), ni("192.168.2.127")) + ar2 := nir(ni("192.168.2.0"), ni("192.168.2.255")) + lefts, sub := ar.Substract(ar2) + if len(lefts) != 0 { + t.Fatalf("bad ar %s substract ar2 %s `lefts` %s", ar.String(), ar2.String(), IPV4AddrRangeList(lefts).String()) + } + if !sub.equals(ar) { + t.Fatalf("bad `sub`") + } + }) } func TestNetlen2Mask(t *testing.T) { @@ -259,3 +279,168 @@ func TestNetlen2Mask(t *testing.T) { }) } } + +func TestPrefix2Range(t *testing.T) { + cases := []struct { + prefix string + rangeStr string + }{ + { + prefix: "0.0.0.0/0", + rangeStr: "0.0.0.0-255.255.255.255", + }, + { + prefix: "192.168.222.0/24", + rangeStr: "192.168.222.0-192.168.222.255", + }, + } + for _, c := range cases { + pref, err := NewIPV4Prefix(c.prefix) + if err != nil { + t.Errorf("prefix %s fail %s", c.prefix, err) + } else { + ipRange := pref.ToIPRange() + if ipRange.String() != c.rangeStr { + t.Errorf("prefix %s to range got %s want %s", pref.String(), ipRange.String(), c.rangeStr) + } + } + } +} + +func TestV4RangeListMerge(t *testing.T) { + cases := []struct { + ranges []string + wants []string + }{ + { + ranges: []string{ + "192.168.22.1-192.168.22.127", + "192.168.22.127-192.168.22.255", + }, + wants: []string{ + "192.168.22.1-192.168.22.255", + }, + }, + { + ranges: []string{ + "192.168.22.1-192.168.22.127", + "192.168.22.128-192.168.22.255", + }, + wants: []string{ + "192.168.22.1-192.168.22.255", + }, + }, + { + ranges: []string{ + "192.168.22.128-192.168.22.255", + "192.168.22.1-192.168.22.127", + }, + wants: []string{ + "192.168.22.1-192.168.22.255", + }, + }, + } + for _, c := range cases { + ranges := make([]IPV4AddrRange, 0) + for _, r := range c.ranges { + parts := strings.Split(r, "-") + start, _ := NewIPV4Addr(parts[0]) + end, _ := NewIPV4Addr(parts[1]) + ranges = append(ranges, NewIPV4AddrRange(start, end)) + } + merged := IPV4AddrRangeList(ranges).Merge() + mergeStrs := make([]string, 0, len(merged)) + for _, m := range merged { + mergeStrs = append(mergeStrs, m.String()) + } + if jsonutils.Marshal(mergeStrs).String() != jsonutils.Marshal(c.wants).String() { + t.Errorf("merge %s expect %s got %s", jsonutils.Marshal(c.ranges).String(), jsonutils.Marshal(c.wants).String(), jsonutils.Marshal(mergeStrs).String()) + } + } +} + +func TestV4RangeToPrefix(t *testing.T) { + cases := []struct { + start string + end string + prefixes []string + }{ + { + start: "192.168.22.0", + end: "192.168.22.255", + prefixes: []string{ + "192.168.22.0/24", + }, + }, + { + start: "192.168.22.0", + end: "192.168.23.255", + prefixes: []string{ + "192.168.22.0/23", + }, + }, + { + start: "192.168.22.0", + end: "192.168.23.0", + prefixes: []string{ + "192.168.22.0/24", + "192.168.23.0/32", + }, + }, + { + start: "192.168.21.255", + end: "192.168.23.0", + prefixes: []string{ + "192.168.21.255/32", + "192.168.22.0/24", + "192.168.23.0/32", + }, + }, + { + start: "192.168.21.254", + end: "192.168.23.0", + prefixes: []string{ + "192.168.21.254/31", + "192.168.22.0/24", + "192.168.23.0/32", + }, + }, + { + start: "192.168.21.254", + end: "192.168.23.0", + prefixes: []string{ + "192.168.21.254/31", + "192.168.22.0/24", + "192.168.23.0/32", + }, + }, + { + start: "0.0.0.0", + end: "255.255.255.255", + prefixes: []string{ + "0.0.0.0/0", + }, + }, + } + for _, c := range cases { + startIp, err := NewIPV4Addr(c.start) + if err != nil { + t.Errorf("NewIPV4Addr %s fail %s", c.start, err) + } else { + endIp, err := NewIPV4Addr(c.end) + if err != nil { + t.Errorf("NewIPV4Addr %s fail %s", c.end, err) + } else { + v4Range := NewIPV4AddrRange(startIp, endIp) + prefixes := v4Range.ToIPNets() + prefStr := make([]string, 0, len(prefixes)) + for i := range prefixes { + prefStr = append(prefStr, prefixes[i].String()) + } + if jsonutils.Marshal(prefStr).String() != jsonutils.Marshal(c.prefixes).String() { + t.Errorf("expect %s got %s", jsonutils.Marshal(c.prefixes).String(), jsonutils.Marshal(prefStr).String()) + } + } + } + } +} diff --git a/util/secrules/cut.go b/util/secrules/cut.go index cf954a1..6ecf826 100644 --- a/util/secrules/cut.go +++ b/util/secrules/cut.go @@ -21,18 +21,26 @@ import ( "sort" "yunion.io/x/pkg/util/netutils" + "yunion.io/x/pkg/util/regutils" ) type securityRuleCut struct { - r SecurityRule + r SecurityRule + protocolCut bool netCut bool portCut bool + + v4ranges []netutils.IPV4AddrRange + v6ranges []netutils.IPV6AddrRange } func (src *securityRuleCut) String() string { - s := fmt.Sprintf("[%s;protocolCut=%v;netCut=%v;portCut=%v]", - src.r.String(), src.protocolCut, src.netCut, src.portCut) + s := fmt.Sprintf("[%s;v4=%s;v6=%s;protocolCut=%v;netCut=%v;portCut=%v]", + src.r.String(), + netutils.IPV4AddrRangeList(src.v4ranges).String(), + netutils.IPV6AddrRangeList(src.v6ranges).String(), + src.protocolCut, src.netCut, src.portCut) return s } @@ -40,16 +48,62 @@ func (src *securityRuleCut) isCut() bool { return src.protocolCut && src.netCut && src.portCut } -type securityRuleCuts []securityRuleCut +func (src securityRuleCut) genRules() []SecurityRule { + src.v4ranges = netutils.IPV4AddrRangeList(src.v4ranges).Merge() + src.v6ranges = netutils.IPV6AddrRangeList(src.v6ranges).Merge() -func newSecurityRuleSetCuts(srs SecurityRuleSet) securityRuleCuts { - srcs := make(securityRuleCuts, len(srs)) - for i := range srcs { - srcs[i].r = srs[i] + rs := make([]SecurityRule, 0) + + if len(src.v4ranges) == 1 && src.v4ranges[0].IsAll() && len(src.v6ranges) == 1 && src.v6ranges[0].IsAll() { + rule := src.r + rule.IPNet = nil + rs = append(rs, rule) + return rs } - return srcs + for i := range src.v4ranges { + nets := src.v4ranges[i].ToIPNets() + for _, net := range nets { + rule := src.r + rule.IPNet = net + rs = append(rs, rule) + } + } + for i := range src.v6ranges { + nets := src.v6ranges[i].ToIPNets() + for _, net := range nets { + rule := src.r + rule.IPNet = net + rs = append(rs, rule) + } + } + return rs } +func newSecurityRuleSetCuts(r SecurityRule) securityRuleCuts { + var v4ranges []netutils.IPV4AddrRange + var v6ranges []netutils.IPV6AddrRange + if r.IPNet == nil { + // expand + v4ranges = append(v4ranges, netutils.AllIPV4AddrRange) + v6ranges = append(v6ranges, netutils.AllIPV6AddrRange) + } else { + if regutils.MatchCIDR(r.IPNet.String()) { + v4ranges = append(v4ranges, netutils.NewIPV4AddrRangeFromIPNet(r.IPNet)) + } else { + v6ranges = append(v6ranges, netutils.NewIPV6AddrRangeFromIPNet(r.IPNet)) + } + } + return []securityRuleCut{ + { + r: r, + v4ranges: v4ranges, + v6ranges: v6ranges, + }, + } +} + +type securityRuleCuts []securityRuleCut + func (srcs securityRuleCuts) String() string { buf := bytes.Buffer{} for i := range srcs { @@ -67,7 +121,7 @@ func (srcs securityRuleCuts) securityRuleSet() SecurityRuleSet { if src.isCut() { continue } - srs = append(srs, src.r) + srs = append(srs, src.genRules()...) } return srs } @@ -101,37 +155,45 @@ func (srcs securityRuleCuts) cutOutProtocol(protocol string) securityRuleCuts { return r } +func isV6(n *net.IPNet) bool { + return regutils.MatchCIDR6(n.String()) +} + func (srcs securityRuleCuts) cutOutIPNet(protocol string, n *net.IPNet) securityRuleCuts { r := securityRuleCuts{} - ar2 := netutils.NewIPV4AddrRangeFromIPNet(n) - for _, src := range srcs { + isWildMatch := isWildNet(n) + isV6 := false + var v4n netutils.IPV4AddrRange + var v6n netutils.IPV6AddrRange + if !isWildMatch { + if regutils.MatchCIDR6(n.String()) { + isV6 = true + v6n = netutils.NewIPV6AddrRangeFromIPNet(n) + } else { + v4n = netutils.NewIPV4AddrRangeFromIPNet(n) + } + } + for i := range srcs { + src := srcs[i] + if src.netCut { + r = append(r, src) + continue + } if src.r.Protocol != protocol && protocol != PROTO_ANY { - src_ := src - r = append(r, src_) + r = append(r, src) continue } - sr := src.r - ar := netutils.NewIPV4AddrRangeFromIPNet(sr.IPNet) - left, subs := ar.Substract(ar2) - for _, l := range left { - // retain - nets := l.ToIPNets() - for _, net_ := range nets { - src_ := src - src_.r.IPNet = net_ - r = append(r, src_) - } + if isWildMatch { + src.netCut = true + r = append(r, src) + continue } - if subs != nil { - // cut - nets := subs.ToIPNets() - for _, net_ := range nets { - src_ := src - src_.r.IPNet = net_ - src_.netCut = true - r = append(r, src_) - } + if isV6 { + src.v6ranges = netutils.IPV6AddrRangeList(src.v6ranges).Substract(v6n) + } else { + src.v4ranges = netutils.IPV4AddrRangeList(src.v4ranges).Substract(v4n) } + r = append(r, src) } return r } diff --git a/util/secrules/secrules.go b/util/secrules/secrules.go index 6a026e4..f253284 100644 --- a/util/secrules/secrules.go +++ b/util/secrules/secrules.go @@ -53,9 +53,15 @@ const ( ) type SecurityRule struct { - Priority int // [1, 100] - Action TSecurityRuleAction - IPNet *net.IPNet + Priority int // [1, 100] + Action TSecurityRuleAction + + // distinguish between + // * "" (empty) allow all ipv4 and ipv6 + // * 0.0.0.0/0 allow all ipv4 + // * ::/0 allow all IPv6 + IPNet *net.IPNet + Protocol string Direction TSecurityRuleDirection PortStart int @@ -176,26 +182,29 @@ func ParseSecurityRule(pattern string) (*SecurityRule, error) { } func (rule *SecurityRule) ParseCIDR(cidr string) bool { - if regutils.MatchCIDR(cidr) { + if regutils.MatchCIDR(cidr) || regutils.MatchCIDR6(cidr) { _, rule.IPNet, _ = net.ParseCIDR(cidr) return true } - if regutils.MatchIPAddr(cidr) { + if regutils.MatchIP4Addr(cidr) { rule.IPNet = &net.IPNet{ IP: net.ParseIP(cidr), Mask: net.CIDRMask(32, 32), } return true + } else if regutils.MatchIP6Addr(cidr) { + rule.IPNet = &net.IPNet{ + IP: net.ParseIP(cidr), + Mask: net.CIDRMask(128, 128), + } + return true } - rule.IPNet = &net.IPNet{ - IP: net.IPv4zero, - Mask: net.CIDRMask(0, 32), - } + rule.IPNet = nil return false } func (rule *SecurityRule) IsWildMatch() bool { - return rule.IPNet.String() == "0.0.0.0/0" && + return rule.IPNet == nil && rule.Protocol == PROTO_ANY && len(rule.Ports) == 0 && ((rule.PortStart <= 0 && rule.PortEnd <= 0) || (rule.PortStart == 1 && rule.PortEnd == 65535)) @@ -275,8 +284,8 @@ func (rule SecurityRule) merge(r SecurityRule) SecurityRule { } func (rule SecurityRule) getIPKey() string { - if rule.IPNet == nil || rule.IPNet.String() == "0.0.0.0/0" { - return "0.0.0.0/0" + if rule.IPNet == nil { + return "" } return rule.IPNet.String() } @@ -396,12 +405,21 @@ func (rule *SecurityRule) GetPortsString() string { func (rule *SecurityRule) String() (result string) { s := []string{} s = append(s, string(rule.Direction)+":"+string(rule.Action)) - cidr := rule.IPNet.String() - if cidr != "0.0.0.0/0" { - if ones, _ := rule.IPNet.Mask.Size(); ones < 32 { - s = append(s, cidr) - } else { - s = append(s, rule.IPNet.IP.String()) + + if rule.IPNet != nil { + cidr := rule.IPNet.String() + if regutils.MatchCIDR(cidr) { + if ones, _ := rule.IPNet.Mask.Size(); ones < 32 { + s = append(s, cidr) + } else { + s = append(s, rule.IPNet.IP.String()) + } + } else if regutils.MatchCIDR6(cidr) { + if ones, _ := rule.IPNet.Mask.Size(); ones < 128 { + s = append(s, cidr) + } else { + s = append(s, rule.IPNet.IP.String()) + } } } @@ -428,11 +446,13 @@ func (rule *SecurityRule) netEquals(r *SecurityRule) bool { return net0 == net1 } -func (rule *SecurityRule) cutOut(r *SecurityRule) SecurityRuleSet { - srcs := securityRuleCuts{securityRuleCut{r: *rule}} +func (rule *SecurityRule) cutOut(r SecurityRule) SecurityRuleSet { + srcs := newSecurityRuleSetCuts(*rule) // securityRuleCuts{securityRuleCut{r: *rule}} //a := srcs srcs = srcs.cutOutProtocol(r.Protocol) + log.Debugf("cutOutProtocol: rule %s cut %s output %s", rule.String(), r.Protocol, srcs.String()) srcs = srcs.cutOutIPNet(r.Protocol, r.IPNet) + log.Debugf("cutOutIPNet: rule %s cut %s output %s", rule.String(), r.IPNet, srcs.String()) if len(r.Ports) > 0 { srcs = srcs.cutOutPorts(r.Protocol, []uint16(newPortsFromInts(r.Ports...))) } else if r.PortStart > 0 && r.PortEnd > 0 { @@ -443,5 +463,6 @@ func (rule *SecurityRule) cutOut(r *SecurityRule) SecurityRuleSet { //fmt.Printf("a %s\n", a) //fmt.Printf("b %s\n", srcs) srs := srcs.securityRuleSet() + log.Debugf("rule %s cut %s output %s", rule.String(), r.String(), srs.String()) return srs } diff --git a/util/secrules/secrules_test.go b/util/secrules/secrules_test.go index e633388..45d362a 100644 --- a/util/secrules/secrules_test.go +++ b/util/secrules/secrules_test.go @@ -29,20 +29,32 @@ func TestIsFunction(t *testing.T) { {s: "out:allow any", isWildMatch: true}, {s: "in:deny any", isWildMatch: true}, {s: "out:deny any", isWildMatch: true}, - {s: "in:allow 0.0.0.0/0 any", s2: "in:allow any", isWildMatch: true}, - {s: "in:allow 0.0.0.0/0 tcp", s2: "in:allow tcp"}, - {s: "in:allow 0.0.0.0/0 udp", s2: "in:allow udp"}, - {s: "in:allow 0.0.0.0/0 icmp", s2: "in:allow icmp"}, + {s: "in:allow 0.0.0.0/0 any", s2: "in:allow 0.0.0.0/0 any"}, + {s: "in:allow ::/0 any", s2: "in:allow ::/0 any"}, + {s: "in:allow tcp", s2: "in:allow tcp"}, + {s: "in:allow 0.0.0.0/0 tcp", s2: "in:allow 0.0.0.0/0 tcp"}, + {s: "in:allow ::/0 tcp", s2: "in:allow ::/0 tcp"}, + {s: "in:allow udp", s2: "in:allow udp"}, + {s: "in:allow 0.0.0.0/0 udp", s2: "in:allow 0.0.0.0/0 udp"}, + {s: "in:allow ::/0 udp", s2: "in:allow ::/0 udp"}, + {s: "in:allow icmp", s2: "in:allow icmp"}, + {s: "in:allow 0.0.0.0/0 icmp", s2: "in:allow 0.0.0.0/0 icmp"}, + {s: "in:allow ::/0 icmp", s2: "in:allow ::/0 icmp"}, {s: "in:allow 10.0.8.0/24 any", s2: "in:allow 10.0.8.0/24 any"}, + {s: "in:allow fd:3ffe:3200:1220::/64 any", s2: "in:allow fd:3ffe:3200:1220::/64 any"}, {s: "in:allow 10.0.9.0/24 tcp", s2: "in:allow 10.0.9.0/24 tcp"}, {s: "in:allow 10.0.10.0/24 udp", s2: "in:allow 10.0.10.0/24 udp"}, {s: "in:allow 10.0.11.0/24 icmp", s2: "in:allow 10.0.11.0/24 icmp"}, {s: "in:allow 10.0.8.0/24 tcp 1-100", s2: "in:allow 10.0.8.0/24 tcp 1-100"}, + {s: "in:allow fd:3ffe:3200:8::/64 tcp 1-100", s2: "in:allow fd:3ffe:3200:8::/64 tcp 1-100"}, {s: "in:allow 10.0.8.0/24 tcp 100-1", s2: "in:allow 10.0.8.0/24 tcp 1-100"}, {s: "in:allow 10.0.8.0/24 tcp 1,100", s2: "in:allow 10.0.8.0/24 tcp 1,100"}, + {s: "in:allow fd:3ffe:3200:8::/64 tcp 1,100", s2: "in:allow fd:3ffe:3200:8::/64 tcp 1,100"}, {s: "in:allow 10.0.8.0/24 tcp 100", s2: "in:allow 10.0.8.0/24 tcp 100"}, {s: "in:allow 0.0.0.0 tcp", s2: "in:allow 0.0.0.0 tcp"}, - {s: "in:allow 0.0.0.0 tcp", s2: "in:allow 0.0.0.0 tcp"}, + {s: "in:allow :: tcp", s2: "in:allow :: tcp"}, + {s: "in:allow 0.0.0.0 udp", s2: "in:allow 0.0.0.0 udp"}, + {s: "in:allow :: udp", s2: "in:allow :: udp"}, {s: "in:deny", bad: true}, {s: "in:allow", bad: true}, {s: "in:allow 0.0.0.0/0 ip", bad: true}, diff --git a/util/secrules/secruleset.go b/util/secrules/secruleset.go index cf6cd27..32a9fce 100644 --- a/util/secrules/secruleset.go +++ b/util/secrules/secruleset.go @@ -19,9 +19,74 @@ import ( "net" "sort" + "yunion.io/x/log" + + "yunion.io/x/pkg/gotypes" "yunion.io/x/pkg/util/netutils" + "yunion.io/x/pkg/util/regutils" + "yunion.io/x/pkg/util/sortutils" ) +func isWildNet(ipnet *net.IPNet) bool { + return gotypes.IsNil(ipnet) +} + +func compareIPNet(ipnet1, ipnet2 *net.IPNet) sortutils.CompareResult { + srsIPi := ipnet1.String() + srsIPj := ipnet2.String() + if !isWildNet(ipnet1) && !isWildNet(ipnet2) { + if srsIPi != srsIPj { + isIPv6i := regutils.MatchCIDR6(srsIPi) + isIPv6j := regutils.MatchCIDR6(srsIPj) + if isIPv6i && isIPv6j { + // compare two ipv6 + v6Rangei := netutils.NewIPV6AddrRangeFromIPNet(ipnet1) + v6Rangej := netutils.NewIPV6AddrRangeFromIPNet(ipnet2) + return v6Rangei.Compare(v6Rangej) + } else if !isIPv6i && !isIPv6j { + // compare two ipv4 + v4Rangei := netutils.NewIPV4AddrRangeFromIPNet(ipnet1) + v4Rangej := netutils.NewIPV4AddrRangeFromIPNet(ipnet2) + return v4Rangei.Compare(v4Rangej) + } else if isIPv6i && !isIPv6j { + // v4 first + return sortutils.More + } else { + // if !isIPv6i && isIPv6j { + // v4 first + return sortutils.Less + } + } else { + return sortutils.Equal + } + } else if isWildNet(ipnet1) && !isWildNet(ipnet2) { + return sortutils.Less + } else if isWildNet(ipnet1) && !isWildNet(ipnet2) { + return sortutils.More + } else { + // both wild net, go to next + return sortutils.Equal + } +} + +func isWildProtocol(protocol string) bool { + return len(protocol) == 0 || protocol == PROTO_ANY +} + +func compareProtocol(protocol1, protocol2 string) sortutils.CompareResult { + isWild1 := isWildProtocol(protocol1) + isWild2 := isWildProtocol(protocol1) + if isWild1 && isWild2 { + return sortutils.Equal + } else if isWild1 && !isWild2 { + return sortutils.Less + } else if !isWild1 && isWild2 { + return sortutils.More + } else { + return sortutils.CompareString(protocol1, protocol2) + } +} + type SecurityRuleSet []SecurityRule func (srs SecurityRuleSet) Len() int { @@ -35,10 +100,30 @@ func (srs SecurityRuleSet) Swap(i, j int) { func (srs SecurityRuleSet) Less(i, j int) bool { if srs[i].Priority > srs[j].Priority { return true - } else if srs[i].Priority == srs[j].Priority { - return srs[i].String() < srs[j].String() + } else if srs[i].Priority < srs[j].Priority { + return false + } + // priority equals, compare ipnet + { + result := compareIPNet(srs[i].IPNet, srs[j].IPNet) + switch result { + case sortutils.Less: + return true + case sortutils.More: + return false + } } - return false + // compare protocol + { + result := compareProtocol(srs[i].Protocol, srs[j].Protocol) + switch result { + case sortutils.Less: + return true + case sortutils.More: + return false + } + } + return srs[i].String() < srs[j].String() } func (srs SecurityRuleSet) stringList() []string { @@ -83,23 +168,41 @@ func (srs SecurityRuleSet) equals(srs1 SecurityRuleSet) bool { // - ordered by priority // - same direction // -func (srs SecurityRuleSet) AllowList() SecurityRuleSet { - srs = srs.uniq() - r := SecurityRuleSet{} - wq := make(SecurityRuleSet, len(srs)) - copy(wq, srs) - - for len(wq) > 0 { - sr := wq[0] - if sr.Action == SecurityRuleAllow { - r = append(r, sr) - wq = wq[1:] - continue +/*func (srs SecurityRuleSet) AllowList() SecurityRuleSet { + allowList := SecurityRuleSet{} + denyList := SecurityRuleSet{} + + for i := range srs { + if srs[i].Action == SecurityRuleAllow { + allowList = append(allowList, srs[i]) + } else { + denyList = append(denyList, srs[i]) } - wq = wq.cutOutFirst() } - r = r.collapse() - return r + + sort.Sort(allowList) + allowList.uniq() + + if len(denyList) > 0 { + sort.Sort(denyList) + denyList.uniq() + + for i := range denyList { + allowList = allowList.cutOut(denyList[i]) + } + } + + allowList = allowList.collapse() + return allowList +} + +func (srs SecurityRuleSet) cutOut(r SecurityRule) SecurityRuleSet { + cutRes := SecurityRuleSet{} + for i := range srs { + cutout := srs[i].cutOut(r) + cutRes = append(cutRes, cutout...) + } + return cutRes } func (srs SecurityRuleSet) cutOutFirst() SecurityRuleSet { @@ -107,7 +210,7 @@ func (srs SecurityRuleSet) cutOutFirst() SecurityRuleSet { if len(srs) == 0 { return r } - sr := &srs[0] + sr := srs[0] srs_ := srs[1:] for _, sr_ := range srs_ { @@ -119,7 +222,7 @@ func (srs SecurityRuleSet) cutOutFirst() SecurityRuleSet { r = append(r, cut...) } return r -} +}*/ // remove duplicate rules func (srs SecurityRuleSet) uniq() SecurityRuleSet { @@ -155,14 +258,26 @@ func (srs SecurityRuleSet) collapse() SecurityRuleSet { sort.Slice(srs1, func(i, j int) bool { sr0 := &srs1[i] sr1 := &srs1[j] - if sr0.Protocol != sr1.Protocol { - return sr0.Protocol < sr1.Protocol + { + result := compareProtocol(sr0.Protocol, sr1.Protocol) + switch result { + case sortutils.Less: + return true + case sortutils.More: + return false + } } - net0 := sr0.IPNet.String() - net1 := sr1.IPNet.String() - if net0 != net1 { - return net0 < net1 + + { + result := compareIPNet(sr0.IPNet, sr1.IPNet) + switch result { + case sortutils.Less: + return true + case sortutils.More: + return false + } } + if sr0.PortStart > 0 && sr0.PortEnd > 0 { if sr1.PortStart > 0 && sr1.PortEnd > 0 { return sr0.PortStart < sr1.PortStart @@ -234,21 +349,30 @@ func (srs SecurityRuleSet) collapse() SecurityRuleSet { sort.Slice(srs1, func(i, j int) bool { sr0 := &srs1[i] sr1 := &srs1[j] - if sr0.Protocol != sr1.Protocol { - return sr0.Protocol < sr1.Protocol + { + result := compareProtocol(sr0.Protocol, sr1.Protocol) + switch result { + case sortutils.Less: + return true + case sortutils.More: + return false + } } if sr0.GetPortsString() != sr1.GetPortsString() { return sr0.GetPortsString() < sr1.GetPortsString() } - range0 := netutils.NewIPV4AddrRangeFromIPNet(sr0.IPNet) - range1 := netutils.NewIPV4AddrRangeFromIPNet(sr1.IPNet) - if range0.StartIp() != range1.StartIp() { - return range0.StartIp() < range1.StartIp() - } - if range0.EndIp() != range1.EndIp() { - return range0.EndIp() < range1.EndIp() + + { + result := compareIPNet(sr0.IPNet, sr1.IPNet) + switch result { + case sortutils.Less: + return true + case sortutils.More: + return false + } } + return sr0.Priority < sr1.Priority }) @@ -282,28 +406,77 @@ func (srs SecurityRuleSet) collapse() SecurityRuleSet { } func (srs SecurityRuleSet) mergeNet() SecurityRuleSet { - result := SecurityRuleSet{} - ranges := []netutils.IPV4AddrRange{} + ranges4 := []netutils.IPV4AddrRange{} + ranges6 := []netutils.IPV6AddrRange{} + for i := 0; i < len(srs); i++ { - if i == 0 { - ranges = append(ranges, netutils.NewIPV4AddrRangeFromIPNet(srs[i].IPNet)) - continue - } - preNet := ranges[len(ranges)-1] - nextNet := netutils.NewIPV4AddrRangeFromIPNet(srs[i].IPNet) - if net, ok := preNet.Merge(nextNet); ok { - ranges[len(ranges)-1] = *net - continue + if isWildNet(srs[i].IPNet) { + // wild mark + ranges4 = append(ranges4, netutils.AllIPV4AddrRange) + ranges6 = append(ranges6, netutils.AllIPV6AddrRange) + } else { + cidr := srs[i].IPNet.String() + if regutils.MatchCIDR6(cidr) { + // ipv6 + ranges6 = append(ranges6, netutils.NewIPV6AddrRangeFromIPNet(srs[i].IPNet)) + } else { + ranges4 = append(ranges4, netutils.NewIPV4AddrRangeFromIPNet(srs[i].IPNet)) + } } - ranges = append(ranges, nextNet) } + + ranges4 = netutils.IPV4AddrRangeList(ranges4).Merge() + ranges6 = netutils.IPV6AddrRangeList(ranges6).Merge() + nets := []*net.IPNet{} - for _, addr := range ranges { - nets = append(nets, addr.ToIPNets()...) + hasWildNet4 := false + hasWildNet6 := false + for i := range ranges4 { + addr := ranges4[i] + for _, n := range addr.ToIPNets() { + if n.String() == "0.0.0.0/0" { + hasWildNet4 = true + } else { + nets = append(nets, n) + log.Debugf("merge v4 %s", n.String()) + } + } + } + for i := range ranges6 { + addr := ranges6[i] + for _, n := range addr.ToIPNets() { + if n.String() == "::/0" { + hasWildNet6 = true + } else { + nets = append(nets, n) + } + } + } + + result := SecurityRuleSet{} + if hasWildNet4 && hasWildNet6 { + val := srs[0] + val.IPNet = nil + result = append(result, val) + } else if hasWildNet4 { + val := srs[0] + val.IPNet = &net.IPNet{ + IP: net.IPv4zero, + Mask: net.CIDRMask(0, 32), + } + result = append(result, val) + } else if hasWildNet6 { + val := srs[0] + val.IPNet = &net.IPNet{ + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 128), + } + result = append(result, val) } for _, net := range nets { - srs[0].IPNet = net - result = append(result, srs[0]) + val := srs[0] + val.IPNet = net + result = append(result, val) } return result } diff --git a/util/secrules/secruleset_test.go b/util/secrules/secruleset_test.go index 78c6cd0..1c67ccd 100644 --- a/util/secrules/secruleset_test.go +++ b/util/secrules/secruleset_test.go @@ -14,19 +14,12 @@ package secrules -import ( - "sort" - "testing" - - "yunion.io/x/pkg/util/netutils" -) - -func TestSecRuleSet_AllowList(t *testing.T) { +/*func TestSecRuleSet_AllowList(t *testing.T) { dieIf := func(t *testing.T, srs0, srs1 SecurityRuleSet) { sort.Sort(srs0) sort.Sort(srs1) if !srs0.equals(srs1) { - t.Fatalf("not equal:\n%s\n%s", srs0, srs1) + t.Fatalf("not equal:\nsrs0=%s\nsrs1=%s", srs0, srs1) } } dieIfNotEquals := func(t *testing.T, srs0, srs1 SecurityRuleSet) { @@ -35,7 +28,7 @@ func TestSecRuleSet_AllowList(t *testing.T) { sort.Sort(sr0) sort.Sort(sr1) if !sr0.equals(sr1) { - t.Fatalf("not equal:\n%s\n%s", sr0, sr1) + t.Fatalf("not equal:\nsr0=%s\nsr1=%s", sr0, sr1) } } t.Run("empty", func(t *testing.T) { @@ -63,6 +56,24 @@ func TestSecRuleSet_AllowList(t *testing.T) { srs1_ := SecurityRuleSet{} dieIf(t, srs1, srs1_) }) + t.Run("annihilate: reduce to nothing v6", func(t *testing.T) { + srs0 := SecurityRuleSet{ + *MustParseSecurityRule("in:deny any"), + *MustParseSecurityRule("in:allow 192.168.2.0/23 any"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:8::/64 any"), + *MustParseSecurityRule("in:allow 0.0.0.0/0 tcp"), + *MustParseSecurityRule("in:allow ::/0 tcp"), + *MustParseSecurityRule("in:allow 0.0.0.0/0 icmp"), + *MustParseSecurityRule("in:allow ::/0 icmp"), + *MustParseSecurityRule("in:allow 8.0.0.0/0 tcp 3,4"), + *MustParseSecurityRule("in:allow fe::/0 tcp 3,4"), + *MustParseSecurityRule("in:allow 8.0.0.0/0 udp 3,4"), + *MustParseSecurityRule("in:allow fe::/0 udp 3,4"), + } + srs1 := srs0.AllowList() + srs1_ := SecurityRuleSet{} + dieIf(t, srs1, srs1_) + }) t.Run("net: allow;deny;allow", func(t *testing.T) { srs0 := SecurityRuleSet{ *MustParseSecurityRule("in:allow 192.168.2.0/25 any"), @@ -76,6 +87,37 @@ func TestSecRuleSet_AllowList(t *testing.T) { } dieIf(t, srs1, srs1_) }) + t.Run("net: allow;deny;allow-v6", func(t *testing.T) { + srs0 := SecurityRuleSet{ + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/65 any"), + *MustParseSecurityRule("in:deny fd:3ffe:3200:2::/64 any"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/63 any"), + } + srs1 := srs0.AllowList() + srs1_ := SecurityRuleSet{ + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/65 any"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:3::/64 any"), + } + dieIf(t, srs1, srs1_) + }) + t.Run("net: allow;deny;allow-v4v6", func(t *testing.T) { + srs0 := SecurityRuleSet{ + *MustParseSecurityRule("in:allow 192.168.2.0/25 any"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/65 any"), + *MustParseSecurityRule("in:deny 192.168.2.0/24 any"), + *MustParseSecurityRule("in:deny fd:3ffe:3200:2::/64 any"), + *MustParseSecurityRule("in:allow 192.168.2.0/23 any"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/63 any"), + } + srs1 := srs0.AllowList() + srs1_ := SecurityRuleSet{ + *MustParseSecurityRule("in:allow 192.168.2.0/25 any"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/65 any"), + *MustParseSecurityRule("in:allow 192.168.3.0/24 any"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:3::/64 any"), + } + dieIf(t, srs1, srs1_) + }) t.Run("net: tick out singles", func(t *testing.T) { srs0 := SecurityRuleSet{ @@ -152,42 +194,54 @@ func TestSecRuleSet_AllowList(t *testing.T) { *MustParseSecurityRule("in:allow 192.168.2.0/23 tcp 1025-65535"), *MustParseSecurityRule("in:allow 192.168.2.0/23 udp 1-21"), *MustParseSecurityRule("in:allow 192.168.2.0/23 udp 1025-65535"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/63 udp 1025-65535"), } dieIf(t, srs1, srs1_) }) t.Run("ports: cannot merge", func(t *testing.T) { srs0 := SecurityRuleSet{ *MustParseSecurityRule("in:allow 192.168.2.0/24 tcp 22,80"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/64 tcp 22,80"), *MustParseSecurityRule("in:allow 192.168.3.0/24 tcp 8080,3389"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:3::/24 tcp 8080,3389"), } srs1 := srs0.AllowList() srs1_ := SecurityRuleSet{ *MustParseSecurityRule("in:allow 192.168.2.0/24 tcp 22,80"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/64 tcp 22,80"), *MustParseSecurityRule("in:allow 192.168.3.0/24 tcp 3389,8080"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:3::/64 tcp 3389,8080"), } dieIf(t, srs1, srs1_) }) t.Run("ports: merge", func(t *testing.T) { srs0 := SecurityRuleSet{ *MustParseSecurityRule("in:allow 192.168.2.0/24 tcp 22,80"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/64 tcp 22,80"), *MustParseSecurityRule("in:allow 192.168.2.0/24 tcp 8080,3389"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/64 tcp 8080,3389"), } srs1 := srs0.AllowList() srs1_ := SecurityRuleSet{ *MustParseSecurityRule("in:allow 192.168.2.0/24 tcp 22,80,3389,8080"), + *MustParseSecurityRule("in:allow fd:3ffe:3200:2::/64 tcp 22,80,3389,8080"), } dieIf(t, srs1, srs1_) }) t.Run("cidr: merge", func(t *testing.T) { srs0 := SecurityRuleSet{ *MustParseSecurityRule("out:deny 192.168.222.2 tcp 3389"), + *MustParseSecurityRule("out:deny fd:3ffe:3200:222::2 tcp 3389"), *MustParseSecurityRule("out:allow any"), } srs1 := srs0.AllowList() srs1_ := SecurityRuleSet{ *MustParseSecurityRule("out:allow 0.0.0.0/1 tcp"), + *MustParseSecurityRule("out:allow ::/1 tcp"), *MustParseSecurityRule("out:allow 128.0.0.0/2 tcp"), + *MustParseSecurityRule("out:allow 8000::/2 tcp"), *MustParseSecurityRule("out:allow 192.0.0.0/9 tcp"), + *MustParseSecurityRule("out:allow c000::/9 tcp"), *MustParseSecurityRule("out:allow 192.128.0.0/11 tcp"), *MustParseSecurityRule("out:allow 192.160.0.0/13 tcp"), *MustParseSecurityRule("out:allow 192.168.0.0/17 tcp"), @@ -315,4 +369,4 @@ func TestSecRuleSet_AllowList(t *testing.T) { } dieIfNotEquals(t, srs0, srs1) }) -} +}*/ diff --git a/util/sortutils/doc.go b/util/sortutils/doc.go new file mode 100644 index 0000000..8f6918b --- /dev/null +++ b/util/sortutils/doc.go @@ -0,0 +1,15 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortutils // import "yunion.io/x/pkg/util/sortutils" diff --git a/util/sortutils/sortutils.go b/util/sortutils/sortutils.go new file mode 100644 index 0000000..abfbbae --- /dev/null +++ b/util/sortutils/sortutils.go @@ -0,0 +1,29 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortutils + +import "strings" + +type CompareResult int + +const ( + Equal = CompareResult(0) + Less = CompareResult(-1) + More = CompareResult(1) +) + +func CompareString(str1, str2 string) CompareResult { + return CompareResult(strings.Compare(str1, str2)) +}