diff --git a/Sources/NIO/ChannelPipeline.swift b/Sources/NIO/ChannelPipeline.swift index 6d533b419a..36bcc31000 100644 --- a/Sources/NIO/ChannelPipeline.swift +++ b/Sources/NIO/ChannelPipeline.swift @@ -446,9 +446,16 @@ public final class ChannelPipeline: ChannelInvoker { return promise.futureResult } + /// Returns a `ChannelHandlerContext` which matches. + /// + /// This skips head and tail (as these are internal and should not be accessible by the user. + /// + /// - parameters: + /// - body: The predicate to execute per `ChannelHandlerContext` in the `ChannelPipeline`. + /// -returns: The `ChannelHandlerContext` that matches or `nil` if non did. private func contextForPredicate0(_ body: @escaping((ChannelHandlerContext) -> Bool)) -> ChannelHandlerContext? { - var curCtx: ChannelHandlerContext? = self.head - while let ctx = curCtx { + var curCtx: ChannelHandlerContext? = self.head?.next + while let ctx = curCtx, ctx !== self.tail { if body(ctx) { return ctx } @@ -1422,8 +1429,8 @@ public final class ChannelHandlerContext: ChannelInvoker { extension ChannelPipeline: CustomDebugStringConvertible { public var debugDescription: String { var desc = "ChannelPipeline (\(ObjectIdentifier(self))):\n" - var node = self.head - while let ctx = node { + var node = self.head?.next + while let ctx = node, ctx !== self.tail { let inboundStr = ctx.handler is _ChannelInboundHandler ? "I" : "" let outboundStr = ctx.handler is _ChannelOutboundHandler ? "O" : "" desc += " \(ctx.name) (\(type(of: ctx.handler))) [\(inboundStr)\(outboundStr)]\n" diff --git a/Tests/NIOTests/ChannelPipelineTest.swift b/Tests/NIOTests/ChannelPipelineTest.swift index d2a1a292f7..ab558a58f9 100644 --- a/Tests/NIOTests/ChannelPipelineTest.swift +++ b/Tests/NIOTests/ChannelPipelineTest.swift @@ -620,4 +620,48 @@ class ChannelPipelineTest: XCTestCase { XCTAssertTrue(try h1 === channel.pipeline.context(handlerType: TestHandler.self).wait().handler) XCTAssertFalse(try h2 === channel.pipeline.context(handlerType: TestHandler.self).wait().handler) } + + func testContextForHeadOrTail() throws { + let channel = EmbeddedChannel() + + defer { + XCTAssertFalse(try channel.finish()) + } + + do { + _ = try channel.pipeline.context(name: "head").wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + + do { + _ = try channel.pipeline.context(name: "tail").wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + } + + func testRemoveHeadOrTail() throws { + let channel = EmbeddedChannel() + + defer { + XCTAssertFalse(try channel.finish()) + } + + do { + _ = try channel.pipeline.remove(name: "head").wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + + do { + _ = try channel.pipeline.remove(name: "tail").wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + } }