diff --git a/Sources/NIO/BaseSocketChannel.swift b/Sources/NIO/BaseSocketChannel.swift index 3f1f5a6492..366af3ef33 100644 --- a/Sources/NIO/BaseSocketChannel.swift +++ b/Sources/NIO/BaseSocketChannel.swift @@ -681,9 +681,6 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { p = nil } - // Fail all pending writes and so ensure all pending promises are notified - self.unsetCachedAddressesFromSocket() - // Transition our internal state. let callouts = self.lifecycleManager.close(promise: p) @@ -701,6 +698,9 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { self.pipeline.removeHandlers() self.closePromise.succeed(result: ()) + + // Now reset the addresses as we notified all handlers / futures. + self.unsetCachedAddressesFromSocket() } } diff --git a/Sources/NIO/ChannelPipeline.swift b/Sources/NIO/ChannelPipeline.swift index 2eb2c8b217..7725abba7f 100644 --- a/Sources/NIO/ChannelPipeline.swift +++ b/Sources/NIO/ChannelPipeline.swift @@ -1017,11 +1017,29 @@ public final class ChannelHandlerContext: ChannelInvoker { } public var remoteAddress: SocketAddress? { - return try? self.channel._unsafe.remoteAddress0() + do { + // Fast-path access to the remoteAddress. + return try self.channel._unsafe.remoteAddress0() + } catch ChannelError.ioOnClosedChannel { + // Channel was closed already but we may still have the address cached so try to access it via the Channel + // so we are able to use it in channelInactive(...) / handlerRemoved(...) methods. + return self.channel.remoteAddress + } catch { + return nil + } } public var localAddress: SocketAddress? { - return try? self.channel._unsafe.localAddress0() + do { + // Fast-path access to the localAddress. + return try self.channel._unsafe.localAddress0() + } catch ChannelError.ioOnClosedChannel { + // Channel was closed already but we may still have the address cached so try to access it via the Channel + // so we are able to use it in channelInactive(...) / handlerRemoved(...) methods. + return self.channel.localAddress + } catch { + return nil + } } public var eventLoop: EventLoop { diff --git a/Tests/NIOTests/SocketChannelTest+XCTest.swift b/Tests/NIOTests/SocketChannelTest+XCTest.swift index c0a2f0f830..28ad354337 100644 --- a/Tests/NIOTests/SocketChannelTest+XCTest.swift +++ b/Tests/NIOTests/SocketChannelTest+XCTest.swift @@ -43,6 +43,7 @@ extension SocketChannelTest { ("testWithConfiguredStreamSocket", testWithConfiguredStreamSocket), ("testWithConfiguredDatagramSocket", testWithConfiguredDatagramSocket), ("testPendingConnectNotificationOrder", testPendingConnectNotificationOrder), + ("testLocalAndRemoteAddressNotNilInChannelInactiveAndHandlerRemoved", testLocalAndRemoteAddressNotNilInChannelInactiveAndHandlerRemoved), ] } } diff --git a/Tests/NIOTests/SocketChannelTest.swift b/Tests/NIOTests/SocketChannelTest.swift index 1aaf47c104..0c19bc1234 100644 --- a/Tests/NIOTests/SocketChannelTest.swift +++ b/Tests/NIOTests/SocketChannelTest.swift @@ -446,4 +446,58 @@ public class SocketChannelTest : XCTestCase { XCTAssertNoThrow(try channel.closeFuture.wait()) XCTAssertNoThrow(try promise.futureResult.wait()) } + + public func testLocalAndRemoteAddressNotNilInChannelInactiveAndHandlerRemoved() throws { + + class AddressVerificationHandler: ChannelInboundHandler { + typealias InboundIn = Never + typealias OutboundIn = Never + + enum HandlerState { + case created + case inactive + case removed + } + + let promise: EventLoopPromise + var state = HandlerState.created + + init(promise: EventLoopPromise) { + self.promise = promise + } + + func channelInactive(ctx: ChannelHandlerContext) { + XCTAssertNotNil(ctx.localAddress) + XCTAssertNotNil(ctx.remoteAddress) + XCTAssertEqual(.created, state) + state = .inactive + } + + func handlerRemoved(ctx: ChannelHandlerContext) { + XCTAssertNotNil(ctx.localAddress) + XCTAssertNotNil(ctx.remoteAddress) + XCTAssertEqual(.inactive, state) + state = .removed + + ctx.channel.closeFuture.whenComplete { + XCTAssertNil(ctx.localAddress) + XCTAssertNil(ctx.remoteAddress) + + self.promise.succeed(result: ()) + } + } + } + + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + + let handler = AddressVerificationHandler(promise: group.next().newPromise()) + let serverChannel = try ServerBootstrap(group: group).childChannelInitializer { $0.pipeline.add(handler: handler) }.bind(host: "127.0.0.1", port: 0).wait() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + + let clientChannel = try ClientBootstrap(group: group).connect(to: serverChannel.localAddress!).wait() + + XCTAssertNoThrow(try clientChannel.close().wait()) + XCTAssertNoThrow(try handler.promise.futureResult.wait()) + } }