@@ -18,7 +18,6 @@ import (
1818 "os"
1919 "path/filepath"
2020 "reflect"
21- "runtime"
2221 "strings"
2322 "testing"
2423 "time"
@@ -213,6 +212,17 @@ func TestTenantID(t *testing.T) {
213212 }
214213}
215214
215+ type testCache []byte
216+
217+ func (c * testCache ) Export (_ context.Context , m cache.Marshaler , _ cache.ExportHints ) (err error ) {
218+ * c , err = m .Marshal ()
219+ return
220+ }
221+
222+ func (c * testCache ) Replace (_ context.Context , u cache.Unmarshaler , _ cache.ReplaceHints ) error {
223+ return u .Unmarshal (* c )
224+ }
225+
216226func TestUserAuthentication (t * testing.T ) {
217227 type authenticater interface {
218228 azcore.TokenCredential
@@ -221,30 +231,30 @@ func TestUserAuthentication(t *testing.T) {
221231 for _ , credential := range []struct {
222232 name string
223233 interactive , recordable bool
224- new func (* TokenCachePersistenceOptions , azcore.ClientOptions , AuthenticationRecord , bool ) (authenticater , error )
234+ new func (Cache , azcore.ClientOptions , AuthenticationRecord , bool ) (authenticater , error )
225235 }{
226236 {
227237 name : credNameBrowser ,
228- new : func (tcpo * TokenCachePersistenceOptions , co azcore.ClientOptions , ar AuthenticationRecord , disableAutoAuth bool ) (authenticater , error ) {
238+ new : func (c Cache , co azcore.ClientOptions , ar AuthenticationRecord , disableAutoAuth bool ) (authenticater , error ) {
229239 return NewInteractiveBrowserCredential (& InteractiveBrowserCredentialOptions {
230240 AdditionallyAllowedTenants : []string {"*" },
231241 AuthenticationRecord : ar ,
242+ Cache : c ,
232243 ClientOptions : co ,
233244 DisableAutomaticAuthentication : disableAutoAuth ,
234- TokenCachePersistenceOptions : tcpo ,
235245 })
236246 },
237247 interactive : true ,
238248 },
239249 {
240250 name : credNameDeviceCode ,
241- new : func (tcpo * TokenCachePersistenceOptions , co azcore.ClientOptions , ar AuthenticationRecord , disableAutoAuth bool ) (authenticater , error ) {
251+ new : func (c Cache , co azcore.ClientOptions , ar AuthenticationRecord , disableAutoAuth bool ) (authenticater , error ) {
242252 o := DeviceCodeCredentialOptions {
243253 AdditionallyAllowedTenants : []string {"*" },
244254 AuthenticationRecord : ar ,
255+ Cache : c ,
245256 ClientOptions : co ,
246257 DisableAutomaticAuthentication : disableAutoAuth ,
247- TokenCachePersistenceOptions : tcpo ,
248258 }
249259 if recording .GetRecordMode () == recording .PlaybackMode {
250260 o .UserPrompt = func (context.Context , DeviceCodeMessage ) error { return nil }
@@ -256,12 +266,12 @@ func TestUserAuthentication(t *testing.T) {
256266 },
257267 {
258268 name : credNameUserPassword ,
259- new : func (tcpo * TokenCachePersistenceOptions , co azcore.ClientOptions , ar AuthenticationRecord , disableAutoAuth bool ) (authenticater , error ) {
269+ new : func (c Cache , co azcore.ClientOptions , ar AuthenticationRecord , disableAutoAuth bool ) (authenticater , error ) {
260270 opts := UsernamePasswordCredentialOptions {
261- AdditionallyAllowedTenants : []string {"*" },
262- AuthenticationRecord : ar ,
263- ClientOptions : co ,
264- TokenCachePersistenceOptions : tcpo ,
271+ AdditionallyAllowedTenants : []string {"*" },
272+ AuthenticationRecord : ar ,
273+ Cache : c ,
274+ ClientOptions : co ,
265275 }
266276 return NewUsernamePasswordCredential (liveUser .tenantID , developerSignOnClientID , liveUser .username , liveUser .password , & opts )
267277 },
@@ -286,13 +296,13 @@ func TestUserAuthentication(t *testing.T) {
286296 }}
287297
288298 co := azcore.ClientOptions {Cloud : cc , Transport : & sts }
289- cred , err := credential .new (nil , co , AuthenticationRecord {}, false )
299+ cred , err := credential .new (Cache {} , co , AuthenticationRecord {}, false )
290300 require .NoError (t , err )
291301 _ , err = cred .Authenticate (context .Background (), nil )
292302 require .NoError (t , err )
293303
294304 t .Setenv (azureAuthorityHost , cc .ActiveDirectoryAuthorityHost )
295- cred , err = credential .new (nil , azcore.ClientOptions {Transport : & sts }, AuthenticationRecord {}, false )
305+ cred , err = credential .new (Cache {} , azcore.ClientOptions {Transport : & sts }, AuthenticationRecord {}, false )
296306 require .NoError (t , err )
297307 _ , err = cred .Authenticate (context .Background (), nil )
298308 if cc .ActiveDirectoryAuthorityHost == customCloud .ActiveDirectoryAuthorityHost {
@@ -320,14 +330,14 @@ func TestUserAuthentication(t *testing.T) {
320330 counter := tokenRequestCountingPolicy {}
321331 co .PerCallPolicies = append (co .PerCallPolicies , & counter )
322332
323- cred , err := credential .new (nil , co , AuthenticationRecord {}, false )
333+ cred , err := credential .new (Cache {} , co , AuthenticationRecord {}, false )
324334 require .NoError (t , err )
325335 ar , err := cred .Authenticate (context .Background (), & testTRO )
326336 require .NoError (t , err )
327337
328338 // some fields of the returned AuthenticationRecord should have specific values
329- require .Equal (t , ar .ClientID , developerSignOnClientID )
330- require .Equal (t , ar . Version , supportedAuthRecordVersions [0 ])
339+ require .Equal (t , developerSignOnClientID , ar .ClientID )
340+ require .Equal (t , supportedAuthRecordVersions [0 ], ar . Version )
331341 // all others should have nonempty values
332342 v := reflect .Indirect (reflect .ValueOf (& ar ))
333343 for _ , f := range reflect .VisibleFields (reflect .TypeOf (ar )) {
@@ -337,48 +347,47 @@ func TestUserAuthentication(t *testing.T) {
337347 require .Equal (t , 1 , counter .count )
338348 })
339349
340- t .Run ("PersistentCache_Live/" + credential .name , func (t * testing.T ) {
341- switch recording .GetRecordMode () {
342- case recording .LiveMode :
343- if credential .interactive && ! runManualTests {
344- t .Skipf ("set %s to run this test" , azidentityRunManualTests )
345- }
346- case recording .PlaybackMode , recording .RecordingMode :
347- if ! credential .recordable {
348- t .Skip ("this test can't be recorded" )
349- }
350+ t .Run ("PersistentCache/" + credential .name , func (t * testing.T ) {
351+ if credential .name == credNameBrowser && ! runManualTests {
352+ t .Skipf ("set %s to run this test" , azidentityRunManualTests )
350353 }
351- if runtime .GOOS != "windows" {
352- t .Skip ("this test runs only on Windows" )
353- }
354- p , err := internal .CacheFilePath (t .Name ())
355- require .NoError (t , err )
356- os .Remove (p )
357- co , stop := initRecording (t )
358- defer stop ()
359- counter := tokenRequestCountingPolicy {}
360- co .PerCallPolicies = append (co .PerCallPolicies , & counter )
361- tcpo := TokenCachePersistenceOptions {Name : t .Name ()}
354+ tokenReqs := 0
355+ c := internal .NewCache (func (bool ) (cache.ExportReplace , error ) {
356+ return & testCache {}, nil
357+ })
358+ co := azcore.ClientOptions {Transport : & mockSTS {
359+ tokenRequestCallback : func (* http.Request ) * http.Response {
360+ tokenReqs ++
361+ return nil
362+ },
363+ }}
362364
363- cred , err := credential .new (& tcpo , co , AuthenticationRecord {}, true )
365+ cred , err := credential .new (c , co , AuthenticationRecord {}, false )
364366 require .NoError (t , err )
365- record , err := cred .Authenticate (context . Background () , & testTRO )
367+ record , err := cred .Authenticate (ctx , & testTRO )
366368 require .NoError (t , err )
367- defer os .Remove (p )
368- tk , err := cred .GetToken (context .Background (), testTRO )
369+ _ , err = cred .GetToken (ctx , testTRO )
369370 require .NoError (t , err )
370- require .Equal (t , 1 , counter . count )
371+ require .Equal (t , 1 , tokenReqs )
371372
372- cred2 , err := credential .new (& tcpo , co , record , true )
373+ // cred2 should return the token cached by cred
374+ cred2 , err := credential .new (c , co , record , true )
373375 require .NoError (t , err )
374- tk2 , err : = cred2 .GetToken (context . Background () , testTRO )
376+ _ , err = cred2 .GetToken (ctx , testTRO )
375377 require .NoError (t , err )
376- require .Equal (t , tk .Token , tk2 .Token )
378+ require .Equal (t , 1 , tokenReqs )
379+
380+ // cred should request a new token because the cached one isn't a CAE token
381+ caeTRO := testTRO
382+ caeTRO .EnableCAE = true
383+ _ , err = cred .GetToken (ctx , caeTRO )
384+ require .NoError (t , err )
385+ require .Equal (t , 2 , tokenReqs )
377386 })
378387
379388 if credential .interactive {
380389 t .Run ("DisableAutomaticAuthentication/" + credential .name , func (t * testing.T ) {
381- cred , err := credential .new (nil , policy.ClientOptions {Transport : & mockSTS {}}, AuthenticationRecord {}, true )
390+ cred , err := credential .new (Cache {} , policy.ClientOptions {Transport : & mockSTS {}}, AuthenticationRecord {}, true )
382391 require .NoError (t , err )
383392 expected := policy.TokenRequestOptions {
384393 Claims : "claims" ,
@@ -402,7 +411,7 @@ func TestUserAuthentication(t *testing.T) {
402411 }
403412 })
404413 t .Run ("DisableAutomaticAuthentication/ChainedTokenCredential/" + credential .name , func (t * testing.T ) {
405- cred , err := credential .new (nil , policy.ClientOptions {}, AuthenticationRecord {}, true )
414+ cred , err := credential .new (Cache {} , policy.ClientOptions {}, AuthenticationRecord {}, true )
406415 require .NoError (t , err )
407416 expected := azcore.AccessToken {ExpiresOn : time .Now ().UTC (), Token : tokenValue }
408417 fake := NewFakeCredential ()
@@ -1103,107 +1112,90 @@ func TestResolveTenant(t *testing.T) {
11031112 }
11041113}
11051114
1106- func TestTokenCachePersistenceOptions (t * testing.T ) {
1107- af := filepath .Join (t .TempDir (), t .Name ()+ credNameWorkloadIdentity )
1108- if err := os .WriteFile (af , []byte ("assertion" ), os .ModePerm ); err != nil {
1109- t .Fatal (err )
1110- }
1111- before := internal .NewCache
1112- t .Cleanup (func () { internal .NewCache = before })
1113- for _ , test := range []struct {
1114- desc string
1115- options * TokenCachePersistenceOptions
1116- err error
1115+ func TestConfidentialClientPersistentCache (t * testing.T ) {
1116+ // for WorkloadIdentityCredential
1117+ tfp := filepath .Join (t .TempDir (), "tokenfile" )
1118+ require .NoError (t , os .WriteFile (tfp , []byte ("token" ), 0600 ))
1119+ for _ , credential := range []struct {
1120+ name string
1121+ new func (azcore.ClientOptions , Cache ) (azcore.TokenCredential , error )
11171122 }{
11181123 {
1119- desc : "nil options" ,
1124+ name : credNameAssertion ,
1125+ new : func (co azcore.ClientOptions , c Cache ) (azcore.TokenCredential , error ) {
1126+ o := ClientAssertionCredentialOptions {Cache : c , ClientOptions : co }
1127+ return NewClientAssertionCredential (fakeTenantID , fakeClientID , func (context.Context ) (string , error ) { return "..." , nil }, & o )
1128+ },
11201129 },
1130+ // TODO: set SYSTEM_OIDC_REQUEST_URI, fake response
1131+ // {
1132+ // name: credNameAzurePipelines,
1133+ // new: func(co azcore.ClientOptions, c Cache) (azcore.TokenCredential, error) {
1134+ // o := AzurePipelinesCredentialOptions{Cache: c, ClientOptions: co}
1135+ // return NewAzurePipelinesCredential(fakeTenantID, fakeClientID, "service-connection", tokenValue, &o)
1136+ // },
1137+ // },
11211138 {
1122- desc : "default options" ,
1123- options : & TokenCachePersistenceOptions {},
1139+ name : credNameCert ,
1140+ new : func (co azcore.ClientOptions , c Cache ) (azcore.TokenCredential , error ) {
1141+ o := ClientCertificateCredentialOptions {Cache : c , ClientOptions : co }
1142+ return NewClientCertificateCredential (fakeTenantID , fakeClientID , allCertTests [0 ].certs , allCertTests [0 ].key , & o )
1143+ },
11241144 },
11251145 {
1126- desc : "all options set" ,
1127- options : & TokenCachePersistenceOptions {AllowUnencryptedStorage : true , Name : "name" },
1146+ name : credNameSecret ,
1147+ new : func (co azcore.ClientOptions , c Cache ) (azcore.TokenCredential , error ) {
1148+ o := ClientSecretCredentialOptions {Cache : c , ClientOptions : co }
1149+ return NewClientSecretCredential (fakeTenantID , fakeClientID , fakeSecret , & o )
1150+ },
11281151 },
1129- } {
1130- internal .NewCache = func (o * internal.TokenCachePersistenceOptions , _ bool ) (cache.ExportReplace , error ) {
1131- if (test .options == nil ) != (o == nil ) {
1132- t .Fatalf ("expected %v, got %v" , test .options , o )
1133- }
1134- if test .options != nil {
1135- if test .options .AllowUnencryptedStorage != o .AllowUnencryptedStorage {
1136- t .Fatalf ("expected AllowUnencryptedStorage %v, got %v" , test .options .AllowUnencryptedStorage , o .AllowUnencryptedStorage )
1137- }
1138- if test .options .Name != o .Name {
1139- t .Fatalf ("expected Name %q, got %q" , test .options .Name , o .Name )
1152+ {
1153+ name : credNameWorkloadIdentity ,
1154+ new : func (co azcore.ClientOptions , c Cache ) (azcore.TokenCredential , error ) {
1155+ o := WorkloadIdentityCredentialOptions {
1156+ Cache : c ,
1157+ ClientID : fakeClientID ,
1158+ ClientOptions : co ,
1159+ TenantID : fakeTenantID ,
1160+ TokenFilePath : tfp ,
11401161 }
1141- }
1142- return nil , nil
1143- }
1144- for _ , subtest := range []struct {
1145- ctor func (azcore.ClientOptions , * TokenCachePersistenceOptions ) (azcore.TokenCredential , error )
1146- env map [string ]string
1147- name string
1148- }{
1149- {
1150- name : credNameAssertion ,
1151- ctor : func (co azcore.ClientOptions , tco * TokenCachePersistenceOptions ) (azcore.TokenCredential , error ) {
1152- o := ClientAssertionCredentialOptions {ClientOptions : co , TokenCachePersistenceOptions : tco }
1153- return NewClientAssertionCredential (fakeTenantID , fakeClientID , func (context.Context ) (string , error ) { return "..." , nil }, & o )
1154- },
1155- },
1156- {
1157- name : credNameCert ,
1158- ctor : func (co azcore.ClientOptions , tco * TokenCachePersistenceOptions ) (azcore.TokenCredential , error ) {
1159- o := ClientCertificateCredentialOptions {ClientOptions : co , TokenCachePersistenceOptions : tco }
1160- return NewClientCertificateCredential (fakeTenantID , fakeClientID , allCertTests [0 ].certs , allCertTests [0 ].key , & o )
1161- },
1162- },
1163- {
1164- name : credNameDeviceCode ,
1165- ctor : func (co azcore.ClientOptions , tco * TokenCachePersistenceOptions ) (azcore.TokenCredential , error ) {
1166- o := DeviceCodeCredentialOptions {
1167- ClientOptions : co ,
1168- TokenCachePersistenceOptions : tco ,
1169- UserPrompt : func (context.Context , DeviceCodeMessage ) error { return nil },
1170- }
1171- return NewDeviceCodeCredential (& o )
1172- },
1173- },
1174- {
1175- name : credNameSecret ,
1176- ctor : func (co azcore.ClientOptions , tco * TokenCachePersistenceOptions ) (azcore.TokenCredential , error ) {
1177- o := ClientSecretCredentialOptions {ClientOptions : co , TokenCachePersistenceOptions : tco }
1178- return NewClientSecretCredential (fakeTenantID , fakeClientID , fakeSecret , & o )
1179- },
1180- },
1181- {
1182- name : credNameUserPassword ,
1183- ctor : func (co azcore.ClientOptions , tco * TokenCachePersistenceOptions ) (azcore.TokenCredential , error ) {
1184- o := UsernamePasswordCredentialOptions {ClientOptions : co , TokenCachePersistenceOptions : tco }
1185- return NewUsernamePasswordCredential (fakeTenantID , fakeClientID , fakeUsername , "password" , & o )
1186- },
1162+ return NewWorkloadIdentityCredential (& o )
11871163 },
1188- } {
1189- t .Run (fmt .Sprintf ("%s/%s" , subtest .name , test .desc ), func (t * testing.T ) {
1190- for k , v := range subtest .env {
1191- t .Setenv (k , v )
1192- }
1193- c , err := subtest .ctor (policy.ClientOptions {Transport : & mockSTS {}}, test .options )
1194- if err != nil {
1195- t .Fatal (err )
1196- }
1197- _ , err = c .GetToken (context .Background (), testTRO )
1198- if err != nil {
1199- if ! errors .Is (err , test .err ) {
1200- t .Fatalf ("expected %v, got %v" , test .err , err )
1201- }
1202- } else if test .err != nil {
1203- t .Fatal ("expected an error" )
1204- }
1164+ },
1165+ } {
1166+ t .Run (credential .name , func (t * testing.T ) {
1167+ tokenReqs := 0
1168+ c := internal .NewCache (func (bool ) (cache.ExportReplace , error ) {
1169+ return & testCache {}, nil
12051170 })
1206- }
1171+ sts := mockSTS {
1172+ tokenRequestCallback : func (* http.Request ) * http.Response {
1173+ tokenReqs ++
1174+ return nil
1175+ },
1176+ }
1177+ cred , err := credential .new (policy.ClientOptions {Transport : & sts }, c )
1178+ require .NoError (t , err )
1179+ _ , err = cred .GetToken (context .Background (), testTRO )
1180+ require .NoError (t , err )
1181+ _ , err = cred .GetToken (ctx , testTRO )
1182+ require .NoError (t , err )
1183+ require .Equal (t , 1 , tokenReqs )
1184+
1185+ // cred2 should return the token cached by cred
1186+ cred2 , err := credential .new (policy.ClientOptions {Transport : & sts }, c )
1187+ require .NoError (t , err )
1188+ _ , err = cred2 .GetToken (ctx , testTRO )
1189+ require .NoError (t , err )
1190+ require .Equal (t , 1 , tokenReqs )
1191+
1192+ // cred should request a new token because the cached one isn't a CAE token
1193+ caeTRO := testTRO
1194+ caeTRO .EnableCAE = true
1195+ _ , err = cred .GetToken (ctx , caeTRO )
1196+ require .NoError (t , err )
1197+ require .Equal (t , 2 , tokenReqs )
1198+ })
12071199 }
12081200}
12091201
0 commit comments