diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c02831bd9..0780103e7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,6 +71,35 @@ jobs: chmod a+x builder ./builder build -p ${{ env.PACKAGE_NAME }} + # Test that everything works with Swift 6 strict concurrency. We can remove this CI in the future once we raise our minimum Swift version to Swift 6. + macos-swift6: + runs-on: macos-15 + env: + DEVELOPER_DIR: /Applications/Xcode.app + XCODE_DESTINATION: 'OS X' + NSUnbufferedIO: YES + strategy: + fail-fast: false + steps: + - uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ env.CRT_CI_ROLE }} + aws-region: ${{ env.AWS_DEFAULT_REGION }} + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + - name: Set Swift version to 6 + run: | + sed -i '' 's/swift-tools-version:5\.[0-9][0-9]*/swift-tools-version:6\.0/' Package.swift + # Verify that substitution was successful + grep -q 'swift-tools-version:6\.0' Package.swift || (echo "No version 5.x found to update" && exit 1) + - name: Build ${{ env.PACKAGE_NAME }} + consumers + run: | + swift build + swift test + devices: runs-on: ${{ matrix.runner }} env: diff --git a/Package.swift b/Package.swift index dcbda2167..ce904c9b9 100644 --- a/Package.swift +++ b/Package.swift @@ -11,6 +11,10 @@ var package = Package(name: "aws-crt-swift", products: [ .library(name: "AwsCommonRuntimeKit", targets: ["AwsCommonRuntimeKit"]), .executable(name: "Elasticurl", targets: ["Elasticurl"]) + ], + dependencies: [ + // Arugment Parser Dependency for ElasticCurl + .package(url: "https://github.com/apple/swift-argument-parser.git", .upToNextMajor(from: "1.5.0")) ] ) @@ -332,7 +336,10 @@ packageTargets.append(contentsOf: [ ), .executableTarget( name: "Elasticurl", - dependencies: ["AwsCommonRuntimeKit"], + dependencies: [ + "AwsCommonRuntimeKit", + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ], path: "Source/Elasticurl" ) ] ) diff --git a/Source/AwsCommonRuntimeKit/CommonRuntimeKit.swift b/Source/AwsCommonRuntimeKit/CommonRuntimeKit.swift index 6ca707f29..85047666c 100644 --- a/Source/AwsCommonRuntimeKit/CommonRuntimeKit.swift +++ b/Source/AwsCommonRuntimeKit/CommonRuntimeKit.swift @@ -1,12 +1,10 @@ -import AwsCEventStream import AwsCAuth +import AwsCEventStream import AwsCMqtt import LibNative -/** - * Initializes the library. - * `CommonRuntimeKit.initialize` must be called before using any other functionality. - */ +/// Initializes the library. +/// `CommonRuntimeKit.initialize` must be called before using any other functionality. public struct CommonRuntimeKit { /// Initializes the library. @@ -15,7 +13,9 @@ public struct CommonRuntimeKit { aws_auth_library_init(allocator.rawValue) aws_event_stream_library_init(allocator.rawValue) aws_mqtt_library_init(allocator.rawValue) - aws_register_error_info(&s_crt_swift_error_list) + withUnsafePointer(to: s_crt_swift_error_list) { ptr in + aws_register_error_info(ptr) + } } /** @@ -23,12 +23,13 @@ public struct CommonRuntimeKit { * Use this function only if you want to make sure that there are no memory leaks at the end of the application. * Warning: It will hang if you are still holding references to any CRT objects such as HostResolver. */ - public static func cleanUp() { - aws_unregister_error_info(&s_crt_swift_error_list) + public static nonisolated func cleanUp() { + withUnsafePointer(to: s_crt_swift_error_list) { ptr in + aws_unregister_error_info(ptr) + } aws_mqtt_library_clean_up() aws_event_stream_library_clean_up() aws_auth_library_clean_up() - } private init() {} diff --git a/Source/AwsCommonRuntimeKit/auth/credentials/Credentials.swift b/Source/AwsCommonRuntimeKit/auth/credentials/Credentials.swift index 35066360f..32347d8b2 100644 --- a/Source/AwsCommonRuntimeKit/auth/credentials/Credentials.swift +++ b/Source/AwsCommonRuntimeKit/auth/credentials/Credentials.swift @@ -4,7 +4,9 @@ import AwsCAuth import Foundation -public final class Credentials { +// We can't mutate this class after initialization. Swift can not verify the sendability due to OpaquePointer, +// So mark it unchecked Sendable +public final class Credentials: @unchecked Sendable { let rawValue: OpaquePointer diff --git a/Source/AwsCommonRuntimeKit/auth/credentials/CredentialsProvider.swift b/Source/AwsCommonRuntimeKit/auth/credentials/CredentialsProvider.swift index 47ec29638..dc3c6b98b 100644 --- a/Source/AwsCommonRuntimeKit/auth/credentials/CredentialsProvider.swift +++ b/Source/AwsCommonRuntimeKit/auth/credentials/CredentialsProvider.swift @@ -34,8 +34,9 @@ public struct CognitoLoginPair: CStruct { } } -public class CredentialsProvider: CredentialsProviding { - +// We can't mutate this class after initialization. Swift can not verify the sendability due to pointer, +// So mark it unchecked Sendable +public class CredentialsProvider: CredentialsProviding, @unchecked Sendable { let rawValue: UnsafeMutablePointer init(credentialsProvider: UnsafeMutablePointer) { @@ -667,25 +668,19 @@ private func onGetCredentials(credentials: OpaquePointer?, continuationCore.continuation.resume(returning: Credentials(rawValue: credentials!)) } -// We need to share this pointer to C in a task block but Swift compiler complains -// that Pointer does not conform to Sendable. Wrap the pointer in a @unchecked Sendable block -// for Swift compiler to stop complaining. -struct SendablePointer: @unchecked Sendable { - let pointer: UnsafeMutableRawPointer -} - private func getCredentialsDelegateFn(_ delegatePtr: UnsafeMutableRawPointer!, _ callbackFn: (@convention(c) ( OpaquePointer?, Int32, UnsafeMutableRawPointer?) -> Void)!, _ userData: UnsafeMutableRawPointer!) -> Int32 { - let delegate = Unmanaged> - .fromOpaque(delegatePtr) - .takeUnretainedValue().contents - let userData = SendablePointer(pointer: userData) + let userData = SendableRawPointer(pointer: userData) + let delegatePtr = SendableRawPointer(pointer: delegatePtr) Task { do { + let delegate = Unmanaged> + .fromOpaque(delegatePtr.pointer!) + .takeUnretainedValue().contents let credentials = try await delegate.getCredentials() callbackFn(credentials.rawValue, AWS_OP_SUCCESS, userData.pointer) } catch CommonRunTimeError.crtError(let crtError) { diff --git a/Source/AwsCommonRuntimeKit/auth/imds/IAMProfile.swift b/Source/AwsCommonRuntimeKit/auth/imds/IAMProfile.swift index 95a604d7d..9e5d6819f 100644 --- a/Source/AwsCommonRuntimeKit/auth/imds/IAMProfile.swift +++ b/Source/AwsCommonRuntimeKit/auth/imds/IAMProfile.swift @@ -4,7 +4,7 @@ import AwsCAuth import Foundation -public struct IAMProfile { +public struct IAMProfile: @unchecked Sendable { public let lastUpdated: Date public let profileArn: String public let profileId: String diff --git a/Source/AwsCommonRuntimeKit/auth/imds/IMDSClient.swift b/Source/AwsCommonRuntimeKit/auth/imds/IMDSClient.swift index de3622a94..4e58ab73f 100644 --- a/Source/AwsCommonRuntimeKit/auth/imds/IMDSClient.swift +++ b/Source/AwsCommonRuntimeKit/auth/imds/IMDSClient.swift @@ -4,7 +4,9 @@ import AwsCAuth // swiftlint:disable type_body_length -public class IMDSClient { +// We can't mutate this class after initialization. Swift can not verify the sendability due to OpaquePointer, +// So mark it unchecked Sendable +public class IMDSClient: @unchecked Sendable { let rawValue: OpaquePointer /// Creates an IMDSClient that always uses IMDSv2 diff --git a/Source/AwsCommonRuntimeKit/auth/imds/IMDSInstanceInfo.swift b/Source/AwsCommonRuntimeKit/auth/imds/IMDSInstanceInfo.swift index e10fd4b2d..a26e2c852 100644 --- a/Source/AwsCommonRuntimeKit/auth/imds/IMDSInstanceInfo.swift +++ b/Source/AwsCommonRuntimeKit/auth/imds/IMDSInstanceInfo.swift @@ -4,7 +4,7 @@ import AwsCAuth import Foundation -public struct IMDSInstanceInfo { +public struct IMDSInstanceInfo: @unchecked Sendable { public let marketPlaceProductCodes: [String] public let availabilityZone: String public let privateIp: String @@ -21,7 +21,8 @@ public struct IMDSInstanceInfo { public let region: String init(instanceInfo: aws_imds_instance_info) { - self.marketPlaceProductCodes = instanceInfo.marketplace_product_codes.byteCursorListToStringArray() + self.marketPlaceProductCodes = instanceInfo.marketplace_product_codes + .byteCursorListToStringArray() self.availabilityZone = instanceInfo.availability_zone.toString() self.privateIp = instanceInfo.private_ip.toString() self.version = instanceInfo.version.toString() diff --git a/Source/AwsCommonRuntimeKit/auth/signing/Signer.swift b/Source/AwsCommonRuntimeKit/auth/signing/Signer.swift index e815f071a..e35d4250f 100644 --- a/Source/AwsCommonRuntimeKit/auth/signing/Signer.swift +++ b/Source/AwsCommonRuntimeKit/auth/signing/Signer.swift @@ -164,9 +164,11 @@ public class Signer { } } -class SignRequestCore { +// After signing, we mutate the request and resume the continuation, which may result in a thread change. +// We won't modify it after continuation.resume is called. So we can mark it @unchecked Sendable +class SignRequestCore: @unchecked Sendable { let request: HTTPRequestBase - var continuation: CheckedContinuation + let continuation: CheckedContinuation let shouldSignHeader: ((String) -> Bool)? init(request: HTTPRequestBase, continuation: CheckedContinuation, @@ -207,6 +209,7 @@ private func onRequestSigningComplete(signingResult: UnsafeMutablePointer?, @@ -220,9 +223,10 @@ private func onSigningComplete(signingResult: UnsafeMutablePointer! + let signature = AWSString("signature") guard aws_signing_result_get_property( signingResult!, - g_aws_signature_property_name, + signature.rawValue, &awsStringPointer) == AWS_OP_SUCCESS else { chunkSignerCore.continuation.resume(throwing: CommonRunTimeError.crtError(.makeFromLastError())) return diff --git a/Source/AwsCommonRuntimeKit/auth/signing/SigningConfig.swift b/Source/AwsCommonRuntimeKit/auth/signing/SigningConfig.swift index 3b69a5fbb..ba34a4385 100644 --- a/Source/AwsCommonRuntimeKit/auth/signing/SigningConfig.swift +++ b/Source/AwsCommonRuntimeKit/auth/signing/SigningConfig.swift @@ -4,7 +4,7 @@ import AwsCAuth import Foundation -public struct SigningConfig: CStructWithUserData { +public struct SigningConfig: CStructWithUserData, @unchecked Sendable { /// What signing algorithm to use. public var algorithm: SigningAlgorithmType @@ -136,7 +136,7 @@ private func onShouldSignHeader(nameCursor: UnsafePointer!, return signRequestCore.shouldSignHeader!(name) } -public enum SignatureType { +public enum SignatureType: Sendable { /// A signature for a full http request should be computed, with header updates applied to the signing result. case requestHeaders @@ -162,7 +162,7 @@ public enum SignatureType { case requestEvent } -public enum SignedBodyHeaderType { +public enum SignedBodyHeaderType: Sendable { /// Do not add a header case none @@ -174,7 +174,7 @@ public enum SignedBodyHeaderType { /// Optional string to use as the canonical request's body value. /// Typically, this is the SHA-256 of the (request/chunk/event) payload, written as lowercase hex. /// If this has been precalculated, it can be set here. Special values used by certain services can also be set. -public enum SignedBodyValue: CustomStringConvertible, Equatable { +public enum SignedBodyValue: CustomStringConvertible, Equatable, Sendable { /// if empty, a public value will be calculated from the payload during signing case empty /// For empty sha256 @@ -226,7 +226,7 @@ public enum SignedBodyValue: CustomStringConvertible, Equatable { } } -public enum SigningAlgorithmType { +public enum SigningAlgorithmType: Sendable { case signingV4 case signingV4Asymmetric case signingV4S3Express diff --git a/Source/AwsCommonRuntimeKit/crt/Allocator.swift b/Source/AwsCommonRuntimeKit/crt/Allocator.swift index 1923ff779..63bc9cac6 100644 --- a/Source/AwsCommonRuntimeKit/crt/Allocator.swift +++ b/Source/AwsCommonRuntimeKit/crt/Allocator.swift @@ -2,11 +2,17 @@ // SPDX-License-Identifier: Apache-2.0. import AwsCCommon -/** - The default allocator. - You are probably looking to use `allocator` instead. +/* + * The default allocator. + * We need to declare `allocator` as mutable (`var`) instead of `let` because we override it with a tracing allocator in tests. This is not mutated anywhere else apart from the start of tests. + * Swift compiler doesn't let us compile this code in Swift 6 due to global shared mutable state without locks, and complains that this is not safe. Disable the safety here since we won't modify it. + * Remove the Ifdef once our minimum supported Swift version reaches 5.10 */ +#if swift(>=5.10) +nonisolated(unsafe) var allocator = aws_default_allocator()! +#else var allocator = aws_default_allocator()! +#endif /// An allocator is used to allocate memory on the heap. protocol Allocator { diff --git a/Source/AwsCommonRuntimeKit/crt/CommonRuntimeError.swift b/Source/AwsCommonRuntimeKit/crt/CommonRuntimeError.swift index c57b21820..258b020a4 100644 --- a/Source/AwsCommonRuntimeKit/crt/CommonRuntimeError.swift +++ b/Source/AwsCommonRuntimeKit/crt/CommonRuntimeError.swift @@ -8,7 +8,7 @@ public enum CommonRunTimeError: Error { case crtError(CRTError) } -public struct CRTError: Equatable { +public struct CRTError: Equatable, Sendable { public let code: Int32 public let message: String public let name: String diff --git a/Source/AwsCommonRuntimeKit/crt/Logger.swift b/Source/AwsCommonRuntimeKit/crt/Logger.swift index 03d24ba6b..1c25fd5af 100644 --- a/Source/AwsCommonRuntimeKit/crt/Logger.swift +++ b/Source/AwsCommonRuntimeKit/crt/Logger.swift @@ -10,7 +10,7 @@ public enum LogTarget { case filePath(String) } -public struct Logger { +public actor Logger { private static var logger: aws_logger? private static let lock = NSLock() @@ -55,7 +55,7 @@ public struct Logger { } } -public enum LogLevel { +public enum LogLevel: Sendable { case none case fatal case error diff --git a/Source/AwsCommonRuntimeKit/crt/Utilities.swift b/Source/AwsCommonRuntimeKit/crt/Utilities.swift index 327e6760b..1d8b129a6 100644 --- a/Source/AwsCommonRuntimeKit/crt/Utilities.swift +++ b/Source/AwsCommonRuntimeKit/crt/Utilities.swift @@ -28,6 +28,13 @@ class Box { } } +// We need to share this pointer to C in a task block but Swift compiler complains +// that Pointer does not conform to Sendable. Wrap the pointer in a @unchecked Sendable block +// for Swift compiler to stop complaining. +struct SendableRawPointer: @unchecked Sendable { + let pointer: UnsafeMutableRawPointer? +} + extension String { func withByteCursor(_ body: (aws_byte_cursor) -> Result @@ -194,6 +201,9 @@ extension aws_byte_cursor { } let data = Data(bytesNoCopy: self.ptr, count: self.len, deallocator: .none) + // Using non-failable String(decoding:as:) to handle invalid UTF-8 with replacement characters. + // Disable warning as it's a false positive. + // swiftlint:disable:next optional_data_string_conversion return String(decoding: data, as: UTF8.self) } diff --git a/Source/AwsCommonRuntimeKit/http/HTTP1Stream.swift b/Source/AwsCommonRuntimeKit/http/HTTP1Stream.swift index 089b3a8e3..5efca6efd 100644 --- a/Source/AwsCommonRuntimeKit/http/HTTP1Stream.swift +++ b/Source/AwsCommonRuntimeKit/http/HTTP1Stream.swift @@ -3,8 +3,10 @@ import AwsCHttp import Foundation +// Swift cannot verify the sendability due to a pointer, and thread safety is handled in the C layer. +// So mark it as unchecked Sendable. /// An HTTP1Stream represents a single HTTP/1.1 specific Http Request/Response. -public class HTTP1Stream: HTTPStream { +public class HTTP1Stream: HTTPStream, @unchecked Sendable { /// Stream keeps a reference to HttpConnection to keep it alive private let httpConnection: HTTPClientConnection diff --git a/Source/AwsCommonRuntimeKit/http/HTTP2ClientConnection.swift b/Source/AwsCommonRuntimeKit/http/HTTP2ClientConnection.swift index c4e2c1321..01265d6f1 100644 --- a/Source/AwsCommonRuntimeKit/http/HTTP2ClientConnection.swift +++ b/Source/AwsCommonRuntimeKit/http/HTTP2ClientConnection.swift @@ -5,7 +5,9 @@ import AwsCHttp import AwsCIo import Foundation -public class HTTP2ClientConnection: HTTPClientConnection { +// Swift cannot verify the sendability due to a pointer, and thread safety is handled in the C layer. +// So mark it as unchecked Sendable. +public class HTTP2ClientConnection: HTTPClientConnection, @unchecked Sendable { /// Creates a new http2 stream from the `HTTPRequestOptions` given. /// - Parameter requestOptions: An `HTTPRequestOptions` struct containing callbacks on diff --git a/Source/AwsCommonRuntimeKit/http/HTTP2Stream.swift b/Source/AwsCommonRuntimeKit/http/HTTP2Stream.swift index 64db4d5e1..1142cc122 100644 --- a/Source/AwsCommonRuntimeKit/http/HTTP2Stream.swift +++ b/Source/AwsCommonRuntimeKit/http/HTTP2Stream.swift @@ -4,8 +4,10 @@ import AwsCHttp import Foundation +// Swift cannot verify the sendability due to a pointer, and thread safety is handled in the C layer. +// So mark it as unchecked Sendable. /// An HTTP2Stream represents a single HTTP/2 specific HTTP Request/Response. -public class HTTP2Stream: HTTPStream { +public class HTTP2Stream: HTTPStream, @unchecked Sendable { private let httpConnection: HTTPClientConnection? // Called by Connection Manager diff --git a/Source/AwsCommonRuntimeKit/http/HTTP2StreamManager.swift b/Source/AwsCommonRuntimeKit/http/HTTP2StreamManager.swift index 34d4e3e8b..c827e8492 100644 --- a/Source/AwsCommonRuntimeKit/http/HTTP2StreamManager.swift +++ b/Source/AwsCommonRuntimeKit/http/HTTP2StreamManager.swift @@ -2,8 +2,10 @@ // SPDX-License-Identifier: Apache-2.0. import AwsCHttp +// Swift cannot verify the sendability due to a pointer, and thread safety is handled in the C layer. +// So mark it as unchecked Sendable. /// Manages a Pool of HTTP/2 Streams. Creates and manages HTTP/2 connections under the hood. -public class HTTP2StreamManager { +public class HTTP2StreamManager: @unchecked Sendable { let rawValue: UnsafeMutablePointer public init(options: HTTP2StreamManagerOptions) throws { diff --git a/Source/AwsCommonRuntimeKit/http/HTTPClientConnection.swift b/Source/AwsCommonRuntimeKit/http/HTTPClientConnection.swift index a38643b2f..0cff7721c 100644 --- a/Source/AwsCommonRuntimeKit/http/HTTPClientConnection.swift +++ b/Source/AwsCommonRuntimeKit/http/HTTPClientConnection.swift @@ -6,7 +6,9 @@ import AwsCIo import Foundation // swiftlint:disable force_try -public class HTTPClientConnection { +// Swift cannot verify the sendability due to a pointer, and thread safety is handled in the C layer. +// So mark it as unchecked Sendable. +public class HTTPClientConnection: @unchecked Sendable { let rawValue: UnsafeMutablePointer /// This will keep the connection manager alive until connection is alive let manager: HTTPClientConnectionManager diff --git a/Source/AwsCommonRuntimeKit/http/HTTPClientConnectionManager.swift b/Source/AwsCommonRuntimeKit/http/HTTPClientConnectionManager.swift index 0d2f5e543..d331f4b78 100644 --- a/Source/AwsCommonRuntimeKit/http/HTTPClientConnectionManager.swift +++ b/Source/AwsCommonRuntimeKit/http/HTTPClientConnectionManager.swift @@ -4,7 +4,7 @@ import AwsCHttp typealias OnConnectionAcquired = (HTTPClientConnection?, Int32) -> Void -public class HTTPClientConnectionManager { +public class HTTPClientConnectionManager: @unchecked Sendable { let rawValue: OpaquePointer public init(options: HTTPClientConnectionOptions) throws { diff --git a/Source/AwsCommonRuntimeKit/http/HTTPRequest.swift b/Source/AwsCommonRuntimeKit/http/HTTPRequest.swift index d948b22e2..50591e1ac 100644 --- a/Source/AwsCommonRuntimeKit/http/HTTPRequest.swift +++ b/Source/AwsCommonRuntimeKit/http/HTTPRequest.swift @@ -5,7 +5,7 @@ import AwsCIo import AwsCCommon /// Represents a single client request to be sent on a HTTP 1.1 connection -public class HTTPRequest: HTTPRequestBase { +public class HTTPRequest: HTTPRequestBase, @unchecked Sendable { public var method: String { get { diff --git a/Source/AwsCommonRuntimeKit/http/HTTPRequestOptions.swift b/Source/AwsCommonRuntimeKit/http/HTTPRequestOptions.swift index 62b9554f4..32fec80a9 100644 --- a/Source/AwsCommonRuntimeKit/http/HTTPRequestOptions.swift +++ b/Source/AwsCommonRuntimeKit/http/HTTPRequestOptions.swift @@ -3,7 +3,7 @@ import Foundation /// Definition for outgoing request and callbacks to receive response. -public struct HTTPRequestOptions { +public struct HTTPRequestOptions: @unchecked Sendable { /// Callback to receive interim response public typealias OnInterimResponse = (_ statusCode: UInt32, diff --git a/Source/AwsCommonRuntimeKit/http/HTTPStream.swift b/Source/AwsCommonRuntimeKit/http/HTTPStream.swift index 6ee20ce66..3c9fca30f 100644 --- a/Source/AwsCommonRuntimeKit/http/HTTPStream.swift +++ b/Source/AwsCommonRuntimeKit/http/HTTPStream.swift @@ -3,9 +3,11 @@ import AwsCHttp import Foundation +// Swift cannot verify the sendability due to a pointer, and thread safety is handled in the C layer. +// So mark it as unchecked Sendable. /// An base abstract class that represents a single Http Request/Response for both HTTP/1.1 and HTTP/2. /// Can be used to update the Window size, and get status code. -public class HTTPStream { +public class HTTPStream: @unchecked Sendable { let rawValue: UnsafeMutablePointer var callbackData: HTTPStreamCallbackCore diff --git a/Source/AwsCommonRuntimeKit/io/HostAddress.swift b/Source/AwsCommonRuntimeKit/io/HostAddress.swift index 1ce7646d1..907e072fa 100644 --- a/Source/AwsCommonRuntimeKit/io/HostAddress.swift +++ b/Source/AwsCommonRuntimeKit/io/HostAddress.swift @@ -4,7 +4,7 @@ import AwsCIo /// Represents a single HostAddress resolved by the Host Resolver -public struct HostAddress: CStruct { +public struct HostAddress: CStruct, Sendable { /// Address type is ipv4 or ipv6 public let addressType: HostAddressType diff --git a/Source/AwsCommonRuntimeKit/io/HostAddressType.swift b/Source/AwsCommonRuntimeKit/io/HostAddressType.swift index cc973853d..3140e6971 100644 --- a/Source/AwsCommonRuntimeKit/io/HostAddressType.swift +++ b/Source/AwsCommonRuntimeKit/io/HostAddressType.swift @@ -4,7 +4,7 @@ import AwsCIo /// Type of Host Address (ipv4 or ipv6) -public enum HostAddressType { +public enum HostAddressType: Sendable { case A case AAAA } diff --git a/Source/AwsCommonRuntimeKit/io/HostResolver.swift b/Source/AwsCommonRuntimeKit/io/HostResolver.swift index 4a1bf35b0..1750b04bf 100644 --- a/Source/AwsCommonRuntimeKit/io/HostResolver.swift +++ b/Source/AwsCommonRuntimeKit/io/HostResolver.swift @@ -23,8 +23,10 @@ public protocol HostResolverProtocol { func purgeCache() async } +// Swift cannot verify the sendability due to a pointer, and thread safety is handled in the C layer. +// So mark it as unchecked Sendable. /// CRT Host Resolver which performs async DNS lookups -public class HostResolver: HostResolverProtocol { +public class HostResolver: HostResolverProtocol, @unchecked Sendable { let rawValue: UnsafeMutablePointer let maxTTL: Int diff --git a/Source/AwsCommonRuntimeKit/io/retryer/RetryToken.swift b/Source/AwsCommonRuntimeKit/io/retryer/RetryToken.swift index 745c60d20..32300ab80 100644 --- a/Source/AwsCommonRuntimeKit/io/retryer/RetryToken.swift +++ b/Source/AwsCommonRuntimeKit/io/retryer/RetryToken.swift @@ -3,8 +3,10 @@ import AwsCIo +// Swift cannot verify the sendability due to a pointer, and thread safety is handled in the C layer. +// So mark it as unchecked Sendable. /// This is just a wrapper for aws_retry_token which user can not create themself but pass around once acquired. -public class RetryToken { +public class RetryToken: @unchecked Sendable { let rawValue: UnsafeMutablePointer init(rawValue: UnsafeMutablePointer) { diff --git a/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Client.swift b/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Client.swift index 509696bb6..ba3932b8d 100644 --- a/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Client.swift +++ b/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Client.swift @@ -34,7 +34,7 @@ public class ClientOperationStatistics { } /// Class containing data related to a Publish Received Callback -public class PublishReceivedData { +public class PublishReceivedData: @unchecked Sendable { /// Data model of an `MQTT5 PUBLISH `_ packet. public let publishPacket: PublishPacket @@ -51,7 +51,7 @@ public class LifecycleStoppedData { } public class LifecycleAttemptingConnectData { } /// Class containing results of a Connect Success Lifecycle Event. -public class LifecycleConnectionSuccessData { +public class LifecycleConnectionSuccessData: @unchecked Sendable { /// Data model of an `MQTT5 CONNACK `_ packet. public let connackPacket: ConnackPacket @@ -66,7 +66,7 @@ public class LifecycleConnectionSuccessData { } /// Dataclass containing results of a Connect Failure Lifecycle Event. -public class LifecycleConnectionFailureData { +public class LifecycleConnectionFailureData: @unchecked Sendable { /// Error which caused connection failure. public let crtError: CRTError @@ -81,7 +81,7 @@ public class LifecycleConnectionFailureData { } /// Dataclass containing results of a Disconnect Lifecycle Event -public class LifecycleDisconnectData { +public class LifecycleDisconnectData: @unchecked Sendable { /// Error which caused disconnection. public let crtError: CRTError @@ -119,7 +119,7 @@ public typealias OnLifecycleEventConnectionFailure = @Sendable (LifecycleConnect public typealias OnLifecycleEventDisconnection = @Sendable (LifecycleDisconnectData) -> Void /// Callback for users to invoke upon completion of, presumably asynchronous, OnWebSocketHandshakeIntercept callback's initiated process. -public typealias OnWebSocketHandshakeInterceptComplete = (HTTPRequestBase, Int32) -> Void +public typealias OnWebSocketHandshakeInterceptComplete = @Sendable (HTTPRequestBase, Int32) -> Void /// Invoked during websocket handshake to give users opportunity to transform an http request for purposes /// such as signing/authorization etc... Returning from this function does not continue the websocket @@ -128,8 +128,8 @@ public typealias OnWebSocketHandshakeInterceptComplete = (HTTPRequestBase, Int32 public typealias OnWebSocketHandshakeIntercept = @Sendable (HTTPRequest, @escaping OnWebSocketHandshakeInterceptComplete) -> Void // MARK: - Mqtt5 Client -public class Mqtt5Client { - internal var clientCore: Mqtt5ClientCore +public final class Mqtt5Client: Sendable { + internal let clientCore: Mqtt5ClientCore /// Creates a Mqtt5Client instance using the provided MqttClientOptions. /// @@ -212,8 +212,10 @@ public class Mqtt5Client { // MARK: - Internal/Private +// IMPORTANT: You are responsible for concurrency correctness of Mqtt5ClientCore. +// The rawValue is mutable cross threads and protected by the rwlock. /// Mqtt5 Client Core, internal class to handle Mqtt5 Client operations -public class Mqtt5ClientCore { +internal class Mqtt5ClientCore: @unchecked Sendable { fileprivate var rawValue: UnsafeMutablePointer? fileprivate let rwlock = ReadWriteLock() @@ -521,6 +523,7 @@ internal func MqttClientWebsocketTransform( _ user_data: UnsafeMutableRawPointer?, _ complete_fn: (@convention(c) (OpaquePointer?, Int32, UnsafeMutableRawPointer?) -> Void)?, _ complete_ctx: UnsafeMutableRawPointer?) { + let complete_ctx = SendableRawPointer(pointer: complete_ctx) let clientCore = Unmanaged.fromOpaque(user_data!).takeUnretainedValue() @@ -533,7 +536,7 @@ internal func MqttClientWebsocketTransform( } let httpRequest = HTTPRequest(nativeHttpMessage: request) @Sendable func signerTransform(request: HTTPRequestBase, errorCode: Int32) { - complete_fn?(request.rawValue, errorCode, complete_ctx) + complete_fn?(request.rawValue, errorCode, complete_ctx.pointer) } if clientCore.onWebsocketInterceptor != nil { diff --git a/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Enums.swift b/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Enums.swift index ddecb5be0..77c7161f4 100644 --- a/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Enums.swift +++ b/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Enums.swift @@ -5,7 +5,7 @@ import AwsCMqtt /// MQTT message delivery quality of service. /// Enum values match `MQTT5 spec `__ encoding values. -public enum QoS { +public enum QoS: Sendable { /// The message is delivered according to the capabilities of the underlying network. No response is sent by the /// receiver and no retry is performed by the sender. The message arrives at the receiver either once or not at all. @@ -47,7 +47,7 @@ extension QoS { /// Server return code for connect attempts. /// Enum values match `MQTT5 spec `__ encoding values. -public enum ConnectReasonCode: Int { +public enum ConnectReasonCode: Int, Sendable { /// Returned when the connection is accepted. case success = 0 @@ -260,7 +260,7 @@ public enum DisconnectReasonCode: Int { /// Reason code inside PUBACK packets that indicates the result of the associated PUBLISH request. /// Enum values match `MQTT5 spec `__ encoding values. -public enum PubackReasonCode: Int { +public enum PubackReasonCode: Int, Sendable { /// Returned when the (QoS 1) publish was accepted by the recipient. /// May be sent by the client or the server. @@ -304,7 +304,7 @@ public enum PubackReasonCode: Int { /// Reason code inside SUBACK packet payloads. /// Enum values match `MQTT5 spec `__ encoding values. /// This will only be sent by the server and not the client. -public enum SubackReasonCode: Int { +public enum SubackReasonCode: Int, Sendable { /// Returned when the subscription was accepted and the maximum QoS sent will be QoS 0. case grantedQos0 = 0 @@ -350,7 +350,7 @@ public enum SubackReasonCode: Int { /// Reason codes inside UNSUBACK packet payloads that specify the results for each topic filter in the associated /// UNSUBSCRIBE packet. /// Enum values match `MQTT5 spec `__ encoding values. -public enum UnsubackReasonCode: Int { +public enum UnsubackReasonCode: Int, Sendable { /// Returned when the unsubscribe was successful and the client is no longer subscribed to the topic filter on the server. case success = 0 @@ -474,7 +474,7 @@ extension ClientOperationQueueBehaviorType { /// Optional property describing a PUBLISH payload's format. /// Enum values match `MQTT5 spec `__ encoding values. -public enum PayloadFormatIndicator { +public enum PayloadFormatIndicator: Sendable { /// The payload is arbitrary binary data case bytes @@ -507,7 +507,7 @@ extension PayloadFormatIndicator { /// Configures how retained messages should be handled when subscribing with a topic filter that matches topics with /// associated retained messages. /// Enum values match `MQTT5 spec `_ encoding values. -public enum RetainHandlingType { +public enum RetainHandlingType: Sendable { /// The server should always send all retained messages on topics that match a subscription's filter. case sendOnSubscribe diff --git a/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Packets.swift b/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Packets.swift index 7854e5d5b..a5e7a658f 100644 --- a/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Packets.swift +++ b/Source/AwsCommonRuntimeKit/mqtt/Mqtt5Packets.swift @@ -5,8 +5,10 @@ import Foundation import AwsCHttp import AwsCMqtt +// We can't mutate this class after initialization. Swift can not verify the sendability due to direct use of c pointer, +// so mark it unchecked Sendable /// Mqtt5 User Property -public class UserProperty: CStruct { +public class UserProperty: CStruct, @unchecked Sendable { /// Property name public let name: String @@ -19,8 +21,6 @@ public class UserProperty: CStruct { self.value = value withByteCursorFromStrings(self.name, self.value) { cNameCursor, cValueCursor in - aws_byte_buf_clean_up(&name_buffer) - aws_byte_buf_clean_up(&value_buffer) aws_byte_buf_init_copy_from_cursor(&name_buffer, allocator, cNameCursor) aws_byte_buf_init_copy_from_cursor(&value_buffer, allocator, cValueCursor) } @@ -64,8 +64,10 @@ func convertOptionalUserProperties(count: size_t, userPropertiesPointer: UnsafeP return userProperties } +// We can't mutate this class after initialization. Swift can not verify the sendability due to the class is non-final, +// so mark it unchecked Sendable /// Data model of an `MQTT5 PUBLISH `_ packet -public class PublishPacket: CStruct { +public class PublishPacket: CStruct, @unchecked Sendable { /// The payload of the publish message in a byte buffer format public let payload: Data? @@ -226,10 +228,12 @@ public class PublishPacket: CStruct { } } +// We can't mutate this class after initialization. Swift can not verify the sendability due to the class is non-final, +// so mark it unchecked Sendable /// Publish result returned by Publish operation. /// - Members /// - puback: returned PublishPacket for qos 1 publish; nil for qos 0 packet. -public class PublishResult { +public class PublishResult: @unchecked Sendable { public let puback: PubackPacket? public init (puback: PubackPacket? = nil) { @@ -237,8 +241,10 @@ public class PublishResult { } } +// We can't mutate this class after initialization. Swift can not verify the sendability due to the class is non-final, +// so mark it unchecked Sendable /// "Data model of an `MQTT5 PUBACK `_ packet -public class PubackPacket { +public class PubackPacket: @unchecked Sendable { /// Success indicator or failure reason for the associated PUBLISH packet. public let reasonCode: PubackReasonCode @@ -267,8 +273,10 @@ public class PubackPacket { } } +// We can't mutate this class after initialization. Swift can not verify the sendability due to direct use of c pointer, +// so mark it unchecked Sendable /// Configures a single subscription within a Subscribe operation -public class Subscription: CStruct { +public class Subscription: CStruct, @unchecked Sendable { /// The topic filter to subscribe to public let topicFilter: String @@ -326,8 +334,10 @@ public class Subscription: CStruct { } +// We can't mutate this class after initialization. Swift can not verify the sendability due to the class is non-final, +// so mark it unchecked Sendable /// Data model of an `MQTT5 SUBSCRIBE `_ packet. -public class SubscribePacket: CStruct { +public class SubscribePacket: CStruct, @unchecked Sendable { /// Array of topic filters that the client wishes to listen to public let subscriptions: [Subscription] @@ -396,8 +406,10 @@ public class SubscribePacket: CStruct { } } +// We can't mutate this class after initialization. Swift can not verify the sendability due to the class is non-final, +// so mark it unchecked Sendable /// Data model of an `MQTT5 SUBACK `_ packet. -public class SubackPacket { +public class SubackPacket: @unchecked Sendable { /// Array of reason codes indicating the result of each individual subscription entry in the associated SUBSCRIBE packet. public let reasonCodes: [SubackReasonCode] @@ -427,8 +439,10 @@ public class SubackPacket { } } +// We can't mutate this class after initialization. Swift can not verify the sendability due to direct use of c pointer, +// so mark it unchecked Sendable /// Data model of an `MQTT5 UNSUBSCRIBE `_ packet. -public class UnsubscribePacket: CStruct { +public class UnsubscribePacket: CStruct, @unchecked Sendable { /// Array of topic filters that the client wishes to unsubscribe from. public let topicFilters: [String] @@ -490,8 +504,10 @@ public class UnsubscribePacket: CStruct { } } +// We can't mutate this class after initialization. Swift can not verify the sendability due to the class is non-final, +// so mark it unchecked Sendable /// Data model of an `MQTT5 UNSUBACK `_ packet. -public class UnsubackPacket { +public class UnsubackPacket: @unchecked Sendable { /// Array of reason codes indicating the result of unsubscribing from each individual topic filter entry in the associated UNSUBSCRIBE packet. public let reasonCodes: [UnsubackReasonCode] @@ -522,7 +538,7 @@ public class UnsubackPacket { } /// Data model of an `MQTT5 DISCONNECT `_ packet. -public class DisconnectPacket: CStruct { +public class DisconnectPacket: CStruct, @unchecked Sendable { /// Value indicating the reason that the sender is closing the connection public let reasonCode: DisconnectReasonCode @@ -602,8 +618,10 @@ public class DisconnectPacket: CStruct { } } +// We can't mutate this class after initialization. Swift can not verify the sendability due to the class is non-final, +// so mark it unchecked Sendable /// Data model of an `MQTT5 CONNACK `_ packet. -public class ConnackPacket { +public class ConnackPacket: @unchecked Sendable { /// True if the client rejoined an existing session on the server, false otherwise. public let sessionPresent: Bool diff --git a/Source/AwsCommonRuntimeKit/sdkutils/FileBasedConfiguration.swift b/Source/AwsCommonRuntimeKit/sdkutils/FileBasedConfiguration.swift index 803157178..b5d7cc17a 100644 --- a/Source/AwsCommonRuntimeKit/sdkutils/FileBasedConfiguration.swift +++ b/Source/AwsCommonRuntimeKit/sdkutils/FileBasedConfiguration.swift @@ -4,7 +4,7 @@ import AwsCSdkUtils import struct Foundation.Data -public class FileBasedConfiguration { +public class FileBasedConfiguration: @unchecked Sendable { var rawValue: OpaquePointer /// If the `AWS_PROFILE` environment variable is set, use it. Otherwise, return "default". @@ -145,7 +145,7 @@ extension FileBasedConfiguration { } /// Represents a section in the FileBasedConfiguration - public class Section { + public class Section: @unchecked Sendable { let rawValue: OpaquePointer // Keep a reference of configuration to keep it alive let fileBasedConfiguration: FileBasedConfiguration @@ -205,7 +205,7 @@ extension FileBasedConfiguration.SectionType { extension FileBasedConfiguration.Section { /// Represents a section property in the file based configuration. - public class Property { + public class Property: @unchecked Sendable { let rawValue: OpaquePointer // Keep a reference of configuration to keep it alive let fileBasedConfiguration: FileBasedConfiguration diff --git a/Source/AwsCommonRuntimeKit/sdkutils/endpoint/EndpointProperty.swift b/Source/AwsCommonRuntimeKit/sdkutils/endpoint/EndpointProperty.swift index 675a59b66..329cda544 100644 --- a/Source/AwsCommonRuntimeKit/sdkutils/endpoint/EndpointProperty.swift +++ b/Source/AwsCommonRuntimeKit/sdkutils/endpoint/EndpointProperty.swift @@ -2,29 +2,16 @@ // SPDX-License-Identifier: Apache-2.0. /// Struct that represents endpoint property which can be a boolean, string or array of endpoint properties -enum EndpointProperty { +public enum EndpointProperty: Sendable, Equatable { case bool(Bool) case string(String) indirect case array([EndpointProperty]) indirect case dictionary([String: EndpointProperty]) - - func toAnyHashable() -> AnyHashable { - switch self { - case .bool(let value): - return AnyHashable(value) - case .string(let value): - return AnyHashable(value) - case .array(let value): - return AnyHashable(value.map { $0.toAnyHashable() }) - case .dictionary(let value): - return AnyHashable(value.mapValues { $0.toAnyHashable() }) - } - } } /// Decodable conformance extension EndpointProperty: Decodable { - init(from decoder: Decoder) throws { + public init(from decoder: Decoder) throws { if let container = try? decoder.container(keyedBy: EndpointPropertyCodingKeys.self) { self = EndpointProperty(from: container) } else if let container = try? decoder.unkeyedContainer() { @@ -84,18 +71,6 @@ extension EndpointProperty: Decodable { } } -extension Dictionary where Key == String, Value == EndpointProperty { - /// Converts EndpointProperty to a dictionary of `String`: `AnyHashable` - /// - Returns: Dictionary of `String`: `AnyHashable` - func toStringHashableDictionary() -> [String: AnyHashable] { - var dict: [String: AnyHashable] = [:] - for (key, value) in self { - dict[key] = value.toAnyHashable() - } - return dict - } -} - /// Coding keys for `EndpointProperty` struct EndpointPropertyCodingKeys: CodingKey { var stringValue: String diff --git a/Source/AwsCommonRuntimeKit/sdkutils/endpoint/EndpointsRuleEngine.swift b/Source/AwsCommonRuntimeKit/sdkutils/endpoint/EndpointsRuleEngine.swift index 18b23a486..c8f768280 100644 --- a/Source/AwsCommonRuntimeKit/sdkutils/endpoint/EndpointsRuleEngine.swift +++ b/Source/AwsCommonRuntimeKit/sdkutils/endpoint/EndpointsRuleEngine.swift @@ -109,16 +109,16 @@ public class EndpointsRuleEngine { /// Get the properties of the resolved endpoint /// - Returns: The properties of the resolved endpoint - func getProperties(rawValue: OpaquePointer) throws -> [String: AnyHashable] { + func getProperties(rawValue: OpaquePointer) throws -> [String: EndpointProperty] { var properties = aws_byte_cursor() guard aws_endpoints_resolved_endpoint_get_properties(rawValue, &properties) == AWS_OP_SUCCESS else { throw CommonRunTimeError.crtError(.makeFromLastError()) } guard properties.len > 0 else { - return [String: AnyHashable]() + return [String: EndpointProperty]() } let data = Data(bytes: properties.ptr, count: properties.len) - return try JSONDecoder().decode([String: EndpointProperty].self, from: data).toStringHashableDictionary() + return try JSONDecoder().decode([String: EndpointProperty].self, from: data) } /// Get the error of the resolved endpoint diff --git a/Source/AwsCommonRuntimeKit/sdkutils/endpoint/ResolvedEndpointType.swift b/Source/AwsCommonRuntimeKit/sdkutils/endpoint/ResolvedEndpointType.swift index e92d2fa69..239dc95ba 100644 --- a/Source/AwsCommonRuntimeKit/sdkutils/endpoint/ResolvedEndpointType.swift +++ b/Source/AwsCommonRuntimeKit/sdkutils/endpoint/ResolvedEndpointType.swift @@ -6,7 +6,7 @@ import AwsCSdkUtils /// Resolved endpoint type public enum ResolvedEndpoint { /// Used for endpoints that are resolved successfully - case endpoint(url: String, headers: [String: [String]], properties: [String: AnyHashable]) + case endpoint(url: String, headers: [String: [String]], properties: [String: EndpointProperty]) /// Used for endpoints that resolve to an error case error(message: String) diff --git a/Source/Elasticurl/CommandLine.swift b/Source/Elasticurl/CommandLine.swift deleted file mode 100644 index 4b2f3748f..000000000 --- a/Source/Elasticurl/CommandLine.swift +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0. - -import Foundation -import AwsCCommon -// swiftlint:disable trailing_whitespace - -struct CommandLineParser { - /// A function to parse command line arguments - /// - Parameters: - /// - argc: The number of arguments - /// - arguments: A pointer to a string pointer of the arguments - /// - optionString: a `String` with all the possible options that could be passed in - /// - options: An array of `[aws_cli_option]` containing all the possible option keys as objects - /// with additional metadata - /// - Returns: A dictionary of`[String: Any] ` with `String` as the name of the flag and `Any` as the - /// value passed in - public static func parseArguments(argc: Int32, - arguments: UnsafeMutablePointer?>, - optionString: String, - options: [aws_cli_option]) -> [String: Any] { - var argumentsDict = [String: Any]() - while true { - var optionIndex: Int32 = 0 - let opt = aws_cli_getopt_long(argc, arguments, optionString, options, &optionIndex) - if opt == -1 || opt == 0 { - break - } - if let char = UnicodeScalar(Int(opt)) { - if aws_cli_optarg != nil { - argumentsDict[String(char)] = String(cString: aws_cli_optarg) - } else { - // if argument doesnt have a value just mark it as present in the dictionary - argumentsDict[String(char)] = true - } - } - } - - return argumentsDict - } -} - -enum CLIHasArg { - case none - case required - case optional -} - -extension CLIHasArg: RawRepresentable, CaseIterable { - public init(rawValue: aws_cli_options_has_arg) { - let value = Self.allCases.first(where: {$0.rawValue == rawValue}) - self = value ?? .none - } - public var rawValue: aws_cli_options_has_arg { - switch self { - case .none: return AWS_CLI_OPTIONS_NO_ARGUMENT - case .required: return AWS_CLI_OPTIONS_REQUIRED_ARGUMENT - case .optional: return AWS_CLI_OPTIONS_OPTIONAL_ARGUMENT - } - } -} - -class AWSCLIOption { - let rawValue: aws_cli_option - let name: UnsafeMutablePointer - init(name: String, hasArg: CLIHasArg, flag: UnsafeMutablePointer? = nil, val: String) { - self.name = strdup(name)! - self.rawValue = aws_cli_option( - name: self.name, - has_arg: hasArg.rawValue, - flag: flag, - val: Int32(bitPattern: UnicodeScalar(val)?.value ?? 0)) - } - - deinit { - free(name) - } -} diff --git a/Source/Elasticurl/Elasticurl.swift b/Source/Elasticurl/Elasticurl.swift index bc782e230..3dbbe401d 100644 --- a/Source/Elasticurl/Elasticurl.swift +++ b/Source/Elasticurl/Elasticurl.swift @@ -1,14 +1,15 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0. -import _Concurrency +import ArgumentParser import AwsCommonRuntimeKit import Foundation +import _Concurrency // swiftlint:disable cyclomatic_complexity function_body_length -struct Context { +struct Context: @unchecked Sendable { // args - public var logLevel: LogLevel = .trace + public var logLevel: LogLevel = .error public var verb: String = "GET" public var caCert: String? public var caPath: String? @@ -27,213 +28,188 @@ struct Context { } @main -struct Elasticurl { - private static let version = "0.1.0" - private static var context = Context() - - static func parseArguments() { - let optionString = "a:b:c:e:f:H:d:g:j:l:m:M:GPHiko:t:v:VwWh" - let options = [ElasticurlOptions.caCert.rawValue, - ElasticurlOptions.caPath.rawValue, - ElasticurlOptions.cert.rawValue, - ElasticurlOptions.connectTimeout.rawValue, - ElasticurlOptions.data.rawValue, - ElasticurlOptions.dataFile.rawValue, - ElasticurlOptions.get.rawValue, - ElasticurlOptions.head.rawValue, - ElasticurlOptions.header.rawValue, - ElasticurlOptions.help.rawValue, - ElasticurlOptions.http2.rawValue, - ElasticurlOptions.http1_1.rawValue, - ElasticurlOptions.include.rawValue, - ElasticurlOptions.insecure.rawValue, - ElasticurlOptions.key.rawValue, - ElasticurlOptions.method.rawValue, - ElasticurlOptions.output.rawValue, - ElasticurlOptions.post.rawValue, - ElasticurlOptions.signingContext.rawValue, - ElasticurlOptions.signingFunc.rawValue, - ElasticurlOptions.signingLib.rawValue, - ElasticurlOptions.trace.rawValue, - ElasticurlOptions.version.rawValue, - ElasticurlOptions.verbose.rawValue, - ElasticurlOptions.lastOption.rawValue] - - let argumentsDict = CommandLineParser.parseArguments(argc: CommandLine.argc, - arguments: CommandLine.unsafeArgv, - optionString: optionString, options: options) - - if let caCert = argumentsDict["a"] as? String { - context.caCert = caCert - } +struct Elasticurl: AsyncParsableCommand { + @Option(name: .long, help: "Path to a CA certificate file") + var cacert: String? - if let caPath = argumentsDict["b"] as? String { - context.caPath = caPath - } + @Option(name: .long, help: "Path to a directory containing CA files") + var capath: String? - if let certificate = argumentsDict["c"] as? String { - context.certificate = certificate - } + @Option(name: .long, help: "Path to a PEM encoded certificate to use with mTLS") + var cert: String? - if let privateKey = argumentsDict["e"] as? String { - context.privateKey = privateKey - } + @Option(name: .long, help: "Path to a PEM encoded private key that matches cert") + var key: String? - if let connectTimeout = argumentsDict["f"] as? Int { - context.connectTimeout = connectTimeout - } + @Option(name: .long, help: "Time in milliseconds to wait for a connection") + var connectTimeout: Int = 3000 - if let headers = argumentsDict["H"] as? String { - let keyValues = headers.components(separatedBy: ",") - for headerKeyValuePair in keyValues { - let keyValuePair = headerKeyValuePair.components(separatedBy: ":") - let key = keyValuePair[0] - let value = keyValuePair[1] - context.headers[key] = value - } - } + // Options with both short and long forms + @Option( + name: [.customShort("H"), .long], + help: "Line to send as a header in format [header-key]: [header-value]") + var header: [String] = [] - if let stringData = argumentsDict["d"] as? String { - context.data = stringData.data(using: .utf8) - } + @Option(name: [.short, .long], help: "Data to POST or PUT") + var data: String? - if let dataFilePath = argumentsDict["g"] as? String { - guard let url = URL(string: dataFilePath) else { - print("path to data file is incorrect or does not exist") - exit(-1) - } - do { - context.data = try Data(contentsOf: url) - } catch { - exit(-1) - } - } + @Option(name: .long, help: "File to read from file and POST or PUT") + var dataFile: String? - if let method = argumentsDict["M"] as? String { - context.verb = method - } + @Option(name: [.customShort("M"), .long], help: "HTTP Method verb to use for the request") + var method: String? - if argumentsDict["G"] != nil { - context.verb = "GET" - } + @Flag(name: [.customShort("G"), .long], help: "Uses GET for the verb") + var get: Bool = false - if argumentsDict["P"] != nil { - context.verb = "POST" - } + @Flag(name: [.customShort("P"), .long], help: "Uses POST for the verb") + var post: Bool = false - if argumentsDict["I"] != nil { - context.verb = "HEAD" - } + @Flag(name: [.customShort("I"), .long], help: "Uses HEAD for the verb") + var head: Bool = false - if argumentsDict["i"] != nil { - context.includeHeaders = true - } + @Flag(name: [.short, .long], help: "Includes headers in output") + var include: Bool = false - if argumentsDict["k"] != nil { - context.insecure = true - } + @Flag(name: [.customShort("k"), .long], help: "Turns off SSL/TLS validation") + var insecure: Bool = false - if let fileName = argumentsDict["o"] as? String { - context.outputFileName = fileName - } + @Option(name: .long, help: "Path to a shared library with an exported signing function to use") + var signingLib: String? - if let traceFile = argumentsDict["t"] as? String { - context.traceFile = traceFile - } + @Option(name: .long, help: "Name of the signing function to use within the signing library") + var signingFunc: String? - if let logLevel = argumentsDict["v"] as? String { - context.logLevel = LogLevel.fromString(string: logLevel) - } + @Option( + name: .long, + help: "Key=value pair to pass to the signing function; may be used multiple times") + var signingContext: [String] = [] + + @Option(name: [.short, .long], help: "Dumps content-body to FILE instead of stdout") + var output: String? + + @Option(name: [.short, .long], help: "Dumps logs to FILE instead of stderr") + var trace: String? + + @Option( + name: [.short, .long], + help: "ERROR|INFO|DEBUG|TRACE: log level to configure. Default is ERROR") + var verbose: String? + + @Flag(name: .long, help: "HTTP/2 connection required") + var http2: Bool = false + + @Flag(name: .customLong("http1_1"), help: "HTTP/1.1 connection required") + var http1_1: Bool = false - if argumentsDict["V"] != nil { - print("elasticurl \(version)") - exit(0) + @Argument(help: "URL to make a request to") + var urlString: String + + func run() async { + let context = buildContext() + if let traceFile = context.traceFile { + print("enable logging with trace file") + try? Logger.initialize(target: .filePath(traceFile), level: context.logLevel) + } else { + print("enable logging with stdout") + try? Logger.initialize(target: .standardOutput, level: context.logLevel) } + await Elasticurl.run(context) + } - if argumentsDict["W"] != nil { - context.alpnList.append("http/1.1") + func buildContext() -> Context { + var context = Context() + + // Convert command-line args to Context + context.caCert = cacert + context.caPath = capath + context.certificate = cert + context.privateKey = key + context.connectTimeout = connectTimeout + context.includeHeaders = include + context.outputFileName = output + context.traceFile = trace + context.insecure = insecure + + // Process verbose/log level + if let verboseLevel = verbose { + context.logLevel = LogLevel.fromString(string: verboseLevel) } - if argumentsDict["w"] != nil { - context.alpnList.append("h2") + // Process headers + for headerString in header { + let components = headerString.components(separatedBy: ":") + if components.count >= 2 { + let key = components[0].trimmingCharacters(in: .whitespaces) + let value = components[1...].joined(separator: ":").trimmingCharacters( + in: .whitespaces) + context.headers[key] = value + } } - if argumentsDict["w"] == nil, argumentsDict["W"] == nil { - context.alpnList.append("h2") - context.alpnList.append("http/1.1") + // Process data + if let stringData = data { + context.data = stringData.data(using: .utf8) + } else if let dataFilePath = dataFile { + guard let url = URL(string: dataFilePath) else { + print("Path to data file is incorrect or does not exist") + Foundation.exit(1) + } + do { + context.data = try Data(contentsOf: url) + } catch { + print("Failed to read data file: \(error)") + Foundation.exit(1) + } } - if argumentsDict["h"] != nil { - showHelp() - exit(0) + // Determine HTTP verb + if let method = method { + context.verb = method + } else if get { + context.verb = "GET" + } else if post { + context.verb = "POST" + } else if head { + context.verb = "HEAD" } - // make sure a url was given before we do anything else - guard let urlString = CommandLine.arguments.last, - let url = URL(string: urlString) - else { - print("Invalid URL: \(CommandLine.arguments.last!)") - exit(-1) + // Set ALPN list + if http2 && !http1_1 { + context.alpnList = ["h2"] + } else if http1_1 && !http2 { + context.alpnList = ["http/1.1"] + } else { + context.alpnList = ["h2", "http/1.1"] } - context.url = url - } - static func showHelp() { - print("usage: elasticurl [options] url") - print("url: url to make a request to. The default is a GET request") - print("Options:") - print(" --cacert FILE: path to a CA certficate file.") - print(" --capath PATH: path to a directory containing CA files.") - print(" -c, --cert FILE: path to a PEM encoded certificate to use with mTLS") - print(" --key FILE: Path to a PEM encoded private key that matches cert.") - print(" --connect-timeout INT: time in milliseconds to wait for a connection.") - print(" -H, --header LINE: line to send as a header in format [header-key]: [header-value]") - print(" -d, --data STRING: Data to POST or PUT") - print(" --data-file FILE: File to read from file and POST or PUT") - print(" -M, --method STRING: Http Method verb to use for the request") - print(" -G, --get: uses GET for the verb.") - print(" -P, --post: uses POST for the verb.") - print(" -I, --head: uses HEAD for the verb.") - print(" -i, --include: includes headers in output.") - print(" -k, --insecure: turns off SSL/TLS validation.") - print(" -o, --output FILE: dumps content-body to FILE instead of stdout.") - print(" -t, --trace FILE: dumps logs to FILE instead of stderr.") - print(" -v, --verbose ERROR|INFO|DEBUG|TRACE: log level to configure. Default is none.") - print(" -h, --help: Display this message and quit.") - } + // Set URL + guard let parsedURL = URL(string: urlString) else { + print("Invalid URL: \(urlString)") + Foundation.exit(-1) + } + context.url = parsedURL - static func createOutputFile() { if let fileName = context.outputFileName { let fileManager = FileManager.default let path = FileManager.default.currentDirectoryPath + "/" + fileName fileManager.createFile(atPath: path, contents: nil, attributes: nil) - context.outputStream = FileHandle(forWritingAtPath: fileName) ?? FileHandle.standardOutput + context.outputStream = + FileHandle(forWritingAtPath: fileName) ?? FileHandle.standardOutput } - } - static func writeData(data: Data) { - context.outputStream.write(data) + return context } - static func main() async { - parseArguments() - createOutputFile() - if let traceFile = context.traceFile { - print("enable logging with trace file") - try? Logger.initialize(target: .filePath(traceFile), level: context.logLevel) - } else { - print("enable logging with stdout") - try? Logger.initialize(target: .standardOutput, level: context.logLevel) - } - - await run() + static func writeData(data: Data, context: Context) { + context.outputStream.write(data) } - static func run() async { + static func run(_ context: Context) async { do { guard let host = context.url.host else { print("no proper host was parsed from the url. quitting.") - exit(EXIT_FAILURE) + Foundation.exit(EXIT_FAILURE) } CommonRuntimeKit.initialize() @@ -249,10 +225,12 @@ struct Elasticurl { tlsConnectionOptions.serverName = host let elg = try EventLoopGroup(threadCount: 1) - let hostResolver = try HostResolver.makeDefault(eventLoopGroup: elg, maxHosts: 8, maxTTL: 30) + let hostResolver = try HostResolver.makeDefault( + eventLoopGroup: elg, maxHosts: 8, maxTTL: 30) - let bootstrap = try ClientBootstrap(eventLoopGroup: elg, - hostResolver: hostResolver) + let bootstrap = try ClientBootstrap( + eventLoopGroup: elg, + hostResolver: hostResolver) let socketOptions = SocketOptions(socketType: .stream) @@ -272,14 +250,15 @@ struct Elasticurl { } httpRequest.addHeaders(headers: headers) - let httpClientOptions = HTTPClientConnectionOptions(clientBootstrap: bootstrap, - hostName: context.url.host!, - initialWindowSize: Int.max, - port: port, - proxyOptions: nil, - socketOptions: socketOptions, - tlsOptions: tlsConnectionOptions, - monitoringOptions: nil) + let httpClientOptions = HTTPClientConnectionOptions( + clientBootstrap: bootstrap, + hostName: context.url.host!, + initialWindowSize: Int.max, + port: port, + proxyOptions: nil, + socketOptions: socketOptions, + tlsOptions: tlsConnectionOptions, + monitoringOptions: nil) let connectionManager = try HTTPClientConnectionManager(options: httpClientOptions) let connection = try await connectionManager.acquireConnection() @@ -291,7 +270,7 @@ struct Elasticurl { } let onBody: HTTPRequestOptions.OnIncomingBody = { bodyChunk in - writeData(data: bodyChunk) + writeData(data: bodyChunk, context: context) } let onComplete: HTTPRequestOptions.OnStreamComplete = { result in @@ -305,10 +284,11 @@ struct Elasticurl { } do { - let requestOptions = HTTPRequestOptions(request: httpRequest, - onResponse: onResponse, - onIncomingBody: onBody, - onStreamComplete: onComplete) + let requestOptions = HTTPRequestOptions( + request: httpRequest, + onResponse: onResponse, + onIncomingBody: onBody, + onStreamComplete: onComplete) stream = try connection.makeRequest(requestOptions: requestOptions) try stream!.activate() } catch { @@ -316,11 +296,10 @@ struct Elasticurl { } } - exit(EXIT_SUCCESS) + Foundation.exit(EXIT_SUCCESS) } catch let err { - showHelp() print(err) - exit(EXIT_FAILURE) + Foundation.exit(EXIT_FAILURE) } } } diff --git a/Source/Elasticurl/ElasticurlOptions.swift b/Source/Elasticurl/ElasticurlOptions.swift deleted file mode 100644 index 8e24ee62e..000000000 --- a/Source/Elasticurl/ElasticurlOptions.swift +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0. - -import AwsCommonRuntimeKit - -// swiftlint:disable identifier_name -public struct ElasticurlOptions { - static let caCert = AWSCLIOption(name: "cacert", - hasArg: .required, - flag: nil, - val: "a") - static let caPath = AWSCLIOption(name: "capath", - hasArg: .required, - flag: nil, - val: "b") - static let cert = AWSCLIOption(name: "cert", - hasArg: .required, - flag: nil, - val: "c") - static let key = AWSCLIOption(name: "key", - hasArg: .required, - flag: nil, - val: "e") - static let connectTimeout = AWSCLIOption(name: "connect-timeout", - hasArg: .required, - flag: nil, - val: "f") - static let header = AWSCLIOption(name: "header", - hasArg: .required, - flag: nil, - val: "H") - static let data = AWSCLIOption(name: "data", - hasArg: .required, - flag: nil, - val: "d") - static let dataFile = AWSCLIOption(name: "data-file", - hasArg: .required, - flag: nil, - val: "g") - static let method = AWSCLIOption(name: "method", - hasArg: .required, - flag: nil, - val: "M") - static let get = AWSCLIOption(name: "get", - hasArg: .none, - flag: nil, - val: "G") - static let post = AWSCLIOption(name: "post", - hasArg: .none, - flag: nil, - val: "P") - static let head = AWSCLIOption(name: "head", - hasArg: .none, - flag: nil, - val: "I") - static let signingLib = AWSCLIOption(name: "signing-lib", - hasArg: .required, - flag: nil, - val: "j") - static let include = AWSCLIOption(name: "include", - hasArg: .none, - flag: nil, - val: "i") - static let insecure = AWSCLIOption(name: "insecure", - hasArg: .none, - flag: nil, - val: "k") - static let signingFunc = AWSCLIOption(name: "signing-func", - hasArg: .required, - flag: nil, - val: "l") - static let signingContext = AWSCLIOption(name: "signing-context", - hasArg: .required, - flag: nil, - val: "m") - static let output = AWSCLIOption(name: "output", - hasArg: .required, - flag: nil, - val: "o") - static let trace = AWSCLIOption(name: "trace", - hasArg: .required, - flag: nil, - val: "t") - static let verbose = AWSCLIOption(name: "verbose", - hasArg: .required, - flag: nil, - val: "v") - static let version = AWSCLIOption(name: "version", - hasArg: .none, - flag: nil, - val: "V") - static let http2 = AWSCLIOption(name: "http2", - hasArg: .none, - flag: nil, - val: "w") - static let http1_1 = AWSCLIOption(name: "http1_1", - hasArg: .none, - flag: nil, - val: "W") - static let help = AWSCLIOption(name: "help", - hasArg: .none, - flag: nil, - val: "h") - static let lastOption = AWSCLIOption(name: "", - hasArg: .none, - flag: nil, - val: "0") -} diff --git a/Source/LibNative/CommonRuntimeError.h b/Source/LibNative/CommonRuntimeError.h index c2d70a451..4813c3ec0 100644 --- a/Source/LibNative/CommonRuntimeError.h +++ b/Source/LibNative/CommonRuntimeError.h @@ -23,13 +23,13 @@ enum aws_swift_errors { }; -static struct aws_error_info s_crt_swift_errors[] = { +static const struct aws_error_info s_crt_swift_errors[] = { AWS_DEFINE_ERROR_INFO_CRT_SWIFT( AWS_CRT_SWIFT_MQTT_CLIENT_CLOSED, "The Mqtt Client is closed.") }; -static struct aws_error_info_list s_crt_swift_error_list = { +static const struct aws_error_info_list s_crt_swift_error_list = { .error_list = s_crt_swift_errors, .count = AWS_ARRAY_SIZE(s_crt_swift_errors), }; diff --git a/Test/AwsCommonRuntimeKitTests/XCBaseTestCase.swift b/Test/AwsCommonRuntimeKitTests/XCBaseTestCase.swift index 150172cc8..758fe5e8b 100644 --- a/Test/AwsCommonRuntimeKitTests/XCBaseTestCase.swift +++ b/Test/AwsCommonRuntimeKitTests/XCBaseTestCase.swift @@ -1,13 +1,14 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0. +import AwsCCommon import XCTest + @testable import AwsCommonRuntimeKit -import AwsCCommon class XCBaseTestCase: XCTestCase { internal let tracingAllocator = TracingAllocator(tracingStacksOf: allocator) - + override func setUp() { super.setUp() // XCode currently lacks a way to enable logs exclusively for failed tests only. @@ -24,17 +25,18 @@ class XCBaseTestCase: XCTestCase { CommonRuntimeKit.cleanUp() tracingAllocator.dump() - XCTAssertEqual(tracingAllocator.count, 0, - "Memory was leaked: \(tracingAllocator.bytes) bytes in \(tracingAllocator.count) allocations") + XCTAssertEqual( + tracingAllocator.count, 0, + "Memory was leaked: \(tracingAllocator.bytes) bytes in \(tracingAllocator.count) allocations" + ) super.tearDown() } } - extension XCTestCase { func skipTest(message: String) throws { - throw XCTSkip(message) + throw XCTSkip(message) } func skipIfiOS() throws { @@ -67,6 +69,24 @@ extension XCTestCase { #endif } + func awaitExpectationResult(_ expectations: [XCTestExpectation], _ timeout: TimeInterval = 5) async -> XCTWaiter.Result { + let waiter = XCTWaiter() + // Remove the Ifdef once our minimum supported Swift version reaches 5.10 + #if swift(>=5.10) + return await waiter.fulfillment(of: expectations, timeout: timeout) + #else + return waiter.wait(for: expectations, timeout: timeout) + #endif + } + + func awaitExpectation(_ expectations: [XCTestExpectation], _ timeout: TimeInterval = 5) async { + // Remove the Ifdef once our minimum supported Swift version reaches 5.10 + #if swift(>=5.10) + await fulfillment(of: expectations, timeout: timeout) + #else + wait(for: expectations, timeout: timeout) + #endif + } func skipIfPlatformDoesntSupportTLS() throws { // Skipped for secitem support as the unit tests requires enetitlement setup to have acces to // the data protection keychain. @@ -83,3 +103,33 @@ extension XCTestCase { return result } } + +/* + * Async Semaphore compatible with Swift's structured concurrency. Swift complains about the normal sync Semaphore since it's a blocking wait. + * See: https://forums.swift.org/t/semaphore-alternatives-for-structured-concurrency/59353 + */ +actor TestSemaphore { + private var count: Int + private var waiters: [CheckedContinuation] = [] + + init(value: Int = 0) { + self.count = value + } + + func wait() async { + count -= 1 + if count >= 0 { return } + await withCheckedContinuation { + waiters.append($0) + } + } + + func signal(count: Int = 1) { + assert(count >= 1) + self.count += count + for _ in 0..(closure: () async throws -> T) async rethrows -> T { @@ -142,7 +142,7 @@ class CredentialsProviderTests: XCBaseTestCase { XCTAssertEqual("accessKey", credentials.getAccessKey()) XCTAssertEqual("secretKey", credentials.getSecret()) } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } func testCreateCredentialsProviderProcess() async throws { @@ -158,7 +158,7 @@ class CredentialsProviderTests: XCBaseTestCase { XCTAssertEqual("SecretAccessKey123", credentials.getSecret()) XCTAssertEqual("SessionToken123", credentials.getSessionToken()) } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } func testCreateCredentialsProviderSSO() async throws { @@ -175,7 +175,7 @@ class CredentialsProviderTests: XCBaseTestCase { // get credentials will fail in CI due to expired token, so do not assert on credentials. _ = try? await provider.getCredentials() } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } func testCreateCredentialsProviderImds() async throws { @@ -183,7 +183,7 @@ class CredentialsProviderTests: XCBaseTestCase { _ = try CredentialsProvider(source: .imds(bootstrap: getClientBootstrap(), shutdownCallback: getShutdownCallback())) } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } func testCreateCredentialsProviderCache() async throws { @@ -197,7 +197,7 @@ class CredentialsProviderTests: XCBaseTestCase { XCTAssertNotNil(credentials) assertCredentials(credentials: credentials) } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } func testCreateAWSCredentialsProviderDefaultChain() async throws { @@ -215,7 +215,7 @@ class CredentialsProviderTests: XCBaseTestCase { assertCredentials(credentials: credentials) } } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } @@ -242,7 +242,7 @@ class CredentialsProviderTests: XCBaseTestCase { } catch { exceptionWasThrown.fulfill() } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } // Http proxy related tests could only run behind vpc to access the proxy @@ -259,7 +259,7 @@ class CredentialsProviderTests: XCBaseTestCase { environmentVarName: "AWS_TEST_HTTP_PROXY_HOST") let httpproxyPort = try getEnvironmentVarOrSkipTest( environmentVarName: "AWS_TEST_HTTP_PROXY_PORT") - + let httpProxys = HTTPProxyOptions( hostName: httpproxyHost, port: UInt32(httpproxyPort)!, connectionType: .tunnel) @@ -269,6 +269,7 @@ class CredentialsProviderTests: XCBaseTestCase { bootstrap: getClientBootstrap(), tlsContext: getTlsContext(), endpoint: cognitoEndpoint, identity: cognitoIdentity, + proxyOptions: httpProxys, shutdownCallback: getShutdownCallback())) let credentials = try await provider.getCredentials() XCTAssertNotNil(credentials) @@ -278,7 +279,7 @@ class CredentialsProviderTests: XCBaseTestCase { } catch { exceptionWasThrown.fulfill() } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } func testCreateDestroyStsWebIdentityInvalidEnv() async throws { @@ -288,7 +289,7 @@ class CredentialsProviderTests: XCBaseTestCase { fileBasedConfiguration: FileBasedConfiguration())) ) } - + func testCreateDestroyStsWebIdentity() async throws { _ = try! CredentialsProvider(source: .stsWebIdentity( bootstrap: getClientBootstrap(), @@ -326,6 +327,6 @@ class CredentialsProviderTests: XCBaseTestCase { } catch { exceptionWasThrown.fulfill() } - wait(for: [exceptionWasThrown], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } } diff --git a/Test/AwsCommonRuntimeKitTests/crt/ShutDownCallbackOptionsTests.swift b/Test/AwsCommonRuntimeKitTests/crt/ShutDownCallbackOptionsTests.swift index 9b9e00ae6..01fea4c96 100644 --- a/Test/AwsCommonRuntimeKitTests/crt/ShutDownCallbackOptionsTests.swift +++ b/Test/AwsCommonRuntimeKitTests/crt/ShutDownCallbackOptionsTests.swift @@ -13,6 +13,6 @@ class ShutdownCallbackOptionsTests: XCBaseTestCase { shutdownWasCalled.fulfill() } } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } } diff --git a/Test/AwsCommonRuntimeKitTests/event-stream/EventStreamTests.swift b/Test/AwsCommonRuntimeKitTests/event-stream/EventStreamTests.swift index 18bfd7b9a..b8ef55be4 100644 --- a/Test/AwsCommonRuntimeKitTests/event-stream/EventStreamTests.swift +++ b/Test/AwsCommonRuntimeKitTests/event-stream/EventStreamTests.swift @@ -6,7 +6,6 @@ import AwsCEventStream @testable import AwsCommonRuntimeKit class EventStreamTests: XCBaseTestCase { - let semaphore = DispatchSemaphore(value: 0) func testEncodeDecodeHeaders() async throws { let onCompleteWasCalled = XCTestExpectation(description: "OnComplete was called") @@ -48,7 +47,7 @@ class EventStreamTests: XCBaseTestCase { }) try decoder.decode(data: encoded) XCTAssertTrue(headers.elementsEqual(decodedHeaders)) - wait(for: [onCompleteWasCalled], timeout: 1) + await awaitExpectation([onCompleteWasCalled]) } func testEncodeDecodePayload() async throws { @@ -76,7 +75,7 @@ class EventStreamTests: XCBaseTestCase { }) try decoder.decode(data: encoded) XCTAssertEqual(payload, decodedPayload) - wait(for: [onCompleteWasCalled], timeout: 1) + await awaitExpectation([onCompleteWasCalled]) } func testEncodeOutOfScope() async throws { @@ -114,7 +113,7 @@ class EventStreamTests: XCBaseTestCase { let expectedHeaders = [EventStreamHeader(name: "int16", value: .int32(value: 16))] XCTAssertTrue(expectedHeaders.elementsEqual(decodedHeaders)) - wait(for: [onCompleteWasCalled], timeout: 1) + await awaitExpectation([onCompleteWasCalled]) } func testDecodeByteByByte() async throws { @@ -150,7 +149,7 @@ class EventStreamTests: XCBaseTestCase { XCTAssertEqual(payload, decodedPayload) XCTAssertTrue(headers.elementsEqual(decodedHeaders)) - wait(for: [onCompleteWasCalled], timeout: 1) + await awaitExpectation([onCompleteWasCalled]) } func testEmpty() async throws { @@ -175,6 +174,6 @@ class EventStreamTests: XCBaseTestCase { XCTFail("Error occurred. Code: \(code)\nMessage:\(message)") }) try decoder.decode(data: encoded) - wait(for: [onCompleteWasCalled], timeout: 1) + await awaitExpectation([onCompleteWasCalled]) } } diff --git a/Test/AwsCommonRuntimeKitTests/http/HTTP2ClientConnectionTests.swift b/Test/AwsCommonRuntimeKitTests/http/HTTP2ClientConnectionTests.swift index fc41097ed..38c602e1a 100644 --- a/Test/AwsCommonRuntimeKitTests/http/HTTP2ClientConnectionTests.swift +++ b/Test/AwsCommonRuntimeKitTests/http/HTTP2ClientConnectionTests.swift @@ -3,19 +3,19 @@ import XCTest @testable import AwsCommonRuntimeKit -class HTTP2ClientConnectionTests: HTTPClientTestFixture { +class HTTP2ClientConnectionTests: XCBaseTestCase { let expectedVersion = HTTPVersion.version_2 func testGetHTTP2RequestVersion() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) let connection = try await connectionManager.acquireConnection() XCTAssertEqual(connection.httpVersion, HTTPVersion.version_2) } // Test that the binding works not the actual functionality. C part has tests for functionality func testHTTP2UpdateSetting() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) let connection = try await connectionManager.acquireConnection() if let connection = connection as? HTTP2ClientConnection { try await connection.updateSetting(setting: HTTP2Settings(enablePush: false)) @@ -26,7 +26,7 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { // Test that the binding works not the actual functionality. C part has tests for functionality func testHTTP2UpdateSettingEmpty() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) let connection = try await connectionManager.acquireConnection() if let connection = connection as? HTTP2ClientConnection { try await connection.updateSetting(setting: HTTP2Settings()) @@ -37,7 +37,7 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { // Test that the binding works not the actual functionality. C part has tests for functionality func testHTTP2SendPing() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) let connection = try await connectionManager.acquireConnection() if let connection = connection as? HTTP2ClientConnection { var time = try await connection.sendPing() @@ -51,7 +51,7 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { // Test that the binding works not the actual functionality. C part has tests for functionality func testHTTP2SendGoAway() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) let connection = try await connectionManager.acquireConnection() if let connection = connection as? HTTP2ClientConnection { connection.sendGoAway(error: .internalError, allowMoreStreams: false) @@ -61,8 +61,8 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { } func testGetHttpsRequest() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) - let response = try await sendHTTPRequest( + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) + let response = try await HTTPClientTestFixture.sendHTTPRequest( method: "GET", endpoint: "httpbin.org", path: "/get", @@ -71,7 +71,7 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { requestVersion: .version_2) // The first header of response has to be ":status" for HTTP/2 response XCTAssertEqual(response.headers[0].name, ":status") - let response2 = try await sendHTTPRequest( + let response2 = try await HTTPClientTestFixture.sendHTTPRequest( method: "GET", endpoint: "httpbin.org", path: "/delete", @@ -84,8 +84,8 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { func testGetHttpsRequestWithHTTP1_1Request() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) - let response = try await sendHTTPRequest( + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "httpbin.org", alpnList: ["h2","http/1.1"]) + let response = try await HTTPClientTestFixture.sendHTTPRequest( method: "GET", endpoint: "httpbin.org", path: "/get", @@ -94,7 +94,7 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { requestVersion: .version_1_1) // The first header of response has to be ":status" for HTTP/2 response XCTAssertEqual(response.headers[0].name, ":status") - let response2 = try await sendHTTPRequest( + let response2 = try await HTTPClientTestFixture.sendHTTPRequest( method: "GET", endpoint: "httpbin.org", path: "/delete", @@ -106,8 +106,8 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { } func testHTTP2Download() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "d1cz66xoahf9cl.cloudfront.net", alpnList: ["h2","http/1.1"]) - let response = try await sendHTTPRequest( + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "d1cz66xoahf9cl.cloudfront.net", alpnList: ["h2","http/1.1"]) + let response = try await HTTPClientTestFixture.sendHTTPRequest( method: "GET", endpoint: "d1cz66xoahf9cl.cloudfront.net", path: "/http_test_doc.txt", @@ -121,8 +121,8 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { } func testHTTP2DownloadWithHTTP1_1Request() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "d1cz66xoahf9cl.cloudfront.net", alpnList: ["h2","http/1.1"]) - let response = try await sendHTTPRequest( + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "d1cz66xoahf9cl.cloudfront.net", alpnList: ["h2","http/1.1"]) + let response = try await HTTPClientTestFixture.sendHTTPRequest( method: "GET", endpoint: "d1cz66xoahf9cl.cloudfront.net", path: "/http_test_doc.txt", @@ -136,12 +136,12 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { } func testHTTP2StreamUpload() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: "nghttp2.org", alpnList: ["h2"]) - let semaphore = DispatchSemaphore(value: 0) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: "nghttp2.org", alpnList: ["h2"]) + let semaphore = TestSemaphore(value: 0) var httpResponse = HTTPResponse() var onCompleteCalled = false let testBody = "testBody" - let http2RequestOptions = try getHTTP2RequestOptions( + let http2RequestOptions = try HTTPClientTestFixture.getHTTP2RequestOptions( method: "PUT", path: "/httpbin/put", authority: "nghttp2.org", @@ -156,7 +156,7 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { let streamBase = try connection.makeRequest(requestOptions: http2RequestOptions) try streamBase.activate() XCTAssertFalse(onCompleteCalled) - let data = TEST_DOC_LINE.data(using: .utf8)! + let data = HTTPClientTestFixture.TEST_DOC_LINE.data(using: .utf8)! for chunk in data.chunked(into: 5) { try await streamBase.writeChunk(chunk: chunk, endOfStream: false) XCTAssertFalse(onCompleteCalled) @@ -167,7 +167,7 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { try await Task.sleep(nanoseconds: 5_000_000_000) XCTAssertFalse(onCompleteCalled) try await streamBase.writeChunk(chunk: Data(), endOfStream: true) - semaphore.wait() + await semaphore.wait() XCTAssertTrue(onCompleteCalled) XCTAssertNil(httpResponse.error) XCTAssertEqual(httpResponse.statusCode, 200) @@ -178,6 +178,6 @@ class HTTP2ClientConnectionTests: HTTPClientTestFixture { } let body: Response = try! JSONDecoder().decode(Response.self, from: httpResponse.body) - XCTAssertEqual(body.data, testBody + TEST_DOC_LINE) + XCTAssertEqual(body.data, testBody + HTTPClientTestFixture.TEST_DOC_LINE) } } diff --git a/Test/AwsCommonRuntimeKitTests/http/HTTP2StreamManagerTests.swift b/Test/AwsCommonRuntimeKitTests/http/HTTP2StreamManagerTests.swift index 110a60761..b961ae369 100644 --- a/Test/AwsCommonRuntimeKitTests/http/HTTP2StreamManagerTests.swift +++ b/Test/AwsCommonRuntimeKitTests/http/HTTP2StreamManagerTests.swift @@ -4,7 +4,7 @@ import XCTest @testable import AwsCommonRuntimeKit -class HTT2StreamManagerTests: HTTPClientTestFixture { +class HTT2StreamManagerTests: XCBaseTestCase { let endpoint = "d1cz66xoahf9cl.cloudfront.net"; // Use cloudfront for HTTP/2 let path = "/random_32_byte.data"; @@ -102,16 +102,16 @@ class HTT2StreamManagerTests: HTTPClientTestFixture { func testHTTP2Stream() async throws { let streamManager = try makeStreamManger(host: endpoint) - _ = try await sendHTTP2Request(method: "GET", path: path, authority: endpoint, streamManager: streamManager) + _ = try await HTTPClientTestFixture.sendHTTP2Request(method: "GET", path: path, authority: endpoint, streamManager: streamManager) } func testHTTP2StreamUpload() async throws { let streamManager = try makeStreamManger(host: "nghttp2.org") - let semaphore = DispatchSemaphore(value: 0) + let semaphore = TestSemaphore(value: 0) var httpResponse = HTTPResponse() var onCompleteCalled = false let testBody = "testBody" - let http2RequestOptions = try getHTTP2RequestOptions( + let http2RequestOptions = try HTTPClientTestFixture.getHTTP2RequestOptions( method: "PUT", path: "/httpbin/put", authority: "nghttp2.org", @@ -128,7 +128,7 @@ class HTT2StreamManagerTests: HTTPClientTestFixture { let metrics = streamManager.fetchMetrics() XCTAssertTrue(metrics.availableConcurrency > 0) XCTAssertTrue(metrics.leasedConcurrency > 0) - let data = TEST_DOC_LINE.data(using: .utf8)! + let data = HTTPClientTestFixture.TEST_DOC_LINE.data(using: .utf8)! for chunk in data.chunked(into: 5) { try await stream.writeChunk(chunk: chunk, endOfStream: false) XCTAssertFalse(onCompleteCalled) @@ -139,7 +139,7 @@ class HTT2StreamManagerTests: HTTPClientTestFixture { try await Task.sleep(nanoseconds: 5_000_000_000) XCTAssertFalse(onCompleteCalled) try await stream.writeChunk(chunk: Data(), endOfStream: true) - semaphore.wait() + await semaphore.wait() XCTAssertTrue(onCompleteCalled) XCTAssertNil(httpResponse.error) XCTAssertEqual(httpResponse.statusCode, 200) @@ -150,13 +150,13 @@ class HTT2StreamManagerTests: HTTPClientTestFixture { } let body: Response = try! JSONDecoder().decode(Response.self, from: httpResponse.body) - XCTAssertEqual(body.data, testBody + TEST_DOC_LINE) + XCTAssertEqual(body.data, testBody + HTTPClientTestFixture.TEST_DOC_LINE) } // Test that the binding works not the actual functionality. C part has tests for functionality func testHTTP2StreamReset() async throws { let streamManager = try makeStreamManger(host: endpoint) - let http2RequestOptions = try getHTTP2RequestOptions( + let http2RequestOptions = try HTTPClientTestFixture.getHTTP2RequestOptions( method: "PUT", path: "/httpbin/put", authority: "nghttp2.org") @@ -171,18 +171,12 @@ class HTT2StreamManagerTests: HTTPClientTestFixture { func testHTTP2ParallelStreams(count: Int) async throws { let streamManager = try makeStreamManger(host: "nghttp2.org") - let requestCompleteExpectation = XCTestExpectation(description: "Request was completed successfully") - requestCompleteExpectation.expectedFulfillmentCount = count - await withTaskGroup(of: Void.self) { taskGroup in + return await withTaskGroup(of: Void.self) { taskGroup in for _ in 1...count { taskGroup.addTask { - _ = try! await self.sendHTTP2Request(method: "GET", path: "/httpbin/get", authority: "nghttp2.org", streamManager: streamManager, onComplete: { _ in - requestCompleteExpectation.fulfill() - }) + _ = try! await HTTPClientTestFixture.sendHTTP2Request(method: "GET", path: "/httpbin/get", authority: "nghttp2.org", streamManager: streamManager) } } } - wait(for: [requestCompleteExpectation], timeout: 15) - print("Request were successfully completed.") } } diff --git a/Test/AwsCommonRuntimeKitTests/http/HTTPClientConnectionManagerTests.swift b/Test/AwsCommonRuntimeKitTests/http/HTTPClientConnectionManagerTests.swift index 748db1973..8c87bca56 100644 --- a/Test/AwsCommonRuntimeKitTests/http/HTTPClientConnectionManagerTests.swift +++ b/Test/AwsCommonRuntimeKitTests/http/HTTPClientConnectionManagerTests.swift @@ -37,6 +37,6 @@ class HTTPClientConnectionManagerTests: XCBaseTestCase { } _ = try HTTPClientConnectionManager(options: httpClientOptions) } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } } diff --git a/Test/AwsCommonRuntimeKitTests/http/HTTPClientTestFixture.swift b/Test/AwsCommonRuntimeKitTests/http/HTTPClientTestFixture.swift index 80b87f9d4..87f2d279f 100644 --- a/Test/AwsCommonRuntimeKitTests/http/HTTPClientTestFixture.swift +++ b/Test/AwsCommonRuntimeKitTests/http/HTTPClientTestFixture.swift @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0. import XCTest + @testable import AwsCommonRuntimeKit struct HTTPResponse { @@ -13,54 +14,56 @@ struct HTTPResponse { } class HTTPClientTestFixture: XCBaseTestCase { - let TEST_DOC_LINE: String = """ - This is a sample to prove that http downloads and uploads work. - It doesn't really matter what's in here, - we mainly just need to verify the downloads and uploads work. - """ - - func sendHTTPRequest(method: String, - endpoint: String, - path: String = "/", - body: String = "", - expectedStatus: Int = 200, - connectionManager: HTTPClientConnectionManager, - expectedVersion: HTTPVersion = HTTPVersion.version_1_1, - requestVersion: HTTPVersion = HTTPVersion.version_1_1, - numRetries: UInt = 2, - onResponse: HTTPRequestOptions.OnResponse? = nil, - onBody: HTTPRequestOptions.OnIncomingBody? = nil, - onComplete: HTTPRequestOptions.OnStreamComplete? = nil) async throws -> HTTPResponse { + static let TEST_DOC_LINE: String = """ + This is a sample to prove that http downloads and uploads work. + It doesn't really matter what's in here, + we mainly just need to verify the downloads and uploads work. + """ + + static func sendHTTPRequest( + method: String, + endpoint: String, + path: String = "/", + body: String = "", + expectedStatus: Int = 200, + connectionManager: HTTPClientConnectionManager, + expectedVersion: HTTPVersion = HTTPVersion.version_1_1, + requestVersion: HTTPVersion = HTTPVersion.version_1_1, + numRetries: UInt = 2, + onResponse: HTTPRequestOptions.OnResponse? = nil, + onBody: HTTPRequestOptions.OnIncomingBody? = nil, + onComplete: HTTPRequestOptions.OnStreamComplete? = nil + ) async throws -> HTTPResponse { var httpResponse = HTTPResponse() - let semaphore = DispatchSemaphore(value: 0) + let semaphore = TestSemaphore(value: 0) let httpRequestOptions: HTTPRequestOptions if requestVersion == HTTPVersion.version_2 { httpRequestOptions = try getHTTP2RequestOptions( - method: method, - path: path, - authority: endpoint, - body: body, - response: &httpResponse, - semaphore: semaphore, - onResponse: onResponse, - onBody: onBody, - onComplete: onComplete) + method: method, + path: path, + authority: endpoint, + body: body, + response: &httpResponse, + semaphore: semaphore, + onResponse: onResponse, + onBody: onBody, + onComplete: onComplete) } else { httpRequestOptions = try getHTTPRequestOptions( - method: method, - endpoint: endpoint, - path: path, - body: body, - response: &httpResponse, - semaphore: semaphore, - onResponse: onResponse, - onBody: onBody, - onComplete: onComplete) + method: method, + endpoint: endpoint, + path: path, + body: body, + response: &httpResponse, + semaphore: semaphore, + onResponse: onResponse, + onBody: onBody, + onComplete: onComplete) } - for i in 1...numRetries+1 where httpResponse.statusCode != expectedStatus { + for i in 1...numRetries + 1 where httpResponse.statusCode != expectedStatus { print("Attempt#\(i) to send an HTTP request") let connection = try await connectionManager.acquireConnection() XCTAssertTrue(connection.isOpen) @@ -68,7 +71,10 @@ class HTTPClientTestFixture: XCBaseTestCase { XCTAssertEqual(connection.httpVersion, expectedVersion) let stream = try connection.makeRequest(requestOptions: httpRequestOptions) try stream.activate() - semaphore.wait() + await semaphore.wait() + if httpResponse.statusCode != expectedStatus { + try? await Task.sleep(nanoseconds: 3_000_000_000) + } } XCTAssertNil(httpResponse.error) @@ -76,40 +82,42 @@ class HTTPClientTestFixture: XCBaseTestCase { return httpResponse } - func sendHTTP2Request(method: String, - path: String, - scheme: String = "https", - authority: String, - body: String = "", - expectedStatus: Int = 200, - streamManager: HTTP2StreamManager, - numRetries: UInt = 2, - http2ManualDataWrites: Bool = false, - onResponse: HTTPRequestOptions.OnResponse? = nil, - onBody: HTTPRequestOptions.OnIncomingBody? = nil, - onComplete: HTTPRequestOptions.OnStreamComplete? = nil) async throws -> HTTPResponse { + static func sendHTTP2Request( + method: String, + path: String, + scheme: String = "https", + authority: String, + body: String = "", + expectedStatus: Int = 200, + streamManager: HTTP2StreamManager, + numRetries: UInt = 2, + http2ManualDataWrites: Bool = false, + onResponse: HTTPRequestOptions.OnResponse? = nil, + onBody: HTTPRequestOptions.OnIncomingBody? = nil, + onComplete: HTTPRequestOptions.OnStreamComplete? = nil + ) async throws -> HTTPResponse { var httpResponse = HTTPResponse() - let semaphore = DispatchSemaphore(value: 0) + let semaphore = TestSemaphore(value: 0) let httpRequestOptions = try getHTTP2RequestOptions( - method: method, - path: path, - scheme: scheme, - authority: authority, - body: body, - response: &httpResponse, - semaphore: semaphore, - onResponse: onResponse, - onBody: onBody, - onComplete: onComplete, - http2ManualDataWrites: http2ManualDataWrites) + method: method, + path: path, + scheme: scheme, + authority: authority, + body: body, + response: &httpResponse, + semaphore: semaphore, + onResponse: onResponse, + onBody: onBody, + onComplete: onComplete, + http2ManualDataWrites: http2ManualDataWrites) - for i in 1...numRetries+1 where httpResponse.statusCode != expectedStatus { + for i in 1...numRetries + 1 where httpResponse.statusCode != expectedStatus { print("Attempt#\(i) to send an HTTP request") let stream = try await streamManager.acquireStream(requestOptions: httpRequestOptions) try stream.activate() - semaphore.wait() + await semaphore.wait() } XCTAssertNil(httpResponse.error) @@ -117,13 +125,15 @@ class HTTPClientTestFixture: XCBaseTestCase { return httpResponse } - func getHttpConnectionManager(endpoint: String, - ssh: Bool = true, - port: Int = 443, - alpnList: [String] = ["http/1.1"], - proxyOptions: HTTPProxyOptions? = nil, - monitoringOptions: HTTPMonitoringOptions? = nil, - socketOptions: SocketOptions = SocketOptions(socketType: .stream)) async throws -> HTTPClientConnectionManager { + static func getHttpConnectionManager( + endpoint: String, + ssh: Bool = true, + port: Int = 443, + alpnList: [String] = ["http/1.1"], + proxyOptions: HTTPProxyOptions? = nil, + monitoringOptions: HTTPMonitoringOptions? = nil, + socketOptions: SocketOptions = SocketOptions(socketType: .stream) + ) async throws -> HTTPClientConnectionManager { let tlsContextOptions = TLSContextOptions() tlsContextOptions.setAlpnList(alpnList) let tlsContext = try TLSContext(options: tlsContextOptions, mode: .client) @@ -134,105 +144,113 @@ class HTTPClientTestFixture: XCBaseTestCase { let hostResolver = try HostResolver(eventLoopGroup: elg, maxHosts: 8, maxTTL: 30) let bootstrap = try ClientBootstrap(eventLoopGroup: elg, hostResolver: hostResolver) - let httpClientOptions = HTTPClientConnectionOptions(clientBootstrap: bootstrap, - hostName: endpoint, - port: UInt32(port), - proxyOptions: proxyOptions, - socketOptions: socketOptions, - tlsOptions: ssh ? tlsConnectionOptions : nil, - monitoringOptions: monitoringOptions) + let httpClientOptions = HTTPClientConnectionOptions( + clientBootstrap: bootstrap, + hostName: endpoint, + port: UInt32(port), + proxyOptions: proxyOptions, + socketOptions: socketOptions, + tlsOptions: ssh ? tlsConnectionOptions : nil, + monitoringOptions: monitoringOptions) return try HTTPClientConnectionManager(options: httpClientOptions) } - func getRequestOptions(request: HTTPRequestBase, - response: UnsafeMutablePointer? = nil, - semaphore: DispatchSemaphore? = nil, - onResponse: HTTPRequestOptions.OnResponse? = nil, - onBody: HTTPRequestOptions.OnIncomingBody? = nil, - onComplete: HTTPRequestOptions.OnStreamComplete? = nil, - http2ManualDataWrites: Bool = false) -> HTTPRequestOptions { - HTTPRequestOptions(request: request, - onResponse: { status, headers in - response?.pointee.headers += headers - onResponse?(status, headers) - }, - - onIncomingBody: { bodyChunk in - response?.pointee.body += bodyChunk - onBody?(bodyChunk) - }, - onStreamComplete: { result in - switch result{ - case .success(let status): - response?.pointee.statusCode = Int(status) - case .failure(let error): - print("AWS_TEST_ERROR:\(String(describing: error))") - response?.pointee.error = error - } - onComplete?(result) - semaphore?.signal() - }, - http2ManualDataWrites: http2ManualDataWrites) - } + static func getRequestOptions( + request: HTTPRequestBase, + response: UnsafeMutablePointer? = nil, + semaphore: TestSemaphore? = nil, + onResponse: HTTPRequestOptions.OnResponse? = nil, + onBody: HTTPRequestOptions.OnIncomingBody? = nil, + onComplete: HTTPRequestOptions.OnStreamComplete? = nil, + http2ManualDataWrites: Bool = false + ) -> HTTPRequestOptions { + HTTPRequestOptions( + request: request, + onResponse: { status, headers in + response?.pointee.headers += headers + onResponse?(status, headers) + }, + onIncomingBody: { bodyChunk in + response?.pointee.body += bodyChunk + onBody?(bodyChunk) + }, + onStreamComplete: { result in + switch result { + case .success(let status): + response?.pointee.statusCode = Int(status) + case .failure(let error): + print("AWS_TEST_ERROR:\(String(describing: error))") + response?.pointee.error = error + } + onComplete?(result) + Task { await semaphore?.signal() } + }, + http2ManualDataWrites: http2ManualDataWrites) + } - func getHTTPRequestOptions(method: String, - endpoint: String, - path: String, - body: String = "", - response: UnsafeMutablePointer? = nil, - semaphore: DispatchSemaphore? = nil, - headers: [HTTPHeader] = [HTTPHeader](), - onResponse: HTTPRequestOptions.OnResponse? = nil, - onBody: HTTPRequestOptions.OnIncomingBody? = nil, - onComplete: HTTPRequestOptions.OnStreamComplete? = nil, - useChunkedEncoding: Bool = false + static func getHTTPRequestOptions( + method: String, + endpoint: String, + path: String, + body: String = "", + response: UnsafeMutablePointer? = nil, + semaphore: TestSemaphore? = nil, + headers: [HTTPHeader] = [HTTPHeader](), + onResponse: HTTPRequestOptions.OnResponse? = nil, + onBody: HTTPRequestOptions.OnIncomingBody? = nil, + onComplete: HTTPRequestOptions.OnStreamComplete? = nil, + useChunkedEncoding: Bool = false ) throws -> HTTPRequestOptions { - let httpRequest: HTTPRequest = try HTTPRequest(method: method, path: path, body: useChunkedEncoding ? nil : ByteBuffer(data: body.data(using: .utf8)!)) + let httpRequest: HTTPRequest = try HTTPRequest( + method: method, path: path, + body: useChunkedEncoding ? nil : ByteBuffer(data: body.data(using: .utf8)!)) httpRequest.addHeader(header: HTTPHeader(name: "Host", value: endpoint)) - if (useChunkedEncoding) { + if useChunkedEncoding { httpRequest.addHeader(header: HTTPHeader(name: "Transfer-Encoding", value: "chunked")) - } - else { - httpRequest.addHeader(header: HTTPHeader(name: "Content-Length", value: String(body.count))) + } else { + httpRequest.addHeader( + header: HTTPHeader(name: "Content-Length", value: String(body.count))) } httpRequest.addHeaders(headers: headers) return getRequestOptions( - request: httpRequest, - response: response, - semaphore: semaphore, - onResponse: onResponse, - onBody: onBody, - onComplete: onComplete) + request: httpRequest, + response: response, + semaphore: semaphore, + onResponse: onResponse, + onBody: onBody, + onComplete: onComplete) } - func getHTTP2RequestOptions(method: String, - path: String, - scheme: String = "https", - authority: String, - body: String = "", - manualDataWrites: Bool = false, - response: UnsafeMutablePointer? = nil, - semaphore: DispatchSemaphore? = nil, - onResponse: HTTPRequestOptions.OnResponse? = nil, - onBody: HTTPRequestOptions.OnIncomingBody? = nil, - onComplete: HTTPRequestOptions.OnStreamComplete? = nil, - http2ManualDataWrites: Bool = false) throws -> HTTPRequestOptions { + static func getHTTP2RequestOptions( + method: String, + path: String, + scheme: String = "https", + authority: String, + body: String = "", + manualDataWrites: Bool = false, + response: UnsafeMutablePointer? = nil, + semaphore: TestSemaphore? = nil, + onResponse: HTTPRequestOptions.OnResponse? = nil, + onBody: HTTPRequestOptions.OnIncomingBody? = nil, + onComplete: HTTPRequestOptions.OnStreamComplete? = nil, + http2ManualDataWrites: Bool = false + ) throws -> HTTPRequestOptions { let http2Request = try HTTP2Request(body: ByteBuffer(data: body.data(using: .utf8)!)) http2Request.addHeaders(headers: [ HTTPHeader(name: ":method", value: method), HTTPHeader(name: ":path", value: path), HTTPHeader(name: ":scheme", value: scheme), - HTTPHeader(name: ":authority", value: authority) + HTTPHeader(name: ":authority", value: authority), ]) return getRequestOptions( - request: http2Request, - response: response, - semaphore: semaphore, - onResponse: onResponse, - onBody: onBody, - onComplete: onComplete, - http2ManualDataWrites: http2ManualDataWrites) + request: http2Request, + response: response, + semaphore: semaphore, + onResponse: onResponse, + onBody: onBody, + onComplete: onComplete, + http2ManualDataWrites: http2ManualDataWrites) } } diff --git a/Test/AwsCommonRuntimeKitTests/http/HTTPProxyTests.swift b/Test/AwsCommonRuntimeKitTests/http/HTTPProxyTests.swift index 59f75a9a8..dfea036c0 100644 --- a/Test/AwsCommonRuntimeKitTests/http/HTTPProxyTests.swift +++ b/Test/AwsCommonRuntimeKitTests/http/HTTPProxyTests.swift @@ -6,7 +6,7 @@ import AwsCAuth import Foundation @testable import AwsCommonRuntimeKit -class HTTPProxyTests: HTTPClientTestFixture { +class HTTPProxyTests: XCBaseTestCase { let HTTPProxyHost = ProcessInfo.processInfo.environment["AWS_TEST_HTTP_PROXY_HOST"] let HTTPProxyPort = ProcessInfo.processInfo.environment["AWS_TEST_HTTP_PROXY_PORT"] @@ -185,13 +185,13 @@ class HTTPProxyTests: HTTPClientTestFixture { let uri = getURIFromTestType(type: type) let port = getPortFromTestType(type: type) let proxyOptions = try getProxyOptions(type: type, authType: authType) - let manager = try await getHttpConnectionManager( + let manager = try await HTTPClientTestFixture.getHttpConnectionManager( endpoint: uri, ssh: getSSH(type: type), port: port, alpnList: ["http/1.1"], proxyOptions: proxyOptions) - _ = try await sendHTTPRequest(method: "GET", endpoint: uri, connectionManager: manager) + _ = try await HTTPClientTestFixture.sendHTTPRequest(method: "GET", endpoint: uri, connectionManager: manager) } } diff --git a/Test/AwsCommonRuntimeKitTests/http/HTTPTests.swift b/Test/AwsCommonRuntimeKitTests/http/HTTPTests.swift index c7c6ef796..09e9eab37 100644 --- a/Test/AwsCommonRuntimeKitTests/http/HTTPTests.swift +++ b/Test/AwsCommonRuntimeKitTests/http/HTTPTests.swift @@ -6,35 +6,35 @@ import XCTest import AwsCCommon import AwsCHttp -class HTTPTests: HTTPClientTestFixture { +class HTTPTests: XCBaseTestCase { let host = "postman-echo.com" let getPath = "/get" func testGetHTTPSRequest() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) - _ = try await sendHTTPRequest(method: "GET", endpoint: host, path: getPath, connectionManager: connectionManager) - _ = try await sendHTTPRequest(method: "GET", endpoint: host, path: "/delete", expectedStatus: 404, connectionManager: connectionManager) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + _ = try await HTTPClientTestFixture.sendHTTPRequest(method: "GET", endpoint: host, path: getPath, connectionManager: connectionManager) + _ = try await HTTPClientTestFixture.sendHTTPRequest(method: "GET", endpoint: host, path: "/delete", expectedStatus: 404, connectionManager: connectionManager) } func testGetHTTPSRequestWithUtf8Header() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) let utf8Header = HTTPHeader(name: "TestHeader", value: "TestValueWithEmoji🤯") - let headers = try await sendHTTPRequest(method: "GET", endpoint: host, path: "/response-headers?\(utf8Header.name)=\(utf8Header.value)", connectionManager: connectionManager).headers + let headers = try await HTTPClientTestFixture.sendHTTPRequest(method: "GET", endpoint: host, path: "/response-headers?\(utf8Header.name)=\(utf8Header.value)", connectionManager: connectionManager).headers XCTAssertTrue(headers.contains(where: {$0.name == utf8Header.name && $0.value==utf8Header.value})) } func testGetHTTPRequest() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: false, port: 80) - _ = try await sendHTTPRequest(method: "GET", endpoint: host, path: getPath, connectionManager: connectionManager) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: false, port: 80) + _ = try await HTTPClientTestFixture.sendHTTPRequest(method: "GET", endpoint: host, path: getPath, connectionManager: connectionManager) } func testPutHTTPRequest() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) - let response = try await sendHTTPRequest( + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let response = try await HTTPClientTestFixture.sendHTTPRequest( method: "PUT", endpoint: host, path: "/put", - body: TEST_DOC_LINE, + body: HTTPClientTestFixture.TEST_DOC_LINE, connectionManager: connectionManager) // Parse json body @@ -42,15 +42,15 @@ class HTTPTests: HTTPClientTestFixture { let data: String } let body: Response = try! JSONDecoder().decode(Response.self, from: response.body) - XCTAssertEqual(body.data, TEST_DOC_LINE) + XCTAssertEqual(body.data, HTTPClientTestFixture.TEST_DOC_LINE) } func testHTTPChunkTransferEncoding() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: host, alpnList: ["http/1.1"]) - let semaphore = DispatchSemaphore(value: 0) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, alpnList: ["http/1.1"]) + let semaphore = TestSemaphore(value: 0) var httpResponse = HTTPResponse() var onCompleteCalled = false - let httpRequestOptions = try getHTTPRequestOptions( + let httpRequestOptions = try HTTPClientTestFixture.getHTTPRequestOptions( method: "PUT", endpoint: host, path: "/put", @@ -67,7 +67,7 @@ class HTTPTests: HTTPClientTestFixture { let metrics = connectionManager.fetchMetrics() XCTAssertTrue(metrics.leasedConcurrency > 0) - let data = TEST_DOC_LINE.data(using: .utf8)! + let data = HTTPClientTestFixture.TEST_DOC_LINE.data(using: .utf8)! for chunk in data.chunked(into: 5) { try await streamBase.writeChunk(chunk: chunk, endOfStream: false) XCTAssertFalse(onCompleteCalled) @@ -78,7 +78,7 @@ class HTTPTests: HTTPClientTestFixture { try await Task.sleep(nanoseconds: 5_000_000_000) XCTAssertFalse(onCompleteCalled) try await streamBase.writeChunk(chunk: Data(), endOfStream: true) - semaphore.wait() + await semaphore.wait() XCTAssertTrue(onCompleteCalled) XCTAssertNil(httpResponse.error) XCTAssertEqual(httpResponse.statusCode, 200) @@ -89,15 +89,15 @@ class HTTPTests: HTTPClientTestFixture { } let body: Response = try! JSONDecoder().decode(Response.self, from: httpResponse.body) - XCTAssertEqual(body.data, TEST_DOC_LINE) + XCTAssertEqual(body.data, HTTPClientTestFixture.TEST_DOC_LINE) } func testHTTPChunkTransferEncodingWithDataInLastChunk() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: host, alpnList: ["http/1.1"]) - let semaphore = DispatchSemaphore(value: 0) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, alpnList: ["http/1.1"]) + let semaphore = TestSemaphore(value: 0) var httpResponse = HTTPResponse() var onCompleteCalled = false - let httpRequestOptions = try getHTTPRequestOptions( + let httpRequestOptions = try HTTPClientTestFixture.getHTTPRequestOptions( method: "PUT", endpoint: host, path: "/put", @@ -111,7 +111,7 @@ class HTTPTests: HTTPClientTestFixture { let streamBase = try connection.makeRequest(requestOptions: httpRequestOptions) try streamBase.activate() XCTAssertFalse(onCompleteCalled) - let data = TEST_DOC_LINE.data(using: .utf8)! + let data = HTTPClientTestFixture.TEST_DOC_LINE.data(using: .utf8)! for chunk in data.chunked(into: 5) { try await streamBase.writeChunk(chunk: chunk, endOfStream: false) XCTAssertFalse(onCompleteCalled) @@ -124,7 +124,7 @@ class HTTPTests: HTTPClientTestFixture { let lastChunkData = Data("last chunk data".utf8) try await streamBase.writeChunk(chunk: lastChunkData, endOfStream: true) - semaphore.wait() + await semaphore.wait() XCTAssertTrue(onCompleteCalled) XCTAssertNil(httpResponse.error) XCTAssertEqual(httpResponse.statusCode, 200) @@ -135,14 +135,14 @@ class HTTPTests: HTTPClientTestFixture { } let body: Response = try! JSONDecoder().decode(Response.self, from: httpResponse.body) - XCTAssertEqual(body.data, TEST_DOC_LINE + String(decoding: lastChunkData, as: UTF8.self)) + XCTAssertEqual(body.data, HTTPClientTestFixture.TEST_DOC_LINE + String(decoding: lastChunkData, as: UTF8.self)) } func testHTTPStreamIsReleasedIfNotActivated() async throws { do { - let httpRequestOptions = try getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath) - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let httpRequestOptions = try HTTPClientTestFixture.getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) let connection = try await connectionManager.acquireConnection() _ = try connection.makeRequest(requestOptions: httpRequestOptions) } catch let err { @@ -151,67 +151,67 @@ class HTTPTests: HTTPClientTestFixture { } func testStreamLivesUntilComplete() async throws { - let semaphore = DispatchSemaphore(value: 0) - + let semaphore = TestSemaphore(value: 0) do { - let httpRequestOptions = try getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath, semaphore: semaphore) - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let httpRequestOptions = try HTTPClientTestFixture.getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath, semaphore: semaphore) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) let connection = try await connectionManager.acquireConnection() let stream = try connection.makeRequest(requestOptions: httpRequestOptions) try stream.activate() } - semaphore.wait() + await semaphore.wait() } func testManagerLivesUntilComplete() async throws { var connection: HTTPClientConnection! = nil - let semaphore = DispatchSemaphore(value: 0) + let semaphore = TestSemaphore(value: 0) do { - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) connection = try await connectionManager.acquireConnection() } - let httpRequestOptions = try getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath, semaphore: semaphore) + let httpRequestOptions = try HTTPClientTestFixture.getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath, semaphore: semaphore) let stream = try connection.makeRequest(requestOptions: httpRequestOptions) try stream.activate() - semaphore.wait() + await semaphore.wait() } func testConnectionLivesUntilComplete() async throws { var stream: HTTPStream! = nil - let semaphore = DispatchSemaphore(value: 0) + let semaphore = TestSemaphore(value: 0) + do { - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) let connection = try await connectionManager.acquireConnection() - let httpRequestOptions = try getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath, semaphore: semaphore) + let httpRequestOptions = try HTTPClientTestFixture.getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath, semaphore: semaphore) stream = try connection.makeRequest(requestOptions: httpRequestOptions) } try stream.activate() - semaphore.wait() + await semaphore.wait() } func testConnectionCloseThrow() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) let connection = try await connectionManager.acquireConnection() connection.close() - let httpRequestOptions = try getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath) + let httpRequestOptions = try HTTPClientTestFixture.getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath) XCTAssertThrowsError( _ = try connection.makeRequest(requestOptions: httpRequestOptions)) } func testConnectionCloseActivateThrow() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) let connection = try await connectionManager.acquireConnection() - let httpRequestOptions = try getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath) + let httpRequestOptions = try HTTPClientTestFixture.getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath) let stream = try connection.makeRequest(requestOptions: httpRequestOptions) connection.close() XCTAssertThrowsError(try stream.activate()) } func testConnectionCloseIsIdempotent() async throws { - let connectionManager = try await getHttpConnectionManager(endpoint: host, ssh: true, port: 443) + let connectionManager = try await HTTPClientTestFixture.getHttpConnectionManager(endpoint: host, ssh: true, port: 443) let connection = try await connectionManager.acquireConnection() - let httpRequestOptions = try getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath) + let httpRequestOptions = try HTTPClientTestFixture.getHTTPRequestOptions(method: "GET", endpoint: host, path: getPath) let stream = try connection.makeRequest(requestOptions: httpRequestOptions) connection.close() connection.close() diff --git a/Test/AwsCommonRuntimeKitTests/io/BootstrapTests.swift b/Test/AwsCommonRuntimeKitTests/io/BootstrapTests.swift index 3e9d22d01..9bc5f0db8 100644 --- a/Test/AwsCommonRuntimeKitTests/io/BootstrapTests.swift +++ b/Test/AwsCommonRuntimeKitTests/io/BootstrapTests.swift @@ -33,6 +33,6 @@ class BootstrapTests: XCBaseTestCase { hostResolver: resolver, shutdownCallback: shutdownCallback) } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } } diff --git a/Test/AwsCommonRuntimeKitTests/io/EventLoopGroupTests.swift b/Test/AwsCommonRuntimeKitTests/io/EventLoopGroupTests.swift index c7d4cbcbd..55ba8905e 100644 --- a/Test/AwsCommonRuntimeKitTests/io/EventLoopGroupTests.swift +++ b/Test/AwsCommonRuntimeKitTests/io/EventLoopGroupTests.swift @@ -11,7 +11,7 @@ class EventLoopGroupTests: XCBaseTestCase { _ = try EventLoopGroup() { shutdownWasCalled.fulfill() } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } func testCanCreateGroupWithThreads() throws { diff --git a/Test/AwsCommonRuntimeKitTests/io/HostResolverTests.swift b/Test/AwsCommonRuntimeKitTests/io/HostResolverTests.swift index 9ba7482dd..fe97bd88d 100644 --- a/Test/AwsCommonRuntimeKitTests/io/HostResolverTests.swift +++ b/Test/AwsCommonRuntimeKitTests/io/HostResolverTests.swift @@ -4,13 +4,13 @@ import XCTest @testable import AwsCommonRuntimeKit class HostResolverTests: XCBaseTestCase { - + func testCanResolveHosts() async throws { let elg = try EventLoopGroup() let resolver = try HostResolver(eventLoopGroup: elg, maxHosts: 8, maxTTL: 5) - + let addresses = try await resolver.resolveAddress(args: HostResolverArguments(hostName: "localhost")) XCTAssertNoThrow(addresses) XCTAssertNotNil(addresses.count) @@ -61,6 +61,6 @@ class HostResolverTests: XCBaseTestCase { maxTTL: 5, shutdownCallback: shutdownCallback) } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } } diff --git a/Test/AwsCommonRuntimeKitTests/io/RetryerTests.swift b/Test/AwsCommonRuntimeKitTests/io/RetryerTests.swift index 13218b0a9..351143fe1 100644 --- a/Test/AwsCommonRuntimeKitTests/io/RetryerTests.swift +++ b/Test/AwsCommonRuntimeKitTests/io/RetryerTests.swift @@ -4,12 +4,12 @@ import XCTest @testable import AwsCommonRuntimeKit class RetryerTests: XCBaseTestCase { let expectation = XCTestExpectation(description: "Credentials callback was called") - + func testCreateAWSRetryer() throws { let elg = try EventLoopGroup(threadCount: 1) _ = try RetryStrategy(eventLoopGroup: elg) } - + func testAcquireToken() async throws { let elg = try EventLoopGroup(threadCount: 1) let retryer = try RetryStrategy(eventLoopGroup: elg) @@ -37,7 +37,7 @@ class RetryerTests: XCBaseTestCase { XCTAssertNotNil(token) _ = try await retryer.scheduleRetry(token: token, errorType: RetryError.serverError) } - wait(for: [shutdownWasCalled], timeout: 15) + await awaitExpectation([shutdownWasCalled]) } func testGenerateRandom() async throws { @@ -59,6 +59,6 @@ class RetryerTests: XCBaseTestCase { XCTAssertNotNil(token) _ = try await retryer.scheduleRetry(token: token, errorType: RetryError.serverError) } - wait(for: [generateRandomWasCalled, shutdownWasCalled], timeout: 15) + await awaitExpectation([generateRandomWasCalled, shutdownWasCalled]) } } diff --git a/Test/AwsCommonRuntimeKitTests/mqtt/Mqtt5ClientTests.swift b/Test/AwsCommonRuntimeKitTests/mqtt/Mqtt5ClientTests.swift index 927b0fb80..8c7c146b9 100644 --- a/Test/AwsCommonRuntimeKitTests/mqtt/Mqtt5ClientTests.swift +++ b/Test/AwsCommonRuntimeKitTests/mqtt/Mqtt5ClientTests.swift @@ -13,7 +13,7 @@ enum MqttTestError: Error { case stopFail } -class Mqtt5ClientTests: XCBaseTestCase { +class Mqtt5ClientTests: XCBaseTestCase, @unchecked Sendable { let credentialProviderShutdownWasCalled = XCTestExpectation(description: "Shutdown callback was called") @@ -25,47 +25,30 @@ class Mqtt5ClientTests: XCBaseTestCase { } /// start client and check for connection success - func connectClient(client: Mqtt5Client, testContext: MqttTestContext) throws -> Void { + func connectClient(client: Mqtt5Client, testContext: MqttTestContext) async throws{ try client.start() - if testContext.semaphoreConnectionSuccess.wait(timeout: .now() + 5) == .timedOut { - print("Connection Success Timed out after 5 seconds") - XCTFail("Connection Timed Out") - throw MqttTestError.connectionFail - } + await awaitExpectation([testContext.connectionSuccessExpectation], 5) + } /// stop client and check for discconnection and stopped lifecycle events - func disconnectClientCleanup(client: Mqtt5Client, testContext: MqttTestContext, disconnectPacket: DisconnectPacket? = nil) throws -> Void { + func disconnectClientCleanup(client: Mqtt5Client, testContext: MqttTestContext, disconnectPacket: DisconnectPacket? = nil) async throws -> Void { try client.stop(disconnectPacket: disconnectPacket) - - if testContext.semaphoreDisconnection.wait(timeout: .now() + 5) == .timedOut { - print("Disconnection timed out after 5 seconds") - XCTFail("Disconnection timed out") - throw MqttTestError.disconnectFail - } - - if testContext.semaphoreStopped.wait(timeout: .now() + 5) == .timedOut { - print("Stop timed out after 5 seconds") - XCTFail("Stop timed out") - throw MqttTestError.stopFail - } + await awaitExpectation([testContext.disconnectionExpectation], 5) + await awaitExpectation([testContext.stoppedExpecation], 5) } /// stop client and check for stopped lifecycle event - func stopClient(client: Mqtt5Client, testContext: MqttTestContext) throws -> Void { + func stopClient(client: Mqtt5Client, testContext: MqttTestContext) async throws -> Void { try client.stop() - if testContext.semaphoreStopped.wait(timeout: .now() + 5) == .timedOut { - print("Stop timed out after 5 seconds") - XCTFail("Stop timed out") - throw MqttTestError.stopFail - } + return await awaitExpectation([testContext.stoppedExpecation], 5) } func createClientId() -> String { return "test-aws-crt-swift-unit-" + UUID().uuidString } - class MqttTestContext { + class MqttTestContext : @unchecked Sendable{ public var contextName: String public var onPublishReceived: OnPublishReceived? @@ -76,12 +59,12 @@ class Mqtt5ClientTests: XCBaseTestCase { public var onLifecycleEventDisconnection: OnLifecycleEventDisconnection? public var onWebSocketHandshake: OnWebSocketHandshakeIntercept? - public let semaphorePublishReceived: DispatchSemaphore - public let semaphorePublishTargetReached: DispatchSemaphore - public let semaphoreConnectionSuccess: DispatchSemaphore - public let semaphoreConnectionFailure: DispatchSemaphore - public let semaphoreDisconnection: DispatchSemaphore - public let semaphoreStopped: DispatchSemaphore + public let publishReceivedExpectation: XCTestExpectation + public let publishTargetReachedExpectation: XCTestExpectation + public let connectionSuccessExpectation: XCTestExpectation + public let connectionFailureExpectation: XCTestExpectation + public let disconnectionExpectation: XCTestExpectation + public let stoppedExpecation: XCTestExpectation public var negotiatedSettings: NegotiatedSettings? public var connackPacket: ConnackPacket? @@ -105,12 +88,13 @@ class Mqtt5ClientTests: XCBaseTestCase { self.publishTarget = publishTarget self.publishCount = 0 - self.semaphorePublishReceived = DispatchSemaphore(value: 0) - self.semaphorePublishTargetReached = DispatchSemaphore(value: 0) - self.semaphoreConnectionSuccess = DispatchSemaphore(value: 0) - self.semaphoreConnectionFailure = DispatchSemaphore(value: 0) - self.semaphoreDisconnection = DispatchSemaphore(value: 0) - self.semaphoreStopped = DispatchSemaphore(value: 0) + + self.publishReceivedExpectation = XCTestExpectation(description: "Expect publish received.") + self.publishTargetReachedExpectation = XCTestExpectation(description: "Expect publish target reached") + self.connectionSuccessExpectation = XCTestExpectation(description: "Expect connection Success") + self.connectionFailureExpectation = XCTestExpectation(description: "Expect connection Failure") + self.disconnectionExpectation = XCTestExpectation(description: "Expect disconnect") + self.stoppedExpecation = XCTestExpectation(description: "Expect stopped") self.onPublishReceived = onPublishReceived self.onLifecycleEventStopped = onLifecycleEventStopped @@ -126,16 +110,16 @@ class Mqtt5ClientTests: XCBaseTestCase { print(contextName + " Mqtt5ClientTests: onPublishReceived. Topic:\'\(publishData.publishPacket.topic)\' QoS:\(publishData.publishPacket.qos)") } self.publishPacket = publishData.publishPacket - self.semaphorePublishReceived.signal() + self.publishReceivedExpectation.fulfill() self.publishCount += 1 if self.publishCount == self.publishTarget { - self.semaphorePublishTargetReached.signal() + self.publishTargetReachedExpectation.fulfill() } } self.onLifecycleEventStopped = onLifecycleEventStopped ?? { _ in print(contextName + " Mqtt5ClientTests: onLifecycleEventStopped") - self.semaphoreStopped.signal() + self.stoppedExpecation.fulfill() } self.onLifecycleEventAttemptingConnect = onLifecycleEventAttemptingConnect ?? { _ in print(contextName + " Mqtt5ClientTests: onLifecycleEventAttemptingConnect") @@ -144,17 +128,17 @@ class Mqtt5ClientTests: XCBaseTestCase { print(contextName + " Mqtt5ClientTests: onLifecycleEventConnectionSuccess") self.negotiatedSettings = successData.negotiatedSettings self.connackPacket = successData.connackPacket - self.semaphoreConnectionSuccess.signal() + self.connectionSuccessExpectation.fulfill() } self.onLifecycleEventConnectionFailure = onLifecycleEventConnectionFailure ?? { failureData in print(contextName + " Mqtt5ClientTests: onLifecycleEventConnectionFailure") self.lifecycleConnectionFailureData = failureData - self.semaphoreConnectionFailure.signal() + self.connectionFailureExpectation.fulfill() } self.onLifecycleEventDisconnection = onLifecycleEventDisconnection ?? { disconnectionData in print(contextName + " Mqtt5ClientTests: onLifecycleEventDisconnection") self.lifecycleDisconnectionData = disconnectionData - self.semaphoreDisconnection.signal() + self.disconnectionExpectation.fulfill() } } @@ -235,9 +219,9 @@ class Mqtt5ClientTests: XCBaseTestCase { } } - func withTimeout(client: Mqtt5Client, seconds: TimeInterval, operation: @escaping () async throws -> T) async throws -> T { + func withTimeout(client: Mqtt5Client, seconds: TimeInterval, operation: @escaping @Sendable () async throws -> T) async throws -> T { - let timeoutTask: () async throws -> T = { + let timeoutTask: @Sendable () async throws -> T = { try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000)) throw MqttTestError.timeout } @@ -353,7 +337,7 @@ class Mqtt5ClientTests: XCBaseTestCase { /* * [ConnDC-UC1] Happy path */ - func testMqtt5DirectConnectMinimum() throws { + func testMqtt5DirectConnectMinimum() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_PORT") @@ -363,14 +347,14 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /* * [ConnDC-UC2] Direct Connection with Basic Authentication */ - func testMqtt5DirectConnectWithBasicAuth() throws { + func testMqtt5DirectConnectWithBasicAuth() async throws { let inputUsername = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_BASIC_AUTH_USERNAME") let inputPassword = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_BASIC_AUTH_PASSWORD") @@ -390,14 +374,14 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /* * [ConnDC-UC3] Direct Connection with TLS */ - func testMqtt5DirectConnectWithTLS() throws { + func testMqtt5DirectConnectWithTLS() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_TLS_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_TLS_PORT") @@ -413,40 +397,41 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /* * [ConnDC-UC4] Direct Connection with mutual TLS */ - func testMqtt5DirectConnectWithMutualTLS() throws { + func testMqtt5DirectConnectWithMutualTLS() async throws { try skipIfPlatformDoesntSupportTLS() let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_IOT_CORE_HOST") let inputCert = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_IOT_CORE_RSA_CERT") let inputKey = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_IOT_CORE_RSA_KEY") - + let tlsOptions = try TLSContextOptions.makeMTLS( certificatePath: inputCert, privateKeyPath: inputKey ) let tlsContext = try TLSContext(options: tlsOptions, mode: .client) - + let clientOptions = MqttClientOptions( hostName: inputHost, port: UInt32(8883), tlsCtx: tlsContext) - + let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client: client, testContext: testContext) + } /* * [ConnDC-UC5] Direct Connection with HttpProxy options and TLS */ - func testMqtt5DirectConnectWithHttpProxy() throws { + func testMqtt5DirectConnectWithHttpProxy() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_TLS_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_TLS_PORT") @@ -470,14 +455,14 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /* * [ConnDC-UC6] Direct Connection with all options set */ - func testMqtt5DirectConnectMaximum() throws { + func testMqtt5DirectConnectMaximum() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_PORT") @@ -527,8 +512,8 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /*=============================================================== @@ -537,7 +522,7 @@ class Mqtt5ClientTests: XCBaseTestCase { /* * [ConnWS-UC1] Happy path. Websocket connection with minimal configuration. */ - func testMqtt5WSConnectionMinimal() throws + func testMqtt5WSConnectionMinimal() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_WS_MQTT_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_WS_MQTT_PORT") @@ -549,15 +534,15 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() testContext.withWebsocketTransform(isSuccess: true) let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /* * [ConnWS-UC2] websocket connection with basic authentication */ - func testMqtt5WSConnectWithBasicAuth() throws { + func testMqtt5WSConnectWithBasicAuth() async throws { let inputUsername = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_BASIC_AUTH_USERNAME") let inputPassword = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_BASIC_AUTH_PASSWORD") @@ -578,15 +563,15 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() testContext.withWebsocketTransform(isSuccess: true) let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /* * [ConnWS-UC3] websocket connection with TLS */ - func testMqtt5WSConnectWithTLS() throws { + func testMqtt5WSConnectWithTLS() async throws { try skipIfPlatformDoesntSupportTLS() let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_WS_MQTT_TLS_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_WS_MQTT_TLS_PORT") @@ -605,14 +590,14 @@ class Mqtt5ClientTests: XCBaseTestCase { testContext.withWebsocketTransform(isSuccess: true) let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /* * [ConnWS-UC4] websocket connection with TLS, using sigv4 */ - func testMqtt5WSConnectWithStaticCredentialProvider() throws { + func testMqtt5WSConnectWithStaticCredentialProvider() async throws { do{ try skipIfPlatformDoesntSupportTLS() @@ -642,7 +627,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let signingConfig = SigningConfig(algorithm: SigningAlgorithmType.signingV4, signatureType: SignatureType.requestQueryParams, service: "iotdevicegateway", - region: "us-east-1", + region: region, credentialsProvider: provider, omitSessionToken: true) @@ -677,8 +662,8 @@ class Mqtt5ClientTests: XCBaseTestCase { let client = try Mqtt5Client(clientOptions: clientOptions) XCTAssertNotNil(client) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) // Clean up the WebSocket handshake function to ensure the test context is properly released testContext.onWebSocketHandshake=nil } @@ -686,13 +671,13 @@ class Mqtt5ClientTests: XCBaseTestCase { // Fulfill the callback if the error self.credentialProviderShutdownWasCalled.fulfill() } - wait(for: [credentialProviderShutdownWasCalled], timeout: 15); + await awaitExpectation([credentialProviderShutdownWasCalled]) } /* * [ConnWS-UC5] Websocket connection with HttpProxy options */ - func testMqtt5WSConnectWithHttpProxy() throws { + func testMqtt5WSConnectWithHttpProxy() async throws { try skipIfPlatformDoesntSupportTLS() try skipifmacOS() @@ -721,7 +706,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let signingConfig = SigningConfig(algorithm: SigningAlgorithmType.signingV4, signatureType: SignatureType.requestQueryParams, service: "iotdevicegateway", - region: "us-east-1", + region: region, credentialsProvider: provider, omitSessionToken: true) @@ -758,11 +743,11 @@ class Mqtt5ClientTests: XCBaseTestCase { let client = try Mqtt5Client(clientOptions: clientOptions) XCTAssertNotNil(client) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } - func testMqtt5WSConnectFull() throws { + func testMqtt5WSConnectFull() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_WS_MQTT_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_WS_MQTT_PORT") @@ -812,8 +797,8 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() testContext.withWebsocketTransform(isSuccess: true) let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client:client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } func testMqttWebsocketWithCognitoCredentialProvider() async throws { @@ -878,8 +863,8 @@ class Mqtt5ClientTests: XCBaseTestCase { let client = try Mqtt5Client(clientOptions: clientOptions) XCTAssertNotNil(client) - try connectClient(client: client, testContext: testContext) - try disconnectClientCleanup(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client: client, testContext: testContext) // Clean up the WebSocket handshake function to ensure the test context is properly released testContext.onWebSocketHandshake=nil } @@ -888,7 +873,7 @@ class Mqtt5ClientTests: XCBaseTestCase { print("catch error and fulfill the shutdown callback") self.credentialProviderShutdownWasCalled.fulfill() } - wait(for: [credentialProviderShutdownWasCalled], timeout: 15); + await awaitExpectation([credentialProviderShutdownWasCalled], 15); } @@ -899,7 +884,7 @@ class Mqtt5ClientTests: XCBaseTestCase { /* * [ConnNegativeID-UC1] Client connect with invalid host name */ - func testMqtt5DirectConnectWithInvalidHost() throws { + func testMqtt5DirectConnectWithInvalidHost() async throws { let clientOptions = MqttClientOptions( hostName: "badhost", @@ -910,11 +895,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let client = try createClient(clientOptions: clientOptions, testContext: testContext) try client.start() - if testContext.semaphoreConnectionFailure.wait(timeout: .now() + 5) == .timedOut { - print("Connection Failure Timed out after 5 seconds") - XCTFail("Connection Timed Out") - return - } + await awaitExpectation([testContext.connectionFailureExpectation], 5) if let failureData = testContext.lifecycleConnectionFailureData { XCTAssertEqual(failureData.crtError.code, Int32(AWS_IO_DNS_INVALID_NAME.rawValue)) @@ -923,13 +904,13 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try stopClient(client: client, testContext: testContext) + try await stopClient(client: client, testContext: testContext) } /* * [ConnNegativeID-UC2] Client connect with invalid port for direct connection */ - func testMqtt5DirectConnectWithInvalidPort() throws { + func testMqtt5DirectConnectWithInvalidPort() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_HOST") let clientOptions = MqttClientOptions( @@ -941,11 +922,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let client = try createClient(clientOptions: clientOptions, testContext: testContext) try client.start() - if testContext.semaphoreConnectionFailure.wait(timeout: .now() + 5) == .timedOut { - print("Connection Failure Timed out after 5 seconds") - XCTFail("Connection Timed Out") - return - } + await awaitExpectation([testContext.connectionFailureExpectation], 5) if let failureData = testContext.lifecycleConnectionFailureData { if failureData.crtError.code != Int32(AWS_IO_SOCKET_CONNECTION_REFUSED.rawValue) && @@ -958,13 +935,13 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try stopClient(client: client, testContext: testContext) + try await stopClient(client: client, testContext: testContext) } /* * [ConnNegativeID-UC3] Client connect with invalid port for websocket connection */ - func testMqtt5WSInvalidPort() throws { + func testMqtt5WSInvalidPort() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_HOST") let clientOptions = MqttClientOptions( @@ -977,11 +954,7 @@ class Mqtt5ClientTests: XCBaseTestCase { try client.start() - if testContext.semaphoreConnectionFailure.wait(timeout: .now() + 5) == .timedOut { - print("Connection Failure Timed out after 5 seconds") - XCTFail("Connection Timed Out") - return - } + await awaitExpectation([testContext.connectionFailureExpectation], 5) if let failureData = testContext.lifecycleConnectionFailureData { if failureData.crtError.code != Int32(AWS_IO_SOCKET_CONNECTION_REFUSED.rawValue) && @@ -994,13 +967,13 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try stopClient(client: client, testContext: testContext) + try await stopClient(client: client, testContext: testContext) } /* * [ConnNegativeID-UC4] Client connect with socket timeout */ - func testMqtt5DirectConnectWithSocketTimeout() throws { + func testMqtt5DirectConnectWithSocketTimeout() async throws { let clientOptions = MqttClientOptions( hostName: "www.example.com", port: UInt32(81)) @@ -1010,11 +983,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let client = try createClient(clientOptions: clientOptions, testContext: testContext) try client.start() - if testContext.semaphoreConnectionFailure.wait(timeout: .now() + 5) == .timedOut { - print("Connection Failure Timed out after 5 seconds") - XCTFail("Connection Timed Out") - return - } + await awaitExpectation([testContext.connectionFailureExpectation], 5) if let failureData = testContext.lifecycleConnectionFailureData { XCTAssertEqual(failureData.crtError.code, Int32(AWS_IO_SOCKET_TIMEOUT.rawValue)) @@ -1023,13 +992,13 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try stopClient(client: client, testContext: testContext) + try await stopClient(client: client, testContext: testContext) } /* * [ConnNegativeID-UC5] Client connect with incorrect basic authentication credentials */ - func testMqtt5DirectConnectWithIncorrectBasicAuthenticationCredentials() throws { + func testMqtt5DirectConnectWithIncorrectBasicAuthenticationCredentials() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_BASIC_AUTH_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_BASIC_AUTH_PORT") @@ -1042,11 +1011,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let client = try createClient(clientOptions: clientOptions, testContext: testContext) try client.start() - if testContext.semaphoreConnectionFailure.wait(timeout: .now() + 5) == .timedOut { - print("Connection Failure Timed out after 5 seconds") - XCTFail("Connection Timed Out") - return - } + await awaitExpectation([testContext.connectionFailureExpectation], 5) if let failureData = testContext.lifecycleConnectionFailureData { XCTAssertEqual(failureData.crtError.code, Int32(AWS_ERROR_MQTT5_CONNACK_CONNECTION_REFUSED.rawValue)) @@ -1055,13 +1020,13 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try stopClient(client: client, testContext: testContext) + try await stopClient(client: client, testContext: testContext) } /* * [ConnNegativeID-UC6] Client Websocket Handshake Failure test */ - func testMqtt5WSHandshakeFailure() throws { + func testMqtt5WSHandshakeFailure() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_WS_MQTT_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_WS_MQTT_PORT") @@ -1075,11 +1040,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let client = try createClient(clientOptions: clientOptions, testContext: testContext) try client.start() - if testContext.semaphoreConnectionFailure.wait(timeout: .now() + 5) == .timedOut { - print("Connection Failure Timed out after 5 seconds") - XCTFail("Connection Timed Out") - return - } + await awaitExpectation([testContext.connectionFailureExpectation], 5) if let failureData = testContext.lifecycleConnectionFailureData { if failureData.crtError.code != Int32(AWS_ERROR_UNSUPPORTED_OPERATION.rawValue) { @@ -1091,14 +1052,14 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try stopClient(client: client, testContext: testContext) + try await stopClient(client: client, testContext: testContext) } /* * [ConnNegativeID-UC7] Double Client ID Failure test */ - func testMqtt5MTLSConnectDoubleClientIdFailure() throws { + func testMqtt5MTLSConnectDoubleClientIdFailure() async throws { try skipIfPlatformDoesntSupportTLS() let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_IOT_CORE_HOST") let inputCert = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_IOT_CORE_RSA_CERT") @@ -1123,20 +1084,16 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext(contextName: "client1") let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) // Create a second client with the same client id let testContext2 = MqttTestContext(contextName: "client2") let client2 = try createClient(clientOptions: clientOptions, testContext: testContext2) // Connect with second client - try connectClient(client: client2, testContext: testContext2) + try await connectClient(client: client2, testContext: testContext2) - if testContext.semaphoreDisconnection.wait(timeout: .now() + 5) == .timedOut { - print("Disconnection due to duplicate client id timed out on client1") - XCTFail("Disconnection Timed Out") - return - } + await awaitExpectation([testContext.disconnectionExpectation], 5) if let disconnectionData = testContext.lifecycleDisconnectionData { print(disconnectionData.crtError) @@ -1151,8 +1108,8 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try stopClient(client: client, testContext: testContext) - try disconnectClientCleanup(client: client2, testContext: testContext2) + try await stopClient(client: client, testContext: testContext) + try await disconnectClientCleanup(client: client2, testContext: testContext2) } /*=============================================================== @@ -1297,7 +1254,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) let disconnectPacket = DisconnectPacket(sessionExpiryInterval: -1) do { @@ -1332,7 +1289,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) let publishPacket = PublishPacket(qos: .atMostOnce, topic: "Test/Topic", @@ -1358,7 +1315,7 @@ class Mqtt5ClientTests: XCBaseTestCase { /* * [Negotiated-UC1] Happy path, minimal success test */ - func testMqtt5NegotiatedSettingsMinimalSettings() throws { + func testMqtt5NegotiatedSettingsMinimalSettings() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_PORT") @@ -1373,7 +1330,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) if let negotiatedSettings = testContext.negotiatedSettings { XCTAssertEqual(negotiatedSettings.sessionExpiryInterval, sessionExpirtyInterval) @@ -1382,13 +1339,13 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try disconnectClientCleanup(client: client, testContext: testContext) + try await disconnectClientCleanup(client: client, testContext: testContext) } /* * [Negotiated-UC2] maximum success test */ - func testMqtt5NegotiatedSettingsMaximumSettings() throws { + func testMqtt5NegotiatedSettingsMaximumSettings() async throws { let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_HOST") let inputPort = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_DIRECT_MQTT_PORT") @@ -1408,7 +1365,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) if let negotiatedSettings = testContext.negotiatedSettings { XCTAssertEqual(negotiatedSettings.sessionExpiryInterval, sessionExpirtyInterval) @@ -1420,13 +1377,13 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try disconnectClientCleanup(client: client, testContext: testContext) + try await disconnectClientCleanup(client: client, testContext: testContext) } /* * [Negotiated-UC3] server settings limit test */ - func testMqtt5NegotiatedSettingsServerLimit() throws { + func testMqtt5NegotiatedSettingsServerLimit() async throws { try skipIfPlatformDoesntSupportTLS() let inputHost = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_IOT_CORE_HOST") let inputCert = try getEnvironmentVarOrSkipTest(environmentVarName: "AWS_TEST_MQTT5_IOT_CORE_RSA_CERT") @@ -1457,7 +1414,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) if let negotiatedSettings = testContext.negotiatedSettings { XCTAssertNotEqual(sessionExpiryInterval, negotiatedSettings.sessionExpiryInterval) @@ -1469,7 +1426,7 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - try disconnectClientCleanup(client: client, testContext: testContext) + try await disconnectClientCleanup(client: client, testContext: testContext) } /*=============================================================== @@ -1497,7 +1454,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) let topic = "test/MQTT5_Binding_Swift_" + UUID().uuidString let subscribePacket = SubscribePacket(topicFilter: topic, qos: QoS.atLeastOnce, noLocal: false) @@ -1520,12 +1477,8 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - if testContext.semaphorePublishReceived.wait(timeout: .now() + 5) == .timedOut { - print("Publish not received after 5 seconds") - XCTFail("Publish packet not received on subscribed topic") - return - } - + await awaitExpectation([testContext.publishReceivedExpectation], 5) + let unsubscribePacket = UnsubscribePacket(topicFilter: topic) let unsubackPacket: UnsubackPacket = try await withTimeout(client: client, seconds: 2, operation: { @@ -1533,7 +1486,7 @@ class Mqtt5ClientTests: XCBaseTestCase { }) print("UnsubackPacket received with result \(unsubackPacket.reasonCodes[0])") - try disconnectClientCleanup(client: client, testContext: testContext) + try await disconnectClientCleanup(client: client, testContext: testContext) } /* @@ -1565,7 +1518,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContextPublisher = MqttTestContext(contextName: "Publisher") let clientPublisher = try createClient(clientOptions: clientOptions, testContext: testContextPublisher) - try connectClient(client: clientPublisher, testContext: testContextPublisher) + try await connectClient(client: clientPublisher, testContext: testContextPublisher) let clientIDSubscriber = createClientId() + "Subscriber" let testContextSubscriber = MqttTestContext(contextName: "Subscriber") @@ -1577,7 +1530,7 @@ class Mqtt5ClientTests: XCBaseTestCase { connectOptions: connectOptionsSubscriber) let clientSubscriber = try createClient(clientOptions: clientOptionsSubscriber, testContext: testContextSubscriber) - try connectClient(client: clientSubscriber, testContext: testContextSubscriber) + try await connectClient(client: clientSubscriber, testContext: testContextSubscriber) let subscribePacket = SubscribePacket(topicFilter: topic, qos: QoS.atLeastOnce, noLocal: false) let subackPacket: SubackPacket = @@ -1587,15 +1540,11 @@ class Mqtt5ClientTests: XCBaseTestCase { print("SubackPacket received with result \(subackPacket.reasonCodes[0])") let disconnectPacket = DisconnectPacket(reasonCode: .disconnectWithWillMessage) - try disconnectClientCleanup(client: clientPublisher, testContext: testContextPublisher, disconnectPacket: disconnectPacket) + try await disconnectClientCleanup(client: clientPublisher, testContext: testContextPublisher, disconnectPacket: disconnectPacket) - if testContextSubscriber.semaphorePublishReceived.wait(timeout: .now() + 5) == .timedOut { - print("Publish not received after 5 seconds") - XCTFail("Publish packet not received on subscribed topic") - return - } + await awaitExpectation([testContextSubscriber.publishReceivedExpectation], 5) - try disconnectClientCleanup(client:clientSubscriber, testContext: testContextSubscriber) + try await disconnectClientCleanup(client:clientSubscriber, testContext: testContextSubscriber) } /* @@ -1620,7 +1569,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) let topic = "test/MQTT5_Binding_Swift_" + UUID().uuidString let subscribePacket = SubscribePacket(topicFilter: topic, qos: QoS.atLeastOnce, noLocal: false) @@ -1642,16 +1591,12 @@ class Mqtt5ClientTests: XCBaseTestCase { return } - if testContext.semaphorePublishReceived.wait(timeout: .now() + 5) == .timedOut { - print("Publish not received after 5 seconds") - XCTFail("Publish packet not received on subscribed topic") - return - } + await awaitExpectation([testContext.publishReceivedExpectation], 5) let publishReceived = testContext.publishPacket! XCTAssertEqual(publishReceived.payload, payloadData, "Binary data received as publish not equal to binary data used to generate publish") - try disconnectClientCleanup(client: client, testContext: testContext) + try await disconnectClientCleanup(client: client, testContext: testContext) } /* @@ -1667,7 +1612,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) let topic1 = "test/MQTT5_Binding_Swift_" + UUID().uuidString let topic2 = "test/MQTT5_Binding_Swift_" + UUID().uuidString @@ -1699,7 +1644,7 @@ class Mqtt5ClientTests: XCBaseTestCase { print("Index:\(i) result:\(unsubackPacket.reasonCodes[i])") } - try disconnectClientCleanup(client: client, testContext: testContext) + try await disconnectClientCleanup(client: client, testContext: testContext) } /*=============================================================== @@ -1730,7 +1675,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) let publishPacket = PublishPacket(qos: .atLeastOnce, topic: "") do { @@ -1739,7 +1684,7 @@ class Mqtt5ClientTests: XCBaseTestCase { XCTAssertEqual(crtError.code, Int32(AWS_ERROR_MQTT5_PUBLISH_OPTIONS_VALIDATION.rawValue)) } - try disconnectClientCleanup(client:client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /* @@ -1761,7 +1706,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) let subscribePacket = SubscribePacket(topicFilter: "", qos: .atLeastOnce) do { @@ -1770,7 +1715,7 @@ class Mqtt5ClientTests: XCBaseTestCase { XCTAssertEqual(crtError.code, Int32(AWS_ERROR_MQTT5_SUBSCRIBE_OPTIONS_VALIDATION.rawValue)) } - try disconnectClientCleanup(client:client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /* @@ -1792,7 +1737,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) let unsubscribePacket = UnsubscribePacket(topicFilter: "") do { @@ -1801,7 +1746,7 @@ class Mqtt5ClientTests: XCBaseTestCase { XCTAssertEqual(crtError.code, Int32(AWS_ERROR_MQTT5_UNSUBSCRIBE_OPTIONS_VALIDATION.rawValue)) } - try disconnectClientCleanup(client:client, testContext: testContext) + try await disconnectClientCleanup(client:client, testContext: testContext) } /*=============================================================== @@ -1831,7 +1776,7 @@ class Mqtt5ClientTests: XCBaseTestCase { connectOptions: connectOptions1) let testContext1 = MqttTestContext() let client1 = try createClient(clientOptions: clientOptions1, testContext: testContext1) - try connectClient(client: client1, testContext: testContext1) + try await connectClient(client: client1, testContext: testContext1) // Create and connect client2 let connectOptions2 = MqttConnectOptions(clientId: createClientId()) @@ -1842,7 +1787,7 @@ class Mqtt5ClientTests: XCBaseTestCase { connectOptions: connectOptions2) let testContext2 = MqttTestContext(publishTarget: 10) let client2 = try createClient(clientOptions: clientOptions2, testContext: testContext2) - try connectClient(client: client2, testContext: testContext2) + try await connectClient(client: client2, testContext: testContext2) let topic = "test/MQTT5_Binding_Swift_" + UUID().uuidString let subscribePacket = SubscribePacket(topicFilter: topic, qos: QoS.atLeastOnce, noLocal: false) @@ -1863,14 +1808,11 @@ class Mqtt5ClientTests: XCBaseTestCase { } // Wait for client2 to receive 10 publishes - if testContext2.semaphorePublishTargetReached.wait(timeout: .now() + 10) == .timedOut { - print("Expected Publish receive target not hit after 10 seconds") - XCTFail("Missing Publishes") - return - } + await awaitExpectation([testContext2.publishTargetReachedExpectation], 5) - try disconnectClientCleanup(client:client1, testContext: testContext1) - try disconnectClientCleanup(client:client2, testContext: testContext2) + + try await disconnectClientCleanup(client:client1, testContext: testContext1) + try await disconnectClientCleanup(client:client2, testContext: testContext2) } @@ -1902,7 +1844,7 @@ class Mqtt5ClientTests: XCBaseTestCase { connectOptions: connectOptions1) let testContext1 = MqttTestContext(contextName: "Client1") let client1 = try createClient(clientOptions: clientOptions1, testContext: testContext1) - try connectClient(client: client1, testContext: testContext1) + try await connectClient(client: client1, testContext: testContext1) // Create client2 let connectOptions2 = MqttConnectOptions(clientId: createClientId()) @@ -1947,15 +1889,13 @@ class Mqtt5ClientTests: XCBaseTestCase { } // connect client2 and subscribe to topic with retained client1 publish - try connectClient(client: client2, testContext: testContext2) + try await connectClient(client: client2, testContext: testContext2) _ = try await withTimeout(client: client2, seconds: 2, operation: { try await client2.subscribe(subscribePacket: subscribePacket) }) - if testContext2.semaphorePublishReceived.wait(timeout: .now() + 10) == .timedOut { - XCTFail("Expected retained Publish not received") - return - } + + await awaitExpectation([testContext2.publishReceivedExpectation], 10) XCTAssertEqual(testContext2.publishPacket?.payloadAsString(), publishPacket.payloadAsString()) @@ -1974,22 +1914,24 @@ class Mqtt5ClientTests: XCBaseTestCase { } // connect client3 and subscribe to topic to insure there is no client1 retained publish - try connectClient(client: client3, testContext: testContext3) + try await connectClient(client: client3, testContext: testContext3) _ = try await withTimeout(client: client3, seconds: 2, operation: { try await client3.subscribe(subscribePacket: subscribePacket) }) - if testContext3.semaphorePublishReceived.wait(timeout: .now() + 1) == .timedOut { + let waitResult = await awaitExpectationResult([testContext3.publishReceivedExpectation], 5) + if(waitResult == XCTWaiter.Result.timedOut){ print("no retained publish from client1") - } else { + }else{ XCTFail("Retained publish from client1 received when it should be cleared") return } + - try disconnectClientCleanup(client:client1, testContext: testContext1) - try disconnectClientCleanup(client:client2, testContext: testContext2) - try disconnectClientCleanup(client:client3, testContext: testContext3) + try await disconnectClientCleanup(client:client1, testContext: testContext1) + try await disconnectClientCleanup(client:client2, testContext: testContext2) + try await disconnectClientCleanup(client:client3, testContext: testContext3) } /*=============================================================== @@ -2008,7 +1950,7 @@ class Mqtt5ClientTests: XCBaseTestCase { let testContext = MqttTestContext() let client = try createClient(clientOptions: clientOptions, testContext: testContext) - try connectClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) } /* @@ -2027,13 +1969,23 @@ class Mqtt5ClientTests: XCBaseTestCase { let client = try createClient(clientOptions: clientOptions, testContext: testContext) // offline operation would never complete. Use close to force quit. defer { client.close() } - try connectClient(client: client, testContext: testContext) - try stopClient(client: client, testContext: testContext) + try await connectClient(client: client, testContext: testContext) + try await stopClient(client: client, testContext: testContext) let topic = "test/MQTT5_Binding_Swift_" + UUID().uuidString let publishPacket = PublishPacket(qos: QoS.atLeastOnce, topic: topic) - Task { - let _ = try? await client.publish(publishPacket: publishPacket) + + do{ + // An offline publish would not get the puback back as the operation could never get an ack back + // the operation should timeout + let _ = try await withTimeout(client: client, seconds: 2, operation: { + try await client.publish(publishPacket: publishPacket) + }) + }catch (let error) { + if(error as! MqttTestError != MqttTestError.timeout) { + XCTFail("Offline publish failed with \(error)") + } } + } } diff --git a/Test/AwsCommonRuntimeKitTests/sdkutils/EndpointPropertyTests.swift b/Test/AwsCommonRuntimeKitTests/sdkutils/EndpointPropertyTests.swift index 878b15b55..46fa3d346 100644 --- a/Test/AwsCommonRuntimeKitTests/sdkutils/EndpointPropertyTests.swift +++ b/Test/AwsCommonRuntimeKitTests/sdkutils/EndpointPropertyTests.swift @@ -9,37 +9,39 @@ class EndpointPropertyTests: XCTestCase { func testDecoderWithBool() throws { let data = "true".data(using: .utf8)! let actual = try JSONDecoder().decode(EndpointProperty.self, from: data) - XCTAssertEqual(true, actual.toAnyHashable()) + XCTAssertEqual(.bool(true), actual) } func testDecoderWithString() throws { let data = "\"hello\"".data(using: .utf8)! let actual = try JSONDecoder().decode(EndpointProperty.self, from: data) - XCTAssertEqual("hello", actual.toAnyHashable()) + XCTAssertEqual(.string("hello"), actual) } func testDecoderWithArray() throws { let data = "[\"hello\", \"world\"]".data(using: .utf8)! let actual = try JSONDecoder().decode(EndpointProperty.self, from: data) - XCTAssertEqual(["hello", "world"], actual.toAnyHashable()) + XCTAssertEqual(.array([.string("hello"), .string("world")]), actual) } func testDecoderWithDictionary() throws { let data = "{\"hello\": \"world\"}".data(using: .utf8)! let actual = try JSONDecoder().decode(EndpointProperty.self, from: data) - XCTAssertEqual(["hello": "world"], actual.toAnyHashable()) + XCTAssertEqual(.dictionary(["hello": .string("world")]), actual) } func testDecoderWithMixed() throws { let data = "{\"hello\": [\"world\", \"universe\"], \"isAlive\": true}".data(using: .utf8)! let actual = try JSONDecoder().decode(EndpointProperty.self, from: data) - let expected: [String: AnyHashable] = [ - "hello": [ - "world", - "universe" - ], - "isAlive": true - ] - XCTAssertEqual(expected, actual.toAnyHashable()) + + let expected: EndpointProperty = .dictionary([ + "hello": .array([ + .string("world"), + .string("universe") + ]), + "isAlive": .bool(true) + ]) + + XCTAssertEqual(expected, actual) } } diff --git a/Test/AwsCommonRuntimeKitTests/sdkutils/EndpointsRuleEngineTests.swift b/Test/AwsCommonRuntimeKitTests/sdkutils/EndpointsRuleEngineTests.swift index f3dae8ba8..1b9470ee6 100644 --- a/Test/AwsCommonRuntimeKitTests/sdkutils/EndpointsRuleEngineTests.swift +++ b/Test/AwsCommonRuntimeKitTests/sdkutils/EndpointsRuleEngineTests.swift @@ -220,21 +220,19 @@ class EndpointsRuleEngineTests: XCBaseTestCase { try context.add(name: "Region", value: "us-west-2") let resolved = try engine.resolve(context: context) guard case ResolvedEndpoint.endpoint(url: let url, - headers: let headers, - properties: let properties) = resolved else { + headers: let headers, + properties: let properties) = resolved else { XCTFail("Endpoint resolved to an error") return } XCTAssertEqual("https://example.us-west-2.amazonaws.com", url) - let expectedProperties = [ - "authSchemes": [ - [ - "name": "sigv4", - "signingName": "serviceName", - "signingRegion": "us-west-2" - ] - ] - ] + let expectedProperties: [String: EndpointProperty] = ["authSchemes": .array([ + .dictionary([ + "name": .string("sigv4"), + "signingName": .string("serviceName"), + "signingRegion": .string("us-west-2") + ]) + ])] XCTAssertEqual(expectedProperties, properties) let expectedHeaders = [ "x-amz-region": [