diff --git a/Sources/NIOPosix/BaseSocketChannel.swift b/Sources/NIOPosix/BaseSocketChannel.swift index 2a1951531f6..37e32b25b06 100644 --- a/Sources/NIOPosix/BaseSocketChannel.swift +++ b/Sources/NIOPosix/BaseSocketChannel.swift @@ -595,8 +595,16 @@ class BaseSocketChannel: SelectableChannel, Chan switch writeResult.writeResult { case .couldNotWriteEverything: newWriteRegistrationState = .register - case .writtenCompletely: + case .writtenCompletely(let closeState): newWriteRegistrationState = .unregister + switch closeState { + case .open: + () + case .readyForClose: + self.close0(error: ChannelError.outputClosed, mode: .output, promise: nil) + case .closed: + () // we can be flushed before becoming active + } } if !self.isOpen || !self.hasFlushedPendingWrites() { diff --git a/Sources/NIOPosix/BaseStreamSocketChannel.swift b/Sources/NIOPosix/BaseStreamSocketChannel.swift index 7878415d3f2..0ec10d7fff9 100644 --- a/Sources/NIOPosix/BaseStreamSocketChannel.swift +++ b/Sources/NIOPosix/BaseStreamSocketChannel.swift @@ -194,13 +194,35 @@ class BaseStreamSocketChannel: BaseSocketChannel self.close0(error: error, mode: .all, promise: promise) return } - try self.shutdownSocket(mode: mode) - // Fail all pending writes and so ensure all pending promises are notified - self.pendingWrites.failAll(error: error, close: false) - self.unregisterForWritable() - promise?.succeed(()) - self.pipeline.fireUserInboundEventTriggered(ChannelEvent.outputClosed) + let result = self.pendingWrites.closeOutbound(promise) + switch result { + case .pending: + () // promise is stored in `pendingWrites` state for completing later + + case .readyForClose(let closePromise): + // Shutdown the socket only when the pending writes are dealt with + do { + try self.shutdownSocket(mode: mode) + closePromise?.succeed(()) + } catch let err { + closePromise?.fail(err) + } + self.unregisterForWritable() + self.pipeline.fireUserInboundEventTriggered(ChannelEvent.outputClosed) + + case .closed(let closePromise): + closePromise?.succeed(()) + + case .errored(let err, let closePromise): + assertionFailure("Close errored: \(err)") + closePromise?.fail(err) + + // Escalate to full closure + // promise is nil here because we have used the supplied promise to convey failure of the half-close + self.close0(error: err, mode: .all, promise: nil) + } + case .input: if self.inputShutdown { promise?.fail(ChannelError._inputClosed) @@ -224,6 +246,7 @@ class BaseStreamSocketChannel: BaseSocketChannel promise?.succeed(()) self.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + case .all: if let timeout = self.connectTimeoutScheduled { self.connectTimeoutScheduled = nil @@ -247,7 +270,9 @@ class BaseStreamSocketChannel: BaseSocketChannel } final override func cancelWritesOnClose(error: Error) { - self.pendingWrites.failAll(error: error, close: true) + if let eventLoopPromise = self.pendingWrites.failAll(error: error) { + eventLoopPromise.fail(error) + } } @discardableResult diff --git a/Sources/NIOPosix/PendingDatagramWritesManager.swift b/Sources/NIOPosix/PendingDatagramWritesManager.swift index e9f201ce90d..8dfca58284d 100644 --- a/Sources/NIOPosix/PendingDatagramWritesManager.swift +++ b/Sources/NIOPosix/PendingDatagramWritesManager.swift @@ -419,6 +419,13 @@ final class PendingDatagramWritesManager: PendingWritesManager { internal var publishedWritability = true internal var writeSpinCount: UInt = 16 private(set) var isOpen = true + var outboundCloseState: CloseState { + if self.isOpen { + .open + } else { + .closed + } + } /// Initialize with a pre-allocated array of message headers and storage references. We pass in these pre-allocated /// objects to save allocations. They can be safely be re-used for all `Channel`s on a given `EventLoop` as an diff --git a/Sources/NIOPosix/PendingWritesManager.swift b/Sources/NIOPosix/PendingWritesManager.swift index 6e6b3fff7d2..c5c8b8d33eb 100644 --- a/Sources/NIOPosix/PendingWritesManager.swift +++ b/Sources/NIOPosix/PendingWritesManager.swift @@ -21,6 +21,15 @@ private struct PendingStreamWrite { var promise: Optional> } +/// Write result is `.couldNotWriteEverything` but we have no more writes to perform. +public struct NIOReportedIncompleteWritesWhenNoMoreToPerform: Error {} + +/// Close result is `.open`. +public struct NIOReportedOpenAfterClose: Error {} + +/// There are buffered writes after it should have been cleared, `.readyForClose` or `.closed` state +public struct NIOReportedPendingWritesInInvalidState: Error {} + /// Does the setup required to issue a writev. /// /// - Parameters: @@ -97,12 +106,33 @@ internal enum OneWriteOperationResult { /// The result of trying to write all the outstanding flushed data. That naturally includes all `ByteBuffer`s and /// `FileRegions` and the individual writes have potentially been retried (see `WriteSpinOption`). internal struct OverallWriteResult { - enum WriteOutcome { + enum WriteOutcome: Equatable { /// Wrote all the data that was flushed. When receiving this result, we can unsubscribe from 'writable' notification. - case writtenCompletely + case writtenCompletely(WrittenCompletelyResult) /// Could not write everything. Before attempting further writes the eventing system should send a 'writable' notification. case couldNotWriteEverything + + /// The resulting status of a `PendingWritesManager` after a completely-written write + /// + /// This type is subtly different to `CloseState` so that it only surfaces the close promise when the caller + /// is expected to fulfill it + internal enum WrittenCompletelyResult: Equatable { + case open + case readyForClose(EventLoopPromise?) + case closed(EventLoopPromise?) + + init(_ closeState: CloseState) { + switch closeState { + case .open: + self = .open + case .pending(let closePromise), .readyForClose(let closePromise): + self = .readyForClose(closePromise) + case .closed: + self = .closed(nil) + } + } + } } internal var writeResult: WriteOutcome @@ -152,7 +182,7 @@ private struct PendingStreamWritesState { self.subtractOutstanding(bytes: bytes) } - /// Initialise a new, empty `PendingWritesState`. + /// Initialize a new, empty `PendingWritesState`. public init() {} /// Check if there are no outstanding writes. @@ -310,6 +340,8 @@ final class PendingStreamWritesManager: PendingWritesManager { private(set) var isOpen = true + private(set) var outboundCloseState: CloseState = .open + /// Mark the flush checkpoint. func markFlushCheckpoint() { self.state.markFlushCheckpoint() @@ -337,7 +369,7 @@ final class PendingStreamWritesManager: PendingWritesManager { /// - result: If the `Channel` is still writable after adding the write of `data`. func add(data: IOData, promise: EventLoopPromise?) -> Bool { assert(self.isOpen) - self.state.append(.init(data: data, promise: promise)) + self.state.append(PendingStreamWrite(data: data, promise: promise)) if self.state.bytes > waterMark.high && channelWritabilityFlag.compareExchange(expected: true, desired: false, ordering: .relaxed).exchanged @@ -463,16 +495,101 @@ final class PendingStreamWritesManager: PendingWritesManager { return self.didWrite(itemCount: result.itemCount, result: result.writeResult) } - /// Fail all the outstanding writes. This is useful if for example the `Channel` is closed. - func failAll(error: Error, close: Bool) { - if close { - assert(self.isOpen) - self.isOpen = false + /// Fail all the outstanding writes. + func failAll(error: Error) -> EventLoopPromise? { + assert(self.isOpen) + + let promise: EventLoopPromise? + self.isOpen = false + switch self.outboundCloseState { + case .open, .closed: + self.outboundCloseState = .closed + promise = nil + case .pending(let closePromise), .readyForClose(let closePromise): + self.outboundCloseState = .closed + promise = closePromise } self.state.removeAll()?.fail(error) assert(self.state.isEmpty) + return promise + } + + // The result of calling `closeOutbound` + enum CloseOutboundResult { + case pending + case readyForClose(EventLoopPromise?) + case closed(EventLoopPromise?) + case errored(Error, EventLoopPromise?) + + init(_ closeState: CloseState, _ isEmpty: Bool, _ promise: EventLoopPromise?) { + switch closeState { + case .open: + assertionFailure( + "We are in .open state after being asked to close. This should never happen." + ) + self = .errored(NIOReportedOpenAfterClose(), promise) + case .pending: + // `promise` has already been taken care of in the pending state for later completion + self = .pending + case .readyForClose(let closePromise): + if isEmpty { + self = .readyForClose(closePromise) + } else { + assertionFailure( + "We are in .readyForClose state but we still have pending writes. This should never happen." + ) + // `promise` has already been cascaded off `closePromise` + self = .errored(NIOReportedPendingWritesInInvalidState(), closePromise) + } + case .closed: + if isEmpty { + self = .closed(promise) + } else { + assertionFailure( + "We are in .closed state but we still have pending writes. This should never happen." + ) + self = .errored(NIOReportedPendingWritesInInvalidState(), promise) + } + } + } + } + + /// Signal the intention to close. Takes a promise which will be returned for completing when pending writes are dealt with + /// + /// - Parameters: + /// - promise: Optionally an `EventLoopPromise` which is stored and is returned to be completed by the caller once + /// all outstanding writes have been dealt with or an error condition is encountered. + func closeOutbound(_ promise: EventLoopPromise?) -> CloseOutboundResult { + assert(self.isOpen) + + // Update our internal state + switch self.outboundCloseState { + case .open: + if self.isEmpty { + self.outboundCloseState = .readyForClose(promise) + } else { + self.outboundCloseState = .pending(promise) + } + case .readyForClose(var closePromise): + closePromise.setOrCascade(to: promise) + self.outboundCloseState = .readyForClose(closePromise) + case .pending(var closePromise): + closePromise.setOrCascade(to: promise) + if self.isEmpty { + self.outboundCloseState = .readyForClose(closePromise) + } else { + self.outboundCloseState = .pending(closePromise) + } + case .closed: + () + } + + // Decide on the result + let result = CloseOutboundResult(self.outboundCloseState, self.isEmpty, promise) + + return result } /// Initialize with a pre-allocated array of IO vectors and storage references. We pass in these pre-allocated @@ -496,6 +613,8 @@ internal enum WriteMechanism { internal protocol PendingWritesManager: AnyObject { var isOpen: Bool { get } + var isEmpty: Bool { get } + var outboundCloseState: CloseState { get } var isFlushPending: Bool { get } var writeSpinCount: UInt { get } var currentBestWriteMechanism: WriteMechanism { get } @@ -507,6 +626,18 @@ internal protocol PendingWritesManager: AnyObject { var publishedWritability: Bool { get set } } +/// Describes the state that a `PendingWritesManager` closure state machine will step through when instructed to close +internal enum CloseState { + /// The manager will accept new writes + case open + /// The manager has been asked to close but cannot because its write buffer is not empty + case pending(EventLoopPromise?) + /// The manager has been asked to close and is ready to be closed because its write buffer is empty + case readyForClose(EventLoopPromise?) + /// The manager is closed + case closed +} + extension PendingWritesManager { // This is called from `Channel` API so must be thread-safe. var isWritable: Bool { @@ -522,7 +653,7 @@ extension PendingWritesManager { var oneResult: OneWriteOperationResult repeat { guard self.isOpen && self.isFlushPending else { - result.writeResult = .writtenCompletely + result.writeResult = .writtenCompletely(.init(self.outboundCloseState)) break writeSpinLoop } diff --git a/Tests/NIOPosixTests/ChannelTests.swift b/Tests/NIOPosixTests/ChannelTests.swift index fd939f43d11..76e5483ed3c 100644 --- a/Tests/NIOPosixTests/ChannelTests.swift +++ b/Tests/NIOPosixTests/ChannelTests.swift @@ -528,7 +528,7 @@ final class ChannelTests: XCTestCase { XCTAssertFalse(pwm.isEmpty) XCTAssertFalse(pwm.isFlushPending) XCTAssertEqual(0, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) result = try assertExpectedWritability( pendingWritesManager: pwm, @@ -539,7 +539,7 @@ final class ChannelTests: XCTestCase { returns: [], promiseStates: [[true, false]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) pwm.markFlushCheckpoint() @@ -554,7 +554,7 @@ final class ChannelTests: XCTestCase { promiseStates: [[true, true]] ) XCTAssertEqual(0, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) } } @@ -584,7 +584,7 @@ final class ChannelTests: XCTestCase { returns: [.processed(8)], promiseStates: [[true, true, false]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) pwm.markFlushCheckpoint() @@ -598,7 +598,7 @@ final class ChannelTests: XCTestCase { returns: [.processed(0)], promiseStates: [[true, true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -655,7 +655,7 @@ final class ChannelTests: XCTestCase { returns: [.processed(8)], promiseStates: [[true, true, true, true], [true, true, true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(totalBytes - 1 - 7 - 8, pwm.bufferedBytes) } } @@ -704,7 +704,7 @@ final class ChannelTests: XCTestCase { returns: [.processed(1)], promiseStates: [[true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -768,7 +768,7 @@ final class ChannelTests: XCTestCase { returns: [.processed(1)], promiseStates: [Array(repeating: true, count: numberOfBytes)] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -819,7 +819,7 @@ final class ChannelTests: XCTestCase { ) XCTAssertEqual(0, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) } } @@ -853,7 +853,7 @@ final class ChannelTests: XCTestCase { XCTAssertEqual(.couldNotWriteEverything, result.writeResult) XCTAssertEqual(totalBytes - 2, pwm.bufferedBytes) - pwm.failAll(error: ChannelError.operationUnsupported, close: true) + _ = pwm.failAll(error: ChannelError.operationUnsupported) XCTAssertTrue(ps.map { $0.futureResult.isFulfilled }.allSatisfy { $0 }) } @@ -892,7 +892,7 @@ final class ChannelTests: XCTestCase { returns: [.processed(2 * halfTheWriteVLimit), .processed(halfTheWriteVLimit)], promiseStates: [[true, true, false], [true, true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -962,7 +962,7 @@ final class ChannelTests: XCTestCase { [true, true, true], ] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) pwm.markFlushCheckpoint() } @@ -1000,7 +1000,7 @@ final class ChannelTests: XCTestCase { returns: [.processed(2)], promiseStates: [[true, false]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) totalBytes -= Int64(fr1.readableBytes) XCTAssertEqual(totalBytes, pwm.bufferedBytes) @@ -1013,7 +1013,7 @@ final class ChannelTests: XCTestCase { returns: [], promiseStates: [[true, false]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(totalBytes, pwm.bufferedBytes) pwm.markFlushCheckpoint() @@ -1029,7 +1029,7 @@ final class ChannelTests: XCTestCase { totalBytes -= Int64(fr2.readableBytes) XCTAssertEqual(totalBytes, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) } } @@ -1059,7 +1059,7 @@ final class ChannelTests: XCTestCase { ) XCTAssertEqual(0, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) } } @@ -1134,7 +1134,7 @@ final class ChannelTests: XCTestCase { totalBytes -= 4 XCTAssertEqual(totalBytes, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) } } @@ -1180,7 +1180,7 @@ final class ChannelTests: XCTestCase { returns: [.processed(8)], promiseStates: [[true, true, false]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) pwm.markFlushCheckpoint() @@ -1196,7 +1196,7 @@ final class ChannelTests: XCTestCase { ) XCTAssertEqual(0, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) } } @@ -1223,7 +1223,7 @@ final class ChannelTests: XCTestCase { returns: [.processed(0)], promiseStates: [[true, true, false]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) pwm.markFlushCheckpoint() @@ -1239,7 +1239,7 @@ final class ChannelTests: XCTestCase { ) XCTAssertEqual(0, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) } } @@ -1259,7 +1259,7 @@ final class ChannelTests: XCTestCase { XCTAssertEqual(Int64(buffer.readableBytes * 3), pwm.bufferedBytes) ps[0].futureResult.assumeIsolated().whenComplete { (_: Result) in - pwm.failAll(error: ChannelError.inputClosed, close: true) + _ = pwm.failAll(error: ChannelError.inputClosed) } let result = try assertExpectedWritability( @@ -1273,7 +1273,7 @@ final class ChannelTests: XCTestCase { ) XCTAssertEqual(0, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.closed(nil)), result.writeResult) XCTAssertNoThrow(try ps[0].futureResult.wait()) XCTAssertThrowsError(try ps[1].futureResult.wait()) XCTAssertThrowsError(try ps[2].futureResult.wait()) @@ -1322,7 +1322,7 @@ final class ChannelTests: XCTestCase { promiseStates: [Array(repeating: true, count: Socket.writevLimitIOVectors + 1)] ) XCTAssertEqual(0, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) } } @@ -1353,7 +1353,7 @@ final class ChannelTests: XCTestCase { ) XCTAssertEqual(0, pwm.bufferedBytes) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) } } diff --git a/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift b/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift index a91da0879b4..8771e21b27f 100644 --- a/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift +++ b/Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift @@ -346,7 +346,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { XCTAssertFalse(pwm.isEmpty) XCTAssertFalse(pwm.isFlushPending) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) result = try assertExpectedWritability( pendingWritesManager: pwm, @@ -356,7 +356,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [], promiseStates: [[true, false]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(Int64(buffer.readableBytes), pwm.bufferedBytes) pwm.markFlushCheckpoint() @@ -369,7 +369,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(0))], promiseStates: [[true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(Int64(buffer.readableBytes), pwm.bufferedBytes) } } @@ -401,7 +401,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(2))], promiseStates: [[true, true, false]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) pwm.markFlushCheckpoint() @@ -414,7 +414,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(0))], promiseStates: [[true, true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -474,7 +474,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(4))], promiseStates: [[true, true, true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -527,7 +527,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(12))], promiseStates: [Array(repeating: true, count: ps.count - 1) + [true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -600,7 +600,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(2)), .success(.processed(1))], promiseStates: [[true, true, false], [true, true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -653,7 +653,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { promiseStates: [[true, false, false], [true, true, false], [true, true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) XCTAssertNoThrow(try ps[1].futureResult.wait()) @@ -693,7 +693,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(2))], promiseStates: [[true, true, false]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) pwm.markFlushCheckpoint() @@ -706,7 +706,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(0))], promiseStates: [[true, true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -739,7 +739,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(1))], promiseStates: [[true, true, true]] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.closed(nil)), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) XCTAssertNoThrow(try ps[0].futureResult.wait()) XCTAssertThrowsError(try ps[1].futureResult.wait()) @@ -784,7 +784,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(4))], promiseStates: [Array(repeating: true, count: Socket.writevLimitIOVectors + 1)] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } @@ -845,7 +845,7 @@ class PendingDatagramWritesManagerTests: XCTestCase { returns: [.success(.processed(1))], promiseStates: [Array(repeating: true, count: 5)] ) - XCTAssertEqual(.writtenCompletely, result.writeResult) + XCTAssertEqual(.writtenCompletely(.open), result.writeResult) XCTAssertEqual(0, pwm.bufferedBytes) } } diff --git a/Tests/NIOPosixTests/StreamChannelsTest.swift b/Tests/NIOPosixTests/StreamChannelsTest.swift index dfbe8a943f5..c7b268deb60 100644 --- a/Tests/NIOPosixTests/StreamChannelsTest.swift +++ b/Tests/NIOPosixTests/StreamChannelsTest.swift @@ -14,6 +14,7 @@ import Atomics import CNIOLinux +import NIOConcurrencyHelpers import NIOCore import NIOTestUtils import XCTest @@ -261,6 +262,160 @@ class StreamChannelTest: XCTestCase { XCTAssertNoThrow(try forEachCrossConnectedStreamChannelPair(runTest)) } + func testHalfCloseOwnOutputWithPopulatedBuffer() throws { + func runTest(chan1: Channel, chan2: Channel) throws { + let readPromise = chan2.eventLoop.makePromise(of: Void.self) + + XCTAssertNoThrow(try chan1.setOption(.allowRemoteHalfClosure, value: true).wait()) + + self.buffer.writeString("X") + XCTAssertNoThrow( + try chan2.pipeline.addHandler(FulfillOnFirstEventHandler(channelReadPromise: readPromise)).wait() + ) + + // let's write a byte from chan1 to chan2 which we leave in the buffer. + let writeFuture = chan1.write(self.buffer) + + // close chan1's output, this shouldn't take effect until the buffer is empty + let closeFuture = chan1.close(mode: .output) + + // flush chan1's output + chan1.flush() + + // Attempt to write a byte from chan1 to chan2 which should be refused after the close + XCTAssertThrowsError(try chan1.write(self.buffer).wait()) { error in + XCTAssertEqual(ChannelError.outputClosed, error as? ChannelError, "\(chan1)") + } + + // wait for the write to complete + XCTAssertNoThrow(try writeFuture.wait(), "chan1 write failed") + + // and wait for it to arrive + XCTAssertNoThrow(try readPromise.futureResult.wait()) + + // wait for the close to complete + XCTAssertNoThrow(try closeFuture.wait(), "chan1 close failed") + + XCTAssertNoThrow(try chan1.syncCloseAcceptingAlreadyClosed()) + XCTAssertNoThrow(try chan2.syncCloseAcceptingAlreadyClosed()) + } + XCTAssertNoThrow(try forEachCrossConnectedStreamChannelPair(runTest)) + } + + func testHalfCloseOwnOutputWithWritabilityChange() throws { + final class BytesReadCountingHandler: ChannelInboundHandler, Sendable { + typealias InboundIn = ByteBuffer + + private let numBytes = NIOLockedValueBox(0) + private let numBytesReadAtInputClose = NIOLockedValueBox(0) + + var bytesRead: Int { + self.numBytes.withLockedValue { $0 } + } + var bytesReadAtInputClose: Int { + self.numBytesReadAtInputClose.withLockedValue { $0 } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let currentBuffer = Self.unwrapInboundIn(data) + self.numBytes.withLockedValue { numBytes in + numBytes += currentBuffer.readableBytes + } + context.fireChannelRead(data) + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + if event as? ChannelEvent == .some(.inputClosed) { + let numBytes = self.numBytes.withLockedValue { $0 } + self.numBytesReadAtInputClose.withLockedValue { $0 = numBytes } + context.close(mode: .all, promise: nil) + } + context.fireUserInboundEventTriggered(event) + } + } + + final class BytesWrittenCountingHandler: ChannelInboundHandler, Sendable { + typealias InboundIn = ByteBuffer + + public typealias OutboundIn = ByteBuffer + public typealias OutboundOut = ByteBuffer + + private let numBytes = NIOLockedValueBox(0) + private let seenOutputClosed = NIOLockedValueBox(false) + + func setup(_ context: ChannelHandlerContext) { + let bufferLength = 1024 + let bytesToWrite = ByteBuffer.init(repeating: 0x42, count: bufferLength) + + // write until the kernel buffer and the pendingWrites buffer are full + while context.channel.isWritable { + XCTAssertNoThrow(context.writeAndFlush(self.wrapOutboundOut(bytesToWrite), promise: nil)) + self.numBytes.withLockedValue { numBytes in + numBytes += bufferLength + } + } + } + + var bytesWritten: Int { + self.numBytes.withLockedValue { $0 } + } + + var seenOutputClosedEvent: Bool { + self.seenOutputClosed.withLockedValue { $0 } + } + + func channelActive(context: ChannelHandlerContext) { + self.setup(context) + context.fireChannelActive() + } + + func handlerAdded(context: ChannelHandlerContext) { + self.setup(context) + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + if event as? ChannelEvent == .some(.outputClosed) { + self.seenOutputClosed.withLockedValue { $0 = true } + } + context.fireUserInboundEventTriggered(event) + } + } + + func runTest(chan1: Channel, chan2: Channel) throws { + try chan1.setOption(.autoRead, value: false).wait() + try chan1.setOption(.allowRemoteHalfClosure, value: true).wait() + + let bytesReadCountingHandler = BytesReadCountingHandler() + try chan1.pipeline.addHandler(bytesReadCountingHandler).wait() + + let bytesWrittenCountingHandler = BytesWrittenCountingHandler() + try chan2.pipeline.addHandler(bytesWrittenCountingHandler).wait() + + XCTAssertFalse(bytesWrittenCountingHandler.seenOutputClosedEvent) + + // close the writing side + let chan2ClosePromise = chan2.eventLoop.makePromise(of: Void.self) + chan2.close(mode: .output, promise: chan2ClosePromise) + + XCTAssertFalse(bytesWrittenCountingHandler.seenOutputClosedEvent) + + // tell the read side to begin reading leading to the write buffers draining + try chan1.setOption(.autoRead, value: true).wait() + + // wait for the reading-side close to complete + try chan1.closeFuture.wait() + + XCTAssertTrue(bytesWrittenCountingHandler.seenOutputClosedEvent) + + // now the dust has settled all the bytes should be accounted for + XCTAssertNotEqual(bytesWrittenCountingHandler.bytesWritten, 0) + XCTAssertEqual(bytesReadCountingHandler.bytesRead, bytesWrittenCountingHandler.bytesWritten) + XCTAssertEqual(bytesReadCountingHandler.bytesRead, bytesReadCountingHandler.bytesReadAtInputClose) + + } + XCTAssertNoThrow(try forEachCrossConnectedStreamChannelPair(forceSeparateEventLoops: false, runTest)) + } + func testHalfCloseOwnInput() { func runTest(chan1: Channel, chan2: Channel) throws {