diff --git a/Sources/CNIODarwin/include/CNIODarwin.h b/Sources/CNIODarwin/include/CNIODarwin.h index d86fe809d7..e096763956 100644 --- a/Sources/CNIODarwin/include/CNIODarwin.h +++ b/Sources/CNIODarwin/include/CNIODarwin.h @@ -18,6 +18,8 @@ #include #include +#include + // Darwin platforms do not have a sendmmsg implementation available to them. This C module // provides a shim that implements sendmmsg on top of sendmsg. It also provides a shim for // recvmmsg, but does not actually implement that shim, instantly throwing errors if called. @@ -34,5 +36,13 @@ typedef struct { int CNIODarwin_sendmmsg(int sockfd, CNIODarwin_mmsghdr *msgvec, unsigned int vlen, int flags); int CNIODarwin_recvmmsg(int sockfd, CNIODarwin_mmsghdr *msgvec, unsigned int vlen, int flags, struct timespec *timeout); +// cmsghdr handling +struct cmsghdr *CNIODarwin_CMSG_FIRSTHDR(const struct msghdr *); +struct cmsghdr *CNIODarwin_CMSG_NXTHDR(const struct msghdr *, const struct cmsghdr *); +const void *CNIODarwin_CMSG_DATA(const struct cmsghdr *); +void *CNIODarwin_CMSG_DATA_MUTABLE(struct cmsghdr *); +size_t CNIODarwin_CMSG_LEN(size_t); +size_t CNIODarwin_CMSG_SPACE(size_t); + #endif // __APPLE__ #endif // C_NIO_DARWIN_H diff --git a/Sources/CNIODarwin/shim.c b/Sources/CNIODarwin/shim.c index 28639e71ac..a77a380e59 100644 --- a/Sources/CNIODarwin/shim.c +++ b/Sources/CNIODarwin/shim.c @@ -18,6 +18,7 @@ #include #include #include +#include int CNIODarwin_sendmmsg(int sockfd, CNIODarwin_mmsghdr *msgvec, unsigned int vlen, int flags) { // Some quick error checking. If vlen can't fit into int, we bail. @@ -50,4 +51,33 @@ int CNIODarwin_recvmmsg(int sockfd, CNIODarwin_mmsghdr *msgvec, unsigned int vle errx(EX_SOFTWARE, "recvmmsg shim not implemented on Darwin platforms\n"); } +struct cmsghdr *CNIODarwin_CMSG_FIRSTHDR(const struct msghdr *mhdr) { + assert(mhdr != NULL); + return CMSG_FIRSTHDR(mhdr); +} + +struct cmsghdr *CNIODarwin_CMSG_NXTHDR(const struct msghdr *mhdr, const struct cmsghdr *cmsg) { + assert(mhdr != NULL); + assert(cmsg != NULL); // Not required by Darwin but Linux needs this so we should match. + return CMSG_NXTHDR(mhdr, cmsg); +} + +const void *CNIODarwin_CMSG_DATA(const struct cmsghdr *cmsg) { + assert(cmsg != NULL); + return CMSG_DATA(cmsg); +} + +void *CNIODarwin_CMSG_DATA_MUTABLE(struct cmsghdr *cmsg) { + assert(cmsg != NULL); + return CMSG_DATA(cmsg); +} + +size_t CNIODarwin_CMSG_LEN(size_t payloadSizeBytes) { + return CMSG_LEN(payloadSizeBytes); +} + +size_t CNIODarwin_CMSG_SPACE(size_t payloadSizeBytes) { + return CMSG_SPACE(payloadSizeBytes); +} + #endif // __APPLE__ diff --git a/Sources/CNIOLinux/include/CNIOLinux.h b/Sources/CNIOLinux/include/CNIOLinux.h index 17293079f5..7a4d02293f 100644 --- a/Sources/CNIOLinux/include/CNIOLinux.h +++ b/Sources/CNIOLinux/include/CNIOLinux.h @@ -25,6 +25,7 @@ #include #include #include +#include // Some explanation is required here. // @@ -64,5 +65,13 @@ void CNIOLinux_CPU_SET(int cpu, cpu_set_t *set); void CNIOLinux_CPU_ZERO(cpu_set_t *set); int CNIOLinux_CPU_ISSET(int cpu, cpu_set_t *set); int CNIOLinux_CPU_SETSIZE(); + +// cmsghdr handling +struct cmsghdr *CNIOLinux_CMSG_FIRSTHDR(const struct msghdr *); +struct cmsghdr *CNIOLinux_CMSG_NXTHDR(struct msghdr *, struct cmsghdr *); +const void *CNIOLinux_CMSG_DATA(const struct cmsghdr *); +void *CNIOLinux_CMSG_DATA_MUTABLE(struct cmsghdr *); +size_t CNIOLinux_CMSG_LEN(size_t); +size_t CNIOLinux_CMSG_SPACE(size_t); #endif #endif diff --git a/Sources/CNIOLinux/shim.c b/Sources/CNIOLinux/shim.c index 353050ca87..acd256a42f 100644 --- a/Sources/CNIOLinux/shim.c +++ b/Sources/CNIOLinux/shim.c @@ -24,6 +24,7 @@ void CNIOLinux_i_do_nothing_just_working_around_a_darwin_toolchain_bug(void) {} #include #include #include +#include _Static_assert(sizeof(CNIOLinux_mmsghdr) == sizeof(struct mmsghdr), "sizes of CNIOLinux_mmsghdr and struct mmsghdr differ"); @@ -112,4 +113,33 @@ int CNIOLinux_CPU_ISSET(int cpu, cpu_set_t *set) { int CNIOLinux_CPU_SETSIZE() { return CPU_SETSIZE; } + +struct cmsghdr *CNIOLinux_CMSG_FIRSTHDR(const struct msghdr *mhdr) { + assert(mhdr != NULL); + return CMSG_FIRSTHDR(mhdr); +} + +struct cmsghdr *CNIOLinux_CMSG_NXTHDR(struct msghdr *mhdr, struct cmsghdr *cmsg) { + assert(mhdr != NULL); + assert(cmsg != NULL); + return CMSG_NXTHDR(mhdr, cmsg); +} + +const void *CNIOLinux_CMSG_DATA(const struct cmsghdr *cmsg) { + assert(cmsg != NULL); + return CMSG_DATA(cmsg); +} + +void *CNIOLinux_CMSG_DATA_MUTABLE(struct cmsghdr *cmsg) { + assert(cmsg != NULL); + return CMSG_DATA(cmsg); +} + +size_t CNIOLinux_CMSG_LEN(size_t payloadSizeBytes) { + return CMSG_LEN(payloadSizeBytes); +} + +size_t CNIOLinux_CMSG_SPACE(size_t payloadSizeBytes) { + return CMSG_SPACE(payloadSizeBytes); +} #endif diff --git a/Sources/NIO/ControlMessage.swift b/Sources/NIO/ControlMessage.swift new file mode 100644 index 0000000000..4cdca0484a --- /dev/null +++ b/Sources/NIO/ControlMessage.swift @@ -0,0 +1,250 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +import CNIODarwin +#elseif os(Linux) || os(FreeBSD) || os(Android) +import CNIOLinux +#endif + +/// Representation of a `cmsghdr` and associated data. +/// Unsafe as captures pointers and must not escape the scope where those pointers are valid. +struct UnsafeControlMessage { + var level: CInt + var type: CInt + var data: UnsafeRawBufferPointer? +} + +/// Collection representation of `cmsghdr` structures and associated data from `recvmsg` +/// Unsafe as captures pointers held in msghdr structure which must not escape scope of validity. +struct UnsafeControlMessageCollection { + private var messageHeader: msghdr + + init(messageHeader: msghdr) { + self.messageHeader = messageHeader + } +} + +// Add the `Collection` functionality to UnsafeControlMessageCollection. +extension UnsafeControlMessageCollection: Collection { + typealias Element = UnsafeControlMessage + + struct Index: Equatable, Comparable { + fileprivate var cmsgPointer: UnsafeMutablePointer? + + static func < (lhs: UnsafeControlMessageCollection.Index, + rhs: UnsafeControlMessageCollection.Index) -> Bool { + // nil is high, as that's the end of the collection. + switch (lhs.cmsgPointer, rhs.cmsgPointer) { + case (.some(let lhs), .some(let rhs)): + return lhs < rhs + case (.some, .none): + return true + case (.none, .some), (.none, .none): + return false + } + } + + fileprivate init(cmsgPointer: UnsafeMutablePointer?) { + self.cmsgPointer = cmsgPointer + } + } + + var startIndex: Index { + var messageHeader = self.messageHeader + return withUnsafePointer(to: &messageHeader) { messageHeaderPtr in + let firstCMsg = Posix.cmsgFirstHeader(inside: messageHeaderPtr) + return Index(cmsgPointer: firstCMsg) + } + } + + var endIndex: Index { return Index(cmsgPointer: nil) } + + func index(after: Index) -> Index { + var msgHdr = messageHeader + return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in + return Index(cmsgPointer: Posix.cmsgNextHeader(inside: messageHeaderPtr, + after: after.cmsgPointer!)) + } + } + + public subscript(position: Index) -> Element { + let cmsg = position.cmsgPointer! + return UnsafeControlMessage(level: cmsg.pointee.cmsg_level, + type: cmsg.pointee.cmsg_type, + data: Posix.cmsgData(for: cmsg)) + } +} + +/// Extract information from a collection of control messages. +struct ControlMessageParser { + var ecnValue: NIOExplicitCongestionNotificationState = .transportNotCapable // Default + + init(parsing controlMessagesReceived: UnsafeControlMessageCollection) { + for controlMessage in controlMessagesReceived { + self.receiveMessage(controlMessage) + } + } + + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + private static let ipv4TosType = IP_RECVTOS + #else + private static let ipv4TosType = IP_TOS // Linux + #endif + + static func _readCInt(data: UnsafeRawBufferPointer) -> CInt { + assert(data.count == MemoryLayout.size) + precondition(data.count >= MemoryLayout.size) + var readValue = CInt(0) + withUnsafeMutableBytes(of: &readValue) { valuePtr in + valuePtr.copyMemory(from: data) + } + return readValue + } + + mutating func receiveMessage(_ controlMessage: UnsafeControlMessage) { + if controlMessage.level == IPPROTO_IP && controlMessage.type == ControlMessageParser.ipv4TosType { + if let data = controlMessage.data { + assert(data.count == 1) + precondition(data.count >= 1) + let readValue = CInt(data[0]) + self.ecnValue = .init(receivedValue: readValue) + } + } else if controlMessage.level == IPPROTO_IPV6 && controlMessage.type == IPV6_TCLASS { + if let data = controlMessage.data { + let readValue = ControlMessageParser._readCInt(data: data) + self.ecnValue = .init(receivedValue: readValue) + } + } + } +} + +extension NIOExplicitCongestionNotificationState { + /// Initialise a NIOExplicitCongestionNotificationState from a value received via either TCLASS or TOS cmsg. + init(receivedValue: CInt) { + switch receivedValue & IPTOS_ECN_MASK { + case IPTOS_ECN_ECT1: + self = .transportCapableFlag1 + case IPTOS_ECN_ECT0: + self = .transportCapableFlag0 + case IPTOS_ECN_CE: + self = .congestionExperienced + default: + self = .transportNotCapable + } + } +} + +extension CInt { + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + private static let notCapableValue = IPTOS_ECN_NOTECT + #else + private static let notCapableValue = IPTOS_ECN_NOT_ECT // Linux + #endif + + /// Create a CInt encoding of ExplicitCongestionNotification suitable for sending in TCLASS or TOS cmsg. + init(ecnValue: NIOExplicitCongestionNotificationState) { + switch ecnValue { + case .transportNotCapable: + self = CInt.notCapableValue + case .transportCapableFlag0: + self = IPTOS_ECN_ECT0 + case .transportCapableFlag1: + self = IPTOS_ECN_ECT1 + case .congestionExperienced: + self = IPTOS_ECN_CE + } + } +} + +struct UnsafeOutboundControlBytes { + private var controlBytes: UnsafeMutableRawBufferPointer + private var writePosition: UnsafeMutableRawBufferPointer.Index + + /// This structure must not outlive `controlBytes` + init(controlBytes: UnsafeMutableRawBufferPointer) { + self.controlBytes = controlBytes + self.writePosition = controlBytes.startIndex + } + + mutating func appendControlMessage(level: CInt, type: CInt, payload: CInt) { + self.appendGenericControlMessage(level: level, type: type, payload: payload) + } + + /// Appends a control message. + /// PayloadType needs to be trivial (eg CInt) + private mutating func appendGenericControlMessage(level: CInt, + type: CInt, + payload: PayloadType) { + let writableBuffer = UnsafeMutableRawBufferPointer(rebasing: self.controlBytes[writePosition...]) + + let requiredSize = Posix.cmsgSpace(payloadSize: MemoryLayout.stride(ofValue: payload)) + precondition(writableBuffer.count >= requiredSize, "Insufficient size for cmsghdr and data") + + let bufferBase = writableBuffer.baseAddress! + // Binding to cmsghdr is safe here as this is the only place where we bind to non-Raw. + let cmsghdrPtr = bufferBase.bindMemory(to: cmsghdr.self, capacity: 1) + cmsghdrPtr.pointee.cmsg_level = level + cmsghdrPtr.pointee.cmsg_type = type + cmsghdrPtr.pointee.cmsg_len = .init(Posix.cmsgLen(payloadSize: MemoryLayout.size(ofValue: payload))) + + let dataPointer = Posix.cmsgData(for: cmsghdrPtr)! + precondition(dataPointer.count >= MemoryLayout.stride) + dataPointer.storeBytes(of: payload, as: PayloadType.self) + + self.writePosition += requiredSize + } + + /// The result is only valid while this is valid. + var validControlBytes: UnsafeMutableRawBufferPointer { + if writePosition == 0 { + return UnsafeMutableRawBufferPointer(start: nil, count: 0) + } + return UnsafeMutableRawBufferPointer(rebasing: self.controlBytes[0 ..< self.writePosition]) + } + +} + +extension UnsafeOutboundControlBytes { + /// Add a message describing the explicit congestion state if requested in metadata and valid for this protocol. + /// Parameters: + /// - metadata: Metadata from the addressed envelope which will describe any desired state. + /// - protocolFamily: The type of protocol to encode for. + internal mutating func appendExplicitCongestionState(metadata: AddressedEnvelope.Metadata?, + protocolFamily: NIOBSDSocket.ProtocolFamily?) { + guard let metadata = metadata else { return } + + switch protocolFamily { + case .some(.inet): + self.appendControlMessage(level: .init(IPPROTO_IP), + type: IP_TOS, + payload: CInt(ecnValue: metadata.ecnState)) + case .some(.inet6): + self.appendControlMessage(level: .init(IPPROTO_IPV6), + type: IPV6_TCLASS, + payload: CInt(ecnValue: metadata.ecnState)) + default: + // Nothing to do - if we get here the user is probably making a mistake. + break + } + } +} + +extension AddressedEnvelope.Metadata { + /// It's assumed the caller has checked that congestion information is required before calling. + internal init(from controlMessagesReceived: UnsafeControlMessageCollection) { + let controlMessageReceiver = ControlMessageParser(parsing: controlMessagesReceived) + self.init(ecnState: controlMessageReceiver.ecnValue) + } +} diff --git a/Sources/NIO/System.swift b/Sources/NIO/System.swift index 337a914ffb..7c8d01d712 100644 --- a/Sources/NIO/System.swift +++ b/Sources/NIO/System.swift @@ -97,11 +97,30 @@ private let sysSocketpair: @convention(c) (CInt, CInt, CInt, UnsafeMutablePointe private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer) -> CInt = fstat private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIOLinux_sendmmsg private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIOLinux_recvmmsg +private let sysCmsgFirstHdr: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = + CNIOLinux_CMSG_FIRSTHDR +private let sysCmsgNxtHdr: @convention(c) (UnsafeMutablePointer?, UnsafeMutablePointer?) -> + UnsafeMutablePointer? = CNIOLinux_CMSG_NXTHDR +private let sysCmsgData: @convention(c) (UnsafePointer?) -> UnsafeRawPointer? = CNIOLinux_CMSG_DATA +private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutableRawPointer? = + CNIOLinux_CMSG_DATA_MUTABLE +private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_SPACE +private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIOLinux_CMSG_LEN #elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS) private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer?) -> CInt = fstat private let sysKevent = kevent private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIODarwin_sendmmsg private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIODarwin_recvmmsg +private let sysCmsgFirstHdr: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = + CNIODarwin_CMSG_FIRSTHDR +private let sysCmsgNxtHdr: @convention(c) (UnsafePointer?, UnsafePointer?) -> + UnsafeMutablePointer? = CNIODarwin_CMSG_NXTHDR +private let sysCmsgData: @convention(c) (UnsafePointer?) -> UnsafeRawPointer? = + CNIODarwin_CMSG_DATA +private let sysCmsgDataMutable: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutableRawPointer? = + CNIODarwin_CMSG_DATA_MUTABLE +private let sysCmsgSpace: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_SPACE +private let sysCmsgLen: @convention(c) (size_t) -> size_t = CNIODarwin_CMSG_LEN #elseif os(Windows) private let sysSendMmsg: @convention(c) (NIOBSDSocket.Handle, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIOWindows_sendmmsg private let sysRecvMmsg: @convention(c) (NIOBSDSocket.Handle, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIOWindows_recvmmsg @@ -510,6 +529,39 @@ internal enum Posix { sysSocketpair(domain.rawValue, type.rawValue, `protocol`, socketVector) } } + + static func cmsgFirstHeader(inside msghdr: UnsafePointer) -> UnsafeMutablePointer? { + return sysCmsgFirstHdr(msghdr) + } + + static func cmsgNextHeader(inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer) -> UnsafeMutablePointer? { + return sysCmsgNxtHdr(msghdr, after) + } + + static func cmsgData(for header: UnsafePointer) -> UnsafeRawBufferPointer? { + let dataPointer = sysCmsgData(header) + // Linux and Darwin use different types for cmsg_len. + let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) + let buffer = UnsafeRawBufferPointer(start: dataPointer, count: Int(length)) + return buffer + } + + static func cmsgData(for header: UnsafeMutablePointer) -> UnsafeMutableRawBufferPointer? { + let dataPointer = sysCmsgDataMutable(header) + // Linux and Darwin use different types for cmsg_len. + let length = size_t(header.pointee.cmsg_len) - cmsgLen(payloadSize: 0) + let buffer = UnsafeMutableRawBufferPointer(start: dataPointer, count: Int(length)) + return buffer + } + + static func cmsgLen(payloadSize: size_t) -> size_t { + return sysCmsgLen(payloadSize) + } + + static func cmsgSpace(payloadSize: size_t) -> size_t { + return sysCmsgSpace(payloadSize) + } } /// `NIOFailedToSetSocketNonBlockingError` indicates that NIO was unable to set a socket to non-blocking mode, either diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index e4fdcbfcc3..1b449b4d61 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -57,6 +57,7 @@ class LinuxMainRunnerImpl: LinuxMainRunner { testCase(ChannelTests.allTests), testCase(CircularBufferTests.allTests), testCase(CodableByteBufferTest.allTests), + testCase(ControlMessageTests.allTests), testCase(CustomChannelTests.allTests), testCase(DatagramChannelTests.allTests), testCase(EchoServerClientTest.allTests), diff --git a/Tests/NIOTests/ControlMessageTests+XCTest.swift b/Tests/NIOTests/ControlMessageTests+XCTest.swift new file mode 100644 index 0000000000..59063af217 --- /dev/null +++ b/Tests/NIOTests/ControlMessageTests+XCTest.swift @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// ControlMessageTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension ControlMessageTests { + + @available(*, deprecated, message: "not actually deprecated. Just deprecated to allow deprecated tests (which test deprecated functionality) without warnings") + static var allTests : [(String, (ControlMessageTests) -> () throws -> Void)] { + return [ + ("testEmptyEncode", testEmptyEncode), + ("testEncodeDecode1", testEncodeDecode1), + ("testEncodeDecode2", testEncodeDecode2), + ] + } +} + diff --git a/Tests/NIOTests/ControlMessageTests.swift b/Tests/NIOTests/ControlMessageTests.swift new file mode 100644 index 0000000000..5a9d2f4328 --- /dev/null +++ b/Tests/NIOTests/ControlMessageTests.swift @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import XCTest +@testable import NIO + +fileprivate extension UnsafeControlMessageCollection { + init(controlBytes: UnsafeMutableRawBufferPointer) { + let msgHdr = msghdr(msg_name: nil, + msg_namelen: 0, + msg_iov: nil, + msg_iovlen: 0, + msg_control: controlBytes.baseAddress, + msg_controllen: .init(controlBytes.count), + msg_flags: 0) + self.init(messageHeader: msgHdr) + } +} + +class ControlMessageTests: XCTestCase { + var encoderBytes: UnsafeMutableRawBufferPointer? + var encoder: UnsafeOutboundControlBytes! + + override func setUp() { + self.encoderBytes = UnsafeMutableRawBufferPointer.allocate(byteCount: 1000, + alignment: MemoryLayout.alignment) + self.encoder = UnsafeOutboundControlBytes(controlBytes: self.encoderBytes!) + } + + override func tearDown() { + if let encoderBytes = self.encoderBytes { + self.encoderBytes = nil + encoderBytes.deallocate() + } + } + + func testEmptyEncode() { + XCTAssertEqual(self.encoder.validControlBytes.count, 0) + } + + struct DecodedMessage: Equatable { + var level: CInt + var type: CInt + var payload: CInt + } + + func testEncodeDecode1() { + self.encoder.appendControlMessage(level: 1, type: 2, payload: 3) + let expected = [DecodedMessage(level: 1, type: 2, payload: 3)] + let encodedBytes = self.encoder.validControlBytes + + let decoder = UnsafeControlMessageCollection(controlBytes: encodedBytes) + XCTAssertEqual(decoder.count, 1) + var decoded: [DecodedMessage] = [] + for cmsg in decoder { + XCTAssertEqual(cmsg.data!.count, MemoryLayout.size) + let payload = ControlMessageParser._readCInt(data: cmsg.data!) + decoded.append(DecodedMessage(level: cmsg.level, type: cmsg.type, payload: payload)) + } + XCTAssertEqual(expected, decoded) + } + + func testEncodeDecode2() { + self.encoder.appendControlMessage(level: 1, type: 2, payload: 3) + self.encoder.appendControlMessage(level: 4, type: 5, payload: 6) + let expected = [ + DecodedMessage(level: 1, type: 2, payload: 3), + DecodedMessage(level: 4, type: 5, payload: 6) + ] + let encodedBytes = self.encoder.validControlBytes + + let decoder = UnsafeControlMessageCollection(controlBytes: encodedBytes) + XCTAssertEqual(decoder.count, 2) + var decoded: [DecodedMessage] = [] + for cmsg in decoder { + XCTAssertEqual(cmsg.data!.count, MemoryLayout.size) + let payload = ControlMessageParser._readCInt(data: cmsg.data!) + decoded.append(DecodedMessage(level: cmsg.level, type: cmsg.type, payload: payload)) + } + XCTAssertEqual(expected, decoded) + } +} diff --git a/Tests/NIOTests/SystemTest+XCTest.swift b/Tests/NIOTests/SystemTest+XCTest.swift index 56ae6aada1..053bdcf1fd 100644 --- a/Tests/NIOTests/SystemTest+XCTest.swift +++ b/Tests/NIOTests/SystemTest+XCTest.swift @@ -29,6 +29,10 @@ extension SystemTest { return [ ("testSystemCallWrapperPerformance", testSystemCallWrapperPerformance), ("testErrorsWorkCorrectly", testErrorsWorkCorrectly), + ("testCmsgFirstHeader", testCmsgFirstHeader), + ("testCMsgNextHeader", testCMsgNextHeader), + ("testCMsgData", testCMsgData), + ("testCMsgCollection", testCMsgCollection), ] } } diff --git a/Tests/NIOTests/SystemTest.swift b/Tests/NIOTests/SystemTest.swift index f20659c389..3ff315f531 100644 --- a/Tests/NIOTests/SystemTest.swift +++ b/Tests/NIOTests/SystemTest.swift @@ -46,4 +46,116 @@ class SystemTest: XCTestCase { return [readFD, writeFD] } } + + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + // Example twin data options captured on macOS + private static let cmsghdrExample: [UInt8] = [0x10, 0x00, 0x00, 0x00, // Length 16 including header + 0x00, 0x00, 0x00, 0x00, // IPPROTO_IP + 0x07, 0x00, 0x00, 0x00, // IP_RECVDSTADDR + 0x7F, 0x00, 0x00, 0x01, // 127.0.0.1 + 0x0D, 0x00, 0x00, 0x00, // Length 13 including header + 0x00, 0x00, 0x00, 0x00, // IPPROTO_IP + 0x1B, 0x00, 0x00, 0x00, // IP_RECVTOS + 0x01, 0x00, 0x00, 0x00] // ECT-1 (1 byte) + private static let cmsghdr_secondStartPosition = 16 + private static let cmsghdr_firstDataStart = 12 + private static let cmsghdr_firstDataCount = 4 + private static let cmsghdr_secondDataCount = 1 + private static let cmsghdr_firstType = IP_RECVDSTADDR + private static let cmsghdr_secondType = IP_RECVTOS + #elseif os(Linux) + // Example twin data options captured on Linux + private static let cmsghdrExample: [UInt8] = [ + 0x1C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Length 28 including header. + 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, // IPPROTO_IP, IP_PKTINFO + 0x01, 0x00, 0x00, 0x00, 0x7F, 0x00, 0x00, 0x01, // interface number, 127.0.0.1 (local) + 0x7F, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, // 127.0.0.1 (destination), 4 bytes to align length + 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Length 17 + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, // IPPROTO_IP, IP_TOS + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 // ECT-1 (1 byte) + ] + private static let cmsghdr_secondStartPosition = 32 + private static let cmsghdr_firstDataStart = 16 + private static let cmsghdr_firstDataCount = 12 + private static let cmsghdr_secondDataCount = 1 + private static let cmsghdr_firstType = IP_PKTINFO + private static let cmsghdr_secondType = IP_TOS + #else + #error("No cmsg support on this platform.") + #endif + + func testCmsgFirstHeader() { + var exampleCmsgHdr = SystemTest.cmsghdrExample + exampleCmsgHdr.withUnsafeMutableBytes { pCmsgHdr in + var msgHdr = msghdr() + msgHdr.msg_control = pCmsgHdr.baseAddress + msgHdr.msg_controllen = .init(pCmsgHdr.count) + + withUnsafePointer(to: msgHdr) { pMsgHdr in + let result = Posix.cmsgFirstHeader(inside: pMsgHdr) + XCTAssertEqual(pCmsgHdr.baseAddress, result) + } + } + } + + func testCMsgNextHeader() { + var exampleCmsgHdr = SystemTest.cmsghdrExample + exampleCmsgHdr.withUnsafeMutableBytes { pCmsgHdr in + var msgHdr = msghdr() + msgHdr.msg_control = pCmsgHdr.baseAddress + msgHdr.msg_controllen = .init(pCmsgHdr.count) + + withUnsafeMutablePointer(to: &msgHdr) { pMsgHdr in + let first = Posix.cmsgFirstHeader(inside: pMsgHdr) + let second = Posix.cmsgNextHeader(inside: pMsgHdr, after: first!) + let expectedSecondStart = pCmsgHdr.baseAddress! + SystemTest.cmsghdr_secondStartPosition + XCTAssertEqual(expectedSecondStart, second!) + let third = Posix.cmsgNextHeader(inside: pMsgHdr, after: second!) + XCTAssertEqual(third, nil) + } + } + } + + func testCMsgData() { + var exampleCmsgHrd = SystemTest.cmsghdrExample + exampleCmsgHrd.withUnsafeMutableBytes { pCmsgHdr in + var msgHdr = msghdr() + msgHdr.msg_control = pCmsgHdr.baseAddress + msgHdr.msg_controllen = .init(pCmsgHdr.count) + + withUnsafePointer(to: msgHdr) { pMsgHdr in + let first = Posix.cmsgFirstHeader(inside: pMsgHdr) + let firstData = Posix.cmsgData(for: first!) + let expecedFirstData = UnsafeRawBufferPointer( + rebasing: pCmsgHdr[SystemTest.cmsghdr_firstDataStart..<( + SystemTest.cmsghdr_firstDataStart + SystemTest.cmsghdr_firstDataCount)]) + XCTAssertEqual(expecedFirstData.baseAddress, firstData?.baseAddress) + XCTAssertEqual(expecedFirstData.count, firstData?.count) + } + } + } + + func testCMsgCollection() { + var exampleCmsgHrd = SystemTest.cmsghdrExample + exampleCmsgHrd.withUnsafeMutableBytes { pCmsgHdr in + var msgHdr = msghdr() + msgHdr.msg_control = pCmsgHdr.baseAddress + msgHdr.msg_controllen = .init(pCmsgHdr.count) + let collection = UnsafeControlMessageCollection(messageHeader: msgHdr) + var msgNum = 0 + for cmsg in collection { + if msgNum == 0 { + XCTAssertEqual(cmsg.level, .init(IPPROTO_IP)) + XCTAssertEqual(cmsg.type, .init(SystemTest.cmsghdr_firstType)) + XCTAssertEqual(cmsg.data?.count, SystemTest.cmsghdr_firstDataCount) + } else if msgNum == 1 { + XCTAssertEqual(cmsg.level, .init(IPPROTO_IP)) + XCTAssertEqual(cmsg.type, .init(SystemTest.cmsghdr_secondType)) + XCTAssertEqual(cmsg.data?.count, SystemTest.cmsghdr_secondDataCount) + } + msgNum += 1 + } + XCTAssertEqual(msgNum, 2) + } + } }