From 7abbb038980f2d27ce4cf0edb2ef2d3953f4813c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Wei=C3=9F?= Date: Thu, 22 Mar 2018 16:08:14 +0000 Subject: [PATCH] make Channel lifecycle statemachine explicit Motivation: We had a lot of problems with the Channel lifecycle statemachine as it wasn't explicit, this fixes this. Additionally, it asserts a lot more. Modifications: - made Channel lifecycle statemachine explicit - lots of asserts Result: - hopefully the state machine works better - the asserts should guide our way to work on in corner cases as well --- Sources/NIO/BaseSocketChannel.swift | 289 +++++++++++++++++----- Sources/NIO/Channel.swift | 7 + Sources/NIO/ChannelPipeline.swift | 1 + Sources/NIO/Selector.swift | 8 + Sources/NIO/SocketChannel.swift | 23 +- Tests/NIOTests/ChannelTests+XCTest.swift | 2 + Tests/NIOTests/ChannelTests.swift | 157 ++++++++++++ Tests/NIOTests/DatagramChannelTests.swift | 15 ++ 8 files changed, 434 insertions(+), 68 deletions(-) diff --git a/Sources/NIO/BaseSocketChannel.swift b/Sources/NIO/BaseSocketChannel.swift index b9690bcc80..1eedf1e9a8 100644 --- a/Sources/NIO/BaseSocketChannel.swift +++ b/Sources/NIO/BaseSocketChannel.swift @@ -14,6 +14,151 @@ import NIOConcurrencyHelpers +private struct SocketChannelLifecycleManager { + // MARK: Types + private enum State { + case fresh + case registered + case activated + case closed + } + + private enum Event { + case activate + case register + case close + } + + // MARK: properties + // this is queried from the Channel, ie. must be thread-safe + internal let isActiveAtomic = Atomic(value: false) + // these are only to be accessed on the EventLoop + internal let channelPipeline: ChannelPipeline + private var currentState: State = .fresh { + didSet { + assert(self.eventLoop.inEventLoop) + switch (oldValue, self.currentState) { + case (_, .activated): + self.isActiveAtomic.store(true) + case (.activated, _): + self.isActiveAtomic.store(false) + default: + () + } + } + } + + private var eventLoop: EventLoop { + return self.channelPipeline.eventLoop + } + + // MARK: API + internal init(channelPipeline: ChannelPipeline) { + self.channelPipeline = channelPipeline + } + + // this is called from Channel's deinit, so don't assert we're on the EventLoop! + internal var canBeDestroyed: Bool { + return self.currentState == .closed + } + + @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + internal mutating func register(promise: EventLoopPromise?) -> (() -> Void) { + return self.moveState(event: .register, promise: promise) + } + + @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + internal mutating func close(promise: EventLoopPromise?) -> (() -> Void) { + return self.moveState(event: .close, promise: promise) + } + + @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + internal mutating func activate(promise: EventLoopPromise?) -> (() -> Void) { + return self.moveState(event: .activate, promise: promise) + } + + // MARK: private API + @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + private mutating func moveState(event: Event, promise: EventLoopPromise?) -> (() -> Void) { + assert(self.eventLoop.inEventLoop) + + switch (self.currentState, event) { + // origin: .fresh + case (.fresh, .register): + return self.doStateTransfer(newState: .registered, promise: promise) { pipeline in + pipeline.fireChannelRegistered0() + } + + case (.fresh, .close): + return self.doStateTransfer(newState: .closed, promise: promise) { (_: ChannelPipeline) in } + + // origin: .registered + case (.registered, .activate): + return self.doStateTransfer(newState: .activated, promise: promise) { pipeline in + pipeline.fireChannelActive0() + } + + case (.registered, .close): + return self.doStateTransfer(newState: .closed, promise: promise) { pipeline in + pipeline.fireChannelUnregistered0() + } + + // origin: .activated + case (.activated, .close): + return self.doStateTransfer(newState: .closed, promise: promise) { pipeline in + pipeline.fireChannelInactive0() + pipeline.fireChannelUnregistered0() + } + + // bad transitions + case (.fresh, .activate), // should go through .registered first + (.registered, .register), // already registered + (.activated, .activate), // already activated + (.activated, .register), // already registered (and activated) + (.closed, _): // already closed + self.badTransition(event: event) + } + } + + private func badTransition(event: Event) -> Never { + fatalError("illegal transition: state=\(self.currentState), event=\(event)") + } + + @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + private mutating func doStateTransfer(newState: State, promise: EventLoopPromise?, _ callouts: @escaping (ChannelPipeline) -> Void) -> (() -> Void) { + self.currentState = newState + + let pipeline = self.channelPipeline + return { + promise?.succeed(result: ()) + callouts(pipeline) + } + } + + // MARK: convenience properties + internal var isActive: Bool { + assert(self.eventLoop.inEventLoop) + return self.currentState == .activated + } + + internal var isRegistered: Bool { + assert(self.eventLoop.inEventLoop) + switch self.currentState { + case .fresh, .closed: + return false + case .registered, .activated: + return true + } + } + + /// Returns whether the underlying file descriptor is open. This property will always be true (even before registration) + /// until the Channel is closed. + internal var isOpen: Bool { + assert(self.eventLoop.inEventLoop) + return self.currentState != .closed + } +} + /// The base class for all socket-based channels in NIO. /// /// There are many types of specialised socket-based channel in NIO. Each of these @@ -44,12 +189,9 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { var maxMessagesPerRead: UInt = 4 private var inFlushNow: Bool = false // Guard against re-entrance of flushNow() method. - private var neverRegistered = true - private var neverActivated = true - private var active: Atomic = Atomic(value: false) - private var _isOpen: Bool = true private var autoRead: Bool = true - private var _pipeline: ChannelPipeline! + private var lifecycleManager: SocketChannelLifecycleManager! + private var bufferAllocator: ByteBufferAllocator = ByteBufferAllocator() { didSet { assert(self.eventLoop.inEventLoop) @@ -92,7 +234,12 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { /// `false` if the whole `Channel` is closed and so no more IO operation can be done. var isOpen: Bool { assert(eventLoop.inEventLoop) - return self._isOpen + return self.lifecycleManager.isOpen + } + + var isRegistered: Bool { + assert(self.eventLoop.inEventLoop) + return self.lifecycleManager.isRegistered } internal var selectable: T { @@ -101,7 +248,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // This is `Channel` API so must be thread-safe. public var isActive: Bool { - return self.active.load() + return self.lifecycleManager.isActiveAtomic.load() } // This is `Channel` API so must be thread-safe. @@ -129,7 +276,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // This is `Channel` API so must be thread-safe. public final var pipeline: ChannelPipeline { - return _pipeline + return self.lifecycleManager.channelPipeline } // MARK: Methods to override in subclasses. @@ -187,15 +334,14 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { self.selectableEventLoop = eventLoop self.closePromise = eventLoop.newPromise() self.parent = parent - self.active.store(false) self.recvAllocator = recvAllocator - self._pipeline = ChannelPipeline(channel: self) + self.lifecycleManager = SocketChannelLifecycleManager(channelPipeline: ChannelPipeline(channel: self)) // As the socket may already be connected we should ensure we start with the correct addresses cached. self.addressesCached.store(Box((local: try? socket.localAddress(), remote: try? socket.remoteAddress()))) } deinit { - assert(!self._isOpen, "leak of open Channel") + assert(self.lifecycleManager.canBeDestroyed, "leak of open Channel") } public final func localAddress0() throws -> SocketAddress { @@ -230,6 +376,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { inFlushNow = true do { + assert(self.lifecycleManager.isActive) switch try self.writeToSocket() { case .couldNotWriteEverything: return .register @@ -240,9 +387,11 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // If there is a write error we should try drain the inbound before closing the socket as there may be some data pending. // We ignore any error that is thrown as we will use the original err to close the channel and notify the user. if readIfNeeded0() { + assert(self.lifecycleManager.isActive) // We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it. while let read = try? readFromSocket(), read == .some { + assert(self.lifecycleManager.isActive) pipeline.fireChannelReadComplete() } } @@ -287,7 +436,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // We only want to call read0() or pauseRead0() if we already registered to the EventLoop if not this will be automatically done // once register0 is called. Beside this we also only need to do it when the value actually change. - if !neverRegistered && old != auto { + if self.lifecycleManager.isRegistered && old != auto { if auto { read0() } else { @@ -342,6 +491,9 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { /// - returns: `true` if `readPending` is `true`, `false` otherwise. @discardableResult func readIfNeeded0() -> Bool { assert(eventLoop.inEventLoop) + if !self.lifecycleManager.isActive { + return false + } if !readPending && autoRead { pipeline.read0() @@ -357,6 +509,10 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { promise?.fail(error: ChannelError.ioOnClosedChannel) return } + guard self.isRegistered else { + promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState) + return + } executeAndComplete(promise) { try socket.bind(to: address) @@ -373,6 +529,11 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { return } + guard self.lifecycleManager.isActive else { + promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState) + return + } + bufferPendingWrite(data: data, promise: promise) } @@ -410,7 +571,12 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { self.markFlushPoint() + guard self.lifecycleManager.isActive else { + return + } + if !isWritePending() && flushNow() == .register { + assert(self.lifecycleManager.isRegistered) registerForWritable() } } @@ -423,7 +589,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } readPending = true - if !neverRegistered { + if self.lifecycleManager.isRegistered { registerForReadable() } } @@ -431,13 +597,14 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { private final func pauseRead0() { assert(eventLoop.inEventLoop) - if self.isOpen && !neverRegistered{ + if self.lifecycleManager.isRegistered { unregisterForReadable() } } private func registerForReadable() { assert(eventLoop.inEventLoop) + assert(self.lifecycleManager.isRegistered) switch interestedEvent { case .write: @@ -451,6 +618,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { func unregisterForReadable() { assert(eventLoop.inEventLoop) + assert(self.lifecycleManager.isRegistered) switch interestedEvent { case .read: @@ -493,19 +661,10 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } // Fail all pending writes and so ensure all pending promises are notified - self._isOpen = false self.unsetCachedAddressesFromSocket() self.cancelWritesOnClose(error: error) - if !self.neverActivated { - becomeInactive0(promise: p) - } else if let p = p { - p.succeed(result: ()) - } - - if !self.neverRegistered { - pipeline.fireChannelUnregistered0() - } + self.lifecycleManager.close(promise: p)() eventLoop.execute { // ensure this is executed in a delayed fashion as the users code may still traverse the pipeline @@ -529,13 +688,16 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { return } + guard !self.lifecycleManager.isRegistered else { + promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState) + return + } + // Was not registered yet so do it now. do { // We always register with interested .none and will just trigger readIfNeeded0() later to re-register if needed. try self.safeRegister(interested: .none) - neverRegistered = false - promise?.succeed(result: ()) - pipeline.fireChannelRegistered0() + self.lifecycleManager.register(promise: promise)() } catch { promise?.fail(error: error) } @@ -559,8 +721,10 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { private func finishConnect() { assert(eventLoop.inEventLoop) + assert(self.lifecycleManager.isRegistered) if let connectPromise = pendingConnect { + assert(!self.lifecycleManager.isActive) pendingConnect = nil do { @@ -572,6 +736,8 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // We already know what the local address is. self.updateCachedAddressesFromSocket(updateLocal: false, updateRemote: true) becomeActive0(promise: connectPromise) + } else { + assert(self.lifecycleManager.isActive) } } @@ -579,13 +745,14 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { assert(eventLoop.inEventLoop) if self.isOpen { + assert(self.lifecycleManager.isRegistered) unregisterForWritable() } } public final func readable() { assert(eventLoop.inEventLoop) - assert(self.isOpen) + assert(self.lifecycleManager.isActive) defer { if self.isOpen && !self.readPending { @@ -603,6 +770,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { if try! getOption0(option: ChannelOptions.allowRemoteHalfClosure) { // If we want to allow half closure we will just mark the input side of the Channel // as closed. + assert(self.lifecycleManager.isActive) pipeline.fireChannelReadComplete0() if shouldCloseOnReadError(err) { close0(error: err, mode: .input, promise: nil) @@ -611,19 +779,23 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { return } } else { - pipeline.fireErrorCaught0(error: err) + self.pipeline.fireErrorCaught0(error: err) } // Call before triggering the close of the Channel. - pipeline.fireChannelReadComplete0() + if self.lifecycleManager.isActive { + pipeline.fireChannelReadComplete0() + } if shouldCloseOnReadError(err) { - close0(error: err, mode: .all, promise: nil) + self.close0(error: err, mode: .all, promise: nil) } return } - pipeline.fireChannelReadComplete0() + if self.lifecycleManager.isActive { + pipeline.fireChannelReadComplete0() + } readIfNeeded0() } @@ -662,7 +834,14 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { promise?.fail(error: ChannelError.connectPending) return } + + guard self.lifecycleManager.isRegistered else { + promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState) + return + } + do { + assert(self.lifecycleManager.isRegistered) if try !connectSocket(to: address) { // We aren't connected, we'll get the remote address later. self.updateCachedAddressesFromSocket(updateLocal: true, updateRemote: false) @@ -682,6 +861,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } public func channelRead0(_ data: NIOAny) { + assert(self.lifecycleManager.isActive) // Do nothing by default } @@ -695,6 +875,8 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { private func safeReregister(interested: IOEvent) { assert(eventLoop.inEventLoop) + assert(self.lifecycleManager.isRegistered) + guard self.isOpen else { interestedEvent = .none return @@ -707,53 +889,32 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { do { try selectableEventLoop.reregister(channel: self) } catch let err { - pipeline.fireErrorCaught0(error: err) - close0(error: err, mode: .all, promise: nil) + self.pipeline.fireErrorCaught0(error: err) + self.close0(error: err, mode: .all, promise: nil) } } private func safeRegister(interested: IOEvent) throws { assert(eventLoop.inEventLoop) + assert(!self.lifecycleManager.isRegistered) guard self.isOpen else { - interestedEvent = .none throw ChannelError.ioOnClosedChannel } - interestedEvent = interested + + self.interestedEvent = interested do { - try selectableEventLoop.register(channel: self) - } catch let err { - pipeline.fireErrorCaught0(error: err) - close0(error: err, mode: .all, promise: nil) - throw err + try self.selectableEventLoop.register(channel: self) + } catch { + self.pipeline.fireErrorCaught0(error: error) + self.close0(error: error, mode: .all, promise: nil) + throw error } } final func becomeActive0(promise: EventLoopPromise?) { assert(eventLoop.inEventLoop) - assert(!self.active.load()) - assert(self._isOpen) - - self.neverActivated = false - active.store(true) - - // Notify the promise before firing the inbound event through the pipeline. - if let promise = promise { - promise.succeed(result: ()) - } - pipeline.fireChannelActive0() + self.lifecycleManager.activate(promise: promise)() self.readIfNeeded0() } - - func becomeInactive0(promise: EventLoopPromise?) { - assert(eventLoop.inEventLoop) - assert(self.active.load()) - active.store(false) - - // Notify the promise before firing the inbound event through the pipeline. - if let promise = promise { - promise.succeed(result: ()) - } - pipeline.fireChannelInactive0() - } } diff --git a/Sources/NIO/Channel.swift b/Sources/NIO/Channel.swift index 6ec9658e5a..cf1249f0b1 100644 --- a/Sources/NIO/Channel.swift +++ b/Sources/NIO/Channel.swift @@ -280,6 +280,13 @@ public enum ChannelError: Error { case writeHostUnreachable } +/// This should be inside of `ChannelError` but we keep it separate to not break API. +// TODO: For 2.0: bring this inside of `ChannelError` +public enum ChannelLifecycleError: Error { + /// An operation that was inappropriate given the current `Channel` state was attempted. + case inappropriateOperationForState +} + extension ChannelError: Equatable { public static func ==(lhs: ChannelError, rhs: ChannelError) -> Bool { switch (lhs, rhs) { diff --git a/Sources/NIO/ChannelPipeline.swift b/Sources/NIO/ChannelPipeline.swift index 701e72355b..1513d0473c 100644 --- a/Sources/NIO/ChannelPipeline.swift +++ b/Sources/NIO/ChannelPipeline.swift @@ -823,6 +823,7 @@ public final class ChannelPipeline: ChannelInvoker { } func fireErrorCaught0(error: Error) { + assert((error as? ChannelError).map { $0 != .eof } ?? true) if let firstInboundCtx = firstInboundCtx { firstInboundCtx.invokeErrorCaught(error) } diff --git a/Sources/NIO/Selector.swift b/Sources/NIO/Selector.swift index f2c224f720..2864cbb4dc 100644 --- a/Sources/NIO/Selector.swift +++ b/Sources/NIO/Selector.swift @@ -519,6 +519,14 @@ internal extension Selector where R == NIORegistration { case .datagramChannel(let chan, _): return closeChannel(chan) } + }.map { future in + future.thenIfErrorThrowing { error in + if let error = error as? ChannelError, error == .alreadyClosed { + return () + } else { + throw error + } + } } guard futures.count > 0 else { diff --git a/Sources/NIO/SocketChannel.swift b/Sources/NIO/SocketChannel.swift index ebe795b29e..1c03c48cde 100644 --- a/Sources/NIO/SocketChannel.swift +++ b/Sources/NIO/SocketChannel.swift @@ -49,7 +49,8 @@ final class SocketChannel: BaseSocketChannel { override var isOpen: Bool { assert(eventLoop.inEventLoop) - return pendingWrites.isOpen + assert(super.isOpen == self.pendingWrites.isOpen) + return super.isOpen } init(eventLoop: SelectableEventLoop, protocolFamily: Int32) throws { @@ -133,7 +134,7 @@ final class SocketChannel: BaseSocketChannel { readPending = false - assert(self.isOpen) + assert(self.isActive) pipeline.fireChannelRead0(NIOAny(buffer)) if mayGrow && i < maxMessagesPerRead { // if the ByteBuffer may grow on the next allocation due we used all the writable bytes we should allocate a new `ByteBuffer` to allow ramping up how much data @@ -356,9 +357,14 @@ final class ServerSocketChannel: BaseSocketChannel { return } + guard self.isRegistered else { + promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState) + return + } + let p: EventLoopPromise = eventLoop.newPromise() p.futureResult.map { - // Its important to call the methods before we actual notify the original promise for ordering reasons. + // Its important to call the methods before we actually notify the original promise for ordering reasons. self.becomeActive0(promise: promise) }.whenFailure{ error in promise?.fail(error: error) @@ -389,6 +395,7 @@ final class ServerSocketChannel: BaseSocketChannel { result = .some do { let chan = try SocketChannel(socket: accepted, parent: self, eventLoop: group.next() as! SelectableEventLoop) + assert(self.isActive) pipeline.fireChannelRead0(NIOAny(chan)) } catch let err { _ = try? accepted.close() @@ -473,7 +480,8 @@ final class DatagramChannel: BaseSocketChannel { override var isOpen: Bool { assert(eventLoop.inEventLoop) - return pendingWrites.isOpen + assert(super.isOpen == self.pendingWrites.isOpen) + return super.isOpen } init(eventLoop: SelectableEventLoop, protocolFamily: Int32) throws { @@ -573,6 +581,7 @@ final class DatagramChannel: BaseSocketChannel { readPending = false let msg = AddressedEnvelope(remoteAddress: rawAddress.convert(), data: buffer) + assert(self.isActive) pipeline.fireChannelRead0(NIOAny(msg)) if mayGrow && i < maxMessagesPerRead { buffer = recvAllocator.buffer(allocator: allocator) @@ -610,6 +619,7 @@ final class DatagramChannel: BaseSocketChannel { } if !self.pendingWrites.add(envelope: data, promise: promise) { + assert(self.isActive) pipeline.fireChannelWritabilityChanged0() } } @@ -642,6 +652,7 @@ final class DatagramChannel: BaseSocketChannel { }) if result.writable { // writable again + assert(self.isActive) self.pipeline.fireChannelWritabilityChanged0() } return result.writeResult @@ -651,6 +662,10 @@ final class DatagramChannel: BaseSocketChannel { override func bind0(to address: SocketAddress, promise: EventLoopPromise?) { assert(self.eventLoop.inEventLoop) + guard self.isRegistered else { + promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState) + return + } do { try socket.bind(to: address) self.updateCachedAddressesFromSocket(updateRemote: false) diff --git a/Tests/NIOTests/ChannelTests+XCTest.swift b/Tests/NIOTests/ChannelTests+XCTest.swift index d0e72f407f..bc0fd36841 100644 --- a/Tests/NIOTests/ChannelTests+XCTest.swift +++ b/Tests/NIOTests/ChannelTests+XCTest.swift @@ -61,6 +61,8 @@ extension ChannelTests { ("testEOFOnlyReceivedOnceReadRequested", testEOFOnlyReceivedOnceReadRequested), ("testAcceptsAfterCloseDontCauseIssues", testAcceptsAfterCloseDontCauseIssues), ("testChannelReadsDoesNotHappenAfterRegistration", testChannelReadsDoesNotHappenAfterRegistration), + ("testAppropriateAndInappropriateOperationsForUnregisteredSockets", testAppropriateAndInappropriateOperationsForUnregisteredSockets), + ("testCloseSocketWhenReadErrorWasReceivedAndMakeSureNoReadCompleteArrives", testCloseSocketWhenReadErrorWasReceivedAndMakeSureNoReadCompleteArrives), ] } } diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index 3efb2d6be6..305afbf543 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -1825,4 +1825,161 @@ public class ChannelTests: XCTestCase { try sc.syncCloseAcceptingAlreadyClosed() try clientHasUnregistered.futureResult.wait() } + + func testAppropriateAndInappropriateOperationsForUnregisteredSockets() throws { + func checkThatItThrowsInappropriateOperationForState(file: StaticString = #file, line: UInt = #line, _ body: () throws -> Void) { + do { + try body() + XCTFail("didn't throw", file: file, line: line) + } catch let error as ChannelLifecycleError where error == .inappropriateOperationForState { + //OK + } catch { + XCTFail("unexpected error \(error)", file: file, line: line) + } + } + let elg = MultiThreadedEventLoopGroup(numThreads: 1) + + func withChannel(skipDatagram: Bool = false, skipStream: Bool = false, skipServerSocket: Bool = false, file: StaticString = #file, line: UInt = #line, _ body: (Channel) throws -> Void) { + XCTAssertNoThrow(try { + let el = elg.next() as! SelectableEventLoop + let channels: [Channel] = (skipDatagram ? [] : [try DatagramChannel(eventLoop: el, protocolFamily: PF_INET)]) + + /* Xcode need help */ (skipStream ? []: [try SocketChannel(eventLoop: el, protocolFamily: PF_INET)]) + + /* Xcode need help */ (skipServerSocket ? []: [try ServerSocketChannel(eventLoop: el, group: elg, protocolFamily: PF_INET)]) + for channel in channels { + try body(channel) + XCTAssertNoThrow(try channel.close().wait(), file: file, line: line) + } + }(), file: file, line: line) + } + withChannel { channel in + checkThatItThrowsInappropriateOperationForState { + try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1234)).wait() + } + } + withChannel { channel in + checkThatItThrowsInappropriateOperationForState { + try channel.writeAndFlush("foo").wait() + } + } + withChannel { channel in + XCTAssertNoThrow(try channel.triggerUserOutboundEvent("foo").wait()) + } + withChannel { channel in + XCTAssertFalse(channel.isActive) + } + withChannel(skipServerSocket: true) { channel in + // should probably be changed + XCTAssertTrue(channel.isWritable) + } + withChannel(skipDatagram: true, skipStream: true) { channel in + // this should probably be the default for all types + XCTAssertFalse(channel.isWritable) + } + + withChannel { channel in + checkThatItThrowsInappropriateOperationForState { + XCTAssertEqual(0, channel.localAddress?.port ?? 0xffff) + XCTAssertNil(channel.remoteAddress) + try channel.bind(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait() + } + } + } + + func testCloseSocketWhenReadErrorWasReceivedAndMakeSureNoReadCompleteArrives() throws { + class SocketThatHasTheFirstReadSucceedButFailsTheNextWithECONNRESET: Socket { + private var firstReadHappened = false + init(protocolFamily: CInt) throws { + try super.init(protocolFamily: protocolFamily, type: Posix.SOCK_STREAM, setNonBlocking: true) + } + override func read(pointer: UnsafeMutablePointer, size: Int) throws -> IOResult { + defer { + self.firstReadHappened = true + } + XCTAssertGreaterThan(size, 0) + if self.firstReadHappened { + // this is a copy of the exact error that'd come out of the real Socket.read + throw IOError.init(errnoCode: ECONNRESET, function: "read(descriptor:pointer:size:)") + } else { + pointer.pointee = 0xff + return .processed(1) + } + } + } + class VerifyThingsAreRightHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + private let allDone: EventLoopPromise + enum State { + case fresh + case active + case read + case error + case readComplete + case inactive + } + private var state: State = .fresh + + init(allDone: EventLoopPromise) { + self.allDone = allDone + } + func channelActive(ctx: ChannelHandlerContext) { + XCTAssertEqual(.fresh, self.state) + self.state = .active + } + + func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { + XCTAssertEqual(.active, self.state) + self.state = .read + var buffer = self.unwrapInboundIn(data) + XCTAssertEqual(1, buffer.readableBytes) + XCTAssertEqual([0xff], buffer.readBytes(length: 1)!) + } + + func channelReadComplete(ctx: ChannelHandlerContext) { + XCTFail("channelReadComplete unexpected") + self.state = .readComplete + } + + func errorCaught(ctx: ChannelHandlerContext, error: Error) { + XCTAssertEqual(.read, self.state) + self.state = .error + ctx.close(promise: nil) + } + + func channelInactive(ctx: ChannelHandlerContext) { + XCTAssertEqual(.error, self.state) + self.state = .inactive + self.allDone.succeed(result: ()) + } + } + let group = MultiThreadedEventLoopGroup(numThreads: 2) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + let serverEL = group.next() + let clientEL = group.next() + precondition(serverEL !== clientEL) + let sc = try SocketChannel(socket: SocketThatHasTheFirstReadSucceedButFailsTheNextWithECONNRESET(protocolFamily: PF_INET), eventLoop: clientEL as! SelectableEventLoop) + + let serverChannel = try ServerBootstrap(group: serverEL) + .childChannelInitializer { channel in + var buffer = channel.allocator.buffer(capacity: 4) + buffer.write(string: "foo") + return channel.write(NIOAny(buffer)) + } + .bind(host: "127.0.0.1", port: 0) + .wait() + defer { + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) + } + + let allDone: EventLoopPromise = clientEL.newPromise() + + try sc.register().then { + sc.pipeline.add(handler: VerifyThingsAreRightHandler(allDone: allDone)) + }.then { + sc.connect(to: serverChannel.localAddress!) + }.wait() + try allDone.futureResult.wait() + XCTAssertNoThrow(try sc.syncCloseAcceptingAlreadyClosed()) + } } diff --git a/Tests/NIOTests/DatagramChannelTests.swift b/Tests/NIOTests/DatagramChannelTests.swift index a5e8111014..a5115064fd 100644 --- a/Tests/NIOTests/DatagramChannelTests.swift +++ b/Tests/NIOTests/DatagramChannelTests.swift @@ -36,16 +36,31 @@ private class DatagramReadRecorder: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope typealias InboundOut = AddressedEnvelope + enum State { + case fresh + case registered + case active + } + var reads: [AddressedEnvelope] = [] var loop: EventLoop? = nil + var state: State = .fresh var readWaiters: [Int: EventLoopPromise<[AddressedEnvelope]>] = [:] func channelRegistered(ctx: ChannelHandlerContext) { + XCTAssertEqual(.fresh, self.state) + self.state = .registered self.loop = ctx.eventLoop } + func channelActive(ctx: ChannelHandlerContext) { + XCTAssertEqual(.registered, self.state) + self.state = .active + } + func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { + XCTAssertEqual(.active, self.state) let data = self.unwrapInboundIn(data) reads.append(data)