From d655f0422280cca13111e8d77b42bc4e2c431336 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Mon, 23 Apr 2018 11:05:07 +0200 Subject: [PATCH] Ensure localAddress / remoteAddress are still accessible in channelInactive / handlerRemoved Motivation: Often its useful to be still be able to access the local / remote address during channelInactive / handlerRemoved callbacks to for example log it. We should ensure its still accessible during it. Modifications: - Fallback to slow-path in ChannelHandlerContext.localAddress0 / remoteAddress0 if fast-path fails to try accessing the address via the cache. - Clear cached addresses after all callbacks are run. - Add unit test. Result: Be able to access addresses while handlers are notified. --- Sources/NIO/BaseSocketChannel.swift | 6 +-- Sources/NIO/ChannelPipeline.swift | 22 +++++++- Tests/NIOTests/ChannelTests.swift | 13 +++-- Tests/NIOTests/SocketChannelTest+XCTest.swift | 1 + Tests/NIOTests/SocketChannelTest.swift | 54 +++++++++++++++++++ 5 files changed, 88 insertions(+), 8 deletions(-) diff --git a/Sources/NIO/BaseSocketChannel.swift b/Sources/NIO/BaseSocketChannel.swift index 3f1f5a6492..366af3ef33 100644 --- a/Sources/NIO/BaseSocketChannel.swift +++ b/Sources/NIO/BaseSocketChannel.swift @@ -681,9 +681,6 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { p = nil } - // Fail all pending writes and so ensure all pending promises are notified - self.unsetCachedAddressesFromSocket() - // Transition our internal state. let callouts = self.lifecycleManager.close(promise: p) @@ -701,6 +698,9 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { self.pipeline.removeHandlers() self.closePromise.succeed(result: ()) + + // Now reset the addresses as we notified all handlers / futures. + self.unsetCachedAddressesFromSocket() } } diff --git a/Sources/NIO/ChannelPipeline.swift b/Sources/NIO/ChannelPipeline.swift index 2eb2c8b217..7725abba7f 100644 --- a/Sources/NIO/ChannelPipeline.swift +++ b/Sources/NIO/ChannelPipeline.swift @@ -1017,11 +1017,29 @@ public final class ChannelHandlerContext: ChannelInvoker { } public var remoteAddress: SocketAddress? { - return try? self.channel._unsafe.remoteAddress0() + do { + // Fast-path access to the remoteAddress. + return try self.channel._unsafe.remoteAddress0() + } catch ChannelError.ioOnClosedChannel { + // Channel was closed already but we may still have the address cached so try to access it via the Channel + // so we are able to use it in channelInactive(...) / handlerRemoved(...) methods. + return self.channel.remoteAddress + } catch { + return nil + } } public var localAddress: SocketAddress? { - return try? self.channel._unsafe.localAddress0() + do { + // Fast-path access to the localAddress. + return try self.channel._unsafe.localAddress0() + } catch ChannelError.ioOnClosedChannel { + // Channel was closed already but we may still have the address cached so try to access it via the Channel + // so we are able to use it in channelInactive(...) / handlerRemoved(...) methods. + return self.channel.localAddress + } catch { + return nil + } } public var eventLoop: EventLoop { diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index 38250f4805..c93d506436 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -1437,9 +1437,16 @@ public class ChannelTests: XCTestCase { try serverChannel.syncCloseAcceptingAlreadyClosed() try clientChannel.syncCloseAcceptingAlreadyClosed() - for f in [ serverChannel.remoteAddress, serverChannel.localAddress, clientChannel.remoteAddress, clientChannel.localAddress ] { - XCTAssertNil(f) - } + XCTAssertNoThrow(try serverChannel.closeFuture.wait()) + XCTAssertNoThrow(try clientChannel.closeFuture.wait()) + + // Schedule on the EventLoop to ensure we scheduled the cleanup of the cached addresses before. + XCTAssertNoThrow(try group.next().submit { + for f in [ serverChannel.remoteAddress, serverChannel.localAddress, clientChannel.remoteAddress, clientChannel.localAddress ] { + XCTAssertNil(f) + } + }.wait()) + } func testReceiveAddressAfterAccept() throws { diff --git a/Tests/NIOTests/SocketChannelTest+XCTest.swift b/Tests/NIOTests/SocketChannelTest+XCTest.swift index c0a2f0f830..28ad354337 100644 --- a/Tests/NIOTests/SocketChannelTest+XCTest.swift +++ b/Tests/NIOTests/SocketChannelTest+XCTest.swift @@ -43,6 +43,7 @@ extension SocketChannelTest { ("testWithConfiguredStreamSocket", testWithConfiguredStreamSocket), ("testWithConfiguredDatagramSocket", testWithConfiguredDatagramSocket), ("testPendingConnectNotificationOrder", testPendingConnectNotificationOrder), + ("testLocalAndRemoteAddressNotNilInChannelInactiveAndHandlerRemoved", testLocalAndRemoteAddressNotNilInChannelInactiveAndHandlerRemoved), ] } } diff --git a/Tests/NIOTests/SocketChannelTest.swift b/Tests/NIOTests/SocketChannelTest.swift index 1aaf47c104..0c19bc1234 100644 --- a/Tests/NIOTests/SocketChannelTest.swift +++ b/Tests/NIOTests/SocketChannelTest.swift @@ -446,4 +446,58 @@ public class SocketChannelTest : XCTestCase { XCTAssertNoThrow(try channel.closeFuture.wait()) XCTAssertNoThrow(try promise.futureResult.wait()) } + + public func testLocalAndRemoteAddressNotNilInChannelInactiveAndHandlerRemoved() throws { + + class AddressVerificationHandler: ChannelInboundHandler { + typealias InboundIn = Never + typealias OutboundIn = Never + + enum HandlerState { + case created + case inactive + case removed + } + + let promise: EventLoopPromise + var state = HandlerState.created + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func channelInactive(ctx: ChannelHandlerContext) { + XCTAssertNotNil(ctx.localAddress) + XCTAssertNotNil(ctx.remoteAddress) + XCTAssertEqual(.created, state) + state = .inactive + } + + func handlerRemoved(ctx: ChannelHandlerContext) { + XCTAssertNotNil(ctx.localAddress) + XCTAssertNotNil(ctx.remoteAddress) + XCTAssertEqual(.inactive, state) + state = .removed + + ctx.channel.closeFuture.whenComplete { + XCTAssertNil(ctx.localAddress) + XCTAssertNil(ctx.remoteAddress) + + self.promise.succeed(result: ()) + } + } + } + + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + + let handler = AddressVerificationHandler(promise: group.next().newPromise()) + let serverChannel = try ServerBootstrap(group: group).childChannelInitializer { $0.pipeline.add(handler: handler) }.bind(host: "127.0.0.1", port: 0).wait() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + + let clientChannel = try ClientBootstrap(group: group).connect(to: serverChannel.localAddress!).wait() + + XCTAssertNoThrow(try clientChannel.close().wait()) + XCTAssertNoThrow(try handler.promise.futureResult.wait()) + } }