diff --git a/Sources/NIOPosix/BaseStreamSocketChannel.swift b/Sources/NIOPosix/BaseStreamSocketChannel.swift index 41e6143f20..ca3c652f83 100644 --- a/Sources/NIOPosix/BaseStreamSocketChannel.swift +++ b/Sources/NIOPosix/BaseStreamSocketChannel.swift @@ -183,6 +183,11 @@ class BaseStreamSocketChannel: BaseSocketChannel promise?.fail(ChannelError.outputClosed) return } + if self.inputShutdown { + // Escalate to full closure + self.close0(error: error, mode: .all, promise: promise) + return + } try self.shutdownSocket(mode: mode) // Fail all pending writes and so ensure all pending promises are notified self.pendingWrites.failAll(error: error, close: false) @@ -195,6 +200,11 @@ class BaseStreamSocketChannel: BaseSocketChannel promise?.fail(ChannelError.inputClosed) return } + if self.outputShutdown { + // Escalate to full closure + self.close0(error: error, mode: .all, promise: promise) + return + } switch error { case ChannelError.eof: // No need to explicit call socket.shutdown(...) as we received an EOF and the call would only cause diff --git a/Tests/NIOPosixTests/ChannelTests.swift b/Tests/NIOPosixTests/ChannelTests.swift index 48886773ff..c7bb029914 100644 --- a/Tests/NIOPosixTests/ChannelTests.swift +++ b/Tests/NIOPosixTests/ChannelTests.swift @@ -1249,6 +1249,97 @@ public final class ChannelTests: XCTestCase { try channel.writeAndFlush(NIOAny(buffer)).wait() } + func testInputAndOutputClosedResultsInFullClosure() throws { + final class PromiseOnChildChannelInitHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + private let promise: EventLoopPromise + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func channelActive(context: ChannelHandlerContext) { + self.promise.succeed(context.channel) + context.fireChannelActive() + } + } + + final class ChannelInactiveHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + private let promise: EventLoopPromise + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func channelInactive(context: ChannelHandlerContext) { + self.promise.succeed() + context.fireChannelActive() + } + } + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let serverChildChannelInitPromise: EventLoopPromise = group.next().makePromise() + let serverChildChannelInactivePromise: EventLoopPromise = group.next().makePromise() + let serverChannel: Channel = try ServerBootstrap(group: group) + .childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: true) // Important! + .childChannelInitializer { channel in + channel.pipeline.addHandlers( + PromiseOnChildChannelInitHandler(promise: serverChildChannelInitPromise), + ChannelInactiveHandler(promise: serverChildChannelInactivePromise) + ) + } + .bind(host: "127.0.0.1", port: 0) + .wait() + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + let clientChannelInactivePromise: EventLoopPromise = group.next().makePromise() + let clientChannel = try ClientBootstrap(group: group) + .channelInitializer { channel in + channel.pipeline.addHandler(ChannelInactiveHandler(promise: clientChannelInactivePromise)) + } + .connect(to: serverChannel.localAddress!) + .wait() + + XCTAssertNoThrow(try clientChannel.setOption(ChannelOptions.allowRemoteHalfClosure, value: true).wait()) + + // Ok, the connection is definitely up. + // Now retrieve the client channel that our server opened for the connection to our client. + let serverConnectionChildChannel = try serverChildChannelInitPromise.futureResult.wait() + + // First we close the output of the connection channel on the server. + // This results in the input of the clientChannel being closed. + XCTAssertNoThrow(try serverConnectionChildChannel.close(mode: .output).wait()) + // Now we close the output of the clientChannel. + // Given that the the input of the clientChannel is already closed, + // this should escalate to a full closure of the clientChannel. + XCTAssertNoThrow(try clientChannel.close(mode: .output).wait()) + + // Assert that full closure of client channel occured by verifying + // that channelInactive was invoked on the channel. + XCTAssertNoThrow(try clientChannelInactivePromise.futureResult.wait()) + + // Assert that the server child channel becomes inactive now that the + // client channel has been closed completely. + XCTAssertNoThrow(try serverChildChannelInactivePromise.futureResult.wait()) + + // Additional assertion: trying to close the clientChannel manually + // should fail as it is closed already. + XCTAssertThrowsError(try clientChannel.close().wait()) { error in + if let error = error as? ChannelError { + XCTAssertEqual(ChannelError.alreadyClosed, error) + } else { + XCTFail("unexpected error: \(error)") + } + } + } + enum ShutDownEvent { case input case output