Skip to content

Commit 97cc3c7

Browse files
roninjin10codex
andcommitted
fix(auth): 0154 — serialize TokenManager.refresh to prevent double-redeem race
Co-authored-by: OpenAI Codex <codex@openai.com>
1 parent 3e289df commit 97cc3c7

3 files changed

Lines changed: 128 additions & 61 deletions

File tree

Shared/Sources/SmithersAuth/TokenManager.swift

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -87,51 +87,23 @@ public final class TokenManager {
8787
/// tokens hit the Keychain before this function resolves.
8888
@discardableResult
8989
public func refresh() async throws -> OAuth2Tokens {
90-
// De-duplicate concurrent refreshes — two parallel 401s must not
91-
// both redeem the same refresh token.
92-
let (existingTask, current): (Task<OAuth2Tokens, Error>?, OAuth2Tokens?) = {
93-
lock.lock(); defer { lock.unlock() }
94-
return (inFlightRefresh, cached)
95-
}()
96-
if let existing = existingTask {
97-
return try await existing.value
98-
}
99-
guard let current = current else {
100-
throw TokenManagerError.notSignedIn
101-
}
102-
let task = Task { [client, store] () throws -> OAuth2Tokens in
103-
let newTokens: OAuth2Tokens
104-
do {
105-
newTokens = try await client.refresh(refreshToken: current.refreshToken)
106-
} catch let err as OAuth2Error {
107-
if case .whitelistDenied(let msg) = err {
108-
throw TokenManagerError.whitelistDenied(msg)
109-
}
110-
throw TokenManagerError.refreshFailed(err)
90+
let task = withLock { () -> Task<OAuth2Tokens, Error>? in
91+
if let existing = inFlightRefresh {
92+
return existing
11193
}
112-
// WRITE-BEFORE-RETURN. If this throws, the caller treats the
113-
// user as signed out — see `refreshAndRetry` wiring.
114-
do {
115-
try store.save(newTokens)
116-
} catch let e as TokenStoreError {
117-
throw TokenManagerError.persistenceFailed(e)
94+
guard let current = cached else {
95+
return nil
11896
}
119-
return newTokens
97+
let task = makeRefreshTask(current: current)
98+
inFlightRefresh = task
99+
return task
120100
}
121-
setInflight(task)
122101

123-
do {
124-
let newTokens = try await task.value
125-
self.setCached(newTokens)
126-
self.clearInflight()
127-
return newTokens
128-
} catch {
129-
self.clearInflight()
130-
// A refresh failure (expired/rotated/revoked refresh token OR
131-
// a store write failure) locks the user out. Wipe silently.
132-
await self.localSignOut()
133-
throw error
102+
guard let task else {
103+
throw TokenManagerError.notSignedIn
134104
}
105+
106+
return try await task.value
135107
}
136108

137109
private func setCached(_ t: OAuth2Tokens) {
@@ -142,8 +114,44 @@ public final class TokenManager {
142114
lock.lock(); inFlightRefresh = nil; lock.unlock()
143115
}
144116

145-
private func setInflight(_ t: Task<OAuth2Tokens, Error>) {
146-
lock.lock(); inFlightRefresh = t; lock.unlock()
117+
private func withLock<T>(_ body: () throws -> T) rethrows -> T {
118+
lock.lock()
119+
defer { lock.unlock() }
120+
return try body()
121+
}
122+
123+
private func makeRefreshTask(current: OAuth2Tokens) -> Task<OAuth2Tokens, Error> {
124+
Task { [self, client, store] in
125+
defer { clearInflight() }
126+
127+
do {
128+
let newTokens: OAuth2Tokens
129+
do {
130+
newTokens = try await client.refresh(refreshToken: current.refreshToken)
131+
} catch let err as OAuth2Error {
132+
if case .whitelistDenied(let msg) = err {
133+
throw TokenManagerError.whitelistDenied(msg)
134+
}
135+
throw TokenManagerError.refreshFailed(err)
136+
}
137+
138+
// WRITE-BEFORE-RETURN. If this throws, the caller treats the
139+
// user as signed out — see `refreshAndRetry` wiring.
140+
do {
141+
try store.save(newTokens)
142+
} catch let e as TokenStoreError {
143+
throw TokenManagerError.persistenceFailed(e)
144+
}
145+
146+
setCached(newTokens)
147+
return newTokens
148+
} catch {
149+
// A refresh failure (expired/rotated/revoked refresh token OR
150+
// a store write failure) locks the user out. Wipe silently.
151+
await localSignOut()
152+
throw error
153+
}
154+
}
147155
}
148156

149157
/// 401-retry helper. `perform` receives a bearer access token; return

Shared/Tests/SmithersAuthTests/MockHTTPTransport.swift

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,33 @@ final class MockHTTPTransport: HTTPTransport {
3434
}
3535
}
3636

37+
private let lock = NSLock()
3738
var responses: [CannedResponse] = []
39+
var sendDelayNanoseconds: UInt64 = 0
3840
private(set) var recorded: [Recorded] = []
3941

4042
func send(_ request: URLRequest) async throws -> (Data, Int, [String: String]) {
41-
recorded.append(Recorded(
42-
url: request.url!,
43-
method: request.httpMethod ?? "GET",
44-
body: request.httpBody
45-
))
46-
guard !responses.isEmpty else {
47-
return (Data(), 500, [:])
43+
if sendDelayNanoseconds > 0 {
44+
try? await Task.sleep(nanoseconds: sendDelayNanoseconds)
4845
}
49-
let r = responses.removeFirst()
50-
return (r.body, r.status, r.headers)
46+
47+
return withLock {
48+
recorded.append(Recorded(
49+
url: request.url!,
50+
method: request.httpMethod ?? "GET",
51+
body: request.httpBody
52+
))
53+
guard !responses.isEmpty else {
54+
return (Data(), 500, [:])
55+
}
56+
let r = responses.removeFirst()
57+
return (r.body, r.status, r.headers)
58+
}
59+
}
60+
61+
private func withLock<T>(_ body: () throws -> T) rethrows -> T {
62+
lock.lock()
63+
defer { lock.unlock() }
64+
return try body()
5165
}
5266
}

Shared/Tests/SmithersAuthTests/MockedServerIntegrationTests.swift

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,23 +193,49 @@ final class MockedServerIntegrationTests: XCTestCase {
193193

194194
// MARK: - Concurrent 401s collapse into one refresh
195195

196-
@MainActor
197196
func test_concurrent_refreshes_deduplicate() async throws {
198197
let transport = MockHTTPTransport()
199-
transport.responses.append(.json(payload: [
200-
"access_token": "A2", "refresh_token": "R2",
201-
]))
198+
transport.sendDelayNanoseconds = 25_000_000
199+
for _ in 0..<2 {
200+
transport.responses.append(.json(payload: [
201+
"access_token": "A2", "refresh_token": "R2",
202+
]))
203+
}
202204
let client = OAuth2Client(config: makeConfig(), transport: transport)
203205
let store = InMemoryTokenStore(initial: OAuth2Tokens(accessToken: "A1", refreshToken: "R1"))
204206
let mgr = TokenManager(client: client, store: store)
205207

206-
async let a = mgr.refresh()
207-
async let b = mgr.refresh()
208-
let (ta, tb) = try await (a, b)
208+
let a = Task { try await mgr.refresh() }
209+
let b = Task { try await mgr.refresh() }
210+
211+
let ta = try await a.value
212+
let tb = try await b.value
213+
214+
XCTAssertEqual(ta, tb)
209215
XCTAssertEqual(ta.accessToken, "A2")
210-
XCTAssertEqual(tb.accessToken, "A2")
211-
// Only ONE refresh hit the wire.
212-
XCTAssertEqual(transport.recorded.count, 1)
216+
XCTAssertEqual(refreshRequests(in: transport).count, 1)
217+
}
218+
219+
func test_concurrent_refreshes_stress_deduplicate() async throws {
220+
let transport = MockHTTPTransport()
221+
transport.sendDelayNanoseconds = 25_000_000
222+
for _ in 0..<20 {
223+
transport.responses.append(.json(payload: [
224+
"access_token": "A2", "refresh_token": "R2",
225+
]))
226+
}
227+
let client = OAuth2Client(config: makeConfig(), transport: transport)
228+
let store = InMemoryTokenStore(initial: OAuth2Tokens(accessToken: "A1", refreshToken: "R1"))
229+
let mgr = TokenManager(client: client, store: store)
230+
231+
let tasks = (0..<20).map { _ in
232+
Task { try await mgr.refresh() }
233+
}
234+
let results = try await tasks.asyncMap { try await $0.value }
235+
let first = try XCTUnwrap(results.first)
236+
237+
XCTAssertTrue(results.allSatisfy { $0 == first })
238+
XCTAssertEqual(refreshRequests(in: transport).count, 1)
213239
}
214240

215241
// MARK: - CSRF state mismatch is rejected
@@ -231,3 +257,22 @@ private final class CountingWipeHandler: SessionWipeHandler {
231257
var wipeCount = 0
232258
func wipeAfterSignOut() { wipeCount += 1 }
233259
}
260+
261+
private func refreshRequests(in transport: MockHTTPTransport) -> [MockHTTPTransport.Recorded] {
262+
transport.recorded.filter {
263+
$0.url.absoluteString.hasSuffix("/api/oauth2/token")
264+
&& ($0.bodyString ?? "").contains("grant_type=refresh_token")
265+
}
266+
}
267+
268+
private extension Array {
269+
func asyncMap<T>(_ transform: (Element) async throws -> T) async rethrows -> [T] {
270+
var results: [T] = []
271+
results.reserveCapacity(count)
272+
for element in self {
273+
let value = try await transform(element)
274+
results.append(value)
275+
}
276+
return results
277+
}
278+
}

0 commit comments

Comments
 (0)