diff --git a/Sources/NIO/SocketChannel.swift b/Sources/NIO/SocketChannel.swift index 3033a13ac7..84d8520871 100644 --- a/Sources/NIO/SocketChannel.swift +++ b/Sources/NIO/SocketChannel.swift @@ -284,6 +284,10 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { fileprivate func setOption0(option: T, value: T.OptionType) throws { assert(eventLoop.inEventLoop) + guard isOpen else { + throw ChannelError.alreadyClosed + } + switch option { case _ as SocketOption: let (level, name) = option.value as! (SocketOptionLevel, SocketOptionName) @@ -328,6 +332,10 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { fileprivate func getOption0(option: T) throws -> T.OptionType { assert(eventLoop.inEventLoop) + guard isOpen else { + throw ChannelError.alreadyClosed + } + switch option { case _ as SocketOption: let (level, name) = option.value as! (SocketOptionLevel, SocketOptionName) @@ -802,6 +810,11 @@ final class SocketChannel: BaseSocketChannel { override fileprivate func setOption0(option: T, value: T.OptionType) throws { assert(eventLoop.inEventLoop) + + guard isOpen else { + throw ChannelError.alreadyClosed + } + switch option { case _ as ConnectTimeoutOption: connectTimeout = value as? TimeAmount @@ -818,6 +831,11 @@ final class SocketChannel: BaseSocketChannel { override fileprivate func getOption0(option: T) throws -> T.OptionType { assert(eventLoop.inEventLoop) + + guard isOpen else { + throw ChannelError.alreadyClosed + } + switch option { case _ as ConnectTimeoutOption: return connectTimeout as! T.OptionType @@ -1046,6 +1064,11 @@ final class ServerSocketChannel: BaseSocketChannel { override fileprivate func setOption0(option: T, value: T.OptionType) throws { assert(eventLoop.inEventLoop) + + guard isOpen else { + throw ChannelError.alreadyClosed + } + switch option { case _ as BacklogOption: backlog = value as! Int32 @@ -1056,6 +1079,11 @@ final class ServerSocketChannel: BaseSocketChannel { override fileprivate func getOption0(option: T) throws -> T.OptionType { assert(eventLoop.inEventLoop) + + guard isOpen else { + throw ChannelError.alreadyClosed + } + switch option { case _ as BacklogOption: return backlog as! T.OptionType @@ -1212,6 +1240,11 @@ final class DatagramChannel: BaseSocketChannel { override fileprivate func setOption0(option: T, value: T.OptionType) throws { assert(eventLoop.inEventLoop) + + guard isOpen else { + throw ChannelError.alreadyClosed + } + switch option { case _ as WriteSpinOption: pendingWrites.writeSpinCount = value as! UInt @@ -1224,6 +1257,11 @@ final class DatagramChannel: BaseSocketChannel { override fileprivate func getOption0(option: T) throws -> T.OptionType { assert(eventLoop.inEventLoop) + + guard isOpen else { + throw ChannelError.alreadyClosed + } + switch option { case _ as WriteSpinOption: return pendingWrites.writeSpinCount as! T.OptionType diff --git a/Tests/NIOTests/SocketChannelTest+XCTest.swift b/Tests/NIOTests/SocketChannelTest+XCTest.swift index 5ab1b42438..a512c02864 100644 --- a/Tests/NIOTests/SocketChannelTest+XCTest.swift +++ b/Tests/NIOTests/SocketChannelTest+XCTest.swift @@ -34,6 +34,7 @@ extension SocketChannelTest { ("testAcceptFailsWithENOBUFS", testAcceptFailsWithENOBUFS), ("testAcceptFailsWithENOMEM", testAcceptFailsWithENOMEM), ("testAcceptFailsWithEFAULT", testAcceptFailsWithEFAULT), + ("testSetGetOptionClosedServerSocketChannel", testSetGetOptionClosedServerSocketChannel), ] } } diff --git a/Tests/NIOTests/SocketChannelTest.swift b/Tests/NIOTests/SocketChannelTest.swift index 811a338deb..fd64a80095 100644 --- a/Tests/NIOTests/SocketChannelTest.swift +++ b/Tests/NIOTests/SocketChannelTest.swift @@ -164,4 +164,36 @@ public class SocketChannelTest : XCTestCase { let ioError = try promise.futureResult.wait() XCTAssertEqual(error, ioError.errnoCode) } + + public func testSetGetOptionClosedServerSocketChannel() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + + // Create two channels with different event loops. + let serverChannel = try ServerBootstrap(group: group).bind(host: "127.0.0.1", port: 0).wait() + let clientChannel = try ClientBootstrap(group: group).connect(to: serverChannel.localAddress!).wait() + + try assertSetGetOptionOnOpenAndClosed(channel: clientChannel, option: ChannelOptions.allowRemoteHalfClosure, value: true) + try assertSetGetOptionOnOpenAndClosed(channel: serverChannel, option: ChannelOptions.backlog, value: 100) + + } + + private func assertSetGetOptionOnOpenAndClosed(channel: Channel, option: T, value: T.OptionType) throws { + _ = try channel.setOption(option: option, value: value).wait() + _ = try channel.getOption(option: option).wait() + try channel.close().wait() + try channel.closeFuture.wait() + + do { + _ = try channel.setOption(option: option, value: value).wait() + } catch let err as ChannelError where err == .alreadyClosed { + // expected + } + + do { + _ = try channel.getOption(option: option).wait() + } catch let err as ChannelError where err == .alreadyClosed { + // expected + } + } }