Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ let package = Package(
"NIOCore",
"NIOEmbedded",
"NIOWebSocket",
]
],
swiftSettings: strictConcurrencySettings
),
.testTarget(
name: "NIOTestUtilsTests",
Expand Down
2 changes: 1 addition & 1 deletion Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public typealias NIOWebClientSocketUpgrader = NIOWebSocketClientUpgrader
/// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request.
/// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the
/// pipeline to remove the HTTP `ChannelHandler`s.
public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader {
public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader, Sendable {
/// RFC 6455 specs this as the required entry in the Upgrade header.
public let supportedProtocol: String = "websocket"
/// None of the websocket headers are actually defined as 'required'.
Expand Down
19 changes: 10 additions & 9 deletions Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

import NIOConcurrencyHelpers
import NIOEmbedded
import NIOHTTP1
import XCTest
Expand Down Expand Up @@ -54,17 +55,17 @@ extension ChannelPipeline {
}

private func setUpClientChannel(
clientHTTPHandler: RemovableChannelHandler,
clientUpgraders: [NIOHTTPClientProtocolUpgrader],
_ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void
clientHTTPHandler: RemovableChannelHandler & Sendable,
clientUpgraders: [NIOHTTPClientProtocolUpgrader & Sendable],
_ upgradeCompletionHandler: @escaping @Sendable (ChannelHandlerContext) -> Void
) throws -> EmbeddedChannel {

let channel = EmbeddedChannel()

let config: NIOHTTPClientUpgradeSendableConfiguration = (
upgraders: clientUpgraders,
completionHandler: { context in
channel.pipeline.removeHandler(clientHTTPHandler, promise: nil)
channel.pipeline.syncOperations.removeHandler(clientHTTPHandler, promise: nil)
upgradeCompletionHandler(context)
}
)
Expand All @@ -80,7 +81,7 @@ private func setUpClientChannel(
}

// A HTTP handler that will send an initial request which can be augmented by the upgrade handler.
private final class BasicHTTPHandler: ChannelInboundHandler, RemovableChannelHandler {
private final class BasicHTTPHandler: ChannelInboundHandler, RemovableChannelHandler, Sendable {
fileprivate typealias InboundIn = HTTPClientResponsePart
fileprivate typealias OutboundOut = HTTPClientRequestPart

Expand All @@ -92,7 +93,7 @@ private final class BasicHTTPHandler: ChannelInboundHandler, RemovableChannelHan

// A HTTP handler that will send a request and then fail if it receives a response or an error.
// It can be used when there is a successful upgrade as the handler should be removed by the upgrader.
private final class ExplodingHTTPHandler: ChannelInboundHandler, RemovableChannelHandler {
private final class ExplodingHTTPHandler: ChannelInboundHandler, RemovableChannelHandler, Sendable {
fileprivate typealias InboundIn = HTTPClientResponsePart
fileprivate typealias OutboundOut = HTTPClientRequestPart

Expand Down Expand Up @@ -163,7 +164,7 @@ private func basicRequest(path: String = "/") -> String {
class WebSocketClientEndToEndTests: XCTestCase {
func testSimpleUpgradeSucceeds() throws {

var upgradeHandlerCallbackFired = false
let upgradeHandlerCallbackFired = NIOLockedValueBox(false)
let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM="
let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk="

Expand All @@ -183,7 +184,7 @@ class WebSocketClientEndToEndTests: XCTestCase {
) { _ in

// This is called before the upgrader gets called.
upgradeHandlerCallbackFired = true
upgradeHandlerCallbackFired.withLockedValue { $0 = true }
}

// Read the server request.
Expand Down Expand Up @@ -233,7 +234,7 @@ class WebSocketClientEndToEndTests: XCTestCase {
.assertContains(handlerType: WebSocketRecorderHandler.self)
)

XCTAssert(upgradeHandlerCallbackFired)
XCTAssert(upgradeHandlerCallbackFired.withLockedValue { $0 })

// Close the pipeline.
XCTAssertNoThrow(try clientChannel.close().wait())
Expand Down
73 changes: 36 additions & 37 deletions Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ private class CloseSwallower: ChannelOutboundHandler, RemovableChannelHandler {
private var closePromise: EventLoopPromise<Void>? = nil
private var context: ChannelHandlerContext? = nil

public func allowClose() {
func allowClose() {
self.context!.close(promise: self.closePromise)
self.context = nil
}
Expand Down Expand Up @@ -58,12 +58,12 @@ private final class SynchronousCloser: ChannelInboundHandler {
}
}

public final class WebSocketFrameDecoderTest: XCTestCase {
public var decoderChannel: EmbeddedChannel!
public var encoderChannel: EmbeddedChannel!
public var buffer: ByteBuffer!
final class WebSocketFrameDecoderTest: XCTestCase {
var decoderChannel: EmbeddedChannel!
var encoderChannel: EmbeddedChannel!
var buffer: ByteBuffer!

public override func setUp() {
override func setUp() {
self.decoderChannel = EmbeddedChannel()
self.encoderChannel = EmbeddedChannel()
self.buffer = decoderChannel.allocator.buffer(capacity: 128)
Expand All @@ -73,7 +73,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(try self.encoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder()))
}

public override func tearDown() {
override func tearDown() {
XCTAssertNoThrow(try self.encoderChannel.finish())
_ = try? self.decoderChannel.finish()
self.encoderChannel = nil
Expand Down Expand Up @@ -114,32 +114,31 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
// We need to insert a decoder that doesn't do error handling. We still insert
// an encoder because we want to fail gracefully if a frame is written.
let f = self.decoderChannel.pipeline.context(handlerType: ByteToMessageHandler<WebSocketFrameDecoder>.self)
.flatMapThrowing {
if let handler = $0.handler as? RemovableChannelHandler {
return handler
.assumeIsolated()
.flatMap { context in
if let handler = context.handler as? RemovableChannelHandler {
return self.decoderChannel.pipeline.syncOperations.removeHandler(handler)
} else {
throw ChannelError.unremovableHandler
return context.eventLoop.makeFailedFuture(ChannelError.unremovableHandler)
}
}.flatMap {
self.decoderChannel.pipeline.removeHandler($0)
}

// we need to run the event loop here because removal is not synchronous
(self.decoderChannel.eventLoop as! EmbeddedEventLoop).run()

XCTAssertNoThrow(
try f.flatMap {
self.decoderChannel.pipeline.addHandler(handler)
}.wait()
try f.flatMapThrowing {
try self.decoderChannel.pipeline.syncOperations.addHandler(handler)
}.nonisolated().wait()
)
}

public func testFramesWithoutBodies() throws {
func testFramesWithoutBodies() throws {
let frame = WebSocketFrame(fin: true, opcode: .ping, data: self.buffer)
assertFrameRoundTrips(frame: frame)
}

public func testFramesWithExtensionDataDontRoundTrip() throws {
func testFramesWithExtensionDataDontRoundTrip() throws {
// We don't know what the extensions are, so all data goes in...well...data.
self.buffer.writeBytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
let frame = WebSocketFrame(
Expand All @@ -151,7 +150,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
assertFrameDoesNotRoundTrip(frame: frame)
}

public func testFramesWithExtensionDataCanBeRecovered() throws {
func testFramesWithExtensionDataCanBeRecovered() throws {
self.buffer.writeBytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
let frame = WebSocketFrame(
fin: false,
Expand All @@ -165,7 +164,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertEqual(newFrame, frame)
}

public func testFramesWithReservedBitsSetRoundTrip() throws {
func testFramesWithReservedBitsSetRoundTrip() throws {
self.buffer.writeBytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
let frame = WebSocketFrame(
fin: false,
Expand All @@ -178,7 +177,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
assertFrameRoundTrips(frame: frame)
}

public func testFramesWith16BitLengthsRoundTrip() throws {
func testFramesWith16BitLengthsRoundTrip() throws {
self.buffer.writeBytes(Array(repeating: UInt8(4), count: 300))
let frame = WebSocketFrame(
fin: true,
Expand All @@ -188,7 +187,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
assertFrameRoundTrips(frame: frame)
}

public func testFramesWith64BitLengthsRoundTrip() throws {
func testFramesWith64BitLengthsRoundTrip() throws {
// We need a new decoder channel here, because the max length would otherwise trigger an error.
_ = try! self.decoderChannel.finish()
self.decoderChannel = EmbeddedChannel()
Expand All @@ -207,7 +206,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
assertFrameRoundTrips(frame: frame)
}

public func testMaskedFramesRoundTripWithMaskingIntact() throws {
func testMaskedFramesRoundTripWithMaskingIntact() throws {
self.buffer.writeBytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
let frame = WebSocketFrame(
fin: false,
Expand All @@ -232,7 +231,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertEqual(producedFrame.unmaskedData, self.buffer)
}

public func testMaskedFramesRoundTripWithMaskingIntactEvenWithExtensions() throws {
func testMaskedFramesRoundTripWithMaskingIntactEvenWithExtensions() throws {
self.buffer.writeBytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
let frame = WebSocketFrame(
fin: false,
Expand Down Expand Up @@ -266,7 +265,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
)
}

public func testDecoderRejectsOverlongFrames() throws {
func testDecoderRejectsOverlongFrames() throws {
XCTAssertNoThrow(
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
)
Expand All @@ -284,7 +283,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertEqual([0x88, 0x02, 0x03, 0xF1], try self.decoderChannel.readAllOutboundBytes()))
}

public func testDecoderRejectsFragmentedControlFrames() throws {
func testDecoderRejectsFragmentedControlFrames() throws {
XCTAssertNoThrow(
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
)
Expand All @@ -301,7 +300,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertEqual([0x88, 0x02, 0x03, 0xEA], try self.decoderChannel.readAllOutboundBytes()))
}

public func testDecoderRejectsMultibyteControlFrameLengths() throws {
func testDecoderRejectsMultibyteControlFrameLengths() throws {
XCTAssertNoThrow(
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
)
Expand Down Expand Up @@ -357,12 +356,12 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
swallower.allowClose()

// Take the handler out for cleanliness.
XCTAssertNoThrow(try self.decoderChannel.pipeline.removeHandler(swallower).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.removeHandler(swallower).wait())
}

public func testClosingSynchronouslyOnChannelRead() throws {
func testClosingSynchronouslyOnChannelRead() throws {
// We're going to send a connectionClose frame and confirm we only see it once.
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(SynchronousCloser()).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(SynchronousCloser()))

var errorCodeBuffer = self.encoderChannel.allocator.buffer(capacity: 4)
errorCodeBuffer.write(webSocketErrorCode: .normalClosure)
Expand All @@ -382,7 +381,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertNil(try self.decoderChannel.readInbound(as: WebSocketFrame.self)))
}

public func testDecoderRejectsOverlongFramesWithNoAutomaticErrorHandling() {
func testDecoderRejectsOverlongFramesWithNoAutomaticErrorHandling() {
// We need to insert a decoder that doesn't do error handling. We still insert
// an encoder because we want to fail gracefully if a frame is written.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
Expand All @@ -402,7 +401,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertEqual([], try self.decoderChannel.readAllOutboundBytes()))
}

public func testDecoderRejectsFragmentedControlFramesWithNoAutomaticErrorHandling() throws {
func testDecoderRejectsFragmentedControlFramesWithNoAutomaticErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling. We still insert
// an encoder because we want to fail gracefully if a frame is written.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
Expand All @@ -421,7 +420,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertEqual([], try self.decoderChannel.readAllOutboundBytes()))
}

public func testDecoderRejectsMultibyteControlFrameLengthsWithNoAutomaticErrorHandling() throws {
func testDecoderRejectsMultibyteControlFrameLengthsWithNoAutomaticErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling. We still insert
// an encoder because we want to fail gracefully if a frame is written.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
Expand Down Expand Up @@ -476,7 +475,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertNil(try self.decoderChannel.readOutbound()))
}

public func testDecoderRejectsOverlongFramesWithSeparateErrorHandling() throws {
func testDecoderRejectsOverlongFramesWithSeparateErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling, and then a separate error
// handler.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
Expand All @@ -497,7 +496,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertEqual([0x88, 0x02, 0x03, 0xF1], try self.decoderChannel.readAllOutboundBytes()))
}

public func testDecoderRejectsFragmentedControlFramesWithSeparateErrorHandling() throws {
func testDecoderRejectsFragmentedControlFramesWithSeparateErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling, and then a separate error
// handler.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
Expand All @@ -517,7 +516,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertEqual([0x88, 0x02, 0x03, 0xEA], try self.decoderChannel.readAllOutboundBytes()))
}

public func testDecoderRejectsMultibyteControlFrameLengthsWithSeparateErrorHandling() throws {
func testDecoderRejectsMultibyteControlFrameLengthsWithSeparateErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling, and then a separate error
// handler.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
Expand Down Expand Up @@ -579,7 +578,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
swallower.allowClose()

// Take the handler out for cleanliness.
XCTAssertNoThrow(try self.decoderChannel.pipeline.removeHandler(swallower).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.removeHandler(swallower).wait())
}

func testErrorHandlerDoesNotSwallowRandomErrors() throws {
Expand Down
Loading
Loading