Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ let package = Package(
"NIOCore",
"NIOEmbedded",
"NIOFoundationCompat",
"NIOTestUtils",
swiftAtomics,
],
swiftSettings: strictConcurrencySettings
Expand Down Expand Up @@ -520,6 +521,7 @@ let package = Package(
dependencies: [
"NIOTestUtils",
"NIOCore",
"NIOConcurrencyHelpers",
"NIOEmbedded",
"NIOPosix",
]
Expand Down
116 changes: 116 additions & 0 deletions Sources/NIOTestUtils/NIOThreadPoolTaskExecutor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

#if compiler(>=6)

import NIOPosix

/// Run a `NIOThreadPool` based `TaskExecutor` while executing the given `body`.
///
/// This function provides a `TaskExecutor`, **not** a `SerialExecutor`. The executor can be
/// used for setting the executor preference of a task.
///
/// Example usage:
/// ```swift
/// await withNIOThreadPoolTaskExecutor(numberOfThreads: 2) { taskExecutor in
/// await withDiscardingTaskGroup { group in
/// group.addTask(executorPreference: taskExecutor) { ... }
/// }
/// }
/// ```
///
/// - warning: Do not escape the task executor from the closure for later use and make sure that
/// all tasks running on the executor are completely finished before `body` returns.
/// For unstructured tasks, this means awaiting their results. If any task is still
/// running on the executor when `body` returns, this results in a fatalError.
/// It is highly recommended to use structured concurrency with this task executor.
///
/// - Parameters:
/// - numberOfThreads: The number of threads in the pool.
/// - body: The closure that will accept the task executor.
///
/// - Throws: When `body` throws.
///
/// - Returns: The value returned by `body`.
@inlinable
public func withNIOThreadPoolTaskExecutor<T, Failure>(
numberOfThreads: Int,
body: (NIOThreadPoolTaskExecutor) async throws(Failure) -> T
) async throws(Failure) -> T {
let taskExecutor = NIOThreadPoolTaskExecutor(numberOfThreads: numberOfThreads)
taskExecutor.start()

let result: Result<T, Failure>
do {
result = .success(try await body(taskExecutor))
} catch {
result = .failure(error)
}

await taskExecutor.shutdownGracefully()

return try result.get()
}

/// A task executor based on NIOThreadPool.
///
/// Provides a `TaskExecutor`, **not** a `SerialExecutor`. The executor can be
/// used for setting the executor preference of a task.
///
public final class NIOThreadPoolTaskExecutor: TaskExecutor {
let nioThreadPool: NIOThreadPool

/// Initialize a `NIOThreadPoolTaskExecutor`, using a thread pool with `numberOfThreads` threads.
///
/// - Parameters:
/// - numberOfThreads: The number of threads to use for the thread pool.
public init(numberOfThreads: Int) {
self.nioThreadPool = NIOThreadPool(numberOfThreads: numberOfThreads)
}

/// Start the `NIOThreadPoolTaskExecutor`.
public func start() {
nioThreadPool.start()
}

/// Gracefully shutdown this `NIOThreadPoolTaskExecutor`.
///
/// Make sure that all tasks running on the executor are finished before shutting down.
///
/// - warning: If any task is still running on the executor, this results in a fatalError.
public func shutdownGracefully() async {
do {
try await nioThreadPool.shutdownGracefully()
} catch {
fatalError("Failed to shutdown NIOThreadPool")
}
}

/// Enqueue a job.
///
/// Called by the concurrency runtime.
///
/// - Parameter job: The job to enqueue.
public func enqueue(_ job: consuming ExecutorJob) {
let unownedJob = UnownedJob(job)
self.nioThreadPool.submit { shouldRun in
guard case shouldRun = NIOThreadPool.WorkItemState.active else {
fatalError("Shutdown before all tasks finished")
}
unownedJob.runSynchronously(on: self.asUnownedTaskExecutor())
}
}
}

#endif // compiler(>=6)
90 changes: 50 additions & 40 deletions Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import DequeModule
import NIOConcurrencyHelpers
import NIOTestUtils
import XCTest

@testable import NIOCore
Expand Down Expand Up @@ -606,48 +607,57 @@ final class NIOAsyncWriterTests: XCTestCase {
self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 1)
}

func testSuspendingBufferedYield_whenWriterFinished() async throws {
self.sink.setWritability(to: false)

let bothSuspended = expectation(description: "suspended on both yields")
let suspendedAgain = ConditionLock(value: false)
self.delegate.didSuspendHandler = {
if self.delegate.didSuspendCallCount == 2 {
bothSuspended.fulfill()
} else if self.delegate.didSuspendCallCount > 2 {
suspendedAgain.lock()
suspendedAgain.unlock(withValue: true)
}
}

self.delegate.didYieldHandler = { _ in
if self.delegate.didYieldCallCount == 1 {
// Delay this yield until the other yield is suspended again.
suspendedAgain.lock(whenValue: true)
suspendedAgain.unlock()
func testWriterFinish_AndSuspendBufferedYield() async throws {
#if compiler(>=6)
try await withNIOThreadPoolTaskExecutor(numberOfThreads: 2) { taskExecutor in
try await withThrowingTaskGroup(of: Void.self) { group in
self.sink.setWritability(to: false)

let bothSuspended = expectation(description: "suspended on both yields")
let suspendedAgain = ConditionLock(value: false)
self.delegate.didSuspendHandler = {
if self.delegate.didSuspendCallCount == 2 {
bothSuspended.fulfill()
} else if self.delegate.didSuspendCallCount > 2 {
suspendedAgain.lock()
suspendedAgain.unlock(withValue: true)
}
}

self.delegate.didYieldHandler = { _ in
if self.delegate.didYieldCallCount == 1 {
// Delay this yield until the other yield is suspended again.
if suspendedAgain.lock(whenValue: true, timeoutSeconds: 5) {
suspendedAgain.unlock()
} else {
XCTFail("Timeout while waiting for other yield to suspend again.")
}
}
}

group.addTask(executorPreference: taskExecutor) { [writer] in
try await writer!.yield("message1")
}
group.addTask(executorPreference: taskExecutor) { [writer] in
try await writer!.yield("message2")
}

await fulfillment(of: [bothSuspended], timeout: 5)
self.writer.finish()

self.assert(suspendCallCount: 2, yieldCallCount: 0, terminateCallCount: 0)

// We have to become writable again to unbuffer the yields
// The first call to didYield will pause, so that the other yield will be suspended again.
self.sink.setWritability(to: true)

await XCTAssertNoThrow(try await group.next())
await XCTAssertNoThrow(try await group.next())

self.assert(suspendCallCount: 3, yieldCallCount: 2, terminateCallCount: 1)
}
}

let task1 = Task { [writer] in
try await writer!.yield("message1")
}
let task2 = Task { [writer] in
try await writer!.yield("message2")
}

await fulfillment(of: [bothSuspended], timeout: 1)
self.writer.finish()

self.assert(suspendCallCount: 2, yieldCallCount: 0, terminateCallCount: 0)

// We have to become writable again to unbuffer the yields
// The first call to didYield will pause, so that the other yield will be suspended again.
self.sink.setWritability(to: true)

await XCTAssertNoThrow(try await task1.value)
await XCTAssertNoThrow(try await task2.value)

self.assert(suspendCallCount: 3, yieldCallCount: 2, terminateCallCount: 1)
#endif // compiler(>=6)
}

func testWriterFinish_whenFinished() {
Expand Down
81 changes: 81 additions & 0 deletions Tests/NIOTestUtilsTests/NIOThreadPoolTaskExecutorTest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2019-2025 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIOConcurrencyHelpers
import NIOTestUtils
import XCTest

class NIOThreadPoolTaskExecutorTest: XCTestCase {
struct TestError: Error {}

func runTasksSimultaneously(numberOfTasks: Int) async {
await withNIOThreadPoolTaskExecutor(numberOfThreads: numberOfTasks) { taskExecutor in
await withDiscardingTaskGroup { group in
var taskBlockers = [ConditionLock<Bool>]()
defer {
// Unblock all tasks
for taskBlocker in taskBlockers {
taskBlocker.lock()
taskBlocker.unlock(withValue: true)
}
}

for taskNumber in 1...numberOfTasks {
let taskStarted = ConditionLock(value: false)
let taskBlocker = ConditionLock(value: false)
taskBlockers.append(taskBlocker)

// Start task and block it
group.addTask(executorPreference: taskExecutor) {
taskStarted.lock()
taskStarted.unlock(withValue: true)
taskBlocker.lock(whenValue: true)
taskBlocker.unlock()
}

// Verify that task was able to start
if taskStarted.lock(whenValue: true, timeoutSeconds: 5) {
taskStarted.unlock()
} else {
XCTFail("Task \(taskNumber) failed to start.")
break
}
}
}
}
}

func testRunsTaskOnSingleThread() async {
await runTasksSimultaneously(numberOfTasks: 1)
}

func testRunsMultipleTasksOnMultipleThreads() async {
await runTasksSimultaneously(numberOfTasks: 3)
}

func testReturnsBodyResult() async {
let expectedResult = "result"
let result = await withNIOThreadPoolTaskExecutor(numberOfThreads: 1) { _ in return expectedResult }
XCTAssertEqual(result, expectedResult)
}

func testRethrows() async {
do {
try await withNIOThreadPoolTaskExecutor(numberOfThreads: 1) { _ in throw TestError() }
XCTFail("Function did not rethrow.")
} catch {
XCTAssertTrue(error is TestError)
}
}
}
Loading