diff --git a/shadowaead_2022/protocol.go b/shadowaead_2022/protocol.go index fcb2f96..70a9763 100644 --- a/shadowaead_2022/protocol.go +++ b/shadowaead_2022/protocol.go @@ -566,11 +566,11 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { } if sessionId == c.session.remoteSessionId { - if !c.session.filter.ValidateCounter(packetId, math.MaxUint64) { + if !c.session.filter.ValidateCounter(packetId) { return M.Socksaddr{}, ErrPacketIdNotUnique } } else if sessionId == c.session.lastRemoteSessionId { - if !c.session.lastFilter.ValidateCounter(packetId, math.MaxUint64) { + if !c.session.lastFilter.ValidateCounter(packetId) { return M.Socksaddr{}, ErrPacketIdNotUnique } remoteCipher = c.session.lastRemoteCipher @@ -589,7 +589,7 @@ func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { } c.session.remoteSessionId = sessionId c.session.remoteCipher = remoteCipher - c.session.filter.ValidateCounter(packetId, math.MaxUint64) + c.session.filter.ValidateCounter(packetId) } var clientSessionId uint64 diff --git a/shadowaead_2022/service.go b/shadowaead_2022/service.go index e365bf3..ef20bd4 100644 --- a/shadowaead_2022/service.go +++ b/shadowaead_2022/service.go @@ -361,7 +361,7 @@ returnErr: return err process: - if !session.filter.ValidateCounter(packetId, math.MaxUint64) { + if !session.filter.ValidateCounter(packetId) { err = ErrPacketIdNotUnique goto returnErr } diff --git a/shadowaead_2022/service_multi.go b/shadowaead_2022/service_multi.go index 0d566e2..549d2ca 100644 --- a/shadowaead_2022/service_multi.go +++ b/shadowaead_2022/service_multi.go @@ -299,7 +299,7 @@ returnErr: return err process: - if !session.filter.ValidateCounter(packetId, math.MaxUint64) { + if !session.filter.ValidateCounter(packetId) { err = ErrPacketIdNotUnique goto returnErr } diff --git a/shadowaead_2022/wg_replay/replay.go b/shadowaead_2022/wg_replay/replay.go index 19e93ce..187c2c2 100644 --- a/shadowaead_2022/wg_replay/replay.go +++ b/shadowaead_2022/wg_replay/replay.go @@ -26,18 +26,9 @@ type Filter struct { ring [ringBlocks]block } -// Reset resets the filter to empty state. -func (f *Filter) Reset() { - f.last = 0 - f.ring[0] = 0 -} - // ValidateCounter checks if the counter should be accepted. // Overlimit counters (>= limit) are always rejected. -func (f *Filter) ValidateCounter(counter, limit uint64) bool { - if counter >= limit { - return false - } +func (f *Filter) ValidateCounter(counter uint64) bool { indexBlock := counter >> blockBitLog if counter > f.last { // move window forward current := f.last >> blockBitLog