From 4165c0453e78c7a44dbebfeab89daa48be585c5d Mon Sep 17 00:00:00 2001 From: Peter Adams Date: Mon, 13 Jul 2020 14:27:23 +0100 Subject: [PATCH 1/6] Add methods to create and parse cmsghdr structured data. Motivation: cmsghdrs can be used to send an receive extra data on UDP packets. For example ECN data. Modifications: Map in Linux and Darwin versions of cmsghdr macros as functions. Create strctures for holding a received collection; parsing ecn data from a received collection and building a collection suitable for sending. Result: Functions to manipulate cmsghdr and data exist. --- Sources/CNIODarwin/include/CNIODarwin.h | 10 + Sources/CNIODarwin/shim.c | 30 +++ Sources/CNIOLinux/include/CNIOLinux.h | 9 + Sources/CNIOLinux/shim.c | 30 +++ Sources/NIO/ControlMessage.swift | 224 ++++++++++++++++++ Sources/NIO/System.swift | 67 +++++- Tests/LinuxMain.swift | 1 + .../NIOTests/ControlMessageTests+XCTest.swift | 36 +++ Tests/NIOTests/ControlMessageTests.swift | 70 ++++++ Tests/NIOTests/SystemTest+XCTest.swift | 4 + Tests/NIOTests/SystemTest.swift | 111 +++++++++ 11 files changed, 591 insertions(+), 1 deletion(-) create mode 100644 Sources/NIO/ControlMessage.swift create mode 100644 Tests/NIOTests/ControlMessageTests+XCTest.swift create mode 100644 Tests/NIOTests/ControlMessageTests.swift diff --git a/Sources/CNIODarwin/include/CNIODarwin.h b/Sources/CNIODarwin/include/CNIODarwin.h index d86fe809d7..e096763956 100644 --- a/Sources/CNIODarwin/include/CNIODarwin.h +++ b/Sources/CNIODarwin/include/CNIODarwin.h @@ -18,6 +18,8 @@ #include #include +#include + // Darwin platforms do not have a sendmmsg implementation available to them. This C module // provides a shim that implements sendmmsg on top of sendmsg. It also provides a shim for // recvmmsg, but does not actually implement that shim, instantly throwing errors if called. @@ -34,5 +36,13 @@ typedef struct { int CNIODarwin_sendmmsg(int sockfd, CNIODarwin_mmsghdr *msgvec, unsigned int vlen, int flags); int CNIODarwin_recvmmsg(int sockfd, CNIODarwin_mmsghdr *msgvec, unsigned int vlen, int flags, struct timespec *timeout); +// cmsghdr handling +struct cmsghdr *CNIODarwin_CMSG_FIRSTHDR(const struct msghdr *); +struct cmsghdr *CNIODarwin_CMSG_NXTHDR(const struct msghdr *, const struct cmsghdr *); +const void *CNIODarwin_CMSG_DATA(const struct cmsghdr *); +void *CNIODarwin_CMSG_DATA_MUTABLE(struct cmsghdr *); +size_t CNIODarwin_CMSG_LEN(size_t); +size_t CNIODarwin_CMSG_SPACE(size_t); + #endif // __APPLE__ #endif // C_NIO_DARWIN_H diff --git a/Sources/CNIODarwin/shim.c b/Sources/CNIODarwin/shim.c index 28639e71ac..a77a380e59 100644 --- a/Sources/CNIODarwin/shim.c +++ b/Sources/CNIODarwin/shim.c @@ -18,6 +18,7 @@ #include #include #include +#include int CNIODarwin_sendmmsg(int sockfd, CNIODarwin_mmsghdr *msgvec, unsigned int vlen, int flags) { // Some quick error checking. If vlen can't fit into int, we bail. @@ -50,4 +51,33 @@ int CNIODarwin_recvmmsg(int sockfd, CNIODarwin_mmsghdr *msgvec, unsigned int vle errx(EX_SOFTWARE, "recvmmsg shim not implemented on Darwin platforms\n"); } +struct cmsghdr *CNIODarwin_CMSG_FIRSTHDR(const struct msghdr *mhdr) { + assert(mhdr != NULL); + return CMSG_FIRSTHDR(mhdr); +} + +struct cmsghdr *CNIODarwin_CMSG_NXTHDR(const struct msghdr *mhdr, const struct cmsghdr *cmsg) { + assert(mhdr != NULL); + assert(cmsg != NULL); // Not required by Darwin but Linux needs this so we should match. + return CMSG_NXTHDR(mhdr, cmsg); +} + +const void *CNIODarwin_CMSG_DATA(const struct cmsghdr *cmsg) { + assert(cmsg != NULL); + return CMSG_DATA(cmsg); +} + +void *CNIODarwin_CMSG_DATA_MUTABLE(struct cmsghdr *cmsg) { + assert(cmsg != NULL); + return CMSG_DATA(cmsg); +} + +size_t CNIODarwin_CMSG_LEN(size_t payloadSizeBytes) { + return CMSG_LEN(payloadSizeBytes); +} + +size_t CNIODarwin_CMSG_SPACE(size_t payloadSizeBytes) { + return CMSG_SPACE(payloadSizeBytes); +} + #endif // __APPLE__ diff --git a/Sources/CNIOLinux/include/CNIOLinux.h b/Sources/CNIOLinux/include/CNIOLinux.h index 17293079f5..7a4d02293f 100644 --- a/Sources/CNIOLinux/include/CNIOLinux.h +++ b/Sources/CNIOLinux/include/CNIOLinux.h @@ -25,6 +25,7 @@ #include #include #include +#include // Some explanation is required here. // @@ -64,5 +65,13 @@ void CNIOLinux_CPU_SET(int cpu, cpu_set_t *set); void CNIOLinux_CPU_ZERO(cpu_set_t *set); int CNIOLinux_CPU_ISSET(int cpu, cpu_set_t *set); int CNIOLinux_CPU_SETSIZE(); + +// cmsghdr handling +struct cmsghdr *CNIOLinux_CMSG_FIRSTHDR(const struct msghdr *); +struct cmsghdr *CNIOLinux_CMSG_NXTHDR(struct msghdr *, struct cmsghdr *); +const void *CNIOLinux_CMSG_DATA(const struct cmsghdr *); +void *CNIOLinux_CMSG_DATA_MUTABLE(struct cmsghdr *); +size_t CNIOLinux_CMSG_LEN(size_t); +size_t CNIOLinux_CMSG_SPACE(size_t); #endif #endif diff --git a/Sources/CNIOLinux/shim.c b/Sources/CNIOLinux/shim.c index 353050ca87..acd256a42f 100644 --- a/Sources/CNIOLinux/shim.c +++ b/Sources/CNIOLinux/shim.c @@ -24,6 +24,7 @@ void CNIOLinux_i_do_nothing_just_working_around_a_darwin_toolchain_bug(void) {} #include #include #include +#include _Static_assert(sizeof(CNIOLinux_mmsghdr) == sizeof(struct mmsghdr), "sizes of CNIOLinux_mmsghdr and struct mmsghdr differ"); @@ -112,4 +113,33 @@ int CNIOLinux_CPU_ISSET(int cpu, cpu_set_t *set) { int CNIOLinux_CPU_SETSIZE() { return CPU_SETSIZE; } + +struct cmsghdr *CNIOLinux_CMSG_FIRSTHDR(const struct msghdr *mhdr) { + assert(mhdr != NULL); + return CMSG_FIRSTHDR(mhdr); +} + +struct cmsghdr *CNIOLinux_CMSG_NXTHDR(struct msghdr *mhdr, struct cmsghdr *cmsg) { + assert(mhdr != NULL); + assert(cmsg != NULL); + return CMSG_NXTHDR(mhdr, cmsg); +} + +const void *CNIOLinux_CMSG_DATA(const struct cmsghdr *cmsg) { + assert(cmsg != NULL); + return CMSG_DATA(cmsg); +} + +void *CNIOLinux_CMSG_DATA_MUTABLE(struct cmsghdr *cmsg) { + assert(cmsg != NULL); + return CMSG_DATA(cmsg); +} + +size_t CNIOLinux_CMSG_LEN(size_t payloadSizeBytes) { + return CMSG_LEN(payloadSizeBytes); +} + +size_t CNIOLinux_CMSG_SPACE(size_t payloadSizeBytes) { + return CMSG_SPACE(payloadSizeBytes); +} #endif diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift new file mode 100644 index 0000000000..4b5b8c2346 --- /dev/null +++ b/Sources/NIO/ControlMessage.swift @@ -0,0 +1,224 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import CNIODarwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CNIOLinux +#endif + +/// Representation of a `cmsghdr` and associated data. +/// Unsafe as captures pointers and must not escape the scope where those pointers are valid. +struct UnsafeControlMessage { + var level: Int32 + var type: Int32 + var data: UnsafeRawBufferPointer? +} + +/// Collection representation of `cmsghdr` structures and associated data from `recvmsg` +/// Unsafe as captures pointers held in msghdr structure which must not escape scope of validity. +struct UnsafeControlMessageCollection: Collection { + typealias Index = ControlMessageIndex + typealias Element = UnsafeControlMessage + + struct ControlMessageIndex: Equatable, Comparable { + fileprivate var cmsgPointer: UnsafeMutablePointer? + + static func < (lhs: UnsafeControlMessageCollection.ControlMessageIndex, + rhs: UnsafeControlMessageCollection.ControlMessageIndex) -> Bool { + // Nil must be high as it represents the end of the collection. + if let lhsPointer = lhs.cmsgPointer { + if let rhsPointer = rhs.cmsgPointer { + return lhsPointer < rhsPointer + } + return true + } + return false + } + + fileprivate init(cmsgPointer: UnsafeMutablePointer?) { + self.cmsgPointer = cmsgPointer + } + } + + var startIndex: Index { + var messageHeader = self.messageHeader + return withUnsafePointer(to: &messageHeader) { messageHeaderPtr in + let firstCMsg = Posix.cmsgFirstHeader(inside: messageHeaderPtr) + return ControlMessageIndex(cmsgPointer: firstCMsg) + } + } + + let endIndex = ControlMessageIndex(cmsgPointer: nil) + + func index(after: Index) -> Index { + var msgHdr = messageHeader + return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in + return ControlMessageIndex(cmsgPointer: Posix.cmsgNextHeader(inside: messageHeaderPtr, + from: after.cmsgPointer)) + } + } + + public subscript(position: Index) -> Element { + let cmsg = position.cmsgPointer! + return UnsafeControlMessage(level: cmsg.pointee.cmsg_level, + type: cmsg.pointee.cmsg_type, + data: Posix.cmsgData(for: cmsg)) + } + + private var messageHeader: msghdr + + init(messageHeader: msghdr) { + self.messageHeader = messageHeader + } +} + +/// Extract information from a collection of control messages. +struct ControlMessageReceiver { + var ecnValue: NIOExplicitCongestionNotificationState = .transportNotCapable // Default + + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + private static let ipv4TosType = IP_RECVTOS + #else + private static let ipv4TosType = IP_TOS // Linux + #endif + + mutating func receiveMessage(_ controlMessage: UnsafeControlMessage) { + if controlMessage.level == IPPROTO_IP && controlMessage.type == ControlMessageReceiver.ipv4TosType { + if let data = controlMessage.data { + assert(data.count == 1) + precondition(data.count >= 1) + let readValue: Int32 = .init(data[0]) + self.ecnValue = ControlMessageReceiver.parseEcn(receivedValue: readValue) + } + } else if controlMessage.level == IPPROTO_IPV6 && controlMessage.type == IPV6_TCLASS { + if let data = controlMessage.data { + assert(data.count == 4) + precondition(data.count >= 4) + var readValue: Int32 = 0 + withUnsafeMutableBytes(of: &readValue) { valuePtr in + valuePtr.copyMemory(from: data) + } + self.ecnValue = ControlMessageReceiver.parseEcn(receivedValue: readValue) + } + } + } + + private static func parseEcn(receivedValue: Int32) -> NIOExplicitCongestionNotificationState { + switch receivedValue & IPTOS_ECN_MASK { + case IPTOS_ECN_ECT1: + return .transportCapableFlag1 + case IPTOS_ECN_ECT0: + return .transportCapableFlag0 + case IPTOS_ECN_CE: + return .congestionExperienced + default: + return .transportNotCapable + } + } +} + +extension NIOExplicitCongestionNotificationState { + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + fileprivate static let notCapableValue = IPTOS_ECN_NOTECT + #else + fileprivate static let notCapableValue = IPTOS_ECN_NOT_ECT // Linux + #endif + + func asCInt() -> CInt { + switch self { + case .transportNotCapable: + return .init(NIOExplicitCongestionNotificationState.notCapableValue) + case .transportCapableFlag0: + return .init(IPTOS_ECN_ECT0) + case .transportCapableFlag1: + return .init(IPTOS_ECN_ECT1) + case .congestionExperienced: + return .init(IPTOS_ECN_CE) + } + } +} + +struct UnsafeOutboundControlBytes { + private var controlBytes: UnsafeMutableRawBufferPointer + private var writePosition: size_t = 0 + + /// Control bytes are captured - this structure must not have a lifetime exceeded them. + init(controlBytes: UnsafeMutableRawBufferPointer) { + self.controlBytes = controlBytes + } + + mutating func appendControlMessage(level: Int32, + type: Int32, + payload: PayloadType) { + let writableBuffer = UnsafeMutableRawBufferPointer(rebasing: self.controlBytes[writePosition...]) + + let requiredSize = Posix.cmsgSpace(payloadSize: MemoryLayout.stride(ofValue: payload)) + precondition(writableBuffer.count >= requiredSize, "Insufficient size for cmsghdr and data") + + let bufferBase = writableBuffer.baseAddress! + let cmsghdrPtr = bufferBase.bindMemory(to: cmsghdr.self, capacity: 1) + cmsghdrPtr.pointee.cmsg_level = level + cmsghdrPtr.pointee.cmsg_type = type + cmsghdrPtr.pointee.cmsg_len = .init(Posix.cmsgLen(payloadSize: MemoryLayout.size(ofValue: payload))) + + let dataPointer = Posix.cmsgData(for: .some(cmsghdrPtr)) + precondition(dataPointer!.count >= MemoryLayout.stride) + let dataPointerBase = dataPointer!.baseAddress! + let dataPointerTyped = dataPointerBase.bindMemory(to: PayloadType.self, capacity: 1) + dataPointerTyped.pointee = payload + + self.writePosition += requiredSize + } + + /// The result is only valid while this is valid. + var validControlBytes: UnsafeMutableRawBufferPointer { + if writePosition == 0 { + return UnsafeMutableRawBufferPointer(start: nil, count: 0) + } + return UnsafeMutableRawBufferPointer(rebasing: self.controlBytes[0 ..< self.writePosition]) + } + +} + +extension UnsafeOutboundControlBytes { + /// - address: Either local or remote will do, we just use it for extracting the right protocol. + internal mutating func appendExplicitCongestionState(metadata: AddressedEnvelope.Metadata?, + address: SocketAddress?) { + if let metadata = metadata { + switch address { + case .some(.v4): + self.appendControlMessage(level: .init(IPPROTO_IP), + type: IP_TOS, + payload: metadata.ecnState.asCInt()) + case .some(.v6): + self.appendControlMessage(level: .init(IPPROTO_IPV6), + type: IPV6_TCLASS, + payload: metadata.ecnState.asCInt()) + default: + // Nothing to do - if we get here the user is probably making a mistake. + break + } + } + } +} + +extension AddressedEnvelope.Metadata { + /// It's assumed the caller has checked that congestion information is required before calling. + internal init(from controlMessagesReceived: UnsafeControlMessageCollection) { + var controlMessageReceiver = ControlMessageReceiver() + controlMessagesReceived.forEach { controlMessage in controlMessageReceiver.receiveMessage(controlMessage) } + self.init(ecnState: controlMessageReceiver.ecnValue) + } +} diff --git a/Sources/NIO/System.swift b/Sources/NIO/System.swift index 337a914ffb..34b1fdd5c0 100644 --- a/Sources/NIO/System.swift +++ b/Sources/NIO/System.swift @@ -97,11 +97,30 @@ private let sysSocketpair: @convention(c) (CInt, CInt, CInt, UnsafeMutablePointe private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer) -> CInt = fstat private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIOLinux_sendmmsg private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIOLinux_recvmmsg +private let sysCmsgFirstHdr: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = + CNIOLinux_CMSG_FIRSTHDR +private let sysCmsgNxtHdr: @convention(c) (UnsafeMutablePointer?, UnsafeMutablePointer?) -> + UnsafeMutablePointer? = CNIOLinux_CMSG_NXTHDR +private let sysCmsgData: @convention(c) (UnsafePointer?) -> UnsafeRawPointer? = CNIOLinux_CMSG_DATA +private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutableRawPointer? = + CNIOLinux_CMSG_DATA_MUTABLE +private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_SPACE +private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_LEN #elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS) private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer?) -> CInt = fstat private let sysKevent = kevent private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIODarwin_sendmmsg private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIODarwin_recvmmsg +private let sysCmsgFirstHdr: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = + CNIODarwin_CMSG_FIRSTHDR +private let sysCmsgNxtHdr: @convention(c) (UnsafePointer?, UnsafePointer?) -> + UnsafeMutablePointer? = CNIODarwin_CMSG_NXTHDR +private let sysCmsgData: @convention(c) (UnsafePointer?) -> UnsafeRawPointer? = + CNIODarwin_CMSG_DATA +private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutableRawPointer? = + CNIODarwin_CMSG_DATA_MUTABLE +private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_SPACE +private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_LEN #elseif os(Windows) private let sysSendMmsg: @convention(c) (NIOBSDSocket.Handle, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIOWindows_sendmmsg private let sysRecvMmsg: @convention(c) (NIOBSDSocket.Handle, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIOWindows_recvmmsg @@ -373,7 +392,9 @@ internal enum Posix { } @inline(never) - public static func sendmsg(descriptor: CInt, msgHdr: UnsafePointer, flags: CInt) throws -> IOResult { + public static func sendmsg(descriptor: CInt, + msgHdr: UnsafePointer, + flags: CInt) throws -> IOResult { return try syscall(blocking: true) { sysSendMsg(descriptor, msgHdr, flags) } @@ -535,6 +556,50 @@ internal extension Posix { } } +internal extension Posix { + static func cmsgFirstHeader(inside msghdr: UnsafePointer?) -> UnsafeMutablePointer? { + return sysCmsgFirstHdr(msghdr) + } + + static func cmsgNextHeader(inside msghdr: UnsafeMutablePointer?, from: UnsafeMutablePointer?) + -> UnsafeMutablePointer? { + return sysCmsgNxtHdr(msghdr, from) + } + + static func cmsgData(for header: UnsafePointer?) -> UnsafeRawBufferPointer? { + let dataPointer = sysCmsgData(header) + if let header = header { + // Linux and Darwin use different types for cmsg_len. + let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) + let buffer = UnsafeRawBufferPointer(start: dataPointer, count: Int(length)) + return buffer + } else { + return nil + } + } + + static func cmsgData(for header: UnsafeMutablePointer?) -> UnsafeMutableRawBufferPointer? { + let dataPointer = sysCmsgDataMutable(header) + if let header = header { + // Linux and Darwin use different types for cmsg_len. + let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) + let buffer = UnsafeMutableRawBufferPointer(start: dataPointer, count: Int(length)) + return buffer + } else { + return nil + } + } + + static func cmsgLen(payloadSize: size_t) -> size_t { + return sysCmsgLen(payloadSize) + } + + static func cmsgSpace(payloadSize: size_t) -> size_t { + return sysCmsgSpace(payloadSize) + } + +} + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) internal enum KQueue { diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index e4fdcbfcc3..1b449b4d61 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -57,6 +57,7 @@ class LinuxMainRunnerImpl: LinuxMainRunner { testCase(ChannelTests.allTests), testCase(CircularBufferTests.allTests), testCase(CodableByteBufferTest.allTests), + testCase(ControlMessageTests.allTests), testCase(CustomChannelTests.allTests), testCase(DatagramChannelTests.allTests), testCase(EchoServerClientTest.allTests), diff --git a/Tests/NIOTests/ControlMessageTests+XCTest.swift b/Tests/NIOTests/ControlMessageTests+XCTest.swift new file mode 100644 index 0000000000..59063af217 --- /dev/null +++ b/Tests/NIOTests/ControlMessageTests+XCTest.swift @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// ControlMessageTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension ControlMessageTests { + + @available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings") + static var allTests : [(String, (ControlMessageTests) -> () throws -> Void)] { + return [ + ("testEmptyEncode", testEmptyEncode), + ("testEncodeDecode1", testEncodeDecode1), + ("testEncodeDecode2", testEncodeDecode2), + ] + } +} + diff --git a/Tests/NIOTests/ControlMessageTests.swift b/Tests/NIOTests/ControlMessageTests.swift new file mode 100644 index 0000000000..34607f97c2 --- /dev/null +++ b/Tests/NIOTests/ControlMessageTests.swift @@ -0,0 +1,70 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import XCTest +@testable import NIO + +fileprivate extension UnsafeControlMessageCollection { + init(controlBytes: UnsafeMutableRawBufferPointer) { + let msgHdr = msghdr(msg_name: nil, + msg_namelen: 0, + msg_iov: nil, + msg_iovlen: 0, + msg_control: controlBytes.baseAddress, + msg_controllen: .init(controlBytes.count), + msg_flags: 0) + self.init(messageHeader: msgHdr) + } +} + +class ControlMessageTests: XCTestCase { + var encoder: UnsafeOutboundControlBytes! + + override func setUp() { + let encoderBytes = UnsafeMutableRawBufferPointer.allocate(byteCount: 1000, + alignment: MemoryLayout.alignment) + self.encoder = UnsafeOutboundControlBytes(controlBytes: encoderBytes) + } + + func testEmptyEncode() { + XCTAssertEqual(self.encoder.validControlBytes.count, 0) + } + + func testEncodeDecode1() { + self.encoder.appendControlMessage(level: 1, type: 2, payload: 3) + let encodedBytes = self.encoder.validControlBytes + + let decoder = UnsafeControlMessageCollection(controlBytes: encodedBytes) + XCTAssertEqual(decoder.count, 1) + XCTAssertEqual(decoder.first!.level, 1) + XCTAssertEqual(decoder.first!.type, 2) + XCTAssertEqual(decoder.first!.data!.count, MemoryLayout.size) + } + + func testEncodeDecode2() { + self.encoder.appendControlMessage(level: 1, type: 2, payload: 3) + self.encoder.appendControlMessage(level: 4, type: 5, payload: 6) + let encodedBytes = self.encoder.validControlBytes + + let decoder = UnsafeControlMessageCollection(controlBytes: encodedBytes) + XCTAssertEqual(decoder.count, 2) + XCTAssertEqual(decoder.first!.level, 1) + XCTAssertEqual(decoder.first!.type, 2) + XCTAssertEqual(decoder.first!.data!.count, MemoryLayout.size) + XCTAssertEqual(decoder[decoder.index(after: decoder.startIndex)].level, 4) + XCTAssertEqual(decoder[decoder.index(after: decoder.startIndex)].type, 5) + XCTAssertEqual(decoder[decoder.index(after: decoder.startIndex)].data!.count, MemoryLayout.size) + + } +} diff --git a/Tests/NIOTests/SystemTest+XCTest.swift b/Tests/NIOTests/SystemTest+XCTest.swift index 56ae6aada1..053bdcf1fd 100644 --- a/Tests/NIOTests/SystemTest+XCTest.swift +++ b/Tests/NIOTests/SystemTest+XCTest.swift @@ -29,6 +29,10 @@ extension SystemTest { return [ ("testSystemCallWrapperPerformance", testSystemCallWrapperPerformance), ("testErrorsWorkCorrectly", testErrorsWorkCorrectly), + ("testCmsgFirstHeader", testCmsgFirstHeader), + ("testCMsgNextHeader", testCMsgNextHeader), + ("testCMsgData", testCMsgData), + ("testCMsgCollection", testCMsgCollection), ] } } diff --git a/Tests/NIOTests/SystemTest.swift b/Tests/NIOTests/SystemTest.swift index f20659c389..9ca334e678 100644 --- a/Tests/NIOTests/SystemTest.swift +++ b/Tests/NIOTests/SystemTest.swift @@ -46,4 +46,115 @@ class SystemTest: XCTestCase { return [readFD, writeFD] } } + + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + // Example twin data options captured on macOS + private static let cmsghdrExample: [UInt8] = [0x10, 0x00, 0x00, 0x00, // Length 16 including header + 0x00, 0x00, 0x00, 0x00, // IPPROTO_IP + 0x07, 0x00, 0x00, 0x00, // IP_RECVDSTADDR + 0x7F, 0x00, 0x00, 0x01, // 127.0.0.1 + 0x0D, 0x00, 0x00, 0x00, // Length 13 including header + 0x00, 0x00, 0x00, 0x00, // IPPROTO_IP + 0x1B, 0x00, 0x00, 0x00, // IP_RECVTOS + 0x01, 0x00, 0x00, 0x00] // ECT-1 (1 byte) + private static let cmsghdr_secondStartPosition = 16 + private static let cmsghdr_firstDataStart = 12 + private static let cmsghdr_firstDataCount = 4 + private static let cmsghdr_secondDataCount = 1 + private static let cmsghdr_firstType = IP_RECVDSTADDR + private static let cmsghdr_secondType = IP_RECVTOS + #elseif os(Linux) + // Example twin data options captured on Linux + private static let cmsghdrExample: [UInt8] = [ + 0x1C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Length 28 including header. + 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, // IPPROTO_IP, IP_PKTINFO + 0x01, 0x00, 0x00, 0x00, 0x7F, 0x00, 0x00, 0x01, // interface number, 127.0.0.1 (local) + 0x7F, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, // 127.0.0.1 (destination), 4 bytes to align length + 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Length 17 + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, // IPPROTO_IP, IP_TOS + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 // ECT-1 (1 byte) + ] + private static let cmsghdr_secondStartPosition = 32 + private static let cmsghdr_firstDataStart = 16 + private static let cmsghdr_firstDataCount = 12 + private static let cmsghdr_secondDataCount = 1 + private static let cmsghdr_firstType = IP_PKTINFO + private static let cmsghdr_secondType = IP_TOS + #endif + + func testCmsgFirstHeader() { + var exampleCmsgHrd = SystemTest.cmsghdrExample + exampleCmsgHrd.withUnsafeMutableBytes { pCmsgHdr in + var msgHdr = msghdr() + msgHdr.msg_control = pCmsgHdr.baseAddress + msgHdr.msg_controllen = .init(pCmsgHdr.count) + + withUnsafePointer(to: msgHdr) { pMsgHdr in + let result = Posix.cmsgFirstHeader(inside: pMsgHdr) + XCTAssertEqual(pCmsgHdr.baseAddress, result) + } + } + } + + func testCMsgNextHeader() { + var exampleCmsgHrd = SystemTest.cmsghdrExample + exampleCmsgHrd.withUnsafeMutableBytes { pCmsgHdr in + var msgHdr = msghdr() + msgHdr.msg_control = pCmsgHdr.baseAddress + msgHdr.msg_controllen = .init(pCmsgHdr.count) + + withUnsafeMutablePointer(to: &msgHdr) { pMsgHdr in + let first = Posix.cmsgFirstHeader(inside: pMsgHdr) + let second = Posix.cmsgNextHeader(inside: pMsgHdr, from: first) + let expectedSecondSlice = UnsafeMutableRawBufferPointer( + rebasing: pCmsgHdr[SystemTest.cmsghdr_secondStartPosition...]) + XCTAssertEqual(expectedSecondSlice.baseAddress, second) + let third = Posix.cmsgNextHeader(inside: pMsgHdr, from: second) + XCTAssertEqual(third, nil) + } + } + } + + func testCMsgData() { + var exampleCmsgHrd = SystemTest.cmsghdrExample + exampleCmsgHrd.withUnsafeMutableBytes { pCmsgHdr in + var msgHdr = msghdr() + msgHdr.msg_control = pCmsgHdr.baseAddress + msgHdr.msg_controllen = .init(pCmsgHdr.count) + + withUnsafePointer(to: msgHdr) { pMsgHdr in + let first = Posix.cmsgFirstHeader(inside: pMsgHdr) + let firstData = Posix.cmsgData(for: first) + let expecedFirstData = UnsafeRawBufferPointer( + rebasing: pCmsgHdr[SystemTest.cmsghdr_firstDataStart..<( + SystemTest.cmsghdr_firstDataStart + SystemTest.cmsghdr_firstDataCount)]) + XCTAssertEqual(expecedFirstData.baseAddress, firstData?.baseAddress) + XCTAssertEqual(expecedFirstData.count, firstData?.count) + } + } + } + + func testCMsgCollection() { + var exampleCmsgHrd = SystemTest.cmsghdrExample + exampleCmsgHrd.withUnsafeMutableBytes { pCmsgHdr in + var msgHdr = msghdr() + msgHdr.msg_control = pCmsgHdr.baseAddress + msgHdr.msg_controllen = .init(pCmsgHdr.count) + let collection = UnsafeControlMessageCollection(messageHeader: msgHdr) + var msgNum = 0 + for cmsg in collection { + if msgNum == 0 { + XCTAssertEqual(cmsg.level, .init(IPPROTO_IP)) + XCTAssertEqual(cmsg.type, .init(SystemTest.cmsghdr_firstType)) + XCTAssertEqual(cmsg.data?.count, SystemTest.cmsghdr_firstDataCount) + } else if msgNum == 1 { + XCTAssertEqual(cmsg.level, .init(IPPROTO_IP)) + XCTAssertEqual(cmsg.type, .init(SystemTest.cmsghdr_secondType)) + XCTAssertEqual(cmsg.data?.count, SystemTest.cmsghdr_secondDataCount) + } + msgNum += 1 + } + XCTAssertEqual(msgNum, 2) + } + } } From 4be601f03e53631ea367e406d5d11b2d149be0f2 Mon Sep 17 00:00:00 2001 From: Peter Adams Date: Mon, 13 Jul 2020 18:25:31 +0100 Subject: [PATCH 2/6] Cory review comments. --- Sources/NIO/ControlMessage.swift | 163 +++++++++++++---------- Sources/NIO/System.swift | 81 +++++------ Tests/NIOTests/ControlMessageTests.swift | 47 +++++-- Tests/NIOTests/SystemTest.swift | 21 +-- 4 files changed, 172 insertions(+), 140 deletions(-) diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift index 4b5b8c2346..603d47152d 100644 --- a/Sources/NIO/ControlMessage.swift +++ b/Sources/NIO/ControlMessage.swift @@ -21,30 +21,39 @@ import CNIOLinux /// Representation of a `cmsghdr` and associated data. /// Unsafe as captures pointers and must not escape the scope where those pointers are valid. struct UnsafeControlMessage { - var level: Int32 - var type: Int32 + var level: CInt + var type: CInt var data: UnsafeRawBufferPointer? } /// Collection representation of `cmsghdr` structures and associated data from `recvmsg` /// Unsafe as captures pointers held in msghdr structure which must not escape scope of validity. -struct UnsafeControlMessageCollection: Collection { - typealias Index = ControlMessageIndex +struct UnsafeControlMessageCollection { + private var messageHeader: msghdr + + init(messageHeader: msghdr) { + self.messageHeader = messageHeader + } +} + +// Add the `Collection` functionality to UnsafeControlMessageCollection. +extension UnsafeControlMessageCollection: Collection { typealias Element = UnsafeControlMessage - struct ControlMessageIndex: Equatable, Comparable { + struct Index: Equatable, Comparable { fileprivate var cmsgPointer: UnsafeMutablePointer? - static func < (lhs: UnsafeControlMessageCollection.ControlMessageIndex, - rhs: UnsafeControlMessageCollection.ControlMessageIndex) -> Bool { - // Nil must be high as it represents the end of the collection. - if let lhsPointer = lhs.cmsgPointer { - if let rhsPointer = rhs.cmsgPointer { - return lhsPointer < rhsPointer - } + static func < (lhs: UnsafeControlMessageCollection.Index, + rhs: UnsafeControlMessageCollection.Index) -> Bool { + // nil is high, as that's the end of the collection. + switch (lhs.cmsgPointer, rhs.cmsgPointer) { + case (.some(let lhs), .some(let rhs)): + return lhs < rhs + case (.some, .none): return true + case (.none, .some), (.none, .none): + return false } - return false } fileprivate init(cmsgPointer: UnsafeMutablePointer?) { @@ -56,17 +65,18 @@ struct UnsafeControlMessageCollection: Collection { var messageHeader = self.messageHeader return withUnsafePointer(to: &messageHeader) { messageHeaderPtr in let firstCMsg = Posix.cmsgFirstHeader(inside: messageHeaderPtr) - return ControlMessageIndex(cmsgPointer: firstCMsg) + return Index(cmsgPointer: firstCMsg) } } - let endIndex = ControlMessageIndex(cmsgPointer: nil) + var endIndex: Index { return Index(cmsgPointer: nil) } func index(after: Index) -> Index { + precondition(after.cmsgPointer != nil) var msgHdr = messageHeader return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in - return ControlMessageIndex(cmsgPointer: Posix.cmsgNextHeader(inside: messageHeaderPtr, - from: after.cmsgPointer)) + return Index(cmsgPointer: Posix.cmsgNextHeader(inside: messageHeaderPtr, + after: after.cmsgPointer!)) } } @@ -76,12 +86,6 @@ struct UnsafeControlMessageCollection: Collection { type: cmsg.pointee.cmsg_type, data: Posix.cmsgData(for: cmsg)) } - - private var messageHeader: msghdr - - init(messageHeader: msghdr) { - self.messageHeader = messageHeader - } } /// Extract information from a collection of control messages. @@ -93,91 +97,105 @@ struct ControlMessageReceiver { #else private static let ipv4TosType = IP_TOS // Linux #endif + + static func readCInt(data: UnsafeRawBufferPointer) -> CInt { + assert(data.count == MemoryLayout.size) + precondition(data.count >= MemoryLayout.size) + var readValue = CInt(0) + withUnsafeMutableBytes(of: &readValue) { valuePtr in + valuePtr.copyMemory(from: data) + } + return readValue + } mutating func receiveMessage(_ controlMessage: UnsafeControlMessage) { if controlMessage.level == IPPROTO_IP && controlMessage.type == ControlMessageReceiver.ipv4TosType { if let data = controlMessage.data { assert(data.count == 1) precondition(data.count >= 1) - let readValue: Int32 = .init(data[0]) - self.ecnValue = ControlMessageReceiver.parseEcn(receivedValue: readValue) + let readValue = CInt(data[0]) + self.ecnValue = .init(receivedValue: readValue) } } else if controlMessage.level == IPPROTO_IPV6 && controlMessage.type == IPV6_TCLASS { if let data = controlMessage.data { - assert(data.count == 4) - precondition(data.count >= 4) - var readValue: Int32 = 0 - withUnsafeMutableBytes(of: &readValue) { valuePtr in - valuePtr.copyMemory(from: data) - } - self.ecnValue = ControlMessageReceiver.parseEcn(receivedValue: readValue) + let readValue = ControlMessageReceiver.readCInt(data: data) + self.ecnValue = .init(receivedValue: readValue) } } } +} - private static func parseEcn(receivedValue: Int32) -> NIOExplicitCongestionNotificationState { +extension NIOExplicitCongestionNotificationState { + /// Initialise a NIOExplicitCongestionNotificationState from a value received via either TCLASS or TOS cmsg. + init(receivedValue: CInt) { switch receivedValue & IPTOS_ECN_MASK { case IPTOS_ECN_ECT1: - return .transportCapableFlag1 + self = .transportCapableFlag1 case IPTOS_ECN_ECT0: - return .transportCapableFlag0 + self = .transportCapableFlag0 case IPTOS_ECN_CE: - return .congestionExperienced + self = .congestionExperienced default: - return .transportNotCapable + self = .transportNotCapable } } } -extension NIOExplicitCongestionNotificationState { +extension CInt { #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) fileprivate static let notCapableValue = IPTOS_ECN_NOTECT #else fileprivate static let notCapableValue = IPTOS_ECN_NOT_ECT // Linux #endif - - func asCInt() -> CInt { - switch self { + + /// Create a CInt encoding of ExplicitCongestionNotification suitable for sending in TCLASS or TOS cmsg. + init(ecnValue: NIOExplicitCongestionNotificationState) { + switch ecnValue { case .transportNotCapable: - return .init(NIOExplicitCongestionNotificationState.notCapableValue) + self = CInt.notCapableValue case .transportCapableFlag0: - return .init(IPTOS_ECN_ECT0) + self = IPTOS_ECN_ECT0 case .transportCapableFlag1: - return .init(IPTOS_ECN_ECT1) + self = IPTOS_ECN_ECT1 case .congestionExperienced: - return .init(IPTOS_ECN_CE) + self = IPTOS_ECN_CE } } } struct UnsafeOutboundControlBytes { private var controlBytes: UnsafeMutableRawBufferPointer - private var writePosition: size_t = 0 + private var writePosition: UnsafeMutableRawBufferPointer.Index = 0 - /// Control bytes are captured - this structure must not have a lifetime exceeded them. + /// This structure must not outlive `controlBytes` init(controlBytes: UnsafeMutableRawBufferPointer) { self.controlBytes = controlBytes } - - mutating func appendControlMessage(level: Int32, - type: Int32, - payload: PayloadType) { + + mutating func appendControlMessage(level: CInt, type: CInt, payload: CInt) { + self.appendGenericControlMessage(level: level, type: type, payload: payload) + } + + /// Appends a control message. + /// PayloadType needs to be trivial (eg CInt) + private mutating func appendGenericControlMessage(level: CInt, + type: CInt, + payload: PayloadType) { let writableBuffer = UnsafeMutableRawBufferPointer(rebasing: self.controlBytes[writePosition...]) let requiredSize = Posix.cmsgSpace(payloadSize: MemoryLayout.stride(ofValue: payload)) precondition(writableBuffer.count >= requiredSize, "Insufficient size for cmsghdr and data") let bufferBase = writableBuffer.baseAddress! + // Binding to cmsghdr is safe here as this is the only place where we bind to non-Raw. let cmsghdrPtr = bufferBase.bindMemory(to: cmsghdr.self, capacity: 1) cmsghdrPtr.pointee.cmsg_level = level cmsghdrPtr.pointee.cmsg_type = type cmsghdrPtr.pointee.cmsg_len = .init(Posix.cmsgLen(payloadSize: MemoryLayout.size(ofValue: payload))) - let dataPointer = Posix.cmsgData(for: .some(cmsghdrPtr)) + let dataPointer = Posix.cmsgData(for: cmsghdrPtr) precondition(dataPointer!.count >= MemoryLayout.stride) - let dataPointerBase = dataPointer!.baseAddress! - let dataPointerTyped = dataPointerBase.bindMemory(to: PayloadType.self, capacity: 1) - dataPointerTyped.pointee = payload + dataPointer!.storeBytes(of: payload, as: PayloadType.self) self.writePosition += requiredSize } @@ -193,23 +211,26 @@ struct UnsafeOutboundControlBytes { } extension UnsafeOutboundControlBytes { - /// - address: Either local or remote will do, we just use it for extracting the right protocol. + /// Add a message describing the explicit congestion state if requested in metadata and valid for this protocol. + /// Parameters: + /// - metadata: Metadata from the addressed envelope which will describe any desired state. + /// - protocolFamily: The type of protocol to encode for. internal mutating func appendExplicitCongestionState(metadata: AddressedEnvelope.Metadata?, - address: SocketAddress?) { - if let metadata = metadata { - switch address { - case .some(.v4): - self.appendControlMessage(level: .init(IPPROTO_IP), - type: IP_TOS, - payload: metadata.ecnState.asCInt()) - case .some(.v6): - self.appendControlMessage(level: .init(IPPROTO_IPV6), - type: IPV6_TCLASS, - payload: metadata.ecnState.asCInt()) - default: - // Nothing to do - if we get here the user is probably making a mistake. - break - } + protocolFamily: NIOBSDSocket.ProtocolFamily?) { + guard let metadata = metadata else { return } + + switch protocolFamily { + case .some(.inet): + self.appendControlMessage(level: .init(IPPROTO_IP), + type: IP_TOS, + payload: CInt(ecnValue: metadata.ecnState)) + case .some(.inet6): + self.appendControlMessage(level: .init(IPPROTO_IPV6), + type: IPV6_TCLASS, + payload: CInt(ecnValue: metadata.ecnState)) + default: + // Nothing to do - if we get here the user is probably making a mistake. + break } } } diff --git a/Sources/NIO/System.swift b/Sources/NIO/System.swift index 34b1fdd5c0..7c8d01d712 100644 --- a/Sources/NIO/System.swift +++ b/Sources/NIO/System.swift @@ -392,9 +392,7 @@ internal enum Posix { } @inline(never) - public static func sendmsg(descriptor: CInt, - msgHdr: UnsafePointer, - flags: CInt) throws -> IOResult { + public static func sendmsg(descriptor: CInt, msgHdr: UnsafePointer, flags: CInt) throws -> IOResult { return try syscall(blocking: true) { sysSendMsg(descriptor, msgHdr, flags) } @@ -531,6 +529,39 @@ internal enum Posix { sysSocketpair(domain.rawValue, type.rawValue, `protocol`, socketVector) } } + + static func cmsgFirstHeader(inside msghdr: UnsafePointer) -> UnsafeMutablePointer? { + return sysCmsgFirstHdr(msghdr) + } + + static func cmsgNextHeader(inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer) -> UnsafeMutablePointer? { + return sysCmsgNxtHdr(msghdr, after) + } + + static func cmsgData(for header: UnsafePointer) -> UnsafeRawBufferPointer? { + let dataPointer = sysCmsgData(header) + // Linux and Darwin use different types for cmsg_len. + let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) + let buffer = UnsafeRawBufferPointer(start: dataPointer, count: Int(length)) + return buffer + } + + static func cmsgData(for header: UnsafeMutablePointer) -> UnsafeMutableRawBufferPointer? { + let dataPointer = sysCmsgDataMutable(header) + // Linux and Darwin use different types for cmsg_len. + let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) + let buffer = UnsafeMutableRawBufferPointer(start: dataPointer, count: Int(length)) + return buffer + } + + static func cmsgLen(payloadSize: size_t) -> size_t { + return sysCmsgLen(payloadSize) + } + + static func cmsgSpace(payloadSize: size_t) -> size_t { + return sysCmsgSpace(payloadSize) + } } /// `NIOFailedToSetSocketNonBlockingError` indicates that NIO was unable to set a socket to non-blocking mode, either @@ -556,50 +587,6 @@ internal extension Posix { } } -internal extension Posix { - static func cmsgFirstHeader(inside msghdr: UnsafePointer?) -> UnsafeMutablePointer? { - return sysCmsgFirstHdr(msghdr) - } - - static func cmsgNextHeader(inside msghdr: UnsafeMutablePointer?, from: UnsafeMutablePointer?) - -> UnsafeMutablePointer? { - return sysCmsgNxtHdr(msghdr, from) - } - - static func cmsgData(for header: UnsafePointer?) -> UnsafeRawBufferPointer? { - let dataPointer = sysCmsgData(header) - if let header = header { - // Linux and Darwin use different types for cmsg_len. - let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) - let buffer = UnsafeRawBufferPointer(start: dataPointer, count: Int(length)) - return buffer - } else { - return nil - } - } - - static func cmsgData(for header: UnsafeMutablePointer?) -> UnsafeMutableRawBufferPointer? { - let dataPointer = sysCmsgDataMutable(header) - if let header = header { - // Linux and Darwin use different types for cmsg_len. - let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) - let buffer = UnsafeMutableRawBufferPointer(start: dataPointer, count: Int(length)) - return buffer - } else { - return nil - } - } - - static func cmsgLen(payloadSize: size_t) -> size_t { - return sysCmsgLen(payloadSize) - } - - static func cmsgSpace(payloadSize: size_t) -> size_t { - return sysCmsgSpace(payloadSize) - } - -} - #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) internal enum KQueue { diff --git a/Tests/NIOTests/ControlMessageTests.swift b/Tests/NIOTests/ControlMessageTests.swift index 34607f97c2..8358869b5b 100644 --- a/Tests/NIOTests/ControlMessageTests.swift +++ b/Tests/NIOTests/ControlMessageTests.swift @@ -29,42 +29,65 @@ fileprivate extension UnsafeControlMessageCollection { } class ControlMessageTests: XCTestCase { + var encoderBytes: UnsafeMutableRawBufferPointer? var encoder: UnsafeOutboundControlBytes! override func setUp() { - let encoderBytes = UnsafeMutableRawBufferPointer.allocate(byteCount: 1000, + self.encoderBytes = UnsafeMutableRawBufferPointer.allocate(byteCount: 1000, alignment: MemoryLayout.alignment) - self.encoder = UnsafeOutboundControlBytes(controlBytes: encoderBytes) + self.encoder = UnsafeOutboundControlBytes(controlBytes: self.encoderBytes!) + } + + override func tearDown() { + if let encoderBytes = self.encoderBytes { + self.encoderBytes = nil + encoderBytes.deallocate() + } } func testEmptyEncode() { XCTAssertEqual(self.encoder.validControlBytes.count, 0) } + + struct DecodedMessage: Equatable { + var level: CInt + var type: CInt + var payload: CInt + } func testEncodeDecode1() { self.encoder.appendControlMessage(level: 1, type: 2, payload: 3) + let expected = [DecodedMessage(level: 1, type: 2, payload: 3)] let encodedBytes = self.encoder.validControlBytes let decoder = UnsafeControlMessageCollection(controlBytes: encodedBytes) XCTAssertEqual(decoder.count, 1) - XCTAssertEqual(decoder.first!.level, 1) - XCTAssertEqual(decoder.first!.type, 2) - XCTAssertEqual(decoder.first!.data!.count, MemoryLayout.size) + var decoded: [DecodedMessage] = [] + for cmsg in decoder { + XCTAssertEqual(cmsg.data!.count, MemoryLayout.size) + let payload = ControlMessageReceiver.readCInt(data: cmsg.data!) + decoded.append(DecodedMessage(level: cmsg.level, type: cmsg.type, payload: payload)) + } + XCTAssertEqual(expected, decoded) } func testEncodeDecode2() { self.encoder.appendControlMessage(level: 1, type: 2, payload: 3) self.encoder.appendControlMessage(level: 4, type: 5, payload: 6) + let expected = [ + DecodedMessage(level: 1, type: 2, payload: 3), + DecodedMessage(level: 4, type: 5, payload: 6) + ] let encodedBytes = self.encoder.validControlBytes let decoder = UnsafeControlMessageCollection(controlBytes: encodedBytes) XCTAssertEqual(decoder.count, 2) - XCTAssertEqual(decoder.first!.level, 1) - XCTAssertEqual(decoder.first!.type, 2) - XCTAssertEqual(decoder.first!.data!.count, MemoryLayout.size) - XCTAssertEqual(decoder[decoder.index(after: decoder.startIndex)].level, 4) - XCTAssertEqual(decoder[decoder.index(after: decoder.startIndex)].type, 5) - XCTAssertEqual(decoder[decoder.index(after: decoder.startIndex)].data!.count, MemoryLayout.size) - + var decoded: [DecodedMessage] = [] + for cmsg in decoder { + XCTAssertEqual(cmsg.data!.count, MemoryLayout.size) + let payload = ControlMessageReceiver.readCInt(data: cmsg.data!) + decoded.append(DecodedMessage(level: cmsg.level, type: cmsg.type, payload: payload)) + } + XCTAssertEqual(expected, decoded) } } diff --git a/Tests/NIOTests/SystemTest.swift b/Tests/NIOTests/SystemTest.swift index 9ca334e678..3ff315f531 100644 --- a/Tests/NIOTests/SystemTest.swift +++ b/Tests/NIOTests/SystemTest.swift @@ -80,11 +80,13 @@ class SystemTest: XCTestCase { private static let cmsghdr_secondDataCount = 1 private static let cmsghdr_firstType = IP_PKTINFO private static let cmsghdr_secondType = IP_TOS + #else + #error("No cmsg support on this platform.") #endif func testCmsgFirstHeader() { - var exampleCmsgHrd = SystemTest.cmsghdrExample - exampleCmsgHrd.withUnsafeMutableBytes { pCmsgHdr in + var exampleCmsgHdr = SystemTest.cmsghdrExample + exampleCmsgHdr.withUnsafeMutableBytes { pCmsgHdr in var msgHdr = msghdr() msgHdr.msg_control = pCmsgHdr.baseAddress msgHdr.msg_controllen = .init(pCmsgHdr.count) @@ -97,19 +99,18 @@ class SystemTest: XCTestCase { } func testCMsgNextHeader() { - var exampleCmsgHrd = SystemTest.cmsghdrExample - exampleCmsgHrd.withUnsafeMutableBytes { pCmsgHdr in + var exampleCmsgHdr = SystemTest.cmsghdrExample + exampleCmsgHdr.withUnsafeMutableBytes { pCmsgHdr in var msgHdr = msghdr() msgHdr.msg_control = pCmsgHdr.baseAddress msgHdr.msg_controllen = .init(pCmsgHdr.count) withUnsafeMutablePointer(to: &msgHdr) { pMsgHdr in let first = Posix.cmsgFirstHeader(inside: pMsgHdr) - let second = Posix.cmsgNextHeader(inside: pMsgHdr, from: first) - let expectedSecondSlice = UnsafeMutableRawBufferPointer( - rebasing: pCmsgHdr[SystemTest.cmsghdr_secondStartPosition...]) - XCTAssertEqual(expectedSecondSlice.baseAddress, second) - let third = Posix.cmsgNextHeader(inside: pMsgHdr, from: second) + let second = Posix.cmsgNextHeader(inside: pMsgHdr, after: first!) + let expectedSecondStart = pCmsgHdr.baseAddress! + SystemTest.cmsghdr_secondStartPosition + XCTAssertEqual(expectedSecondStart, second!) + let third = Posix.cmsgNextHeader(inside: pMsgHdr, after: second!) XCTAssertEqual(third, nil) } } @@ -124,7 +125,7 @@ class SystemTest: XCTestCase { withUnsafePointer(to: msgHdr) { pMsgHdr in let first = Posix.cmsgFirstHeader(inside: pMsgHdr) - let firstData = Posix.cmsgData(for: first) + let firstData = Posix.cmsgData(for: first!) let expecedFirstData = UnsafeRawBufferPointer( rebasing: pCmsgHdr[SystemTest.cmsghdr_firstDataStart..<( SystemTest.cmsghdr_firstDataStart + SystemTest.cmsghdr_firstDataCount)]) From 866660aa8f513284d0e250d29a41e3ba7b8a2aee Mon Sep 17 00:00:00 2001 From: Peter Adams Date: Mon, 13 Jul 2020 18:32:40 +0100 Subject: [PATCH 3/6] Missed Cory review comment. --- Sources/NIO/ControlMessage.swift | 8 ++++---- Tests/NIOTests/ControlMessageTests.swift | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift index 603d47152d..42f9c89e53 100644 --- a/Sources/NIO/ControlMessage.swift +++ b/Sources/NIO/ControlMessage.swift @@ -89,7 +89,7 @@ extension UnsafeControlMessageCollection: Collection { } /// Extract information from a collection of control messages. -struct ControlMessageReceiver { +struct ControlMessageParser { var ecnValue: NIOExplicitCongestionNotificationState = .transportNotCapable // Default #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) @@ -109,7 +109,7 @@ struct ControlMessageReceiver { } mutating func receiveMessage(_ controlMessage: UnsafeControlMessage) { - if controlMessage.level == IPPROTO_IP && controlMessage.type == ControlMessageReceiver.ipv4TosType { + if controlMessage.level == IPPROTO_IP && controlMessage.type == ControlMessageParser.ipv4TosType { if let data = controlMessage.data { assert(data.count == 1) precondition(data.count >= 1) @@ -118,7 +118,7 @@ struct ControlMessageReceiver { } } else if controlMessage.level == IPPROTO_IPV6 && controlMessage.type == IPV6_TCLASS { if let data = controlMessage.data { - let readValue = ControlMessageReceiver.readCInt(data: data) + let readValue = ControlMessageParser.readCInt(data: data) self.ecnValue = .init(receivedValue: readValue) } } @@ -238,7 +238,7 @@ extension UnsafeOutboundControlBytes { extension AddressedEnvelope.Metadata { /// It's assumed the caller has checked that congestion information is required before calling. internal init(from controlMessagesReceived: UnsafeControlMessageCollection) { - var controlMessageReceiver = ControlMessageReceiver() + var controlMessageReceiver = ControlMessageParser() controlMessagesReceived.forEach { controlMessage in controlMessageReceiver.receiveMessage(controlMessage) } self.init(ecnState: controlMessageReceiver.ecnValue) } diff --git a/Tests/NIOTests/ControlMessageTests.swift b/Tests/NIOTests/ControlMessageTests.swift index 8358869b5b..f0d33ca323 100644 --- a/Tests/NIOTests/ControlMessageTests.swift +++ b/Tests/NIOTests/ControlMessageTests.swift @@ -65,7 +65,7 @@ class ControlMessageTests: XCTestCase { var decoded: [DecodedMessage] = [] for cmsg in decoder { XCTAssertEqual(cmsg.data!.count, MemoryLayout.size) - let payload = ControlMessageReceiver.readCInt(data: cmsg.data!) + let payload = ControlMessageParser.readCInt(data: cmsg.data!) decoded.append(DecodedMessage(level: cmsg.level, type: cmsg.type, payload: payload)) } XCTAssertEqual(expected, decoded) @@ -85,7 +85,7 @@ class ControlMessageTests: XCTestCase { var decoded: [DecodedMessage] = [] for cmsg in decoder { XCTAssertEqual(cmsg.data!.count, MemoryLayout.size) - let payload = ControlMessageReceiver.readCInt(data: cmsg.data!) + let payload = ControlMessageParser.readCInt(data: cmsg.data!) decoded.append(DecodedMessage(level: cmsg.level, type: cmsg.type, payload: payload)) } XCTAssertEqual(expected, decoded) From 3ad5e06b7bf4d82e8d4a701851628924e8dc3648 Mon Sep 17 00:00:00 2001 From: Peter Adams Date: Tue, 14 Jul 2020 10:20:11 +0100 Subject: [PATCH 4/6] More Cory and George reviewing. --- Sources/NIO/ControlMessage.swift | 29 ++++++++++++++---------- Tests/NIOTests/ControlMessageTests.swift | 4 ++-- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift index 42f9c89e53..b6f7a1559b 100644 --- a/Sources/NIO/ControlMessage.swift +++ b/Sources/NIO/ControlMessage.swift @@ -91,6 +91,12 @@ extension UnsafeControlMessageCollection: Collection { /// Extract information from a collection of control messages. struct ControlMessageParser { var ecnValue: NIOExplicitCongestionNotificationState = .transportNotCapable // Default + + init(parsing controlMessagesReceived: UnsafeControlMessageCollection) { + for controlMessage in controlMessagesReceived { + self.receiveMessage(controlMessage) + } + } #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) private static let ipv4TosType = IP_RECVTOS @@ -98,7 +104,7 @@ struct ControlMessageParser { private static let ipv4TosType = IP_TOS // Linux #endif - static func readCInt(data: UnsafeRawBufferPointer) -> CInt { + static func _readCInt(data: UnsafeRawBufferPointer) -> CInt { assert(data.count == MemoryLayout.size) precondition(data.count >= MemoryLayout.size) var readValue = CInt(0) @@ -111,14 +117,13 @@ struct ControlMessageParser { mutating func receiveMessage(_ controlMessage: UnsafeControlMessage) { if controlMessage.level == IPPROTO_IP && controlMessage.type == ControlMessageParser.ipv4TosType { if let data = controlMessage.data { - assert(data.count == 1) - precondition(data.count >= 1) + precondition(data.count == 1) let readValue = CInt(data[0]) self.ecnValue = .init(receivedValue: readValue) } } else if controlMessage.level == IPPROTO_IPV6 && controlMessage.type == IPV6_TCLASS { if let data = controlMessage.data { - let readValue = ControlMessageParser.readCInt(data: data) + let readValue = ControlMessageParser._readCInt(data: data) self.ecnValue = .init(receivedValue: readValue) } } @@ -143,9 +148,9 @@ extension NIOExplicitCongestionNotificationState { extension CInt { #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) - fileprivate static let notCapableValue = IPTOS_ECN_NOTECT + private static let notCapableValue = IPTOS_ECN_NOTECT #else - fileprivate static let notCapableValue = IPTOS_ECN_NOT_ECT // Linux + private static let notCapableValue = IPTOS_ECN_NOT_ECT // Linux #endif /// Create a CInt encoding of ExplicitCongestionNotification suitable for sending in TCLASS or TOS cmsg. @@ -165,11 +170,12 @@ extension CInt { struct UnsafeOutboundControlBytes { private var controlBytes: UnsafeMutableRawBufferPointer - private var writePosition: UnsafeMutableRawBufferPointer.Index = 0 + private var writePosition: UnsafeMutableRawBufferPointer.Index /// This structure must not outlive `controlBytes` init(controlBytes: UnsafeMutableRawBufferPointer) { self.controlBytes = controlBytes + self.writePosition = controlBytes.startIndex } mutating func appendControlMessage(level: CInt, type: CInt, payload: CInt) { @@ -193,9 +199,9 @@ struct UnsafeOutboundControlBytes { cmsghdrPtr.pointee.cmsg_type = type cmsghdrPtr.pointee.cmsg_len = .init(Posix.cmsgLen(payloadSize: MemoryLayout.size(ofValue: payload))) - let dataPointer = Posix.cmsgData(for: cmsghdrPtr) - precondition(dataPointer!.count >= MemoryLayout.stride) - dataPointer!.storeBytes(of: payload, as: PayloadType.self) + let dataPointer = Posix.cmsgData(for: cmsghdrPtr)! + precondition(dataPointer.count >= MemoryLayout.stride) + dataPointer.storeBytes(of: payload, as: PayloadType.self) self.writePosition += requiredSize } @@ -238,8 +244,7 @@ extension UnsafeOutboundControlBytes { extension AddressedEnvelope.Metadata { /// It's assumed the caller has checked that congestion information is required before calling. internal init(from controlMessagesReceived: UnsafeControlMessageCollection) { - var controlMessageReceiver = ControlMessageParser() - controlMessagesReceived.forEach { controlMessage in controlMessageReceiver.receiveMessage(controlMessage) } + let controlMessageReceiver = ControlMessageParser(parsing: controlMessagesReceived) self.init(ecnState: controlMessageReceiver.ecnValue) } } diff --git a/Tests/NIOTests/ControlMessageTests.swift b/Tests/NIOTests/ControlMessageTests.swift index f0d33ca323..5a9d2f4328 100644 --- a/Tests/NIOTests/ControlMessageTests.swift +++ b/Tests/NIOTests/ControlMessageTests.swift @@ -65,7 +65,7 @@ class ControlMessageTests: XCTestCase { var decoded: [DecodedMessage] = [] for cmsg in decoder { XCTAssertEqual(cmsg.data!.count, MemoryLayout.size) - let payload = ControlMessageParser.readCInt(data: cmsg.data!) + let payload = ControlMessageParser._readCInt(data: cmsg.data!) decoded.append(DecodedMessage(level: cmsg.level, type: cmsg.type, payload: payload)) } XCTAssertEqual(expected, decoded) @@ -85,7 +85,7 @@ class ControlMessageTests: XCTestCase { var decoded: [DecodedMessage] = [] for cmsg in decoder { XCTAssertEqual(cmsg.data!.count, MemoryLayout.size) - let payload = ControlMessageParser.readCInt(data: cmsg.data!) + let payload = ControlMessageParser._readCInt(data: cmsg.data!) decoded.append(DecodedMessage(level: cmsg.level, type: cmsg.type, payload: payload)) } XCTAssertEqual(expected, decoded) From a590a9882fe020dc51f842b21895ea3c6365ed53 Mon Sep 17 00:00:00 2001 From: Peter Adams Date: Tue, 14 Jul 2020 10:23:08 +0100 Subject: [PATCH 5/6] Bring back mismatched pre-condition which Cory and I are ok with. --- Sources/NIO/ControlMessage.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift index b6f7a1559b..8ad2dd1da0 100644 --- a/Sources/NIO/ControlMessage.swift +++ b/Sources/NIO/ControlMessage.swift @@ -117,7 +117,8 @@ struct ControlMessageParser { mutating func receiveMessage(_ controlMessage: UnsafeControlMessage) { if controlMessage.level == IPPROTO_IP && controlMessage.type == ControlMessageParser.ipv4TosType { if let data = controlMessage.data { - precondition(data.count == 1) + assert(data.count == 1) + precondition(data.count >= 1) let readValue = CInt(data[0]) self.ecnValue = .init(receivedValue: readValue) } From db4f63900244bbc92ac3dfed8f9919ef2667ca76 Mon Sep 17 00:00:00 2001 From: Peter Adams Date: Tue, 14 Jul 2020 12:08:00 +0100 Subject: [PATCH 6/6] Remove unnecessary precondition. --- Sources/NIO/ControlMessage.swift | 1 - 1 file changed, 1 deletion(-) diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift index 8ad2dd1da0..4cdca0484a 100644 --- a/Sources/NIO/ControlMessage.swift +++ b/Sources/NIO/ControlMessage.swift @@ -72,7 +72,6 @@ extension UnsafeControlMessageCollection: Collection { var endIndex: Index { return Index(cmsgPointer: nil) } func index(after: Index) -> Index { - precondition(after.cmsgPointer != nil) var msgHdr = messageHeader return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in return Index(cmsgPointer: Posix.cmsgNextHeader(inside: messageHeaderPtr,