diff --git a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift index 45b533472e..cf5ac80065 100644 --- a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift +++ b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift @@ -20,6 +20,15 @@ import NIO +/// A utility function that runs the body code only in debug builds, without +/// emitting compiler warnings. +/// +/// This is currently the only way to do this in Swift: see +/// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. +internal func debugOnly(_ body: () -> Void) { + assert({ body(); return true }()) +} + /// A `ChannelHandler` that handles HTTP pipelining by buffering inbound data until a /// response has been sent. /// @@ -105,6 +114,8 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler { private enum BufferedEvent { /// A channelRead event. case channelRead(NIOAny) + + case error(HTTPParserError) /// A TCP half-close. This is buffered to ensure that subsequent channel /// handlers that are aware of TCP half-close are informed about it in @@ -123,6 +134,13 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler { /// don't pipeline, so this initially allocates no space for data at all. Clients that /// do pipeline will cause dynamic resizing of the buffer, which is generally acceptable. private var eventBuffer = CircularBuffer(initialRingCapacity: 0) + + enum NextExpectedMessageType { + case head + case bodyOrEnd + } + + private var nextExpectedOutboundMessage = NextExpectedMessageType.head public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { if case .responseEndPending = self.state { @@ -151,10 +169,36 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler { ctx.fireUserInboundEventTriggered(event) } } + + public func errorCaught(ctx: ChannelHandlerContext, error: Error) { + guard let httpError = error as? HTTPParserError else { + ctx.fireErrorCaught(error) + return + } + if case .responseEndPending = self.state { + self.eventBuffer.append(.error(httpError)) + return + } + ctx.fireErrorCaught(error) + } public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { assert(self.state != .requestEndPending, "Received second response while waiting for first one to complete") + debugOnly { + let res = self.unwrapOutboundIn(data) + switch res { + case .head: + assert(self.nextExpectedOutboundMessage == .head) + self.nextExpectedOutboundMessage = .bodyOrEnd + case .body: + assert(self.nextExpectedOutboundMessage == .bodyOrEnd) + case .end: + assert(self.nextExpectedOutboundMessage == .bodyOrEnd) + self.nextExpectedOutboundMessage = .head + } + } + var startReadingAgain = false if case .end = self.unwrapOutboundIn(data) { startReadingAgain = true @@ -199,6 +243,8 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler { case .channelRead(let read): self.channelRead(ctx: ctx, data: read) deliveredRead = true + case .error(let error): + ctx.fireErrorCaught(error) case .halfClose: // When we fire the half-close, we want to forget all prior reads. // They will just trigger further half-close notifications we don't diff --git a/Sources/NIOHTTP1/HTTPServerProtocolErrorHandler.swift b/Sources/NIOHTTP1/HTTPServerProtocolErrorHandler.swift index 77557ef3f6..6d8581dd26 100644 --- a/Sources/NIOHTTP1/HTTPServerProtocolErrorHandler.swift +++ b/Sources/NIOHTTP1/HTTPServerProtocolErrorHandler.swift @@ -21,11 +21,14 @@ import NIO /// servers want. This handler does not suppress the parser errors: it allows them to /// continue to pass through the pipeline so that other handlers (e.g. logging ones) can /// deal with the error. -public final class HTTPServerProtocolErrorHandler: ChannelInboundHandler { +public final class HTTPServerProtocolErrorHandler: ChannelDuplexHandler { public typealias InboundIn = HTTPServerRequestPart public typealias InboundOut = HTTPServerRequestPart + public typealias OutboundIn = HTTPServerResponsePart public typealias OutboundOut = HTTPServerResponsePart + private var hasUnterminatedResponse: Bool = false + public func errorCaught(ctx: ChannelHandlerContext, error: Error) { guard error is HTTPParserError else { ctx.fireErrorCaught(error) @@ -34,16 +37,34 @@ public final class HTTPServerProtocolErrorHandler: ChannelInboundHandler { // Any HTTPParserError is automatically fatal, and we don't actually need (or want) to // provide that error to the client: we just want to tell it that it screwed up and then - // let the rest of the pipeline shut the door in its face. + // let the rest of the pipeline shut the door in its face. However, we can only send an + // HTTP error response if another response hasn't started yet. // // A side note here: we cannot block or do any delayed work. ByteToMessageDecoder is going // to come along and close the channel right after we return from this function. - let headers = HTTPHeaders([("Connection", "close"), ("Content-Length", "0")]) - let head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .badRequest, headers: headers) - ctx.write(self.wrapOutboundOut(.head(head)), promise: nil) - ctx.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + if !self.hasUnterminatedResponse { + let headers = HTTPHeaders([("Connection", "close"), ("Content-Length", "0")]) + let head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .badRequest, headers: headers) + ctx.write(self.wrapOutboundOut(.head(head)), promise: nil) + ctx.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } // Now pass the error on in case someone else wants to see it. ctx.fireErrorCaught(error) } + + public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let res = self.unwrapOutboundIn(data) + switch res { + case .head: + precondition(!self.hasUnterminatedResponse) + self.hasUnterminatedResponse = true + case .body: + precondition(self.hasUnterminatedResponse) + case .end: + precondition(self.hasUnterminatedResponse) + self.hasUnterminatedResponse = false + } + ctx.write(data, promise: promise) + } } diff --git a/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest+XCTest.swift b/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest+XCTest.swift index f64d75f1cc..69d4d1eb9c 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest+XCTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest+XCTest.swift @@ -28,6 +28,8 @@ extension HTTPServerProtocolErrorHandlerTest { return [ ("testHandlesBasicErrors", testHandlesBasicErrors), ("testIgnoresNonParserErrors", testIgnoresNonParserErrors), + ("testDoesNotSendAResponseIfResponseHasAlreadyStarted", testDoesNotSendAResponseIfResponseHasAlreadyStarted), + ("testCanHandleErrorsWhenResponseHasStarted", testCanHandleErrorsWhenResponseHasStarted), ] } } diff --git a/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest.swift b/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest.swift index 295ff528c8..5372ede2df 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest.swift @@ -66,4 +66,77 @@ class HTTPServerProtocolErrorHandlerTest: XCTestCase { XCTAssertNoThrow(try channel.finish()) } + + func testDoesNotSendAResponseIfResponseHasAlreadyStarted() throws { + let channel = EmbeddedChannel() + XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, withErrorHandling: true).wait()) + let res = HTTPServerResponsePart.head(.init(version: HTTPVersion(major: 1, minor: 1), + status: .ok, + headers: .init([("Content-Length", "0")]))) + XCTAssertNoThrow(try channel.writeAndFlush(res).wait()) + // now we have started a response but it's not complete yet, let's inject a parser error + channel.pipeline.fireErrorCaught(HTTPParserError.invalidEOFState) + var allOutbound = channel.readAllOutboundBuffers() + let allOutboundString = allOutbound.readString(length: allOutbound.readableBytes) + // there should be no HTTP/1.1 400 or anything in here + XCTAssertEqual("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", allOutboundString) + } + + func testCanHandleErrorsWhenResponseHasStarted() throws { + enum NextExpectedState { + case head + case end + case none + } + class DelayWriteHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + private var nextExpected: NextExpectedState = .head + + func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { + let req = self.unwrapInboundIn(data) + switch req { + case .head: + XCTAssertEqual(.head, self.nextExpected) + self.nextExpected = .end + let res = HTTPServerResponsePart.head(.init(version: HTTPVersion(major: 1, minor: 1), + status: .ok, + headers: .init([("Content-Length", "0")]))) + ctx.writeAndFlush(self.wrapOutboundOut(res), promise: nil) + default: + XCTAssertEqual(.end, self.nextExpected) + self.nextExpected = .none + } + } + + + } + let channel = EmbeddedChannel() + XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).then { + channel.pipeline.add(handler: DelayWriteHandler()) + }.wait()) + + var buffer = channel.allocator.buffer(capacity: 1024) + buffer.write(staticString: "GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\nGET / HT") + XCTAssertNoThrow(try channel.writeInbound(buffer)) + XCTAssertNoThrow(try channel.close().wait()) + (channel.eventLoop as! EmbeddedEventLoop).run() + + // The channel should be closed at this stage. + XCTAssertNoThrow(try channel.closeFuture.wait()) + + // We expect exactly one ByteBuffer in the output. + guard case .some(.byteBuffer(var written)) = channel.readOutbound() else { + XCTFail("No writes") + return + } + + XCTAssertNil(channel.readOutbound()) + + // Check the response. + assertResponseIs(response: written.readString(length: written.readableBytes)!, + expectedResponseLine: "HTTP/1.1 200 OK", + expectedResponseHeaders: ["Content-Length: 0"]) + } }