Skip to content

Commit a033db4

Browse files
adam-fowlerFranzBusch
authored andcommitted
Close channel during upgrade, if client closes input (apple#2756)
Check for half closure during server upgrade and close channel if client closes the channel ### Motivation: This is to fix apple#2742 ### Modifications: Add `userInboundEventTriggered` function to `NIOTypedHTTPServerProtocolUpgrader` which checks for `ChannelEvent.inputClosed` ### Result: Negotiation future now errors when client closes the connection instead of never completing --------- Co-authored-by: Franz Busch <[email protected]> (cherry picked from commit a026de3)
1 parent 4b9cfec commit a033db4

File tree

3 files changed

+183
-12
lines changed

3 files changed

+183
-12
lines changed

Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,25 @@ public final class NIOTypedHTTPServerUpgradeHandler<UpgradeResult: Sendable>: Ch
154154
}
155155
}
156156

157+
public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
158+
switch event {
159+
case let evt as ChannelEvent where evt == ChannelEvent.inputClosed:
160+
// The remote peer half-closed the channel during the upgrade. Should we close the other side
161+
switch self.stateMachine.inputClosed() {
162+
case .close:
163+
context.close(promise: nil)
164+
self.upgradeResultPromise.fail(ChannelError.inputClosed)
165+
case .continue:
166+
break
167+
case .fireInputClosedEvent:
168+
context.fireUserInboundEventTriggered(event)
169+
}
170+
171+
default:
172+
context.fireUserInboundEventTriggered(event)
173+
}
174+
}
175+
157176
private func channelRead(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) {
158177
switch self.stateMachine.channelReadRequestPart(requestPart) {
159178
case .failUpgradePromise(let error):
@@ -399,13 +418,19 @@ public final class NIOTypedHTTPServerUpgradeHandler<UpgradeResult: Sendable>: Ch
399418
private func unbuffer(context: ChannelHandlerContext) {
400419
while true {
401420
switch self.stateMachine.unbuffer() {
421+
case .close:
422+
context.close(promise: nil)
423+
402424
case .fireChannelRead(let data):
403425
context.fireChannelRead(data)
404426

405427
case .fireChannelReadCompleteAndRemoveHandler:
406428
context.fireChannelReadComplete()
407429
context.pipeline.removeHandler(self, promise: nil)
408430
return
431+
432+
case .fireInputClosedEvent:
433+
context.fireUserInboundEventTriggered(ChannelEvent.inputClosed)
409434
}
410435
}
411436
}

Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
2222
/// The state before we received a TLSUserEvent. We are just forwarding any read at this point.
2323
case initial
2424

25+
enum BufferedState {
26+
case data(NIOAny)
27+
case inputClosed
28+
}
29+
2530
@usableFromInline
2631
struct AwaitingUpgrader {
2732
var seenFirstRequest: Bool
28-
var buffer: Deque<NIOAny>
33+
var buffer: Deque<BufferedState>
2934
}
3035

3136
/// The request head has been received. We're currently running the future chain awaiting an upgrader.
@@ -37,22 +42,22 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
3742
var requestHead: HTTPRequestHead
3843
var responseHeaders: HTTPHeaders
3944
var proto: String
40-
var buffer: Deque<NIOAny>
45+
var buffer: Deque<BufferedState>
4146
}
4247

4348
/// We have an upgrader, which means we can begin upgrade we are just waiting for the request end.
4449
case upgraderReady(UpgraderReady)
4550

4651
@usableFromInline
4752
struct Upgrading {
48-
var buffer: Deque<NIOAny>
53+
var buffer: Deque<BufferedState>
4954
}
5055
/// We are either running the upgrading handler.
5156
case upgrading(Upgrading)
5257

5358
@usableFromInline
5459
struct Unbuffering {
55-
var buffer: Deque<NIOAny>
60+
var buffer: Deque<BufferedState>
5661
}
5762
case unbuffering(Unbuffering)
5863

@@ -99,7 +104,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
99104
if awaitingUpgrader.seenFirstRequest {
100105
// We should buffer the data since we have seen the full request.
101106
self.state = .modifying
102-
awaitingUpgrader.buffer.append(data)
107+
awaitingUpgrader.buffer.append(.data(data))
103108
self.state = .awaitingUpgrader(awaitingUpgrader)
104109
return nil
105110
} else {
@@ -114,7 +119,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
114119

115120
case .unbuffering(var unbuffering):
116121
self.state = .modifying
117-
unbuffering.buffer.append(data)
122+
unbuffering.buffer.append(.data(data))
118123
self.state = .unbuffering(unbuffering)
119124
return nil
120125

@@ -125,7 +130,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
125130
// We got a read while running ugprading.
126131
// We have to buffer the read to unbuffer it afterwards
127132
self.state = .modifying
128-
upgrading.buffer.append(data)
133+
upgrading.buffer.append(.data(data))
129134
self.state = .upgrading(upgrading)
130135
return nil
131136

@@ -167,8 +172,8 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
167172
guard requestedProtocols.count > 0 else {
168173
// We have to buffer now since we got the request head but are not upgrading.
169174
// The user is configuring the HTTP pipeline now.
170-
var buffer = Deque<NIOAny>()
171-
buffer.append(NIOAny(requestPart))
175+
var buffer = Deque<State.BufferedState>()
176+
buffer.append(.data(NIOAny(requestPart)))
172177
self.state = .upgrading(.init(buffer: buffer))
173178
return .runNotUpgradingInitializer
174179
}
@@ -364,8 +369,10 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
364369

365370
@usableFromInline
366371
enum UnbufferAction {
372+
case close
367373
case fireChannelRead(NIOAny)
368374
case fireChannelReadCompleteAndRemoveHandler
375+
case fireInputClosedEvent
369376
}
370377

371378
@inlinable
@@ -379,8 +386,12 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
379386

380387
if let element = unbuffering.buffer.popFirst() {
381388
self.state = .unbuffering(unbuffering)
382-
383-
return .fireChannelRead(element)
389+
switch element {
390+
case .data(let data):
391+
return .fireChannelRead(data)
392+
case .inputClosed:
393+
return .fireInputClosedEvent
394+
}
384395
} else {
385396
self.state = .finished
386397

@@ -393,5 +404,55 @@ struct NIOTypedHTTPServerUpgraderStateMachine<UpgradeResult> {
393404
}
394405
}
395406

407+
@usableFromInline
408+
enum InputClosedAction {
409+
case close
410+
case `continue`
411+
case fireInputClosedEvent
412+
}
413+
414+
@inlinable
415+
mutating func inputClosed() -> InputClosedAction {
416+
switch self.state {
417+
case .initial:
418+
self.state = .finished
419+
return .close
420+
421+
case .awaitingUpgrader(var awaitingUpgrader):
422+
if awaitingUpgrader.seenFirstRequest {
423+
// We should buffer the input close since we have seen the full request.
424+
awaitingUpgrader.buffer.append(.inputClosed)
425+
self.state = .awaitingUpgrader(awaitingUpgrader)
426+
return .continue
427+
} else {
428+
// We shouldn't buffer. This means we were still expecting HTTP parts.
429+
return .close
430+
}
431+
432+
case .upgrading(var upgrading):
433+
upgrading.buffer.append(.inputClosed)
434+
self.state = .upgrading(upgrading)
435+
return .continue
436+
437+
case .upgraderReady:
438+
// if the state is `upgraderReady` we have received a `.head` but not an `.end`.
439+
// If input is closed then there is no way to move this forward so we should
440+
// close.
441+
self.state = .finished
442+
return .close
443+
444+
case .unbuffering(var unbuffering):
445+
unbuffering.buffer.append(.inputClosed)
446+
self.state = .unbuffering(unbuffering)
447+
return .continue
448+
449+
case .finished:
450+
return .fireInputClosedEvent
451+
452+
case .modifying:
453+
fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine")
454+
}
455+
}
456+
396457
}
397458
#endif

Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ class HTTPServerUpgradeTestCase: XCTestCase {
489489
upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader],
490490
extraHandlers: [ChannelHandler],
491491
notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture<Bool>)? = nil,
492+
upgradeErrorHandler: (@Sendable (Error) -> Void)? = nil,
492493
_ upgradeCompletionHandler: @escaping UpgradeCompletionHandler
493494
) throws -> (Channel, Channel, Channel) {
494495
let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval(
@@ -1770,11 +1771,13 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase {
17701771
upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader],
17711772
extraHandlers: [ChannelHandler],
17721773
notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture<Bool>)? = nil,
1774+
upgradeErrorHandler: (@Sendable (Error) -> Void)? = nil,
17731775
_ upgradeCompletionHandler: @escaping UpgradeCompletionHandler
17741776
) throws -> (Channel, Channel, Channel) {
17751777
let connectionChannelPromise = Self.eventLoop.makePromise(of: Channel.self)
17761778
let serverChannelFuture = ServerBootstrap(group: Self.eventLoop)
1777-
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
1779+
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
1780+
.childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
17781781
.childChannelInitializer { channel in
17791782
channel.eventLoop.makeCompletedFuture {
17801783
connectionChannelPromise.succeed(channel)
@@ -1800,6 +1803,10 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase {
18001803
return channel.eventLoop.makeSucceededVoidFuture()
18011804
}
18021805
}
1806+
.flatMapErrorThrowing { error in
1807+
upgradeErrorHandler?(error)
1808+
throw error
1809+
}
18031810
}
18041811
.flatMap { _ in
18051812
let futureResults = extraHandlers.map { channel.pipeline.addHandler($0) }
@@ -2313,5 +2320,83 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase {
23132320
// We also want to confirm that the upgrade handler is no longer in the pipeline.
23142321
try connectedServer.pipeline.waitForUpgraderToBeRemoved()
23152322
}
2323+
2324+
func testHalfClosure() throws {
2325+
let errorCaught = UnsafeMutableTransferBox<Bool>(false)
2326+
2327+
let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in
2328+
XCTFail("Upgrade cannot be successful if we don't send any data to server")
2329+
}
2330+
let (_, client, connectedServer) = try setUpTestWithAutoremoval(
2331+
upgraders: [upgrader],
2332+
extraHandlers: [],
2333+
upgradeErrorHandler: { error in
2334+
switch error {
2335+
case ChannelError.inputClosed:
2336+
errorCaught.wrappedValue = true
2337+
default:
2338+
break
2339+
}
2340+
},
2341+
{ _ in }
2342+
)
2343+
2344+
try client.close(mode: .output).wait()
2345+
try connectedServer.closeFuture.wait()
2346+
XCTAssertEqual(errorCaught.wrappedValue, true)
2347+
}
2348+
2349+
/// Test that send a request and closing immediately performs a successful upgrade
2350+
func testSendRequestCloseImmediately() throws {
2351+
let upgradePerformed = UnsafeMutableTransferBox<Bool>(false)
2352+
2353+
let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in
2354+
upgradePerformed.wrappedValue = true
2355+
}
2356+
let (_, client, connectedServer) = try setUpTestWithAutoremoval(
2357+
upgraders: [upgrader],
2358+
extraHandlers: [],
2359+
upgradeErrorHandler: { error in
2360+
XCTFail("Error: \(error)")
2361+
},
2362+
{ _ in }
2363+
)
2364+
2365+
let request =
2366+
"OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n"
2367+
XCTAssertNoThrow(try client.writeAndFlush(client.allocator.buffer(string: request)).wait())
2368+
try client.close(mode: .output).wait()
2369+
try connectedServer.pipeline.waitForUpgraderToBeRemoved()
2370+
XCTAssertEqual(upgradePerformed.wrappedValue, true)
2371+
}
2372+
2373+
/// Test that sending an unfinished upgrade request and closing immediately throws
2374+
/// an input closed error
2375+
func testSendUnfinishedRequestCloseImmediately() throws {
2376+
let errorCaught = UnsafeMutableTransferBox<Bool>(false)
2377+
2378+
let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in
2379+
}
2380+
let (_, client, connectedServer) = try setUpTestWithAutoremoval(
2381+
upgraders: [upgrader],
2382+
extraHandlers: [],
2383+
upgradeErrorHandler: { error in
2384+
switch error {
2385+
case ChannelError.inputClosed:
2386+
errorCaught.wrappedValue = true
2387+
default:
2388+
XCTFail("Error: \(error)")
2389+
}
2390+
},
2391+
{ _ in }
2392+
)
2393+
2394+
let request =
2395+
"OPTIONS * HTTP/1.1\r\nHost: localhost\r\ncontent-length: 10\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n"
2396+
XCTAssertNoThrow(try client.writeAndFlush(client.allocator.buffer(string: request)).wait())
2397+
try client.close(mode: .output).wait()
2398+
try connectedServer.pipeline.waitForUpgraderToBeRemoved()
2399+
XCTAssertEqual(errorCaught.wrappedValue, true)
2400+
}
23162401
}
23172402
#endif

0 commit comments

Comments
 (0)