diff --git a/Sources/NIO/System.swift b/Sources/NIO/System.swift index d13c0863209..bd01f850cf4 100644 --- a/Sources/NIO/System.swift +++ b/Sources/NIO/System.swift @@ -150,6 +150,7 @@ private func preconditionIsNotUnacceptableErrno(err: CInt, where function: Strin @inline(__always) @discardableResult internal func syscall(blocking: Bool, + eprototypeWorkaround: Bool = false, where function: String = #function, _ body: () throws -> T) throws -> IOResult { @@ -157,11 +158,18 @@ internal func syscall(blocking: Bool, let res = try body() if res == -1 { let err = errno - switch (err, blocking) { - case (EINTR, _): + switch (err, blocking, eprototypeWorkaround) { + case (EINTR, _, _): continue - case (EWOULDBLOCK, true): + case (EWOULDBLOCK, true, _): return .wouldBlock(0) + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + case (EPROTOTYPE, _, true): + // EPROTOTYPE can, on Darwin platforms, sometimes fire due to a race in the XNU kernel. + // The socket in question is about to shut down, so we can just retry the syscall and get + // the actual error (usually, but not necessarily, EPIPE). + continue + #endif default: preconditionIsNotUnacceptableErrno(err: err, where: function) throw IOError(errnoCode: err, reason: function) @@ -356,7 +364,7 @@ internal enum Posix { @inline(never) public static func write(descriptor: CInt, pointer: UnsafeRawPointer, size: Int) throws -> IOResult { - return try syscall(blocking: true) { + return try syscall(blocking: true, eprototypeWorkaround: true) { sysWrite(descriptor, pointer, size) } } @@ -371,7 +379,7 @@ internal enum Posix { #if !os(Windows) @inline(never) public static func writev(descriptor: CInt, iovecs: UnsafeBufferPointer) throws -> IOResult { - return try syscall(blocking: true) { + return try syscall(blocking: true, eprototypeWorkaround: true) { sysWritev(descriptor, iovecs.baseAddress!, CInt(iovecs.count)) } } @@ -445,7 +453,7 @@ internal enum Posix { public static func sendfile(descriptor: CInt, fd: CInt, offset: off_t, count: size_t) throws -> IOResult { var written: off_t = 0 do { - _ = try syscall(blocking: false) { () -> ssize_t in + _ = try syscall(blocking: false, eprototypeWorkaround: true) { () -> ssize_t in #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) var w: off_t = off_t(count) let result: CInt = Darwin.sendfile(fd, descriptor, offset, &w, nil, 0) diff --git a/Tests/NIOTests/ChannelTests+XCTest.swift b/Tests/NIOTests/ChannelTests+XCTest.swift index 1aeb0eeba3b..4da06e45835 100644 --- a/Tests/NIOTests/ChannelTests+XCTest.swift +++ b/Tests/NIOTests/ChannelTests+XCTest.swift @@ -84,6 +84,7 @@ extension ChannelTests { ("testFixedSizeRecvByteBufferAllocatorSizeIsConstant", testFixedSizeRecvByteBufferAllocatorSizeIsConstant), ("testCloseInConnectPromise", testCloseInConnectPromise), ("testWritabilityChangeDuringReentrantFlushNow", testWritabilityChangeDuringReentrantFlushNow), + ("testTriggerEPROTOTYPE", testTriggerEPROTOTYPE), ] } } diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index 13d22a4df97..2e2481d679d 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -2814,6 +2814,40 @@ public final class ChannelTests: XCTestCase { XCTAssertNoThrow(try handler.becameUnwritable.futureResult.wait()) XCTAssertNoThrow(try handler.becameWritable.futureResult.wait()) } + + func testTriggerEPROTOTYPE() throws { + // This is a probabilistic test for https://github.com/swift-server/async-http-client/issues/322. + // We believe we'll see EPROTOTYPE on write syscalls if we write while the connections are being torn down. + // To check this we create 500 connections and close them, while the server attempts to write AS FAST AS IT CAN. + // As this test is probabilistic, we must not ignore transient failures in it. + let group = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let serverLoop = group.next() + let clientLoop = group.next() + XCTAssertFalse(serverLoop === clientLoop) + + let serverFuture = ServerBootstrap(group: serverLoop) + .childChannelInitializer { channel in + return channel.pipeline.addHandler(AlwaysBeWritingHandler(vectorWrites: [true, false].randomElement()!)) + } + .bind(host: "localhost", port: 0) + + let server: Channel = try assertNoThrowWithValue(try serverFuture.wait()) + defer { + XCTAssertNoThrow(try server.close().wait()) + } + + let clientFactory = ClientBootstrap(group: clientLoop) + let serverAddress = server.localAddress! + + for _ in 0..<500 { + let client = try assertNoThrowWithValue(clientFactory.connect(to: serverAddress).wait()) + XCTAssertNoThrow(try client.close().wait()) + } + } } fileprivate final class FailRegistrationAndDelayCloseHandler: ChannelOutboundHandler { @@ -2926,3 +2960,49 @@ final class ReentrantWritabilityChangingHandler: ChannelInboundHandler { } } } + +final class AlwaysBeWritingHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + static let buffer = ByteBuffer(string: "This is some data that I'm sending right now") + + private let doVectorWrite: Bool + + init(vectorWrites: Bool) { + self.doVectorWrite = vectorWrites + } + + func channelActive(context: ChannelHandlerContext) { + self.keepWriting(context: context) + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + if let error = error as? IOError, error.errnoCode == EPROTOTYPE { + XCTFail("Received EPROTOTYPE error") + } + } + + private func keepWriting(context: ChannelHandlerContext) { + if self.doVectorWrite { + context.write(self.wrapOutboundOut(AlwaysBeWritingHandler.buffer)).whenFailure { error in + if let error = error as? IOError, error.errnoCode == EPROTOTYPE { + XCTFail("Received EPROTOTYPE error") + } + } + } + context.writeAndFlush(self.wrapOutboundOut(AlwaysBeWritingHandler.buffer)).whenComplete { result in + switch result { + case .success: + // We unroll the stack here to avoid blowing it apart. + context.eventLoop.execute { + self.keepWriting(context: context) + } + case .failure(let error): + if let error = error as? IOError, error.errnoCode == EPROTOTYPE { + XCTFail("Received EPROTOTYPE error") + } + } + } + } +}