Skip to content

Commit

Permalink
make Channel lifecycle statemachine explicit
Browse files Browse the repository at this point in the history
Motivation:

We had a lot of problems with the Channel lifecycle statemachine as it
wasn't explicit, this fixes this. Additionally, it asserts a lot more.

Modifications:

- made Channel lifecycle statemachine explicit
- lots of asserts

Result:

- hopefully the state machine works better
- the asserts should guide our way to work on in corner cases as well
  • Loading branch information
weissi committed Mar 29, 2018
1 parent 224120c commit 7abbb03
Show file tree
Hide file tree
Showing 8 changed files with 434 additions and 68 deletions.
289 changes: 225 additions & 64 deletions Sources/NIO/BaseSocketChannel.swift

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions Sources/NIO/Channel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,13 @@ public enum ChannelError: Error {
case writeHostUnreachable
}

/// This should be inside of `ChannelError` but we keep it separate to not break API.
// TODO: For 2.0: bring this inside of `ChannelError`
public enum ChannelLifecycleError: Error {
/// An operation that was inappropriate given the current `Channel` state was attempted.
case inappropriateOperationForState
}

extension ChannelError: Equatable {
public static func ==(lhs: ChannelError, rhs: ChannelError) -> Bool {
switch (lhs, rhs) {
Expand Down
1 change: 1 addition & 0 deletions Sources/NIO/ChannelPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ public final class ChannelPipeline: ChannelInvoker {
}

func fireErrorCaught0(error: Error) {
assert((error as? ChannelError).map { $0 != .eof } ?? true)
if let firstInboundCtx = firstInboundCtx {
firstInboundCtx.invokeErrorCaught(error)
}
Expand Down
8 changes: 8 additions & 0 deletions Sources/NIO/Selector.swift
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,14 @@ internal extension Selector where R == NIORegistration {
case .datagramChannel(let chan, _):
return closeChannel(chan)
}
}.map { future in
future.thenIfErrorThrowing { error in
if let error = error as? ChannelError, error == .alreadyClosed {
return ()
} else {
throw error
}
}
}

guard futures.count > 0 else {
Expand Down
23 changes: 19 additions & 4 deletions Sources/NIO/SocketChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ final class SocketChannel: BaseSocketChannel<Socket> {

override var isOpen: Bool {
assert(eventLoop.inEventLoop)
return pendingWrites.isOpen
assert(super.isOpen == self.pendingWrites.isOpen)
return super.isOpen
}

init(eventLoop: SelectableEventLoop, protocolFamily: Int32) throws {
Expand Down Expand Up @@ -133,7 +134,7 @@ final class SocketChannel: BaseSocketChannel<Socket> {

readPending = false

assert(self.isOpen)
assert(self.isActive)
pipeline.fireChannelRead0(NIOAny(buffer))
if mayGrow && i < maxMessagesPerRead {
// if the ByteBuffer may grow on the next allocation due we used all the writable bytes we should allocate a new `ByteBuffer` to allow ramping up how much data
Expand Down Expand Up @@ -356,9 +357,14 @@ final class ServerSocketChannel: BaseSocketChannel<ServerSocket> {
return
}

guard self.isRegistered else {
promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState)
return
}

let p: EventLoopPromise<Void> = eventLoop.newPromise()
p.futureResult.map {
// Its important to call the methods before we actual notify the original promise for ordering reasons.
// Its important to call the methods before we actually notify the original promise for ordering reasons.
self.becomeActive0(promise: promise)
}.whenFailure{ error in
promise?.fail(error: error)
Expand Down Expand Up @@ -389,6 +395,7 @@ final class ServerSocketChannel: BaseSocketChannel<ServerSocket> {
result = .some
do {
let chan = try SocketChannel(socket: accepted, parent: self, eventLoop: group.next() as! SelectableEventLoop)
assert(self.isActive)
pipeline.fireChannelRead0(NIOAny(chan))
} catch let err {
_ = try? accepted.close()
Expand Down Expand Up @@ -473,7 +480,8 @@ final class DatagramChannel: BaseSocketChannel<Socket> {

override var isOpen: Bool {
assert(eventLoop.inEventLoop)
return pendingWrites.isOpen
assert(super.isOpen == self.pendingWrites.isOpen)
return super.isOpen
}

init(eventLoop: SelectableEventLoop, protocolFamily: Int32) throws {
Expand Down Expand Up @@ -573,6 +581,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
readPending = false

let msg = AddressedEnvelope(remoteAddress: rawAddress.convert(), data: buffer)
assert(self.isActive)
pipeline.fireChannelRead0(NIOAny(msg))
if mayGrow && i < maxMessagesPerRead {
buffer = recvAllocator.buffer(allocator: allocator)
Expand Down Expand Up @@ -610,6 +619,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
}

if !self.pendingWrites.add(envelope: data, promise: promise) {
assert(self.isActive)
pipeline.fireChannelWritabilityChanged0()
}
}
Expand Down Expand Up @@ -642,6 +652,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
})
if result.writable {
// writable again
assert(self.isActive)
self.pipeline.fireChannelWritabilityChanged0()
}
return result.writeResult
Expand All @@ -651,6 +662,10 @@ final class DatagramChannel: BaseSocketChannel<Socket> {

override func bind0(to address: SocketAddress, promise: EventLoopPromise<Void>?) {
assert(self.eventLoop.inEventLoop)
guard self.isRegistered else {
promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState)
return
}
do {
try socket.bind(to: address)
self.updateCachedAddressesFromSocket(updateRemote: false)
Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOTests/ChannelTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ extension ChannelTests {
("testEOFOnlyReceivedOnceReadRequested", testEOFOnlyReceivedOnceReadRequested),
("testAcceptsAfterCloseDontCauseIssues", testAcceptsAfterCloseDontCauseIssues),
("testChannelReadsDoesNotHappenAfterRegistration", testChannelReadsDoesNotHappenAfterRegistration),
("testAppropriateAndInappropriateOperationsForUnregisteredSockets", testAppropriateAndInappropriateOperationsForUnregisteredSockets),
("testCloseSocketWhenReadErrorWasReceivedAndMakeSureNoReadCompleteArrives", testCloseSocketWhenReadErrorWasReceivedAndMakeSureNoReadCompleteArrives),
]
}
}
Expand Down
157 changes: 157 additions & 0 deletions Tests/NIOTests/ChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1825,4 +1825,161 @@ public class ChannelTests: XCTestCase {
try sc.syncCloseAcceptingAlreadyClosed()
try clientHasUnregistered.futureResult.wait()
}

func testAppropriateAndInappropriateOperationsForUnregisteredSockets() throws {
func checkThatItThrowsInappropriateOperationForState(file: StaticString = #file, line: UInt = #line, _ body: () throws -> Void) {
do {
try body()
XCTFail("didn't throw", file: file, line: line)
} catch let error as ChannelLifecycleError where error == .inappropriateOperationForState {
//OK
} catch {
XCTFail("unexpected error \(error)", file: file, line: line)
}
}
let elg = MultiThreadedEventLoopGroup(numThreads: 1)

func withChannel(skipDatagram: Bool = false, skipStream: Bool = false, skipServerSocket: Bool = false, file: StaticString = #file, line: UInt = #line, _ body: (Channel) throws -> Void) {
XCTAssertNoThrow(try {
let el = elg.next() as! SelectableEventLoop
let channels: [Channel] = (skipDatagram ? [] : [try DatagramChannel(eventLoop: el, protocolFamily: PF_INET)]) +
/* Xcode need help */ (skipStream ? []: [try SocketChannel(eventLoop: el, protocolFamily: PF_INET)]) +
/* Xcode need help */ (skipServerSocket ? []: [try ServerSocketChannel(eventLoop: el, group: elg, protocolFamily: PF_INET)])
for channel in channels {
try body(channel)
XCTAssertNoThrow(try channel.close().wait(), file: file, line: line)
}
}(), file: file, line: line)
}
withChannel { channel in
checkThatItThrowsInappropriateOperationForState {
try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1234)).wait()
}
}
withChannel { channel in
checkThatItThrowsInappropriateOperationForState {
try channel.writeAndFlush("foo").wait()
}
}
withChannel { channel in
XCTAssertNoThrow(try channel.triggerUserOutboundEvent("foo").wait())
}
withChannel { channel in
XCTAssertFalse(channel.isActive)
}
withChannel(skipServerSocket: true) { channel in
// should probably be changed
XCTAssertTrue(channel.isWritable)
}
withChannel(skipDatagram: true, skipStream: true) { channel in
// this should probably be the default for all types
XCTAssertFalse(channel.isWritable)
}

withChannel { channel in
checkThatItThrowsInappropriateOperationForState {
XCTAssertEqual(0, channel.localAddress?.port ?? 0xffff)
XCTAssertNil(channel.remoteAddress)
try channel.bind(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait()
}
}
}

func testCloseSocketWhenReadErrorWasReceivedAndMakeSureNoReadCompleteArrives() throws {
class SocketThatHasTheFirstReadSucceedButFailsTheNextWithECONNRESET: Socket {
private var firstReadHappened = false
init(protocolFamily: CInt) throws {
try super.init(protocolFamily: protocolFamily, type: Posix.SOCK_STREAM, setNonBlocking: true)
}
override func read(pointer: UnsafeMutablePointer<UInt8>, size: Int) throws -> IOResult<Int> {
defer {
self.firstReadHappened = true
}
XCTAssertGreaterThan(size, 0)
if self.firstReadHappened {
// this is a copy of the exact error that'd come out of the real Socket.read
throw IOError.init(errnoCode: ECONNRESET, function: "read(descriptor:pointer:size:)")
} else {
pointer.pointee = 0xff
return .processed(1)
}
}
}
class VerifyThingsAreRightHandler: ChannelInboundHandler {
typealias InboundIn = ByteBuffer
private let allDone: EventLoopPromise<Void>
enum State {
case fresh
case active
case read
case error
case readComplete
case inactive
}
private var state: State = .fresh

init(allDone: EventLoopPromise<Void>) {
self.allDone = allDone
}
func channelActive(ctx: ChannelHandlerContext) {
XCTAssertEqual(.fresh, self.state)
self.state = .active
}

func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
XCTAssertEqual(.active, self.state)
self.state = .read
var buffer = self.unwrapInboundIn(data)
XCTAssertEqual(1, buffer.readableBytes)
XCTAssertEqual([0xff], buffer.readBytes(length: 1)!)
}

func channelReadComplete(ctx: ChannelHandlerContext) {
XCTFail("channelReadComplete unexpected")
self.state = .readComplete
}

func errorCaught(ctx: ChannelHandlerContext, error: Error) {
XCTAssertEqual(.read, self.state)
self.state = .error
ctx.close(promise: nil)
}

func channelInactive(ctx: ChannelHandlerContext) {
XCTAssertEqual(.error, self.state)
self.state = .inactive
self.allDone.succeed(result: ())
}
}
let group = MultiThreadedEventLoopGroup(numThreads: 2)
defer {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
let serverEL = group.next()
let clientEL = group.next()
precondition(serverEL !== clientEL)
let sc = try SocketChannel(socket: SocketThatHasTheFirstReadSucceedButFailsTheNextWithECONNRESET(protocolFamily: PF_INET), eventLoop: clientEL as! SelectableEventLoop)

let serverChannel = try ServerBootstrap(group: serverEL)
.childChannelInitializer { channel in
var buffer = channel.allocator.buffer(capacity: 4)
buffer.write(string: "foo")
return channel.write(NIOAny(buffer))
}
.bind(host: "127.0.0.1", port: 0)
.wait()
defer {
XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed())
}

let allDone: EventLoopPromise<Void> = clientEL.newPromise()

try sc.register().then {
sc.pipeline.add(handler: VerifyThingsAreRightHandler(allDone: allDone))
}.then {
sc.connect(to: serverChannel.localAddress!)
}.wait()
try allDone.futureResult.wait()
XCTAssertNoThrow(try sc.syncCloseAcceptingAlreadyClosed())
}
}
15 changes: 15 additions & 0 deletions Tests/NIOTests/DatagramChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,31 @@ private class DatagramReadRecorder<DataType>: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<DataType>
typealias InboundOut = AddressedEnvelope<DataType>

enum State {
case fresh
case registered
case active
}

var reads: [AddressedEnvelope<DataType>] = []
var loop: EventLoop? = nil
var state: State = .fresh

var readWaiters: [Int: EventLoopPromise<[AddressedEnvelope<DataType>]>] = [:]

func channelRegistered(ctx: ChannelHandlerContext) {
XCTAssertEqual(.fresh, self.state)
self.state = .registered
self.loop = ctx.eventLoop
}

func channelActive(ctx: ChannelHandlerContext) {
XCTAssertEqual(.registered, self.state)
self.state = .active
}

func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
XCTAssertEqual(.active, self.state)
let data = self.unwrapInboundIn(data)
reads.append(data)

Expand Down

0 comments on commit 7abbb03

Please # to comment.