diff --git a/Sources/NIO/BaseSocketChannel.swift b/Sources/NIO/BaseSocketChannel.swift index 7cee2d2a70..48c075804d 100644 --- a/Sources/NIO/BaseSocketChannel.swift +++ b/Sources/NIO/BaseSocketChannel.swift @@ -263,7 +263,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { /// Returned by the `private func readable0()` to inform the caller about the current state of the underlying read stream. /// This is mostly useful when receiving `.readEOF` as we then need to drain the read stream fully (ie. until we receive EOF or error of course) - private enum ReadStreamState { + private enum ReadStreamState: Equatable { /// Everything seems normal case normal(ReadResult) @@ -450,11 +450,11 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it. var readAtLeastOnce = false - while let read = try? readFromSocket(), read == .some { + while let read = try? self.readFromSocket(), read == .some { readAtLeastOnce = true } if readAtLeastOnce && self.lifecycleManager.isActive { - pipeline.fireChannelReadComplete() + self.pipeline.fireChannelReadComplete() } } @@ -959,7 +959,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { let readResult: ReadResult do { - readResult = try readFromSocket() + readResult = try self.readFromSocket() } catch let err { let readStreamState: ReadStreamState // ChannelError.eof is not something we want to fire through the pipeline as it just means the remote @@ -970,7 +970,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // getOption0 can only fail if the channel is not active anymore but we assert further up that it is. If // that's not the case this is a precondition failure and we would like to know. - if self.lifecycleManager.isActive, try! getOption0(ChannelOptions.allowRemoteHalfClosure) { + if self.lifecycleManager.isActive, try! self.getOption0(ChannelOptions.allowRemoteHalfClosure) { // If we want to allow half closure we will just mark the input side of the Channel // as closed. assert(self.lifecycleManager.isActive) @@ -987,7 +987,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } // Call before triggering the close of the Channel. - if self.lifecycleManager.isActive { + if readStreamState != .error, self.lifecycleManager.isActive { self.pipeline.fireChannelReadComplete0() } @@ -997,6 +997,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { return readStreamState } + assert(readResult == .some) if self.lifecycleManager.isActive { self.pipeline.fireChannelReadComplete0() } diff --git a/Sources/NIO/ServerSocket.swift b/Sources/NIO/ServerSocket.swift index 79ae9f8e01..5705548bfb 100644 --- a/Sources/NIO/ServerSocket.swift +++ b/Sources/NIO/ServerSocket.swift @@ -89,7 +89,13 @@ let sock = Socket(descriptor: fd) #if !os(Linux) if setNonBlocking { - try sock.setNonBlocking() + do { + try sock.setNonBlocking() + } catch { + // best effort + try? sock.close() + throw error + } } #endif return sock diff --git a/Sources/NIO/SocketChannel.swift b/Sources/NIO/SocketChannel.swift index 8f4c2aaa24..a0742c7be6 100644 --- a/Sources/NIO/SocketChannel.swift +++ b/Sources/NIO/SocketChannel.swift @@ -406,16 +406,18 @@ final class ServerSocketChannel: BaseSocketChannel { guard self.isOpen else { throw ChannelError.eof } - if let accepted = try self.socket.accept(setNonBlocking: true) { + if let accepted = try self.socket.accept(setNonBlocking: true) { readPending = false result = .some do { - let chan = try SocketChannel(socket: accepted, parent: self, eventLoop: group.next() as! SelectableEventLoop) + let chan = try SocketChannel(socket: accepted, + parent: self, + eventLoop: group.next() as! SelectableEventLoop) assert(self.isActive) pipeline.fireChannelRead0(NIOAny(chan)) - } catch let err { + } catch { try? accepted.close() - throw err + throw error } } else { break @@ -436,6 +438,16 @@ final class ServerSocketChannel: BaseSocketChannel { // These are errors we may be able to recover from. The user may just want to stop accepting connections for example // or provide some other means of back-pressure. This could be achieved by a custom ChannelDuplexHandler. return false + case EINVAL: + if case .function(let function) = err.reason { + // see https://github.com/apple/swift-nio/issues/1030 + // on Darwin, fcntl(fd, F_SETFL, O_NONBLOCK) sometimes returns EINVAL... + // + // the comparison is so weird because StaticString isn't Equatable and doesn't have an == method... + return String(describing: function) != "fcntl(descriptor:command:value:)" + } else { + return true + } default: return true } diff --git a/Tests/NIOTests/SocketChannelTest+XCTest.swift b/Tests/NIOTests/SocketChannelTest+XCTest.swift index 4934bb2800..9307d29cff 100644 --- a/Tests/NIOTests/SocketChannelTest+XCTest.swift +++ b/Tests/NIOTests/SocketChannelTest+XCTest.swift @@ -49,6 +49,7 @@ extension SocketChannelTest { ("testUnprocessedOutboundUserEventFailsOnServerSocketChannel", testUnprocessedOutboundUserEventFailsOnServerSocketChannel), ("testUnprocessedOutboundUserEventFailsOnSocketChannel", testUnprocessedOutboundUserEventFailsOnSocketChannel), ("testSetSockOptDoesNotOverrideExistingFlags", testSetSockOptDoesNotOverrideExistingFlags), + ("testServerChannelDoesNotBreakIfAcceptingFailsWithEINVAL", testServerChannelDoesNotBreakIfAcceptingFailsWithEINVAL), ] } } diff --git a/Tests/NIOTests/SocketChannelTest.swift b/Tests/NIOTests/SocketChannelTest.swift index ff0053db70..44efffab7b 100644 --- a/Tests/NIOTests/SocketChannelTest.swift +++ b/Tests/NIOTests/SocketChannelTest.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import XCTest @testable import NIO +import NIOTestUtils import NIOConcurrencyHelpers private extension Array { @@ -679,4 +680,65 @@ public final class SocketChannelTest : XCTestCase { XCTAssertNoThrow(try s.close()) } + + func testServerChannelDoesNotBreakIfAcceptingFailsWithEINVAL() throws { + // regression test for https://github.com/apple/swift-nio/issues/1030 + class HandsOutMoodySocketsServerSocket: ServerSocket { + let shouldAcceptsFail: Atomic = .init(value: true) + override func accept(setNonBlocking: Bool = false) throws -> Socket? { + XCTAssertTrue(setNonBlocking) + if self.shouldAcceptsFail.load() { + throw IOError(errnoCode: EINVAL, function: "fcntl(descriptor:command:value:)") + } else { + return try Socket(protocolFamily: PF_INET, + type: Posix.SOCK_STREAM, + setNonBlocking: false) + } + } + } + + class CloseAcceptedSocketsHandler: ChannelInboundHandler { + typealias InboundIn = Channel + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.unwrapInboundIn(data).close(promise: nil) + } + } + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + let serverSock = try HandsOutMoodySocketsServerSocket(protocolFamily: PF_INET, setNonBlocking: true) + + let serverChan = try ServerSocketChannel(serverSocket: serverSock, + eventLoop: group.next() as! SelectableEventLoop, + group: group) + try serverChan.setOption(ChannelOptions.maxMessagesPerRead, value: 1).wait() + try serverChan.setOption(ChannelOptions.autoRead, value: false).wait() + try serverChan.register().wait() + try serverChan.bind(to: .init(ipAddress: "127.0.0.1", port: 0)).wait() + + let eventCounter = EventCounterHandler() + try serverChan.pipeline.addHandler(eventCounter).wait() + try serverChan.pipeline.addHandler(CloseAcceptedSocketsHandler()).wait() + + XCTAssertEqual([], eventCounter.allTriggeredEvents()) + try serverChan.eventLoop.submit { + serverChan.readable() + }.wait() + XCTAssertEqual(["errorCaught"], eventCounter.allTriggeredEvents()) + XCTAssertEqual(1, eventCounter.errorCaughtCalls) + + serverSock.shouldAcceptsFail.store(false) + + try serverChan.eventLoop.submit { + serverChan.readable() + }.wait() + XCTAssertEqual(["errorCaught", "channelRead", "channelReadComplete"], + eventCounter.allTriggeredEvents()) + XCTAssertEqual(1, eventCounter.errorCaughtCalls) + XCTAssertEqual(1, eventCounter.channelReadCalls) + XCTAssertEqual(1, eventCounter.channelReadCompleteCalls) + } }