diff --git a/Sources/NIOWebSocket/WebSocketFrame.swift b/Sources/NIOWebSocket/WebSocketFrame.swift index d6b8a1ff2fd..dc180dff867 100644 --- a/Sources/NIOWebSocket/WebSocketFrame.swift +++ b/Sources/NIOWebSocket/WebSocketFrame.swift @@ -15,11 +15,13 @@ import NIOCore extension UInt8 { - fileprivate func isAnyBitSetInMask(_ mask: UInt8) -> Bool { + @usableFromInline + internal func isAnyBitSetInMask(_ mask: UInt8) -> Bool { self & mask != 0 } - fileprivate mutating func changingBitsInMask(_ mask: UInt8, to: Bool) { + @usableFromInline + internal mutating func changingBitsInMask(_ mask: UInt8, to: Bool) { if to { self |= mask } else { @@ -159,11 +161,13 @@ public struct WebSocketFrame { /// Rather than unpack all the fields from the first byte, and thus take up loads /// of storage in the structure, we keep them in their packed form in this byte and /// use computed properties to unpack them. + @usableFromInline internal var firstByte: UInt8 = 0 /// The value of the `fin` bit. If set, this is the last frame in a fragmented frame. If not /// set, this frame is one of the intermediate frames in a fragmented frame. Must be set if /// a frame is not fragmented at all. + @inlinable public var fin: Bool { get { self.firstByte.isAnyBitSetInMask(0x80) @@ -174,6 +178,7 @@ public struct WebSocketFrame { } /// The value of the first reserved bit. Must be `false` unless using an extension that defines its use. + @inlinable public var rsv1: Bool { get { self.firstByte.isAnyBitSetInMask(0x40) @@ -184,6 +189,7 @@ public struct WebSocketFrame { } /// The value of the second reserved bit. Must be `false` unless using an extension that defines its use. + @inlinable public var rsv2: Bool { get { self.firstByte.isAnyBitSetInMask(0x20) @@ -194,6 +200,7 @@ public struct WebSocketFrame { } /// The value of the third reserved bit. Must be `false` unless using an extension that defines its use. + @inlinable public var rsv3: Bool { get { self.firstByte.isAnyBitSetInMask(0x10) @@ -204,6 +211,7 @@ public struct WebSocketFrame { } /// The opcode for this frame. + @inlinable public var opcode: WebSocketOpcode { get { // this is a public initialiser which only fails if the opcode is invalid. But all opcodes in 0...0xF @@ -216,6 +224,7 @@ public struct WebSocketFrame { } /// The total length of the data in the frame. + @inlinable public var length: Int { data.readableBytes + (extensionData?.readableBytes ?? 0) } @@ -406,3 +415,35 @@ extension WebSocketFrame: CustomDebugStringConvertible { "(\(self.description))" } } + +extension WebSocketFrame { + /// WebSocketFrame reserved bits option set + public struct ReservedBits: OptionSet, Sendable { + public var rawValue: UInt8 + + @inlinable + public init(rawValue: UInt8) { + self.rawValue = rawValue + } + + @inlinable + public static var rsv1: Self { .init(rawValue: 0x40) } + @inlinable + public static var rsv2: Self { .init(rawValue: 0x20) } + @inlinable + public static var rsv3: Self { .init(rawValue: 0x10) } + @inlinable + public static var all: Self { .init(rawValue: 0x70) } + } + + /// The value of all the reserved bits. Must be `empty` unless using an extension that defines their use. + @inlinable + public var reservedBits: ReservedBits { + get { + .init(rawValue: self.firstByte & 0x70) + } + set { + self.firstByte = (self.firstByte & 0x8F) + newValue.rawValue + } + } +} diff --git a/Tests/NIOWebSocketTests/WebSocketMaskingKeyTests.swift b/Tests/NIOWebSocketTests/WebSocketMaskingKeyTests.swift index ff0cdf148b4..de928497dad 100644 --- a/Tests/NIOWebSocketTests/WebSocketMaskingKeyTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketMaskingKeyTests.swift @@ -43,4 +43,41 @@ final class WebSocketMaskingKeyTests: XCTestCase { "at least 1 of 1000 random masking keys with default generator should not be all zeros" ) } + + func testGetReservedBits() { + let frame = WebSocketFrame(rsv1: true, opcode: .binary, data: .init()) + XCTAssertEqual(frame.reservedBits.contains(.rsv1), true) + XCTAssertEqual(frame.reservedBits.contains(.rsv2), false) + XCTAssertEqual(frame.reservedBits.contains(.rsv3), false) + let frame2 = WebSocketFrame(rsv2: true, opcode: .binary, data: .init()) + XCTAssertEqual(frame2.reservedBits.contains(.rsv1), false) + XCTAssertEqual(frame2.reservedBits.contains(.rsv2), true) + XCTAssertEqual(frame2.reservedBits.contains(.rsv3), false) + let frame3 = WebSocketFrame(rsv3: true, opcode: .binary, data: .init()) + XCTAssertEqual(frame3.reservedBits.contains(.rsv1), false) + XCTAssertEqual(frame3.reservedBits.contains(.rsv2), false) + XCTAssertEqual(frame3.reservedBits.contains(.rsv3), true) + } + + func testSetReservedBits() { + var frame = WebSocketFrame(opcode: .binary, data: .init()) + frame.reservedBits = .rsv1 + XCTAssertEqual(frame.reservedBits.contains(.rsv1), true) + XCTAssertEqual(frame.reservedBits.contains(.rsv2), false) + XCTAssertEqual(frame.reservedBits.contains(.rsv3), false) + XCTAssertEqual(frame.fin, false) + XCTAssertEqual(frame.opcode, .binary) + frame.reservedBits = .rsv2 + XCTAssertEqual(frame.reservedBits.contains(.rsv1), false) + XCTAssertEqual(frame.reservedBits.contains(.rsv2), true) + XCTAssertEqual(frame.reservedBits.contains(.rsv3), false) + XCTAssertEqual(frame.fin, false) + XCTAssertEqual(frame.opcode, .binary) + frame.reservedBits = .rsv3 + XCTAssertEqual(frame.reservedBits.contains(.rsv1), false) + XCTAssertEqual(frame.reservedBits.contains(.rsv2), false) + XCTAssertEqual(frame.reservedBits.contains(.rsv3), true) + XCTAssertEqual(frame.fin, false) + XCTAssertEqual(frame.opcode, .binary) + } }