diff --git a/Sources/NIOPosix/BSDSocketAPICommon.swift b/Sources/NIOPosix/BSDSocketAPICommon.swift index 77a855f28f..539136cc26 100644 --- a/Sources/NIOPosix/BSDSocketAPICommon.swift +++ b/Sources/NIOPosix/BSDSocketAPICommon.swift @@ -136,21 +136,33 @@ extension NIOBSDSocket.Option { } extension NIOBSDSocket { - struct ProtocolSubtype: RawRepresentable, Hashable { - typealias RawValue = CInt - var rawValue: RawValue - - init(rawValue: RawValue) { + /// Defines a protocol subtype. + /// + /// Protocol subtypes are the third argument passed to the `socket` system call. + /// They aren't necessarily protocols in their own right: for example, ``mptcp`` + /// is not. They act to modify the socket type instead: thus, ``mptcp`` acts + /// to modify `SOCK_STREAM` to ask for ``mptcp`` support. + public struct ProtocolSubtype: RawRepresentable, Hashable { + public typealias RawValue = CInt + + /// The underlying value of the protocol subtype. + public var rawValue: RawValue + + /// Construct a protocol subtype from its underlying value. + public init(rawValue: RawValue) { self.rawValue = rawValue } } } extension NIOBSDSocket.ProtocolSubtype { - static let `default` = Self(rawValue: 0) + /// Refers to the "default" protocol subtype for a given socket type. + public static let `default` = Self(rawValue: 0) + /// The protocol subtype for MPTCP. + /// /// - returns: nil if MPTCP is not supported. - static var mptcp: Self? { + public static var mptcp: Self? { #if os(Linux) // Defined by the linux kernel, this is IPPROTO_MPTCP. return .init(rawValue: 262) @@ -161,7 +173,8 @@ extension NIOBSDSocket.ProtocolSubtype { } extension NIOBSDSocket.ProtocolSubtype { - init(_ protocol: NIOIPProtocol) { + /// Construct a protocol subtype from an IP protocol. + public init(_ protocol: NIOIPProtocol) { self.rawValue = CInt(`protocol`.rawValue) } } diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index 37e20ff970..0e90bb85d6 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -1403,6 +1403,7 @@ public final class DatagramBootstrap { private var channelInitializer: Optional @usableFromInline internal var _channelOptions: ChannelOptions.Storage + private var proto: NIOBSDSocket.ProtocolSubtype = .default /// Create a `DatagramBootstrap` on the `EventLoopGroup` `group`. /// @@ -1468,6 +1469,11 @@ public final class DatagramBootstrap { return self } + public func protocolSubtype(_ subtype: NIOBSDSocket.ProtocolSubtype) -> Self { + self.proto = subtype + return self + } + #if !os(Windows) /// Use the existing bound socket file descriptor. /// @@ -1542,6 +1548,7 @@ public final class DatagramBootstrap { } private func bind0(_ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture { + let subtype = self.proto let address: SocketAddress do { address = try makeSocketAddress() @@ -1551,7 +1558,7 @@ public final class DatagramBootstrap { func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { return try DatagramChannel(eventLoop: eventLoop, protocolFamily: address.protocol, - protocolSubtype: .default) + protocolSubtype: subtype) } return withNewChannel(makeChannel: makeChannel) { _, channel in channel.register().flatMap { @@ -1590,6 +1597,7 @@ public final class DatagramBootstrap { } private func connect0(_ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture { + let subtype = self.proto let address: SocketAddress do { address = try makeSocketAddress() @@ -1599,7 +1607,7 @@ public final class DatagramBootstrap { func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { return try DatagramChannel(eventLoop: eventLoop, protocolFamily: address.protocol, - protocolSubtype: .default) + protocolSubtype: subtype) } return withNewChannel(makeChannel: makeChannel) { _, channel in channel.register().flatMap { @@ -1839,12 +1847,13 @@ extension DatagramBootstrap { postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture ) async throws -> PostRegistrationTransformationResult { let address = try makeSocketAddress() + let subtype = self.proto func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { return try DatagramChannel( eventLoop: eventLoop, protocolFamily: address.protocol, - protocolSubtype: .default + protocolSubtype: subtype ) } @@ -1867,12 +1876,13 @@ extension DatagramBootstrap { postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture ) async throws -> PostRegistrationTransformationResult { let address = try makeSocketAddress() + let subtype = self.proto func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { return try DatagramChannel( eventLoop: eventLoop, protocolFamily: address.protocol, - protocolSubtype: .default + protocolSubtype: subtype ) } diff --git a/Tests/NIOPosixTests/DatagramChannelTests.swift b/Tests/NIOPosixTests/DatagramChannelTests.swift index 8cfe8de4a2..700bb5ea16 100644 --- a/Tests/NIOPosixTests/DatagramChannelTests.swift +++ b/Tests/NIOPosixTests/DatagramChannelTests.swift @@ -930,6 +930,118 @@ class DatagramChannelTests: XCTestCase { testEcnAndPacketInfoReceive(address: "::1", vectorRead: true, vectorSend: true, receivePacketInfo: true) } + func testDoingICMPWithoutRoot() throws { + // This test validates we can send ICMP messages on a datagram socket without having root privilege. + // + // This doesn't always work: ability to do this on Linux is gated behind a sysctl (net.ipv4.ping_group_range) + // which may exclude us. So we have to tolerate this throwing EPERM as well. + + final class EchoRequestHandler: ChannelInboundHandler { + typealias InboundIn = AddressedEnvelope + typealias OutboundOut = AddressedEnvelope + + let completePromise: EventLoopPromise + + init(completePromise: EventLoopPromise) { + self.completePromise = completePromise + } + + func channelActive(context: ChannelHandlerContext) { + var buffer = context.channel.allocator.buffer(capacity: 32) + + // We're going to write an ICMP echo packet from scratch, like heroes. + // Echo request is type 8, code 0. + // The checksum is tricky: on Linux, the kernel doesn't care what we set, it'll + // calculate it. On macOS, however, we have to calculate it. For both platforms, then, + // we calculate it. + // Identifier is irrelevant. + // Sequence number does matter, but we'll set to 0. + let type = UInt8(8) + let code = UInt8(0) + let fakeChecksum = UInt16(0) + let identifier = UInt16(0) + let sequenceNumber = UInt16(0) + buffer.writeMultipleIntegers(type, code, fakeChecksum, identifier, sequenceNumber) + + // Then we write a payload, which will be "hello from NIO". + buffer.writeString("Hello from NIO") + + // Now calculate the checksum, and store it back at offset 2. + let checksum = buffer.readableBytesView.computeIPChecksum() + buffer.setInteger(checksum, at: 2) + + // Now wrap it into an addressed envelope pointed at localhost. + let envelope = AddressedEnvelope( + remoteAddress: try! SocketAddress(ipAddress: "127.0.0.1", port: 0), + data: buffer + ) + + context.writeAndFlush(self.wrapOutboundOut(envelope)).cascadeFailure(to: self.completePromise) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let envelope = self.unwrapInboundIn(data) + + // Complete with the payload. + self.completePromise.succeed(envelope.data) + } + } + + let loop = self.group.next() + let completePromise = loop.makePromise(of: ByteBuffer.self) + do { + let channel = try DatagramBootstrap(group: group) + .protocolSubtype(.init(.icmp)) + .channelInitializer { channel in + channel.pipeline.addHandler(EchoRequestHandler(completePromise: completePromise)) + } + .bind(host: "127.0.0.1", port: 0) + .wait() + defer { + XCTAssertNoThrow(try channel.close().wait()) + } + + // Let's try to send an ICMP echo request and get a response. + var response = try completePromise.futureResult.wait() + + #if canImport(Darwin) + // Again, a platform difference. On Darwin, this returns a complete IP packet. On Linux, it does not. + // We assume the Linux platform is the more general approach, but if this test fails on your platform + // it is _probably_ because it behaves differently. To make this general, we can skip the IPv4 header. + // + // To do that, we have to work out how long that header is. That's held in bottom 4 bits of the first + // byte, which is the IHL field. This is in "number of 32-bit words". + guard let firstByte = response.getInteger(at: response.readerIndex, as: UInt8.self), + let _ = response.readSlice(length: Int(firstByte & 0x0F) * 4) else { + XCTFail("Insufficient bytes for IPv4 header") + return + } + #endif + + // Now we've got the ICMP packet. Let's parse this. + guard let header = response.readMultipleIntegers(as: (UInt8, UInt8, UInt16, UInt16, UInt16).self) else { + XCTFail("Insufficient bytes for ICMP header") + return + } + + // Echo response has type 0, code 0, unpredictable checksum and identifier, same sequence number we sent. + XCTAssertEqual(header.0 /* type */, 0) + XCTAssertEqual(header.1 /* code */, 0) + XCTAssertEqual(header.4 /* sequence number */, 0) + + // Remaining payload should have been our string. + XCTAssertEqual(String(buffer: response), "Hello from NIO") + } catch let error as IOError { + // Firstly, fail this promise in case it leaks. + completePromise.fail(error) + if error.errnoCode == EACCES { + // Acceptable + return + } + XCTFail("Unexpected IOError: \(error)") + } + } + func assertSending( data: ByteBuffer, from sourceChannel: Channel, diff --git a/Tests/NIOPosixTests/IPv4Header.swift b/Tests/NIOPosixTests/IPv4Header.swift index 9acd00451a..2bbc3618a1 100644 --- a/Tests/NIOPosixTests/IPv4Header.swift +++ b/Tests/NIOPosixTests/IPv4Header.swift @@ -314,6 +314,22 @@ extension IPv4Header { } } +extension Sequence where Element == UInt8 { + func computeIPChecksum() -> UInt16 { + var sum = UInt16(0) + + var iterator = self.makeIterator() + + while let nextHigh = iterator.next() { + let nextLow = iterator.next() ?? 0 + let next = (UInt16(nextHigh) << 8) | UInt16(nextLow) + sum = onesComplementAdd(lhs: sum, rhs: next) + } + + return ~sum + } +} + private func onesComplementAdd(lhs: Integer, rhs: Integer) -> Integer { var (sum, overflowed) = lhs.addingReportingOverflow(rhs) if overflowed {