diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift index e22a57a806a..b430cf1be44 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift @@ -154,6 +154,25 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch } } + public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case let evt as ChannelEvent where evt == ChannelEvent.inputClosed: + // The remote peer half-closed the channel during the upgrade. Should we close the other side + switch self.stateMachine.inputClosed() { + case .close: + context.close(promise: nil) + self.upgradeResultPromise.fail(ChannelError.inputClosed) + case .continue: + break + case .fireInputClosedEvent: + context.fireUserInboundEventTriggered(event) + } + + default: + context.fireUserInboundEventTriggered(event) + } + } + private func channelRead(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) { switch self.stateMachine.channelReadRequestPart(requestPart) { case .failUpgradePromise(let error): @@ -399,6 +418,9 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch private func unbuffer(context: ChannelHandlerContext) { while true { switch self.stateMachine.unbuffer() { + case .close: + context.close(promise: nil) + case .fireChannelRead(let data): context.fireChannelRead(data) @@ -406,6 +428,9 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch context.fireChannelReadComplete() context.pipeline.removeHandler(self, promise: nil) return + + case .fireInputClosedEvent: + context.fireUserInboundEventTriggered(ChannelEvent.inputClosed) } } } diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift index 80fd018944f..ff228b9834c 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift @@ -22,10 +22,15 @@ struct NIOTypedHTTPServerUpgraderStateMachine { /// The state before we received a TLSUserEvent. We are just forwarding any read at this point. case initial + enum BufferedState { + case data(NIOAny) + case inputClosed + } + @usableFromInline struct AwaitingUpgrader { var seenFirstRequest: Bool - var buffer: Deque + var buffer: Deque } /// The request head has been received. We're currently running the future chain awaiting an upgrader. @@ -37,7 +42,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine { var requestHead: HTTPRequestHead var responseHeaders: HTTPHeaders var proto: String - var buffer: Deque + var buffer: Deque } /// We have an upgrader, which means we can begin upgrade we are just waiting for the request end. @@ -45,14 +50,14 @@ struct NIOTypedHTTPServerUpgraderStateMachine { @usableFromInline struct Upgrading { - var buffer: Deque + var buffer: Deque } /// We are either running the upgrading handler. case upgrading(Upgrading) @usableFromInline struct Unbuffering { - var buffer: Deque + var buffer: Deque } case unbuffering(Unbuffering) @@ -99,7 +104,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine { if awaitingUpgrader.seenFirstRequest { // We should buffer the data since we have seen the full request. self.state = .modifying - awaitingUpgrader.buffer.append(data) + awaitingUpgrader.buffer.append(.data(data)) self.state = .awaitingUpgrader(awaitingUpgrader) return nil } else { @@ -114,7 +119,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine { case .unbuffering(var unbuffering): self.state = .modifying - unbuffering.buffer.append(data) + unbuffering.buffer.append(.data(data)) self.state = .unbuffering(unbuffering) return nil @@ -125,7 +130,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine { // We got a read while running ugprading. // We have to buffer the read to unbuffer it afterwards self.state = .modifying - upgrading.buffer.append(data) + upgrading.buffer.append(.data(data)) self.state = .upgrading(upgrading) return nil @@ -167,8 +172,8 @@ struct NIOTypedHTTPServerUpgraderStateMachine { guard requestedProtocols.count > 0 else { // We have to buffer now since we got the request head but are not upgrading. // The user is configuring the HTTP pipeline now. - var buffer = Deque() - buffer.append(NIOAny(requestPart)) + var buffer = Deque() + buffer.append(.data(NIOAny(requestPart))) self.state = .upgrading(.init(buffer: buffer)) return .runNotUpgradingInitializer } @@ -364,8 +369,10 @@ struct NIOTypedHTTPServerUpgraderStateMachine { @usableFromInline enum UnbufferAction { + case close case fireChannelRead(NIOAny) case fireChannelReadCompleteAndRemoveHandler + case fireInputClosedEvent } @inlinable @@ -379,8 +386,12 @@ struct NIOTypedHTTPServerUpgraderStateMachine { if let element = unbuffering.buffer.popFirst() { self.state = .unbuffering(unbuffering) - - return .fireChannelRead(element) + switch element { + case .data(let data): + return .fireChannelRead(data) + case .inputClosed: + return .fireInputClosedEvent + } } else { self.state = .finished @@ -393,5 +404,55 @@ struct NIOTypedHTTPServerUpgraderStateMachine { } } + @usableFromInline + enum InputClosedAction { + case close + case `continue` + case fireInputClosedEvent + } + + @inlinable + mutating func inputClosed() -> InputClosedAction { + switch self.state { + case .initial: + self.state = .finished + return .close + + case .awaitingUpgrader(var awaitingUpgrader): + if awaitingUpgrader.seenFirstRequest { + // We should buffer the input close since we have seen the full request. + awaitingUpgrader.buffer.append(.inputClosed) + self.state = .awaitingUpgrader(awaitingUpgrader) + return .continue + } else { + // We shouldn't buffer. This means we were still expecting HTTP parts. + return .close + } + + case .upgrading(var upgrading): + upgrading.buffer.append(.inputClosed) + self.state = .upgrading(upgrading) + return .continue + + case .upgraderReady: + // if the state is `upgraderReady` we have received a `.head` but not an `.end`. + // If input is closed then there is no way to move this forward so we should + // close. + self.state = .finished + return .close + + case .unbuffering(var unbuffering): + unbuffering.buffer.append(.inputClosed) + self.state = .unbuffering(unbuffering) + return .continue + + case .finished: + return .fireInputClosedEvent + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + } #endif diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index 166b7b0934f..3a69f806e35 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -489,6 +489,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], extraHandlers: [ChannelHandler], notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, + upgradeErrorHandler: (@Sendable (Error) -> Void)? = nil, _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler ) throws -> (Channel, Channel, Channel) { let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval( @@ -1770,11 +1771,13 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], extraHandlers: [ChannelHandler], notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, + upgradeErrorHandler: (@Sendable (Error) -> Void)? = nil, _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler ) throws -> (Channel, Channel, Channel) { let connectionChannelPromise = Self.eventLoop.makePromise(of: Channel.self) let serverChannelFuture = ServerBootstrap(group: Self.eventLoop) - .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: true) .childChannelInitializer { channel in channel.eventLoop.makeCompletedFuture { connectionChannelPromise.succeed(channel) @@ -1800,6 +1803,10 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { return channel.eventLoop.makeSucceededVoidFuture() } } + .flatMapErrorThrowing { error in + upgradeErrorHandler?(error) + throw error + } } .flatMap { _ in let futureResults = extraHandlers.map { channel.pipeline.addHandler($0) } @@ -2313,5 +2320,83 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { // We also want to confirm that the upgrade handler is no longer in the pipeline. try connectedServer.pipeline.waitForUpgraderToBeRemoved() } + + func testHalfClosure() throws { + let errorCaught = UnsafeMutableTransferBox(false) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + XCTFail("Upgrade cannot be successful if we don't send any data to server") + } + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [], + upgradeErrorHandler: { error in + switch error { + case ChannelError.inputClosed: + errorCaught.wrappedValue = true + default: + break + } + }, + { _ in } + ) + + try client.close(mode: .output).wait() + try connectedServer.closeFuture.wait() + XCTAssertEqual(errorCaught.wrappedValue, true) + } + + /// Test that send a request and closing immediately performs a successful upgrade + func testSendRequestCloseImmediately() throws { + let upgradePerformed = UnsafeMutableTransferBox(false) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in + upgradePerformed.wrappedValue = true + } + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [], + upgradeErrorHandler: { error in + XCTFail("Error: \(error)") + }, + { _ in } + ) + + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(client.allocator.buffer(string: request)).wait()) + try client.close(mode: .output).wait() + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + XCTAssertEqual(upgradePerformed.wrappedValue, true) + } + + /// Test that sending an unfinished upgrade request and closing immediately throws + /// an input closed error + func testSendUnfinishedRequestCloseImmediately() throws { + let errorCaught = UnsafeMutableTransferBox(false) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in + } + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [], + upgradeErrorHandler: { error in + switch error { + case ChannelError.inputClosed: + errorCaught.wrappedValue = true + default: + XCTFail("Error: \(error)") + } + }, + { _ in } + ) + + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\ncontent-length: 10\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(client.allocator.buffer(string: request)).wait()) + try client.close(mode: .output).wait() + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + XCTAssertEqual(errorCaught.wrappedValue, true) + } } #endif