Skip to content

Commit

Permalink
fix errors received when holding messages in HTTPServerPipelineHandler (
Browse files Browse the repository at this point in the history
#314)

Motivation:

We had a bug which is happens in the combination of these states:
- we held a request in the pipelining handler (because we're procesing a
  previous one)
- a http handler error happened whilst a response's `.head` had already
  been sent (but not yet the `.end`)
- the HTTPServerProtocolErrors handler is in use

That would lead to this situation:
- the error isn't held by the pipelining handler
- the error handler then just sends a full response (`.head` and `.end`)
  but the actual http server already send a `.head`. So all in all, we
  sent `.head`, `.head`, `.end` which is illegal
- the pipelining handler didn't notice this and beause it saw an `.end`
  it would send through the next requst
- now the http server handler is in the situation that it gets `.head`,
  `.head` too (which is illegal)

Modifications:

- hold HTTP errors in the pipelining handler too

Result:

- more correctness
  • Loading branch information
weissi authored Apr 16, 2018
1 parent 7758140 commit dc0d731
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 6 deletions.
46 changes: 46 additions & 0 deletions Sources/NIOHTTP1/HTTPServerPipelineHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@

import NIO

/// A utility function that runs the body code only in debug builds, without
/// emitting compiler warnings.
///
/// This is currently the only way to do this in Swift: see
/// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion.
internal func debugOnly(_ body: () -> Void) {
assert({ body(); return true }())
}

/// A `ChannelHandler` that handles HTTP pipelining by buffering inbound data until a
/// response has been sent.
///
Expand Down Expand Up @@ -105,6 +114,8 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler {
private enum BufferedEvent {
/// A channelRead event.
case channelRead(NIOAny)

case error(HTTPParserError)

/// A TCP half-close. This is buffered to ensure that subsequent channel
/// handlers that are aware of TCP half-close are informed about it in
Expand All @@ -123,6 +134,13 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler {
/// don't pipeline, so this initially allocates no space for data at all. Clients that
/// do pipeline will cause dynamic resizing of the buffer, which is generally acceptable.
private var eventBuffer = CircularBuffer<BufferedEvent>(initialRingCapacity: 0)

enum NextExpectedMessageType {
case head
case bodyOrEnd
}

private var nextExpectedOutboundMessage = NextExpectedMessageType.head

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
if case .responseEndPending = self.state {
Expand Down Expand Up @@ -151,10 +169,36 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler {
ctx.fireUserInboundEventTriggered(event)
}
}

public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
guard let httpError = error as? HTTPParserError else {
ctx.fireErrorCaught(error)
return
}
if case .responseEndPending = self.state {
self.eventBuffer.append(.error(httpError))
return
}
ctx.fireErrorCaught(error)
}

public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
assert(self.state != .requestEndPending,
"Received second response while waiting for first one to complete")
debugOnly {
let res = self.unwrapOutboundIn(data)
switch res {
case .head:
assert(self.nextExpectedOutboundMessage == .head)
self.nextExpectedOutboundMessage = .bodyOrEnd
case .body:
assert(self.nextExpectedOutboundMessage == .bodyOrEnd)
case .end:
assert(self.nextExpectedOutboundMessage == .bodyOrEnd)
self.nextExpectedOutboundMessage = .head
}
}

var startReadingAgain = false
if case .end = self.unwrapOutboundIn(data) {
startReadingAgain = true
Expand Down Expand Up @@ -199,6 +243,8 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler {
case .channelRead(let read):
self.channelRead(ctx: ctx, data: read)
deliveredRead = true
case .error(let error):
ctx.fireErrorCaught(error)
case .halfClose:
// When we fire the half-close, we want to forget all prior reads.
// They will just trigger further half-close notifications we don't
Expand Down
33 changes: 27 additions & 6 deletions Sources/NIOHTTP1/HTTPServerProtocolErrorHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ import NIO
/// servers want. This handler does not suppress the parser errors: it allows them to
/// continue to pass through the pipeline so that other handlers (e.g. logging ones) can
/// deal with the error.
public final class HTTPServerProtocolErrorHandler: ChannelInboundHandler {
public final class HTTPServerProtocolErrorHandler: ChannelDuplexHandler {
public typealias InboundIn = HTTPServerRequestPart
public typealias InboundOut = HTTPServerRequestPart
public typealias OutboundIn = HTTPServerResponsePart
public typealias OutboundOut = HTTPServerResponsePart

private var hasUnterminatedResponse: Bool = false

public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
guard error is HTTPParserError else {
ctx.fireErrorCaught(error)
Expand All @@ -34,16 +37,34 @@ public final class HTTPServerProtocolErrorHandler: ChannelInboundHandler {

// Any HTTPParserError is automatically fatal, and we don't actually need (or want) to
// provide that error to the client: we just want to tell it that it screwed up and then
// let the rest of the pipeline shut the door in its face.
// let the rest of the pipeline shut the door in its face. However, we can only send an
// HTTP error response if another response hasn't started yet.
//
// A side note here: we cannot block or do any delayed work. ByteToMessageDecoder is going
// to come along and close the channel right after we return from this function.
let headers = HTTPHeaders([("Connection", "close"), ("Content-Length", "0")])
let head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .badRequest, headers: headers)
ctx.write(self.wrapOutboundOut(.head(head)), promise: nil)
ctx.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
if !self.hasUnterminatedResponse {
let headers = HTTPHeaders([("Connection", "close"), ("Content-Length", "0")])
let head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .badRequest, headers: headers)
ctx.write(self.wrapOutboundOut(.head(head)), promise: nil)
ctx.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
}

// Now pass the error on in case someone else wants to see it.
ctx.fireErrorCaught(error)
}

public func write(ctx: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let res = self.unwrapOutboundIn(data)
switch res {
case .head:
precondition(!self.hasUnterminatedResponse)
self.hasUnterminatedResponse = true
case .body:
precondition(self.hasUnterminatedResponse)
case .end:
precondition(self.hasUnterminatedResponse)
self.hasUnterminatedResponse = false
}
ctx.write(data, promise: promise)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ extension HTTPServerProtocolErrorHandlerTest {
return [
("testHandlesBasicErrors", testHandlesBasicErrors),
("testIgnoresNonParserErrors", testIgnoresNonParserErrors),
("testDoesNotSendAResponseIfResponseHasAlreadyStarted", testDoesNotSendAResponseIfResponseHasAlreadyStarted),
("testCanHandleErrorsWhenResponseHasStarted", testCanHandleErrorsWhenResponseHasStarted),
]
}
}
Expand Down
73 changes: 73 additions & 0 deletions Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,77 @@ class HTTPServerProtocolErrorHandlerTest: XCTestCase {

XCTAssertNoThrow(try channel.finish())
}

func testDoesNotSendAResponseIfResponseHasAlreadyStarted() throws {
let channel = EmbeddedChannel()
XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, withErrorHandling: true).wait())
let res = HTTPServerResponsePart.head(.init(version: HTTPVersion(major: 1, minor: 1),
status: .ok,
headers: .init([("Content-Length", "0")])))
XCTAssertNoThrow(try channel.writeAndFlush(res).wait())
// now we have started a response but it's not complete yet, let's inject a parser error
channel.pipeline.fireErrorCaught(HTTPParserError.invalidEOFState)
var allOutbound = channel.readAllOutboundBuffers()
let allOutboundString = allOutbound.readString(length: allOutbound.readableBytes)
// there should be no HTTP/1.1 400 or anything in here
XCTAssertEqual("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", allOutboundString)
}

func testCanHandleErrorsWhenResponseHasStarted() throws {
enum NextExpectedState {
case head
case end
case none
}
class DelayWriteHandler: ChannelInboundHandler {
typealias InboundIn = HTTPServerRequestPart
typealias OutboundOut = HTTPServerResponsePart

private var nextExpected: NextExpectedState = .head

func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
let req = self.unwrapInboundIn(data)
switch req {
case .head:
XCTAssertEqual(.head, self.nextExpected)
self.nextExpected = .end
let res = HTTPServerResponsePart.head(.init(version: HTTPVersion(major: 1, minor: 1),
status: .ok,
headers: .init([("Content-Length", "0")])))
ctx.writeAndFlush(self.wrapOutboundOut(res), promise: nil)
default:
XCTAssertEqual(.end, self.nextExpected)
self.nextExpected = .none
}
}


}
let channel = EmbeddedChannel()
XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).then {
channel.pipeline.add(handler: DelayWriteHandler())
}.wait())

var buffer = channel.allocator.buffer(capacity: 1024)
buffer.write(staticString: "GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\nGET / HT")
XCTAssertNoThrow(try channel.writeInbound(buffer))
XCTAssertNoThrow(try channel.close().wait())
(channel.eventLoop as! EmbeddedEventLoop).run()

// The channel should be closed at this stage.
XCTAssertNoThrow(try channel.closeFuture.wait())

// We expect exactly one ByteBuffer in the output.
guard case .some(.byteBuffer(var written)) = channel.readOutbound() else {
XCTFail("No writes")
return
}

XCTAssertNil(channel.readOutbound())

// Check the response.
assertResponseIs(response: written.readString(length: written.readableBytes)!,
expectedResponseLine: "HTTP/1.1 200 OK",
expectedResponseHeaders: ["Content-Length: 0"])
}
}

0 comments on commit dc0d731

Please # to comment.