Skip to content
Merged
25 changes: 25 additions & 0 deletions Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,25 @@ public final class NIOTypedHTTPServerUpgradeHandler<UpgradeResult: Sendable>: Ch
}
}

public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case let evt as ChannelEvent where evt == ChannelEvent.inputClosed:
// The remote peer half-closed the channel during the upgrade. Should we close the other side
switch self.stateMachine.inputClosed() {
case .close:
context.close(promise: nil)
self.upgradeResultPromise.fail(ChannelError.inputClosed)
case .continue:
break
case .fireInputClosedEvent:
context.fireUserInboundEventTriggered(event)
}

default:
context.fireUserInboundEventTriggered(event)
}
}

private func channelRead(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) {
switch self.stateMachine.channelReadRequestPart(requestPart) {
case .failUpgradePromise(let error):
Expand Down Expand Up @@ -399,13 +418,19 @@ public final class NIOTypedHTTPServerUpgradeHandler<UpgradeResult: Sendable>: Ch
private func unbuffer(context: ChannelHandlerContext) {
while true {
switch self.stateMachine.unbuffer() {
case .close:
context.close(promise: nil)

case .fireChannelRead(let data):
context.fireChannelRead(data)

case .fireChannelReadCompleteAndRemoveHandler:
context.fireChannelReadComplete()
context.pipeline.removeHandler(self, promise: nil)
return

case .fireInputClosedEvent:
context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
}
}
}
Expand Down
83 changes: 72 additions & 11 deletions Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
/// The state before we received a TLSUserEvent. We are just forwarding any read at this point.
case initial

enum BufferedState {
case data(NIOAny)
case inputClosed
}

@usableFromInline
struct AwaitingUpgrader {
var seenFirstRequest: Bool
var buffer: Deque<NIOAny>
var buffer: Deque<BufferedState>
}

/// The request head has been received. We're currently running the future chain awaiting an upgrader.
Expand All @@ -37,22 +42,22 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
var requestHead: HTTPRequestHead
var responseHeaders: HTTPHeaders
var proto: String
var buffer: Deque<NIOAny>
var buffer: Deque<BufferedState>
}

/// We have an upgrader, which means we can begin upgrade we are just waiting for the request end.
case upgraderReady(UpgraderReady)

@usableFromInline
struct Upgrading {
var buffer: Deque<NIOAny>
var buffer: Deque<BufferedState>
}
/// We are either running the upgrading handler.
case upgrading(Upgrading)

@usableFromInline
struct Unbuffering {
var buffer: Deque<NIOAny>
var buffer: Deque<BufferedState>
}
case unbuffering(Unbuffering)

Expand Down Expand Up @@ -99,7 +104,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
if awaitingUpgrader.seenFirstRequest {
// We should buffer the data since we have seen the full request.
self.state = .modifying
awaitingUpgrader.buffer.append(data)
awaitingUpgrader.buffer.append(.data(data))
self.state = .awaitingUpgrader(awaitingUpgrader)
return nil
} else {
Expand All @@ -114,7 +119,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {

case .unbuffering(var unbuffering):
self.state = .modifying
unbuffering.buffer.append(data)
unbuffering.buffer.append(.data(data))
self.state = .unbuffering(unbuffering)
return nil

Expand All @@ -125,7 +130,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
// We got a read while running ugprading.
// We have to buffer the read to unbuffer it afterwards
self.state = .modifying
upgrading.buffer.append(data)
upgrading.buffer.append(.data(data))
self.state = .upgrading(upgrading)
return nil

Expand Down Expand Up @@ -167,8 +172,8 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
guard requestedProtocols.count > 0 else {
// We have to buffer now since we got the request head but are not upgrading.
// The user is configuring the HTTP pipeline now.
var buffer = Deque<NIOAny>()
buffer.append(NIOAny(requestPart))
var buffer = Deque<State.BufferedState>()
buffer.append(.data(NIOAny(requestPart)))
self.state = .upgrading(.init(buffer: buffer))
return .runNotUpgradingInitializer
}
Expand Down Expand Up @@ -364,8 +369,10 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {

@usableFromInline
enum UnbufferAction {
case close
case fireChannelRead(NIOAny)
case fireChannelReadCompleteAndRemoveHandler
case fireInputClosedEvent
}

@inlinable
Expand All @@ -379,8 +386,12 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {

if let element = unbuffering.buffer.popFirst() {
self.state = .unbuffering(unbuffering)

return .fireChannelRead(element)
switch element {
case .data(let data):
return .fireChannelRead(data)
case .inputClosed:
return .fireInputClosedEvent
}
} else {
self.state = .finished

Expand All @@ -393,5 +404,55 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
}
}

@usableFromInline
enum InputClosedAction {
case close
case `continue`
case fireInputClosedEvent
}

@inlinable
mutating func inputClosed() -> InputClosedAction {
switch self.state {
case .initial:
self.state = .finished
return .close

case .awaitingUpgrader(var awaitingUpgrader):
if awaitingUpgrader.seenFirstRequest {
// We should buffer the input close since we have seen the full request.
awaitingUpgrader.buffer.append(.inputClosed)
self.state = .awaitingUpgrader(awaitingUpgrader)
return .continue
} else {
// We shouldn't buffer. This means we were still expecting HTTP parts.
return .close
}

case .upgrading(var upgrading):
upgrading.buffer.append(.inputClosed)
self.state = .upgrading(upgrading)
return .continue

case .upgraderReady:
// if the state is `upgraderReady` we have received a `.head` but not an `.end`.
// If input is closed then there is no way to move this forward so we should
// close.
self.state = .finished
return .close

case .unbuffering(var unbuffering):
unbuffering.buffer.append(.inputClosed)
self.state = .unbuffering(unbuffering)
return .continue

case .finished:
return .fireInputClosedEvent

case .modifying:
fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine")
}
}

}
#endif
87 changes: 86 additions & 1 deletion Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ class HTTPServerUpgradeTestCase: XCTestCase {
upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader],
extraHandlers: [ChannelHandler],
notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture<Bool>)? = nil,
upgradeErrorHandler: (@Sendable (Error) -> Void)? = nil,
_ upgradeCompletionHandler: @escaping UpgradeCompletionHandler
) throws -> (Channel, Channel, Channel) {
let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval(
Expand Down Expand Up @@ -1770,11 +1771,13 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase {
upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader],
extraHandlers: [ChannelHandler],
notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture<Bool>)? = nil,
upgradeErrorHandler: (@Sendable (Error) -> Void)? = nil,
_ upgradeCompletionHandler: @escaping UpgradeCompletionHandler
) throws -> (Channel, Channel, Channel) {
let connectionChannelPromise = Self.eventLoop.makePromise(of: Channel.self)
let serverChannelFuture = ServerBootstrap(group: Self.eventLoop)
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
connectionChannelPromise.succeed(channel)
Expand All @@ -1800,6 +1803,10 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase {
return channel.eventLoop.makeSucceededVoidFuture()
}
}
.flatMapErrorThrowing { error in
upgradeErrorHandler?(error)
throw error
}
}
.flatMap { _ in
let futureResults = extraHandlers.map { channel.pipeline.addHandler($0) }
Expand Down Expand Up @@ -2313,5 +2320,83 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase {
// We also want to confirm that the upgrade handler is no longer in the pipeline.
try connectedServer.pipeline.waitForUpgraderToBeRemoved()
}

func testHalfClosure() throws {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add one or two more tests cases here:

  1. That sends the request head then input close
  2. That sends a full request head & end and then input close to check that we continue the upgrade

Copy link
Contributor Author

@adam-fowler adam-fowler Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FranzBusch
I did 2 (see testSendRequestCloseImmediately)

I'm not sure what the expected result is for 1
First I had to send a request head that included a content-length so a .head would be passed down the pipeline, but not an.end eg

OPTIONS * HTTP/1.1\r\nHost: localhost\r\ncontent-length: 10\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n

With this it gets stuck in the state .upgradeReady because nothing else happens after closeInbound to forward the state machine onwards and the connection stays open. If I return .close when the state is .upgradeReady from the state machine everything works but you explicitly said I shouldn't do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FranzBusch any thoughts

let errorCaught = UnsafeMutableTransferBox<Bool>(false)

let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in
XCTFail("Upgrade cannot be successful if we don't send any data to server")
}
let (_, client, connectedServer) = try setUpTestWithAutoremoval(
upgraders: [upgrader],
extraHandlers: [],
upgradeErrorHandler: { error in
switch error {
case ChannelError.inputClosed:
errorCaught.wrappedValue = true
default:
break
}
},
{ _ in }
)

try client.close(mode: .output).wait()
try connectedServer.closeFuture.wait()
XCTAssertEqual(errorCaught.wrappedValue, true)
}

/// Test that send a request and closing immediately performs a successful upgrade
func testSendRequestCloseImmediately() throws {
let upgradePerformed = UnsafeMutableTransferBox<Bool>(false)

let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in
upgradePerformed.wrappedValue = true
}
let (_, client, connectedServer) = try setUpTestWithAutoremoval(
upgraders: [upgrader],
extraHandlers: [],
upgradeErrorHandler: { error in
XCTFail("Error: \(error)")
},
{ _ in }
)

let request =
"OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n"
XCTAssertNoThrow(try client.writeAndFlush(client.allocator.buffer(string: request)).wait())
try client.close(mode: .output).wait()
try connectedServer.pipeline.waitForUpgraderToBeRemoved()
XCTAssertEqual(upgradePerformed.wrappedValue, true)
}

/// Test that sending an unfinished upgrade request and closing immediately throws
/// an input closed error
func testSendUnfinishedRequestCloseImmediately() throws {
let errorCaught = UnsafeMutableTransferBox<Bool>(false)

let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in
}
let (_, client, connectedServer) = try setUpTestWithAutoremoval(
upgraders: [upgrader],
extraHandlers: [],
upgradeErrorHandler: { error in
switch error {
case ChannelError.inputClosed:
errorCaught.wrappedValue = true
default:
XCTFail("Error: \(error)")
}
},
{ _ in }
)

let request =
"OPTIONS * HTTP/1.1\r\nHost: localhost\r\ncontent-length: 10\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n"
XCTAssertNoThrow(try client.writeAndFlush(client.allocator.buffer(string: request)).wait())
try client.close(mode: .output).wait()
try connectedServer.pipeline.waitForUpgraderToBeRemoved()
XCTAssertEqual(errorCaught.wrappedValue, true)
}
}
#endif
Loading