Skip to content

Commit c5213b1

Browse files
authored
Redesign persistent token cache API (#23114)
1 parent 5df73f9 commit c5213b1

29 files changed

+690
-380
lines changed

sdk/azidentity/azidentity.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@ var (
5353
errInvalidTenantID = errors.New("invalid tenantID. You can locate your tenantID by following the instructions listed here: https://learn.microsoft.com/partner-center/find-ids-and-domain-names")
5454
)
5555

56-
// TokenCachePersistenceOptions contains options for persistent token caching
57-
type TokenCachePersistenceOptions = internal.TokenCachePersistenceOptions
56+
// Cache represents a persistent cache that makes authentication data available across processes.
57+
// Construct one with [github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache.New]. This package's
58+
// [persistent user authentication example] shows how to use a persistent cache to reuse logins
59+
// across application runs.
60+
//
61+
// [persistent user authentication example]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/[email protected]#example-package-PersistentUserAuthentication
62+
type Cache = internal.Cache
5863

5964
// setAuthorityHost initializes the authority host for credentials. Precedence is:
6065
// 1. cloud.Configuration.ActiveDirectoryAuthorityHost value set by user

sdk/azidentity/azidentity_test.go

Lines changed: 130 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"os"
1919
"path/filepath"
2020
"reflect"
21-
"runtime"
2221
"strings"
2322
"testing"
2423
"time"
@@ -213,6 +212,17 @@ func TestTenantID(t *testing.T) {
213212
}
214213
}
215214

215+
type testCache []byte
216+
217+
func (c *testCache) Export(_ context.Context, m cache.Marshaler, _ cache.ExportHints) (err error) {
218+
*c, err = m.Marshal()
219+
return
220+
}
221+
222+
func (c *testCache) Replace(_ context.Context, u cache.Unmarshaler, _ cache.ReplaceHints) error {
223+
return u.Unmarshal(*c)
224+
}
225+
216226
func TestUserAuthentication(t *testing.T) {
217227
type authenticater interface {
218228
azcore.TokenCredential
@@ -221,30 +231,30 @@ func TestUserAuthentication(t *testing.T) {
221231
for _, credential := range []struct {
222232
name string
223233
interactive, recordable bool
224-
new func(*TokenCachePersistenceOptions, azcore.ClientOptions, AuthenticationRecord, bool) (authenticater, error)
234+
new func(Cache, azcore.ClientOptions, AuthenticationRecord, bool) (authenticater, error)
225235
}{
226236
{
227237
name: credNameBrowser,
228-
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
238+
new: func(c Cache, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
229239
return NewInteractiveBrowserCredential(&InteractiveBrowserCredentialOptions{
230240
AdditionallyAllowedTenants: []string{"*"},
231241
AuthenticationRecord: ar,
242+
Cache: c,
232243
ClientOptions: co,
233244
DisableAutomaticAuthentication: disableAutoAuth,
234-
TokenCachePersistenceOptions: tcpo,
235245
})
236246
},
237247
interactive: true,
238248
},
239249
{
240250
name: credNameDeviceCode,
241-
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
251+
new: func(c Cache, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
242252
o := DeviceCodeCredentialOptions{
243253
AdditionallyAllowedTenants: []string{"*"},
244254
AuthenticationRecord: ar,
255+
Cache: c,
245256
ClientOptions: co,
246257
DisableAutomaticAuthentication: disableAutoAuth,
247-
TokenCachePersistenceOptions: tcpo,
248258
}
249259
if recording.GetRecordMode() == recording.PlaybackMode {
250260
o.UserPrompt = func(context.Context, DeviceCodeMessage) error { return nil }
@@ -256,12 +266,12 @@ func TestUserAuthentication(t *testing.T) {
256266
},
257267
{
258268
name: credNameUserPassword,
259-
new: func(tcpo *TokenCachePersistenceOptions, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
269+
new: func(c Cache, co azcore.ClientOptions, ar AuthenticationRecord, disableAutoAuth bool) (authenticater, error) {
260270
opts := UsernamePasswordCredentialOptions{
261-
AdditionallyAllowedTenants: []string{"*"},
262-
AuthenticationRecord: ar,
263-
ClientOptions: co,
264-
TokenCachePersistenceOptions: tcpo,
271+
AdditionallyAllowedTenants: []string{"*"},
272+
AuthenticationRecord: ar,
273+
Cache: c,
274+
ClientOptions: co,
265275
}
266276
return NewUsernamePasswordCredential(liveUser.tenantID, developerSignOnClientID, liveUser.username, liveUser.password, &opts)
267277
},
@@ -286,13 +296,13 @@ func TestUserAuthentication(t *testing.T) {
286296
}}
287297

288298
co := azcore.ClientOptions{Cloud: cc, Transport: &sts}
289-
cred, err := credential.new(nil, co, AuthenticationRecord{}, false)
299+
cred, err := credential.new(Cache{}, co, AuthenticationRecord{}, false)
290300
require.NoError(t, err)
291301
_, err = cred.Authenticate(context.Background(), nil)
292302
require.NoError(t, err)
293303

294304
t.Setenv(azureAuthorityHost, cc.ActiveDirectoryAuthorityHost)
295-
cred, err = credential.new(nil, azcore.ClientOptions{Transport: &sts}, AuthenticationRecord{}, false)
305+
cred, err = credential.new(Cache{}, azcore.ClientOptions{Transport: &sts}, AuthenticationRecord{}, false)
296306
require.NoError(t, err)
297307
_, err = cred.Authenticate(context.Background(), nil)
298308
if cc.ActiveDirectoryAuthorityHost == customCloud.ActiveDirectoryAuthorityHost {
@@ -320,14 +330,14 @@ func TestUserAuthentication(t *testing.T) {
320330
counter := tokenRequestCountingPolicy{}
321331
co.PerCallPolicies = append(co.PerCallPolicies, &counter)
322332

323-
cred, err := credential.new(nil, co, AuthenticationRecord{}, false)
333+
cred, err := credential.new(Cache{}, co, AuthenticationRecord{}, false)
324334
require.NoError(t, err)
325335
ar, err := cred.Authenticate(context.Background(), &testTRO)
326336
require.NoError(t, err)
327337

328338
// some fields of the returned AuthenticationRecord should have specific values
329-
require.Equal(t, ar.ClientID, developerSignOnClientID)
330-
require.Equal(t, ar.Version, supportedAuthRecordVersions[0])
339+
require.Equal(t, developerSignOnClientID, ar.ClientID)
340+
require.Equal(t, supportedAuthRecordVersions[0], ar.Version)
331341
// all others should have nonempty values
332342
v := reflect.Indirect(reflect.ValueOf(&ar))
333343
for _, f := range reflect.VisibleFields(reflect.TypeOf(ar)) {
@@ -337,48 +347,47 @@ func TestUserAuthentication(t *testing.T) {
337347
require.Equal(t, 1, counter.count)
338348
})
339349

340-
t.Run("PersistentCache_Live/"+credential.name, func(t *testing.T) {
341-
switch recording.GetRecordMode() {
342-
case recording.LiveMode:
343-
if credential.interactive && !runManualTests {
344-
t.Skipf("set %s to run this test", azidentityRunManualTests)
345-
}
346-
case recording.PlaybackMode, recording.RecordingMode:
347-
if !credential.recordable {
348-
t.Skip("this test can't be recorded")
349-
}
350+
t.Run("PersistentCache/"+credential.name, func(t *testing.T) {
351+
if credential.name == credNameBrowser && !runManualTests {
352+
t.Skipf("set %s to run this test", azidentityRunManualTests)
350353
}
351-
if runtime.GOOS != "windows" {
352-
t.Skip("this test runs only on Windows")
353-
}
354-
p, err := internal.CacheFilePath(t.Name())
355-
require.NoError(t, err)
356-
os.Remove(p)
357-
co, stop := initRecording(t)
358-
defer stop()
359-
counter := tokenRequestCountingPolicy{}
360-
co.PerCallPolicies = append(co.PerCallPolicies, &counter)
361-
tcpo := TokenCachePersistenceOptions{Name: t.Name()}
354+
tokenReqs := 0
355+
c := internal.NewCache(func(bool) (cache.ExportReplace, error) {
356+
return &testCache{}, nil
357+
})
358+
co := azcore.ClientOptions{Transport: &mockSTS{
359+
tokenRequestCallback: func(*http.Request) *http.Response {
360+
tokenReqs++
361+
return nil
362+
},
363+
}}
362364

363-
cred, err := credential.new(&tcpo, co, AuthenticationRecord{}, true)
365+
cred, err := credential.new(c, co, AuthenticationRecord{}, false)
364366
require.NoError(t, err)
365-
record, err := cred.Authenticate(context.Background(), &testTRO)
367+
record, err := cred.Authenticate(ctx, &testTRO)
366368
require.NoError(t, err)
367-
defer os.Remove(p)
368-
tk, err := cred.GetToken(context.Background(), testTRO)
369+
_, err = cred.GetToken(ctx, testTRO)
369370
require.NoError(t, err)
370-
require.Equal(t, 1, counter.count)
371+
require.Equal(t, 1, tokenReqs)
371372

372-
cred2, err := credential.new(&tcpo, co, record, true)
373+
// cred2 should return the token cached by cred
374+
cred2, err := credential.new(c, co, record, true)
373375
require.NoError(t, err)
374-
tk2, err := cred2.GetToken(context.Background(), testTRO)
376+
_, err = cred2.GetToken(ctx, testTRO)
375377
require.NoError(t, err)
376-
require.Equal(t, tk.Token, tk2.Token)
378+
require.Equal(t, 1, tokenReqs)
379+
380+
// cred should request a new token because the cached one isn't a CAE token
381+
caeTRO := testTRO
382+
caeTRO.EnableCAE = true
383+
_, err = cred.GetToken(ctx, caeTRO)
384+
require.NoError(t, err)
385+
require.Equal(t, 2, tokenReqs)
377386
})
378387

379388
if credential.interactive {
380389
t.Run("DisableAutomaticAuthentication/"+credential.name, func(t *testing.T) {
381-
cred, err := credential.new(nil, policy.ClientOptions{Transport: &mockSTS{}}, AuthenticationRecord{}, true)
390+
cred, err := credential.new(Cache{}, policy.ClientOptions{Transport: &mockSTS{}}, AuthenticationRecord{}, true)
382391
require.NoError(t, err)
383392
expected := policy.TokenRequestOptions{
384393
Claims: "claims",
@@ -402,7 +411,7 @@ func TestUserAuthentication(t *testing.T) {
402411
}
403412
})
404413
t.Run("DisableAutomaticAuthentication/ChainedTokenCredential/"+credential.name, func(t *testing.T) {
405-
cred, err := credential.new(nil, policy.ClientOptions{}, AuthenticationRecord{}, true)
414+
cred, err := credential.new(Cache{}, policy.ClientOptions{}, AuthenticationRecord{}, true)
406415
require.NoError(t, err)
407416
expected := azcore.AccessToken{ExpiresOn: time.Now().UTC(), Token: tokenValue}
408417
fake := NewFakeCredential()
@@ -1103,107 +1112,90 @@ func TestResolveTenant(t *testing.T) {
11031112
}
11041113
}
11051114

1106-
func TestTokenCachePersistenceOptions(t *testing.T) {
1107-
af := filepath.Join(t.TempDir(), t.Name()+credNameWorkloadIdentity)
1108-
if err := os.WriteFile(af, []byte("assertion"), os.ModePerm); err != nil {
1109-
t.Fatal(err)
1110-
}
1111-
before := internal.NewCache
1112-
t.Cleanup(func() { internal.NewCache = before })
1113-
for _, test := range []struct {
1114-
desc string
1115-
options *TokenCachePersistenceOptions
1116-
err error
1115+
func TestConfidentialClientPersistentCache(t *testing.T) {
1116+
// for WorkloadIdentityCredential
1117+
tfp := filepath.Join(t.TempDir(), "tokenfile")
1118+
require.NoError(t, os.WriteFile(tfp, []byte("token"), 0600))
1119+
for _, credential := range []struct {
1120+
name string
1121+
new func(azcore.ClientOptions, Cache) (azcore.TokenCredential, error)
11171122
}{
11181123
{
1119-
desc: "nil options",
1124+
name: credNameAssertion,
1125+
new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
1126+
o := ClientAssertionCredentialOptions{Cache: c, ClientOptions: co}
1127+
return NewClientAssertionCredential(fakeTenantID, fakeClientID, func(context.Context) (string, error) { return "...", nil }, &o)
1128+
},
11201129
},
1130+
// TODO: set SYSTEM_OIDC_REQUEST_URI, fake response
1131+
// {
1132+
// name: credNameAzurePipelines,
1133+
// new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
1134+
// o := AzurePipelinesCredentialOptions{Cache: c, ClientOptions: co}
1135+
// return NewAzurePipelinesCredential(fakeTenantID, fakeClientID, "service-connection", tokenValue, &o)
1136+
// },
1137+
// },
11211138
{
1122-
desc: "default options",
1123-
options: &TokenCachePersistenceOptions{},
1139+
name: credNameCert,
1140+
new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
1141+
o := ClientCertificateCredentialOptions{Cache: c, ClientOptions: co}
1142+
return NewClientCertificateCredential(fakeTenantID, fakeClientID, allCertTests[0].certs, allCertTests[0].key, &o)
1143+
},
11241144
},
11251145
{
1126-
desc: "all options set",
1127-
options: &TokenCachePersistenceOptions{AllowUnencryptedStorage: true, Name: "name"},
1146+
name: credNameSecret,
1147+
new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
1148+
o := ClientSecretCredentialOptions{Cache: c, ClientOptions: co}
1149+
return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &o)
1150+
},
11281151
},
1129-
} {
1130-
internal.NewCache = func(o *internal.TokenCachePersistenceOptions, _ bool) (cache.ExportReplace, error) {
1131-
if (test.options == nil) != (o == nil) {
1132-
t.Fatalf("expected %v, got %v", test.options, o)
1133-
}
1134-
if test.options != nil {
1135-
if test.options.AllowUnencryptedStorage != o.AllowUnencryptedStorage {
1136-
t.Fatalf("expected AllowUnencryptedStorage %v, got %v", test.options.AllowUnencryptedStorage, o.AllowUnencryptedStorage)
1137-
}
1138-
if test.options.Name != o.Name {
1139-
t.Fatalf("expected Name %q, got %q", test.options.Name, o.Name)
1152+
{
1153+
name: credNameWorkloadIdentity,
1154+
new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
1155+
o := WorkloadIdentityCredentialOptions{
1156+
Cache: c,
1157+
ClientID: fakeClientID,
1158+
ClientOptions: co,
1159+
TenantID: fakeTenantID,
1160+
TokenFilePath: tfp,
11401161
}
1141-
}
1142-
return nil, nil
1143-
}
1144-
for _, subtest := range []struct {
1145-
ctor func(azcore.ClientOptions, *TokenCachePersistenceOptions) (azcore.TokenCredential, error)
1146-
env map[string]string
1147-
name string
1148-
}{
1149-
{
1150-
name: credNameAssertion,
1151-
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
1152-
o := ClientAssertionCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco}
1153-
return NewClientAssertionCredential(fakeTenantID, fakeClientID, func(context.Context) (string, error) { return "...", nil }, &o)
1154-
},
1155-
},
1156-
{
1157-
name: credNameCert,
1158-
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
1159-
o := ClientCertificateCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco}
1160-
return NewClientCertificateCredential(fakeTenantID, fakeClientID, allCertTests[0].certs, allCertTests[0].key, &o)
1161-
},
1162-
},
1163-
{
1164-
name: credNameDeviceCode,
1165-
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
1166-
o := DeviceCodeCredentialOptions{
1167-
ClientOptions: co,
1168-
TokenCachePersistenceOptions: tco,
1169-
UserPrompt: func(context.Context, DeviceCodeMessage) error { return nil },
1170-
}
1171-
return NewDeviceCodeCredential(&o)
1172-
},
1173-
},
1174-
{
1175-
name: credNameSecret,
1176-
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
1177-
o := ClientSecretCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco}
1178-
return NewClientSecretCredential(fakeTenantID, fakeClientID, fakeSecret, &o)
1179-
},
1180-
},
1181-
{
1182-
name: credNameUserPassword,
1183-
ctor: func(co azcore.ClientOptions, tco *TokenCachePersistenceOptions) (azcore.TokenCredential, error) {
1184-
o := UsernamePasswordCredentialOptions{ClientOptions: co, TokenCachePersistenceOptions: tco}
1185-
return NewUsernamePasswordCredential(fakeTenantID, fakeClientID, fakeUsername, "password", &o)
1186-
},
1162+
return NewWorkloadIdentityCredential(&o)
11871163
},
1188-
} {
1189-
t.Run(fmt.Sprintf("%s/%s", subtest.name, test.desc), func(t *testing.T) {
1190-
for k, v := range subtest.env {
1191-
t.Setenv(k, v)
1192-
}
1193-
c, err := subtest.ctor(policy.ClientOptions{Transport: &mockSTS{}}, test.options)
1194-
if err != nil {
1195-
t.Fatal(err)
1196-
}
1197-
_, err = c.GetToken(context.Background(), testTRO)
1198-
if err != nil {
1199-
if !errors.Is(err, test.err) {
1200-
t.Fatalf("expected %v, got %v", test.err, err)
1201-
}
1202-
} else if test.err != nil {
1203-
t.Fatal("expected an error")
1204-
}
1164+
},
1165+
} {
1166+
t.Run(credential.name, func(t *testing.T) {
1167+
tokenReqs := 0
1168+
c := internal.NewCache(func(bool) (cache.ExportReplace, error) {
1169+
return &testCache{}, nil
12051170
})
1206-
}
1171+
sts := mockSTS{
1172+
tokenRequestCallback: func(*http.Request) *http.Response {
1173+
tokenReqs++
1174+
return nil
1175+
},
1176+
}
1177+
cred, err := credential.new(policy.ClientOptions{Transport: &sts}, c)
1178+
require.NoError(t, err)
1179+
_, err = cred.GetToken(context.Background(), testTRO)
1180+
require.NoError(t, err)
1181+
_, err = cred.GetToken(ctx, testTRO)
1182+
require.NoError(t, err)
1183+
require.Equal(t, 1, tokenReqs)
1184+
1185+
// cred2 should return the token cached by cred
1186+
cred2, err := credential.new(policy.ClientOptions{Transport: &sts}, c)
1187+
require.NoError(t, err)
1188+
_, err = cred2.GetToken(ctx, testTRO)
1189+
require.NoError(t, err)
1190+
require.Equal(t, 1, tokenReqs)
1191+
1192+
// cred should request a new token because the cached one isn't a CAE token
1193+
caeTRO := testTRO
1194+
caeTRO.EnableCAE = true
1195+
_, err = cred.GetToken(ctx, caeTRO)
1196+
require.NoError(t, err)
1197+
require.Equal(t, 2, tokenReqs)
1198+
})
12071199
}
12081200
}
12091201

0 commit comments

Comments
 (0)