Skip to content
2 changes: 1 addition & 1 deletion Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ private func _upgrade<UpgradeResult>(
ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize))
)
if enableAutomaticErrorHandling {
try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false))
}
}
.flatMap {
Expand Down
5 changes: 1 addition & 4 deletions Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ extension HTTPHeaders {
///
/// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to
/// remove the HTTP `ChannelHandler`s.
public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unchecked Sendable {
// This type *is* Sendable but we can't express that properly until Swift 5.7. In the meantime
// the conformance is `@unchecked`.
public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, Sendable {

// FIXME: remove @unchecked when 5.7 is the minimum supported version.
private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture<HTTPHeaders?>
private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture<Void>
/// RFC 6455 specs this as the required entry in the Upgrade header.
Expand Down
22 changes: 21 additions & 1 deletion Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,20 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler {
public typealias InboundIn = Never
public typealias OutboundOut = WebSocketFrame

public init() {}
/// Indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is true.
private let isServer: Bool

public init() {
self.isServer = true
}

/// Initialize this `ChannelHandler` to be used by a WebSocket server or client.
///
/// - Parameters:
/// - isServer: indicate whether this `ChannelHandler` is used by a WebSocket server or client.
public init(isServer: Bool) {
self.isServer = isServer
}

public func errorCaught(context: ChannelHandlerContext, error: Error) {
let loopBoundContext = context.loopBound
Expand All @@ -32,6 +45,7 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler {
let frame = WebSocketFrame(
fin: true,
opcode: .connectionClose,
maskKey: self.makeMaskingKey(),
data: data
)
context.writeAndFlush(Self.wrapOutboundOut(frame)).whenComplete { (_: Result<Void, Error>) in
Expand All @@ -44,6 +58,12 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler {
// forward the error on to let others see it.
context.fireErrorCaught(error)
}

private func makeMaskingKey() -> WebSocketMaskingKey? {
// According to RFC 6455 Section 5, a client *must* mask all frames that it sends to the server.
// A server *must not* mask any frames that it sends to the client
self.isServer ? nil : .random()
}
}

@available(*, unavailable)
Expand Down
29 changes: 29 additions & 0 deletions Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,35 @@ class WebSocketClientEndToEndTests: XCTestCase {
// Close the pipeline.
XCTAssertNoThrow(try clientChannel.close().wait())
}

func testErrorHandlerMaskFrameForClient() throws {

let (clientChannel, _) = try self.runSuccessfulUpgrade()
let maskBitMask: UInt8 = 0x80

var data = clientChannel.allocator.buffer(capacity: 4)
// A fake frame header that claims that the length of the frame is 16385 bytes,
// larger than the frame max.
data.writeBytes([0x81, 0xFE, 0x40, 0x01])

XCTAssertThrowsError(try clientChannel.writeInbound(data)) { error in
XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError)
}

clientChannel.embeddedEventLoop.run()
var buffer = try clientChannel.readAllOutboundBuffers()

guard let (_, secondByte) = buffer.readMultipleIntegers(as: (UInt8, UInt8).self) else {
XCTFail("Insufficient bytes from WebSocket frame")
return
}

let maskedBit = (secondByte & maskBitMask)
XCTAssertEqual(0x80, maskedBit)

XCTAssertNoThrow(!clientChannel.isActive)
XCTAssertTrue(try clientChannel.finish(acceptAlreadyClosed: true).isClean)
}
}

#if !canImport(Darwin) || swift(>=5.10)
Expand Down
Loading