Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions Sources/NIOWebSocket/WebSocketFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import NIOCore

extension UInt8 {
fileprivate func isAnyBitSetInMask(_ mask: UInt8) -> Bool {
@usableFromInline
internal func isAnyBitSetInMask(_ mask: UInt8) -> Bool {
self & mask != 0
}

fileprivate mutating func changingBitsInMask(_ mask: UInt8, to: Bool) {
@usableFromInline
internal mutating func changingBitsInMask(_ mask: UInt8, to: Bool) {
if to {
self |= mask
} else {
Expand Down Expand Up @@ -159,11 +161,13 @@ public struct WebSocketFrame {
/// Rather than unpack all the fields from the first byte, and thus take up loads
/// of storage in the structure, we keep them in their packed form in this byte and
/// use computed properties to unpack them.
@usableFromInline
internal var firstByte: UInt8 = 0

/// The value of the `fin` bit. If set, this is the last frame in a fragmented frame. If not
/// set, this frame is one of the intermediate frames in a fragmented frame. Must be set if
/// a frame is not fragmented at all.
@inlinable
public var fin: Bool {
get {
self.firstByte.isAnyBitSetInMask(0x80)
Expand All @@ -174,6 +178,7 @@ public struct WebSocketFrame {
}

/// The value of the first reserved bit. Must be `false` unless using an extension that defines its use.
@inlinable
public var rsv1: Bool {
get {
self.firstByte.isAnyBitSetInMask(0x40)
Expand All @@ -184,6 +189,7 @@ public struct WebSocketFrame {
}

/// The value of the second reserved bit. Must be `false` unless using an extension that defines its use.
@inlinable
public var rsv2: Bool {
get {
self.firstByte.isAnyBitSetInMask(0x20)
Expand All @@ -194,6 +200,7 @@ public struct WebSocketFrame {
}

/// The value of the third reserved bit. Must be `false` unless using an extension that defines its use.
@inlinable
public var rsv3: Bool {
get {
self.firstByte.isAnyBitSetInMask(0x10)
Expand All @@ -204,6 +211,7 @@ public struct WebSocketFrame {
}

/// The opcode for this frame.
@inlinable
public var opcode: WebSocketOpcode {
get {
// this is a public initialiser which only fails if the opcode is invalid. But all opcodes in 0...0xF
Expand All @@ -216,6 +224,7 @@ public struct WebSocketFrame {
}

/// The total length of the data in the frame.
@inlinable
public var length: Int {
data.readableBytes + (extensionData?.readableBytes ?? 0)
}
Expand Down Expand Up @@ -406,3 +415,35 @@ extension WebSocketFrame: CustomDebugStringConvertible {
"(\(self.description))"
}
}

extension WebSocketFrame {
/// WebSocketFrame reserved bits option set
public struct ReservedBits: OptionSet, Sendable {
public var rawValue: UInt8

@inlinable
public init(rawValue: UInt8) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this @inlinable and all the static vars as well. That unlocks the meaningful performance improvement.

Copy link
Contributor Author

@adam-fowler adam-fowler Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed the get/set accessors on firstByte are not @inlinable. Is it worthwhile tagging all of those? eg fin, opcode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tagged them anyway

self.rawValue = rawValue
}

@inlinable
public static var rsv1: Self { .init(rawValue: 0x40) }
@inlinable
public static var rsv2: Self { .init(rawValue: 0x20) }
@inlinable
public static var rsv3: Self { .init(rawValue: 0x10) }
@inlinable
public static var all: Self { .init(rawValue: 0x70) }
}

/// The value of all the reserved bits. Must be `empty` unless using an extension that defines their use.
@inlinable
public var reservedBits: ReservedBits {
get {
.init(rawValue: self.firstByte & 0x70)
}
set {
self.firstByte = (self.firstByte & 0x8F) + newValue.rawValue
}
}
}
37 changes: 37 additions & 0 deletions Tests/NIOWebSocketTests/WebSocketMaskingKeyTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,41 @@ final class WebSocketMaskingKeyTests: XCTestCase {
"at least 1 of 1000 random masking keys with default generator should not be all zeros"
)
}

func testGetReservedBits() {
let frame = WebSocketFrame(rsv1: true, opcode: .binary, data: .init())
XCTAssertEqual(frame.reservedBits.contains(.rsv1), true)
XCTAssertEqual(frame.reservedBits.contains(.rsv2), false)
XCTAssertEqual(frame.reservedBits.contains(.rsv3), false)
let frame2 = WebSocketFrame(rsv2: true, opcode: .binary, data: .init())
XCTAssertEqual(frame2.reservedBits.contains(.rsv1), false)
XCTAssertEqual(frame2.reservedBits.contains(.rsv2), true)
XCTAssertEqual(frame2.reservedBits.contains(.rsv3), false)
let frame3 = WebSocketFrame(rsv3: true, opcode: .binary, data: .init())
XCTAssertEqual(frame3.reservedBits.contains(.rsv1), false)
XCTAssertEqual(frame3.reservedBits.contains(.rsv2), false)
XCTAssertEqual(frame3.reservedBits.contains(.rsv3), true)
}

func testSetReservedBits() {
var frame = WebSocketFrame(opcode: .binary, data: .init())
frame.reservedBits = .rsv1
XCTAssertEqual(frame.reservedBits.contains(.rsv1), true)
XCTAssertEqual(frame.reservedBits.contains(.rsv2), false)
XCTAssertEqual(frame.reservedBits.contains(.rsv3), false)
XCTAssertEqual(frame.fin, false)
XCTAssertEqual(frame.opcode, .binary)
frame.reservedBits = .rsv2
XCTAssertEqual(frame.reservedBits.contains(.rsv1), false)
XCTAssertEqual(frame.reservedBits.contains(.rsv2), true)
XCTAssertEqual(frame.reservedBits.contains(.rsv3), false)
XCTAssertEqual(frame.fin, false)
XCTAssertEqual(frame.opcode, .binary)
frame.reservedBits = .rsv3
XCTAssertEqual(frame.reservedBits.contains(.rsv1), false)
XCTAssertEqual(frame.reservedBits.contains(.rsv2), false)
XCTAssertEqual(frame.reservedBits.contains(.rsv3), true)
XCTAssertEqual(frame.fin, false)
XCTAssertEqual(frame.opcode, .binary)
}
}
Loading